In [1]:
# Parameters
cell_type = "B"


In [2]:
import sys
import os

import anndata as ad

import pandas as pd
import numpy as np

from pyprojroot import here

import re

from tqdm import tqdm

In [3]:
def load_sorted_shap_values_fnames(
    cell_type: str = '',
    run_name: str = 'run1'):

    dirpath = here(f"inflammabucket_bkp/03_downstream_analysis/05_SHAP/results/04_shap/shap_vals")
    fname_regex = f'^{run_name}_{cell_type}_shap_values_' + r'(\d+)'
    results_batches = [
        (fname, int(re.search(fname_regex, fname).group(1)))
        for fname in os.listdir(dirpath)
        if re.search(fname_regex, fname)]

    sorted_files = sorted(results_batches, key=lambda x: x[1])

    sorted_filenames = [os.path.join(dirpath, filename) for filename, _ in sorted_files]

    return sorted_filenames

In [4]:
shap_val_list = [np.load(fname)['shap_values'] for fname in tqdm(load_sorted_shap_values_fnames(cell_type))] 


  0%|          | 0/36 [00:00<?, ?it/s]

  3%|▎         | 1/36 [00:03<02:03,  3.52s/it]

  6%|▌         | 2/36 [00:07<01:59,  3.51s/it]

  8%|▊         | 3/36 [00:10<01:55,  3.51s/it]

 11%|█         | 4/36 [00:14<01:53,  3.54s/it]

 14%|█▍        | 5/36 [00:17<01:50,  3.55s/it]

 17%|█▋        | 6/36 [00:21<01:46,  3.55s/it]

 19%|█▉        | 7/36 [00:24<01:43,  3.56s/it]

 22%|██▏       | 8/36 [00:28<01:39,  3.57s/it]

 25%|██▌       | 9/36 [00:31<01:36,  3.57s/it]

 28%|██▊       | 10/36 [00:35<01:33,  3.59s/it]

 31%|███       | 11/36 [00:39<01:31,  3.66s/it]

 33%|███▎      | 12/36 [00:43<01:27,  3.63s/it]

 36%|███▌      | 13/36 [00:46<01:22,  3.60s/it]

 39%|███▉      | 14/36 [00:50<01:19,  3.61s/it]

 42%|████▏     | 15/36 [00:53<01:15,  3.59s/it]

 44%|████▍     | 16/36 [00:57<01:11,  3.59s/it]

 47%|████▋     | 17/36 [01:01<01:09,  3.65s/it]

 50%|█████     | 18/36 [01:04<01:05,  3.61s/it]

 53%|█████▎    | 19/36 [01:08<01:00,  3.58s/it]

 56%|█████▌    | 20/36 [01:11<00:56,  3.56s/it]

 58%|█████▊    | 21/36 [01:15<00:53,  3.55s/it]

 61%|██████    | 22/36 [01:18<00:49,  3.53s/it]

 64%|██████▍   | 23/36 [01:22<00:45,  3.52s/it]

 67%|██████▋   | 24/36 [01:25<00:42,  3.52s/it]

 69%|██████▉   | 25/36 [01:29<00:39,  3.55s/it]

 72%|███████▏  | 26/36 [01:32<00:35,  3.54s/it]

 75%|███████▌  | 27/36 [01:36<00:31,  3.51s/it]

 78%|███████▊  | 28/36 [01:39<00:28,  3.52s/it]

 81%|████████  | 29/36 [01:43<00:24,  3.51s/it]

 83%|████████▎ | 30/36 [01:46<00:21,  3.53s/it]

 86%|████████▌ | 31/36 [01:50<00:17,  3.53s/it]

 89%|████████▉ | 32/36 [01:54<00:14,  3.56s/it]

 92%|█████████▏| 33/36 [01:57<00:10,  3.56s/it]

 94%|█████████▍| 34/36 [02:01<00:07,  3.55s/it]

 97%|█████████▋| 35/36 [02:04<00:03,  3.54s/it]

100%|██████████| 36/36 [02:07<00:00,  3.23s/it]

100%|██████████| 36/36 [02:07<00:00,  3.53s/it]




In [5]:
adata = ad.read_h5ad(here(f'inflammabucket_bkp/03_downstream_analysis/05_SHAP/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'), backed='r')

In [6]:
## Gene symbols and Disease labels
symbols_df = pd.read_pickle(here('inflammabucket_bkp/03_downstream_analysis/04_selected_gene_list.pkl'))
symbols_sorted = symbols_df.loc[adata.var_names].symbol.values

DISEASES = ['BRCA', 'CD', 'COPD', 'COVID', 'CRC', 'HBV', 'HIV', 'HNSCC', 'MS', 'NPC', 'PS', 'PSA', 'RA', 'SLE', 'UC', 'asthma', 'cirrhosis', 'flu', 'healthy', 'sepsis']

diseaseDict = dict()
for d in DISEASES:
    diseaseDict[d] = []

In [7]:
shape_values_matrix = np.concatenate(shap_val_list)

In [8]:
for idx, values in tqdm(adata.obs.groupby('sampleID', observed=True).indices.items()):
    geneXdisease_sample_i = pd.DataFrame(shape_values_matrix[values].mean(0))
    geneXdisease_sample_i.columns = DISEASES
    geneXdisease_sample_i.index = symbols_sorted
    for d in geneXdisease_sample_i.columns:
        diseaseDict[d].append(pd.DataFrame.from_dict({idx:geneXdisease_sample_i[d]}))

  0%|          | 0/814 [00:00<?, ?it/s]

  2%|▏         | 16/814 [00:00<00:05, 153.12it/s]

  4%|▍         | 35/814 [00:00<00:04, 172.82it/s]

  7%|▋         | 53/814 [00:00<00:04, 168.40it/s]

  9%|▊         | 71/814 [00:00<00:04, 167.55it/s]

 11%|█         | 88/814 [00:00<00:04, 154.89it/s]

 15%|█▌        | 124/814 [00:00<00:03, 215.98it/s]

 18%|█▊        | 147/814 [00:00<00:03, 171.11it/s]

 20%|██        | 166/814 [00:01<00:04, 150.27it/s]

 22%|██▏       | 183/814 [00:01<00:05, 119.77it/s]

 24%|██▍       | 197/814 [00:01<00:05, 110.07it/s]

 26%|██▌       | 210/814 [00:01<00:05, 110.47it/s]

 27%|██▋       | 222/814 [00:01<00:05, 109.14it/s]

 29%|██▉       | 235/814 [00:01<00:05, 113.82it/s]

 30%|███       | 247/814 [00:01<00:04, 115.31it/s]

 32%|███▏      | 263/814 [00:01<00:04, 125.84it/s]

 34%|███▍      | 276/814 [00:02<00:04, 123.96it/s]

 36%|███▌      | 292/814 [00:02<00:03, 132.92it/s]

 38%|███▊      | 306/814 [00:02<00:05, 100.15it/s]

 39%|███▉      | 318/814 [00:02<00:04, 102.69it/s]

 41%|████      | 330/814 [00:02<00:04, 103.90it/s]

 42%|████▏     | 342/814 [00:02<00:04, 105.04it/s]

 44%|████▎     | 355/814 [00:02<00:04, 110.37it/s]

 45%|████▌     | 369/814 [00:02<00:04, 110.13it/s]

 47%|████▋     | 381/814 [00:03<00:04, 86.89it/s] 

 49%|████▉     | 401/814 [00:03<00:03, 112.04it/s]

 51%|█████     | 414/814 [00:03<00:03, 112.36it/s]

 53%|█████▎    | 433/814 [00:03<00:02, 131.07it/s]

 55%|█████▌    | 451/814 [00:03<00:02, 142.34it/s]

 57%|█████▋    | 467/814 [00:03<00:02, 140.27it/s]

 59%|█████▉    | 482/814 [00:03<00:02, 142.27it/s]

 61%|██████    | 497/814 [00:03<00:02, 139.29it/s]

 63%|██████▎   | 512/814 [00:04<00:02, 142.20it/s]

 65%|██████▍   | 527/814 [00:04<00:02, 140.90it/s]

 67%|██████▋   | 542/814 [00:04<00:02, 135.67it/s]

 68%|██████▊   | 556/814 [00:04<00:02, 125.34it/s]

 70%|██████▉   | 569/814 [00:04<00:01, 123.69it/s]

 71%|███████▏  | 582/814 [00:04<00:01, 124.19it/s]

 74%|███████▍  | 601/814 [00:04<00:01, 141.76it/s]

 76%|███████▌  | 620/814 [00:04<00:01, 154.92it/s]

 78%|███████▊  | 636/814 [00:04<00:01, 129.00it/s]

 80%|███████▉  | 650/814 [00:05<00:01, 122.29it/s]

 81%|████████▏ | 663/814 [00:05<00:01, 107.43it/s]

 84%|████████▍ | 683/814 [00:05<00:01, 128.57it/s]

 86%|████████▋ | 703/814 [00:05<00:00, 145.27it/s]

 89%|████████▊ | 722/814 [00:05<00:00, 156.90it/s]

 92%|█████████▏| 745/814 [00:05<00:00, 176.10it/s]

 94%|█████████▍| 764/814 [00:05<00:00, 176.26it/s]

 96%|█████████▋| 785/814 [00:05<00:00, 184.22it/s]

 99%|█████████▉| 804/814 [00:06<00:00, 147.63it/s]

100%|██████████| 814/814 [00:06<00:00, 131.24it/s]




In [9]:
for d in tqdm(DISEASES):
    pd.concat(diseaseDict[d], axis=1).to_csv(here(f'03_downstream_analysis/08_gene_importance/new_shap_plots/results/SHAP_AVGsamples/SHAP_AVGsample_{cell_type}_{d}.csv'))

  0%|          | 0/20 [00:00<?, ?it/s]

  5%|▌         | 1/20 [00:00<00:10,  1.80it/s]

 10%|█         | 2/20 [00:01<00:10,  1.72it/s]

 15%|█▌        | 3/20 [00:01<00:10,  1.61it/s]

 20%|██        | 4/20 [00:02<00:11,  1.39it/s]

 25%|██▌       | 5/20 [00:03<00:09,  1.55it/s]

 30%|███       | 6/20 [00:03<00:08,  1.60it/s]

 35%|███▌      | 7/20 [00:04<00:07,  1.70it/s]

 40%|████      | 8/20 [00:04<00:07,  1.65it/s]

 45%|████▌     | 9/20 [00:05<00:06,  1.80it/s]

 50%|█████     | 10/20 [00:05<00:05,  1.85it/s]

 55%|█████▌    | 11/20 [00:06<00:05,  1.66it/s]

 60%|██████    | 12/20 [00:07<00:04,  1.64it/s]

 65%|██████▌   | 13/20 [00:07<00:04,  1.58it/s]

 70%|███████   | 14/20 [00:08<00:04,  1.39it/s]

 75%|███████▌  | 15/20 [00:09<00:03,  1.43it/s]

 80%|████████  | 16/20 [00:10<00:02,  1.43it/s]

 85%|████████▌ | 17/20 [00:10<00:01,  1.54it/s]

 90%|█████████ | 18/20 [00:11<00:01,  1.40it/s]

 95%|█████████▌| 19/20 [00:12<00:00,  1.35it/s]

100%|██████████| 20/20 [00:13<00:00,  1.13it/s]

100%|██████████| 20/20 [00:13<00:00,  1.47it/s]


