In [1]:
# Parameters
cell_type = "T_CD8_NonNaive"


In [2]:
import sys
import os

import anndata as ad

import pandas as pd
import numpy as np

from pyprojroot import here

import re

from tqdm import tqdm

In [3]:
def load_sorted_shap_values_fnames(
    cell_type: str = '',
    run_name: str = 'run1'):

    dirpath = here(f"inflammabucket_bkp/03_downstream_analysis/05_SHAP/results/04_shap/shap_vals")
    fname_regex = f'^{run_name}_{cell_type}_shap_values_' + r'(\d+)'
    results_batches = [
        (fname, int(re.search(fname_regex, fname).group(1)))
        for fname in os.listdir(dirpath)
        if re.search(fname_regex, fname)]

    sorted_files = sorted(results_batches, key=lambda x: x[1])

    sorted_filenames = [os.path.join(dirpath, filename) for filename, _ in sorted_files]

    return sorted_filenames

In [4]:
shap_val_list = [np.load(fname)['shap_values'] for fname in tqdm(load_sorted_shap_values_fnames(cell_type))] 


  0%|          | 0/53 [00:00<?, ?it/s]

  2%|▏         | 1/53 [00:03<03:01,  3.49s/it]

  4%|▍         | 2/53 [00:06<02:58,  3.49s/it]

  6%|▌         | 3/53 [00:10<02:52,  3.46s/it]

  8%|▊         | 4/53 [00:13<02:50,  3.47s/it]

  9%|▉         | 5/53 [00:17<02:45,  3.45s/it]

 11%|█▏        | 6/53 [00:20<02:43,  3.47s/it]

 13%|█▎        | 7/53 [00:24<02:39,  3.46s/it]

 15%|█▌        | 8/53 [00:27<02:35,  3.45s/it]

 17%|█▋        | 9/53 [00:31<02:33,  3.49s/it]

 19%|█▉        | 10/53 [00:34<02:30,  3.51s/it]

 21%|██        | 11/53 [00:38<02:26,  3.49s/it]

 23%|██▎       | 12/53 [00:41<02:22,  3.47s/it]

 25%|██▍       | 13/53 [00:45<02:18,  3.46s/it]

 26%|██▋       | 14/53 [00:48<02:15,  3.49s/it]

 28%|██▊       | 15/53 [00:52<02:12,  3.49s/it]

 30%|███       | 16/53 [00:55<02:09,  3.50s/it]

 32%|███▏      | 17/53 [00:59<02:05,  3.50s/it]

 34%|███▍      | 18/53 [01:02<02:03,  3.52s/it]

 36%|███▌      | 19/53 [01:06<01:59,  3.52s/it]

 38%|███▊      | 20/53 [01:09<01:55,  3.51s/it]

 40%|███▉      | 21/53 [01:13<01:51,  3.48s/it]

 42%|████▏     | 22/53 [01:16<01:47,  3.46s/it]

 43%|████▎     | 23/53 [01:20<01:44,  3.47s/it]

 45%|████▌     | 24/53 [01:23<01:41,  3.49s/it]

 47%|████▋     | 25/53 [01:27<01:42,  3.65s/it]

 49%|████▉     | 26/53 [01:31<01:37,  3.60s/it]

 51%|█████     | 27/53 [01:34<01:32,  3.56s/it]

 53%|█████▎    | 28/53 [01:38<01:29,  3.57s/it]

 55%|█████▍    | 29/53 [01:41<01:24,  3.54s/it]

 57%|█████▋    | 30/53 [01:45<01:21,  3.54s/it]

 58%|█████▊    | 31/53 [01:48<01:17,  3.54s/it]

 60%|██████    | 32/53 [01:52<01:14,  3.54s/it]

 62%|██████▏   | 33/53 [01:55<01:10,  3.52s/it]

 64%|██████▍   | 34/53 [01:59<01:06,  3.52s/it]

 66%|██████▌   | 35/53 [02:02<01:03,  3.50s/it]

 68%|██████▊   | 36/53 [02:06<00:59,  3.48s/it]

 70%|██████▉   | 37/53 [02:09<00:55,  3.48s/it]

 72%|███████▏  | 38/53 [02:13<00:52,  3.47s/it]

 74%|███████▎  | 39/53 [02:16<00:48,  3.47s/it]

 75%|███████▌  | 40/53 [02:20<00:45,  3.48s/it]

 77%|███████▋  | 41/53 [02:23<00:41,  3.48s/it]

 79%|███████▉  | 42/53 [02:26<00:38,  3.46s/it]

 81%|████████  | 43/53 [02:30<00:34,  3.44s/it]

 83%|████████▎ | 44/53 [02:33<00:31,  3.45s/it]

 85%|████████▍ | 45/53 [02:37<00:27,  3.46s/it]

 87%|████████▋ | 46/53 [02:40<00:24,  3.46s/it]

 89%|████████▊ | 47/53 [02:44<00:20,  3.44s/it]

 91%|█████████ | 48/53 [02:47<00:17,  3.45s/it]

 92%|█████████▏| 49/53 [02:51<00:14,  3.56s/it]

 94%|█████████▍| 50/53 [02:55<00:11,  3.69s/it]

 96%|█████████▌| 51/53 [02:59<00:07,  3.79s/it]

 98%|█████████▊| 52/53 [03:04<00:04,  4.12s/it]

100%|██████████| 53/53 [03:06<00:00,  3.43s/it]

100%|██████████| 53/53 [03:06<00:00,  3.51s/it]




In [5]:
adata = ad.read_h5ad(here(f'inflammabucket_bkp/03_downstream_analysis/05_SHAP/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'), backed='r')

In [6]:
## Gene symbols and Disease labels
symbols_df = pd.read_pickle(here('inflammabucket_bkp/03_downstream_analysis/04_selected_gene_list.pkl'))
symbols_sorted = symbols_df.loc[adata.var_names].symbol.values

DISEASES = ['BRCA', 'CD', 'COPD', 'COVID', 'CRC', 'HBV', 'HIV', 'HNSCC', 'MS', 'NPC', 'PS', 'PSA', 'RA', 'SLE', 'UC', 'asthma', 'cirrhosis', 'flu', 'healthy', 'sepsis']

diseaseDict = dict()
for d in DISEASES:
    diseaseDict[d] = []

In [7]:
shape_values_matrix = np.concatenate(shap_val_list)

In [8]:
for idx, values in tqdm(adata.obs.groupby('sampleID', observed=True).indices.items()):
    geneXdisease_sample_i = pd.DataFrame(shape_values_matrix[values].mean(0))
    geneXdisease_sample_i.columns = DISEASES
    geneXdisease_sample_i.index = symbols_sorted
    for d in geneXdisease_sample_i.columns:
        diseaseDict[d].append(pd.DataFrame.from_dict({idx:geneXdisease_sample_i[d]}))

  0%|          | 0/817 [00:00<?, ?it/s]

  1%|          | 7/817 [00:00<00:12, 63.76it/s]

  2%|▏         | 14/817 [00:00<00:12, 66.52it/s]

  3%|▎         | 21/817 [00:00<00:14, 55.30it/s]

  4%|▍         | 31/817 [00:00<00:11, 67.92it/s]

  5%|▌         | 44/817 [00:00<00:09, 84.14it/s]

  7%|▋         | 55/817 [00:00<00:08, 88.26it/s]

  8%|▊         | 68/817 [00:00<00:07, 98.12it/s]

 10%|▉         | 81/817 [00:00<00:07, 103.93it/s]

 11%|█▏        | 92/817 [00:01<00:07, 102.79it/s]

 14%|█▍        | 115/817 [00:01<00:05, 138.43it/s]

 16%|█▌        | 130/817 [00:01<00:05, 120.04it/s]

 18%|█▊        | 143/817 [00:01<00:09, 73.20it/s] 

 19%|█▉        | 155/817 [00:01<00:08, 81.45it/s]

 20%|██        | 167/817 [00:01<00:08, 80.97it/s]

 22%|██▏       | 177/817 [00:02<00:08, 78.84it/s]

 23%|██▎       | 189/817 [00:02<00:07, 86.62it/s]

 25%|██▍       | 202/817 [00:02<00:06, 94.97it/s]

 26%|██▋       | 215/817 [00:02<00:05, 103.43it/s]

 28%|██▊       | 229/817 [00:02<00:05, 112.04it/s]

 29%|██▉       | 241/817 [00:02<00:05, 113.62it/s]

 31%|███       | 254/817 [00:02<00:04, 117.35it/s]

 33%|███▎      | 268/817 [00:02<00:04, 122.24it/s]

 35%|███▌      | 288/817 [00:02<00:03, 142.74it/s]

 37%|███▋      | 303/817 [00:02<00:03, 142.72it/s]

 39%|███▉      | 318/817 [00:03<00:03, 144.25it/s]

 41%|████      | 337/817 [00:03<00:03, 144.87it/s]

 43%|████▎     | 352/817 [00:03<00:04, 97.52it/s] 

 45%|████▍     | 364/817 [00:03<00:04, 94.20it/s]

 46%|████▌     | 375/817 [00:03<00:05, 85.26it/s]

 47%|████▋     | 385/817 [00:03<00:05, 76.75it/s]

 49%|████▉     | 400/817 [00:04<00:04, 91.57it/s]

 50%|█████     | 411/817 [00:04<00:04, 91.69it/s]

 53%|█████▎    | 430/817 [00:04<00:03, 114.66it/s]

 54%|█████▍    | 443/817 [00:04<00:03, 117.75it/s]

 56%|█████▌    | 456/817 [00:04<00:03, 117.38it/s]

 58%|█████▊    | 471/817 [00:04<00:02, 125.73it/s]

 60%|█████▉    | 490/817 [00:04<00:02, 140.02it/s]

 62%|██████▏   | 505/817 [00:04<00:02, 136.53it/s]

 64%|██████▍   | 524/817 [00:04<00:01, 149.23it/s]

 66%|██████▌   | 540/817 [00:05<00:01, 140.87it/s]

 68%|██████▊   | 555/817 [00:05<00:01, 142.75it/s]

 70%|██████▉   | 571/817 [00:05<00:01, 146.78it/s]

 72%|███████▏  | 586/817 [00:05<00:01, 135.89it/s]

 73%|███████▎  | 600/817 [00:05<00:01, 110.77it/s]

 75%|███████▍  | 612/817 [00:05<00:01, 110.00it/s]

 76%|███████▋  | 624/817 [00:05<00:01, 108.58it/s]

 78%|███████▊  | 636/817 [00:06<00:02, 89.22it/s] 

 79%|███████▉  | 648/817 [00:06<00:01, 93.35it/s]

 81%|████████  | 658/817 [00:06<00:01, 79.53it/s]

 82%|████████▏ | 669/817 [00:06<00:01, 86.15it/s]

 83%|████████▎ | 679/817 [00:06<00:01, 79.87it/s]

 84%|████████▍ | 689/817 [00:06<00:01, 83.83it/s]

 86%|████████▌ | 700/817 [00:06<00:01, 86.29it/s]

 87%|████████▋ | 709/817 [00:06<00:01, 87.07it/s]

 88%|████████▊ | 718/817 [00:07<00:01, 79.52it/s]

 89%|████████▉ | 727/817 [00:07<00:01, 70.07it/s]

 90%|████████▉ | 735/817 [00:07<00:01, 68.66it/s]

 91%|█████████ | 743/817 [00:07<00:01, 70.07it/s]

 92%|█████████▏| 752/817 [00:07<00:00, 72.73it/s]

 93%|█████████▎| 760/817 [00:07<00:00, 69.63it/s]

 94%|█████████▍| 768/817 [00:07<00:00, 72.23it/s]

 95%|█████████▍| 776/817 [00:07<00:00, 73.67it/s]

 96%|█████████▌| 784/817 [00:07<00:00, 74.09it/s]

 97%|█████████▋| 792/817 [00:08<00:00, 71.45it/s]

 98%|█████████▊| 800/817 [00:08<00:00, 66.16it/s]

 99%|█████████▉| 807/817 [00:08<00:00, 60.49it/s]

100%|█████████▉| 814/817 [00:08<00:00, 57.45it/s]

100%|██████████| 817/817 [00:08<00:00, 95.12it/s]




In [9]:
for d in tqdm(DISEASES):
    pd.concat(diseaseDict[d], axis=1).to_csv(here(f'03_downstream_analysis/08_gene_importance/new_shap_plots/results/SHAP_AVGsamples/SHAP_AVGsample_{cell_type}_{d}.csv'))

  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:01<00:20,  1.08s/it]

 10%|█         | 2/20 [00:02<00:23,  1.31s/it]

 15%|█▌        | 3/20 [00:04<00:24,  1.47s/it]

 20%|██        | 4/20 [00:06<00:25,  1.62s/it]

 25%|██▌       | 5/20 [00:07<00:22,  1.51s/it]

 30%|███       | 6/20 [00:08<00:21,  1.52s/it]

 35%|███▌      | 7/20 [00:10<00:18,  1.39s/it]

 40%|████      | 8/20 [00:11<00:15,  1.32s/it]

 45%|████▌     | 9/20 [00:11<00:12,  1.10s/it]

 50%|█████     | 10/20 [00:12<00:11,  1.11s/it]

 55%|█████▌    | 11/20 [00:13<00:09,  1.05s/it]

 60%|██████    | 12/20 [00:15<00:08,  1.08s/it]

 65%|██████▌   | 13/20 [00:16<00:07,  1.07s/it]

 70%|███████   | 14/20 [00:17<00:06,  1.11s/it]

 75%|███████▌  | 15/20 [00:18<00:05,  1.10s/it]

 80%|████████  | 16/20 [00:19<00:04,  1.15s/it]

 85%|████████▌ | 17/20 [00:20<00:03,  1.02s/it]

 90%|█████████ | 18/20 [00:22<00:02,  1.30s/it]

 95%|█████████▌| 19/20 [00:24<00:01,  1.45s/it]

100%|██████████| 20/20 [00:25<00:00,  1.33s/it]

100%|██████████| 20/20 [00:25<00:00,  1.26s/it]


