In [1]:
# Parameters
cell_type = "T_CD8_Naive"


In [2]:
import sys
import os

import anndata as ad

import pandas as pd
import numpy as np

from pyprojroot import here

import re

from tqdm import tqdm

In [3]:
def load_sorted_shap_values_fnames(
    cell_type: str = '',
    run_name: str = 'run1'):

    dirpath = here(f"inflammabucket_bkp/03_downstream_analysis/05_SHAP/results/04_shap/shap_vals")
    fname_regex = f'^{run_name}_{cell_type}_shap_values_' + r'(\d+)'
    results_batches = [
        (fname, int(re.search(fname_regex, fname).group(1)))
        for fname in os.listdir(dirpath)
        if re.search(fname_regex, fname)]

    sorted_files = sorted(results_batches, key=lambda x: x[1])

    sorted_filenames = [os.path.join(dirpath, filename) for filename, _ in sorted_files]

    return sorted_filenames

In [4]:
shap_val_list = [np.load(fname)['shap_values'] for fname in tqdm(load_sorted_shap_values_fnames(cell_type))] 


  0%|          | 0/22 [00:00<?, ?it/s]

  5%|▍         | 1/22 [00:02<00:48,  2.32s/it]

  9%|▉         | 2/22 [00:04<00:46,  2.34s/it]

 14%|█▎        | 3/22 [00:06<00:44,  2.32s/it]

 18%|█▊        | 4/22 [00:09<00:41,  2.33s/it]

 23%|██▎       | 5/22 [00:11<00:39,  2.32s/it]

 27%|██▋       | 6/22 [00:13<00:37,  2.33s/it]

 32%|███▏      | 7/22 [00:16<00:34,  2.32s/it]

 36%|███▋      | 8/22 [00:18<00:32,  2.33s/it]

 41%|████      | 9/22 [00:20<00:30,  2.34s/it]

 45%|████▌     | 10/22 [00:23<00:27,  2.32s/it]

 50%|█████     | 11/22 [00:25<00:25,  2.31s/it]

 55%|█████▍    | 12/22 [00:27<00:23,  2.31s/it]

 59%|█████▉    | 13/22 [00:30<00:20,  2.30s/it]

 64%|██████▎   | 14/22 [00:32<00:18,  2.29s/it]

 68%|██████▊   | 15/22 [00:34<00:15,  2.28s/it]

 73%|███████▎  | 16/22 [00:36<00:13,  2.28s/it]

 77%|███████▋  | 17/22 [00:39<00:11,  2.28s/it]

 82%|████████▏ | 18/22 [00:41<00:09,  2.28s/it]

 86%|████████▋ | 19/22 [00:43<00:06,  2.28s/it]

 91%|█████████ | 20/22 [00:46<00:04,  2.28s/it]

 95%|█████████▌| 21/22 [00:48<00:02,  2.28s/it]

100%|██████████| 22/22 [00:49<00:00,  1.88s/it]

100%|██████████| 22/22 [00:49<00:00,  2.24s/it]




In [5]:
adata = ad.read_h5ad(here(f'inflammabucket_bkp/03_downstream_analysis/05_SHAP/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'), backed='r')

In [6]:
## Gene symbols and Disease labels
symbols_df = pd.read_pickle(here('inflammabucket_bkp/03_downstream_analysis/04_selected_gene_list.pkl'))
symbols_sorted = symbols_df.loc[adata.var_names].symbol.values

DISEASES = ['BRCA', 'CD', 'COPD', 'COVID', 'CRC', 'HBV', 'HIV', 'HNSCC', 'MS', 'NPC', 'PS', 'PSA', 'RA', 'SLE', 'UC', 'asthma', 'cirrhosis', 'flu', 'healthy', 'sepsis']

diseaseDict = dict()
for d in DISEASES:
    diseaseDict[d] = []

In [7]:
shape_values_matrix = np.concatenate(shap_val_list)

In [8]:
for idx, values in tqdm(adata.obs.groupby('sampleID', observed=True).indices.items()):
    geneXdisease_sample_i = pd.DataFrame(shape_values_matrix[values].mean(0))
    geneXdisease_sample_i.columns = DISEASES
    geneXdisease_sample_i.index = symbols_sorted
    for d in geneXdisease_sample_i.columns:
        diseaseDict[d].append(pd.DataFrame.from_dict({idx:geneXdisease_sample_i[d]}))

  0%|          | 0/805 [00:00<?, ?it/s]

  2%|▏         | 13/805 [00:00<00:06, 125.39it/s]

  4%|▍         | 35/805 [00:00<00:04, 177.28it/s]

  8%|▊         | 64/805 [00:00<00:03, 225.95it/s]

 11%|█         | 89/805 [00:00<00:03, 234.74it/s]

 16%|█▌        | 128/805 [00:00<00:02, 285.96it/s]

 20%|█▉        | 157/805 [00:00<00:02, 220.30it/s]

 23%|██▎       | 182/805 [00:00<00:03, 184.07it/s]

 25%|██▌       | 203/805 [00:00<00:03, 187.42it/s]

 28%|██▊       | 224/805 [00:01<00:03, 181.47it/s]

 31%|███       | 248/805 [00:01<00:02, 196.03it/s]

 33%|███▎      | 269/805 [00:01<00:02, 192.07it/s]

 36%|███▋      | 293/805 [00:01<00:02, 199.38it/s]

 39%|███▉      | 314/805 [00:01<00:03, 144.54it/s]

 41%|████      | 331/805 [00:01<00:03, 146.48it/s]

 43%|████▎     | 349/805 [00:01<00:02, 153.03it/s]

 47%|████▋     | 376/805 [00:02<00:02, 181.77it/s]

 51%|█████     | 409/805 [00:02<00:01, 220.09it/s]

 54%|█████▍    | 433/805 [00:02<00:02, 184.59it/s]

 56%|█████▋    | 454/805 [00:02<00:02, 171.63it/s]

 59%|█████▉    | 473/805 [00:02<00:02, 160.99it/s]

 62%|██████▏   | 498/805 [00:02<00:01, 178.48it/s]

 65%|██████▍   | 520/805 [00:02<00:01, 187.36it/s]

 67%|██████▋   | 540/805 [00:02<00:01, 178.40it/s]

 69%|██████▉   | 559/805 [00:03<00:01, 172.58it/s]

 72%|███████▏  | 580/805 [00:03<00:01, 181.96it/s]

 75%|███████▍  | 601/805 [00:03<00:01, 188.98it/s]

 77%|███████▋  | 621/805 [00:03<00:01, 181.08it/s]

 80%|████████  | 645/805 [00:03<00:00, 195.67it/s]

 83%|████████▎ | 665/805 [00:03<00:01, 111.40it/s]

 85%|████████▍ | 681/805 [00:04<00:01, 107.12it/s]

 86%|████████▋ | 695/805 [00:04<00:01, 105.76it/s]

 88%|████████▊ | 709/805 [00:04<00:00, 108.14it/s]

 90%|█████████ | 726/805 [00:04<00:00, 120.03it/s]

 93%|█████████▎| 752/805 [00:04<00:00, 149.87it/s]

 96%|█████████▌| 769/805 [00:04<00:00, 136.91it/s]

 98%|█████████▊| 785/805 [00:04<00:00, 141.61it/s]

100%|█████████▉| 801/805 [00:04<00:00, 142.78it/s]

100%|██████████| 805/805 [00:04<00:00, 165.45it/s]




In [9]:
for d in tqdm(DISEASES):
    pd.concat(diseaseDict[d], axis=1).to_csv(here(f'03_downstream_analysis/08_gene_importance/new_shap_plots/results/SHAP_AVGsamples/SHAP_AVGsample_{cell_type}_{d}.csv'))

  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:00<00:10,  1.81it/s]

 10%|█         | 2/20 [00:00<00:08,  2.04it/s]

 15%|█▌        | 3/20 [00:01<00:08,  2.01it/s]

 20%|██        | 4/20 [00:02<00:08,  1.93it/s]

 25%|██▌       | 5/20 [00:02<00:07,  1.90it/s]

 30%|███       | 6/20 [00:03<00:06,  2.04it/s]

 35%|███▌      | 7/20 [00:03<00:06,  2.12it/s]

 40%|████      | 8/20 [00:03<00:05,  2.02it/s]

 45%|████▌     | 9/20 [00:04<00:05,  2.16it/s]

 50%|█████     | 10/20 [00:04<00:04,  2.26it/s]

 55%|█████▌    | 11/20 [00:05<00:04,  2.07it/s]

 60%|██████    | 12/20 [00:05<00:03,  2.15it/s]

 65%|██████▌   | 13/20 [00:06<00:03,  2.18it/s]

 70%|███████   | 14/20 [00:06<00:03,  1.85it/s]

 75%|███████▌  | 15/20 [00:07<00:02,  1.79it/s]

 80%|████████  | 16/20 [00:07<00:02,  1.92it/s]

 85%|████████▌ | 17/20 [00:08<00:01,  2.07it/s]

 90%|█████████ | 18/20 [00:08<00:01,  1.94it/s]

 95%|█████████▌| 19/20 [00:09<00:00,  1.68it/s]

100%|██████████| 20/20 [00:10<00:00,  1.77it/s]

100%|██████████| 20/20 [00:10<00:00,  1.95it/s]


