# Run SHAP

This file is part of the Verifying explainability of a deep learning tissue classifier trained on RNA-seq data project.

Verifying explainability of a deep learning tissue classifier trained on RNA-seq data project is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.


Verifying explainability of a deep learning tissue classifier trained on RNA-seq data project is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
GNU General Public License for more details.

You should have received a copy of the GNU General Public License
along with the Verifying explainability of a deep learning tissue classifier trained on RNA-seq data project.  If not, see <http://www.gnu.org/licenses/>.


### Objective:
> Load trained model and calculate SHAP values for test data set

### Input files:
1. *gtex_filtered_tmm_intersect_{data_type}.pkl*
2. *gtex_filtered_tmm_intersect_test.pkl*
3. *{data_type}_model_topology.json*
4. *{data_type}_model_weights.hdf5*

### Output files:
1. *shap_scores_{data_type}.pkl* 
2. *{data_type}_ranks.pkl* 
3. *shap_genes.pkl* 


### Table of contents:
1. [Import Modules](#1.-Import-Modules)  
2. [Set static paths](#2.-Set-static-paths)  
3. [Load files](#3.-Load-files)  
    3.1 [Load RNAseq](#3.1-Load-RNAseq)  
    3.2 [Load model](#3.2-Load-model)  
4. [Run SHAP model](#4.-Run-SHAP-model)  
    4.1 [Run inference](#4.1-Run-inference)  
    4.2 [Fit SHAP](#4.2-Fit-SHAP)  
    4.3 [Get SHAP values](#4.3-Get-SHAP-values)  
    4.4 [Filter SHAP values](#4.4-Filter-SHAP-values)  
    4.5 [Rank SHAP values](#4.5-Rank-SHAP-values)  
    4.6 [Get unique genes](#4.6-Get-unique-genes)  
5. [Save out SHAP scores](#5.-Save-out-SHAP-scores) 

## 1. Import Modules

In [None]:
import os
util_path = '../src'
os.chdir(util_path)

In [None]:
import shap
import pickle
import numpy as np
import pandas as pd

from tqdm import tqdm
from keras import backend
from keras.models import model_from_json

from constant import map_dict, inv_map
from modelling.cnn import run_inference, prepare_x_y 
from shap_utils import filter_shap, get_rank_df

In [None]:
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [None]:
%load_ext autoreload
%autoreload 2

## 2. Set static paths

In [None]:
data_type = "smote"
data_dir = "../data/"

In [None]:
input_dir = data_dir + "processed/"
model_dir = f"../models/"
shap_dir = data_dir + "shap/"
gene_dir = data_dir + "gene_lists/"

## 3. Load files

#### 3.1 Load RNAseq

In [None]:
ref_data = pd.read_pickle(
    input_dir + f'gtex_filtered_tmm_intersect_{data_type}.pkl'
)

In [None]:
test_data = pd.read_pickle(
    input_dir + 'gtex_filtered_tmm_intersect_test.pkl'
)

#### 3.2 Load model

In [None]:
# Load model beatifully
model_json_path = model_dir+f"{data_type}_model_topology.json"
model = model_from_json(
    open(model_json_path, "r").read()
)

# load weights into new model
model_weights_path = model_dir+f"{data_type}_model_weights.hdf5"
model.load_weights(model_weights_path)

## 4. Run SHAP model

#### 4.1 Run inference

In [None]:
y_pred = run_inference(test_data, model)

In [None]:
X_ref, _ = prepare_x_y(ref_data, "type")
X_test, _ = prepare_x_y(test_data, "type")

#### 4.2 Fit SHAP

In [None]:
explainer = shap.GradientExplainer(model, X_ref)

#### 4.3 Get SHAP values

In [None]:
out_list = []
num_samples = np.shape(X_test)[0]
for sample in tqdm(num_samples):
    # shap
    shap_values = explainer.shap_values(X_test[sample : sample + 1])
    out_list.append(shap_values)
shap_arr = np.squeeze(np.array(out_list))

#### 4.4 Filter SHAP values

In [None]:
shap_df = filter_shap(test_data, shap_arr, y_pred)

In [None]:
backend.clear_session()

#### 4.5 Rank SHAP values

In [None]:
rank_df = get_rank_df(shap_df)

#### 4.6 Get unique genes

In [None]:
gene_list = []
for index, row in tqdm(rank_df.iterrows()):
    gene_list.extend(list(row.values))
    val = len(np.unique(gene_list))/len(gene_list)
    if val <= 0.5:
        print(index)
        break

In [None]:
shap_genes = np.unique(rank_df[:index].values.flatten())

## 5. Save out SHAP scores

In [None]:
output_file = shap_dir + f"shap_scores_{data_type}.pkl"

pickle.dump(shap_df, open(str(output_file), "wb"))

In [None]:
output_file = shap_dir + f"{data_type}_ranks.pkl"

pickle.dump(rank_df, open(str(output_file), "wb"))

In [None]:
output_file = gene_dir + f"shap_genes.pkl"

pickle.dump(shap_genes, open(str(output_file), "wb"))