This is a temporary notebook for shap value analysis and plots.

In [71]:
import os
import sys

proj_dir = "/home/scai/PhenPred"
if not os.path.exists(proj_dir):
    proj_dir = "/Users/emanuel/Projects/PhenPred"
sys.path.extend([proj_dir])

import json
import PhenPred
import argparse
import pandas as pd
from PhenPred.vae import plot_folder
from PhenPred.vae.Hypers import Hypers
from PhenPred.vae.Train import CLinesTrain
from PhenPred.vae.DatasetDepMap23Q2 import CLinesDatasetDepMap23Q2
pd.set_option("display.max_rows", 100)
pd.set_option("display.max_columns", 100)

In [34]:
import shap
import pickle
from tqdm.notebook import tqdm

In [3]:
hyperparameters = Hypers.read_hyperparameters()
clines_db = CLinesDatasetDepMap23Q2(
        labels_names=hyperparameters["labels"],
        datasets=hyperparameters["datasets"],
        feature_miss_rate_thres=hyperparameters["feature_miss_rate_thres"],
        standardize=hyperparameters["standardize"],
        filter_features=hyperparameters["filter_features"],
        filtered_encoder_only=hyperparameters["filtered_encoder_only"],
    )

All-NaN slice encountered
All-NaN slice encountered


DepMap23Q2 | Samples = 1,523 | Proteomics = 4,922 (0 masked) | Metabolomics = 225 (0 masked) | Drug response = 810 (0 masked) | CRISPR-Cas9 = 17,931 (12,718 masked) | Methylation = 14,608 (7,018 masked) | Transcriptomics = 15,278 (7,200 masked) | Copy number = 777 (0 masked) | Labels = 237


In [17]:
clines_db.features_mask["crisprcas9"][
    clines_db.features_mask["crisprcas9"] == True
].index.values

array(['AAAS', 'AACS', 'AAMP', ..., 'ZWILCH', 'ZWINT', 'ZZZ3'],
      dtype=object)

In [33]:
shap_values = pickle.load(open("./reports/vae/files/20230717_160108_shap_values.pkl", "rb"))

In [74]:
clines_db.view_names

['proteomics',
 'metabolomics',
 'drugresponse',
 'crisprcas9',
 'methylation',
 'transcriptomics',
 'copynumber']

In [None]:
all_shap_df = []
for latent_dim in range(len(shap_values)):
    shap_latent = shap_values[latent_dim]
    latent_dfs = []
    for i in range(len(shap_latent)):
        view_name = clines_db.view_names[i]
        feature_names = clines_db.features_mask[view_name][clines_db.features_mask[view_name] == True].index.values
        tmp_df = pd.DataFrame(shap_latent[i], columns=feature_names, index=clines_db.samples)
        tmp_df.columns = [f"{clines_db.view_names[i]}_{c}" for c in tmp_df.columns]
        latent_dfs.append(tmp_df)
    latent_dfs = pd.concat(latent_dfs, axis=1)
    latent_dfs['latent_dim'] = f"latent_dim_{latent_dim}"
    
    all_shap_df.append(latent_dfs)
all_shap_df = pd.concat(all_shap_df, axis=0)
cols = all_shap_df.columns.tolist()
cols = [cols[-1]] + cols[:-1]
all_shap_df = all_shap_df[cols]
all_shap_df.index.name = 'model_id'

In [42]:
pickle.dump(all_shap_df, open("./reports/vae/files/20230717_160108_shap_values_df.pkl", "wb"))

In [43]:
all_shap_df = all_shap_df.reset_index()

In [49]:
all_shap_df.head()

Unnamed: 0,model_id,latent_dim,proteomics_AAAS,proteomics_AACS,proteomics_AAGAB,proteomics_AAK1,proteomics_AAMDC,proteomics_AAMP,proteomics_AARS1,proteomics_AARS2,...,copynumber_ZNF521,copynumber_ZNF626,copynumber_ZNF680,copynumber_ZNF721,copynumber_ZNF780A,copynumber_ZNF814,copynumber_ZNF93,copynumber_ZNRF3,copynumber_ZRSR2,copynumber_ZXDB
0,SIDM00979,latent_dim_0,0.000114,1.8e-05,3.5e-05,1e-06,7e-06,5.732462e-07,0.000211,0.00031,...,8.9e-05,5.4e-05,0.000174,0.0,9e-06,0.0,6.8e-05,0.000335,7e-06,1.20587e-06
1,SIDM01548,latent_dim_0,2.5e-05,6e-06,1.5e-05,4e-06,9e-06,2.642014e-05,4.9e-05,2.4e-05,...,4.4e-05,0.000133,7.5e-05,0.0,0.000573,0.0,0.000198,5e-06,7.5e-05,1.763348e-06
2,SIDM01461,latent_dim_0,2.4e-05,6e-06,4.6e-05,3e-06,6e-06,2.589496e-05,8e-06,7.6e-05,...,1.3e-05,5.7e-05,3.4e-05,0.0,2e-05,0.0,4.4e-05,7.7e-05,3.9e-05,4.379459e-06
3,SIDM01762,latent_dim_0,5e-06,0.00017,0.000166,5.1e-05,3e-05,0.0003533704,9.6e-05,4e-05,...,9e-06,2e-06,0.000387,0.0,5.2e-05,0.0,2e-06,0.000293,6.5e-05,1.414212e-06
4,SIDM01535,latent_dim_0,4e-06,1e-05,0.000156,1.2e-05,5.3e-05,2.288779e-05,1.2e-05,7.9e-05,...,1.6e-05,0.000332,0.000103,0.0,8.7e-05,0.0,0.000419,0.000215,3.8e-05,4.189702e-07


In [48]:
all_shap_df.iloc[:, 2:] = all_shap_df.iloc[:, 2:].abs()

In [50]:
shap_latent_sum_df = all_shap_df.drop(columns=['latent_dim']).groupby('model_id').sum()

In [54]:
shap_latent_sum_df.to_csv("./reports/vae/files/20230717_160108_shap_values_df_sum.csv.gz", compression="gzip")

In [62]:
shap_latent_sum_df

Unnamed: 0_level_0,proteomics_AAAS,proteomics_AACS,proteomics_AAGAB,proteomics_AAK1,proteomics_AAMDC,proteomics_AAMP,proteomics_AARS1,proteomics_AARS2,proteomics_AARSD1,proteomics_AASDHPPT,...,copynumber_ZNF521,copynumber_ZNF626,copynumber_ZNF680,copynumber_ZNF721,copynumber_ZNF780A,copynumber_ZNF814,copynumber_ZNF93,copynumber_ZNRF3,copynumber_ZRSR2,copynumber_ZXDB
model_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
SIDM00001,0.003697,0.001406,0.003262,0.005072,0.000789,0.002053,0.043771,0.032513,0.002895,0.004275,...,0.010261,0.001493,0.052997,0.0,0.007929,0.0,0.001717,0.025623,0.006783,0.000238
SIDM00003,0.034662,0.022374,0.002186,0.042153,0.000288,0.003861,0.025875,0.014182,0.035586,0.013104,...,0.015710,0.006088,0.009296,0.0,0.001322,0.0,0.006920,0.034975,0.058536,0.003467
SIDM00005,0.001372,0.001750,0.002804,0.002103,0.002325,0.001615,0.002041,0.006507,0.004368,0.002777,...,0.000424,0.002268,0.021009,0.0,0.002722,0.0,0.002500,0.002279,0.004432,0.000075
SIDM00006,0.020894,0.000914,0.002157,0.003778,0.004461,0.048202,0.018414,0.006017,0.007669,0.033665,...,0.001491,0.003719,0.048061,0.0,0.003779,0.0,0.004092,0.004947,0.008841,0.000225
SIDM00007,0.023611,0.000401,0.000548,0.001215,0.002196,0.102766,0.036635,0.002449,0.002661,0.024112,...,0.009166,0.009298,0.013176,0.0,0.004039,0.0,0.010539,0.025728,0.052882,0.000036
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
SIDM01979,0.039984,0.002751,0.003316,0.001383,0.000562,0.001402,0.037768,0.030038,0.001018,0.022441,...,0.000434,0.037422,0.021612,0.0,0.069393,0.0,0.042246,0.002689,0.006359,0.003348
SIDM01980,0.001849,0.001368,0.001304,0.002682,0.001356,0.002207,0.003816,0.001792,0.002646,0.003522,...,0.016046,0.006251,0.019593,0.0,0.009018,0.0,0.008165,0.042779,0.006722,0.000124
SIDM01981,0.044551,0.001731,0.000440,0.022347,0.001248,0.032245,0.014040,0.031690,0.008052,0.023979,...,0.001110,0.008014,0.004350,0.0,0.007467,0.0,0.007093,0.004312,0.003191,0.000160
SIDM01983,0.004391,0.004902,0.000934,0.002534,0.003574,0.001807,0.003184,0.008151,0.001771,0.001353,...,0.001824,0.004750,0.011958,0.0,0.006984,0.0,0.005253,0.004120,0.007217,0.000422


In [66]:
global_feature_importance_df = (
    shap_latent_sum_df.mean()
    .sort_values(ascending=False)
    .reset_index(name="importance")
)
global_feature_importance_df.rename(columns={"index": "feature"}, inplace=True)

In [68]:
global_feature_importance_df.to_csv("./reports/vae/files/20230717_160108_shap_values_df_sum_global.csv", index=False)

In [81]:
global_feature_importance_df[global_feature_importance_df['feature'].str.contains('drugresponse')].head(100)

Unnamed: 0,feature,importance
95,drugresponse_1372;Trametinib;GDSC2,0.136074
99,drugresponse_1015;CI-1040;GDSC1,0.132009
105,drugresponse_1564;SCH772984;GDSC2,0.129097
106,drugresponse_283;Omipalisib;GDSC1,0.12863
107,drugresponse_1060;PD0325901;GDSC2,0.126859
109,drugresponse_235;QL-XII-47;GDSC1,0.126645
123,drugresponse_1372;Trametinib;GDSC1,0.120806
133,drugresponse_1494;SN-38;GDSC1,0.116741
135,drugresponse_1558;Lapatinib;GDSC2,0.115164
144,drugresponse_282;Pelitinib;GDSC1,0.111958


In [73]:
global_feature_importance_df[global_feature_importance_df['feature'].str.contains('mutations')].head(100)

Unnamed: 0,feature,importance
