In [26]:
N_GENES: int = 20

In [27]:
import os, sys
from pyprojroot.here import here
from tqdm.auto import trange, tqdm

import pandas as pd
import anndata as ad
import numpy as np

import pickle

from collections import defaultdict

import re

import matplotlib.pyplot as plt

# Import functions
sys.path.insert(1, str(here('bin')))
from customPythonFunctions import generate_shap_data
from customPalette import shap_cell_types as CELL_TYPES

In [28]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [29]:
# CELL_TYPES= ["Mono", "T_CD4_Naive", "T_CD4_NonNaive", "T_CD8_Naive", "T_CD8_NonNaive", "B", "Plasma", "UTC", "ILC", "pDC", "DC"]
EXT_DISEASES = ['healthy', 'sepsis', 'CD', 'SLE', 'HIV', 'cirrhosis', 'RA', 'COVID']
# diseases = np.array(['BRCA', 'CD', 'COPD', 'COVID', 'CRC', 'HBV', 'HIV', 'HNSCC', 'MS', 'NPC', 'PS', 'PSA', 'RA', 'SLE', 'UC', 'asthma', 'cirrhosis', 'flu', 'healthy', 'sepsis'])

### Top relevant genes per *disease*

In [30]:
shap_disease_per_cell_type = dict()

In [31]:
for cell_type in  tqdm(CELL_TYPES):
    shap_ct_df = generate_shap_data(
        cell_type = cell_type,
        shap_stats_path = here(f'03_downstream_analysis/08_gene_importance/results/targetY_disease/shap/shap_vals/total_run1_{cell_type}_shap_stats.npz'), 
        adata_path = here(f'03_downstream_analysis/08_gene_importance/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'),
        gene_symbol_df_path = here('03_downstream_analysis/02_gene_universe_definition/results/04_selected_gene_list.pkl'),
        stat = 'mean_abs',
        category_col = 'disease',
        expressed_gene_cellTypes_path = here('03_downstream_analysis/08_gene_importance/results/genes_expressing_cells.csv')
    )    
    filtered_shap = shap_ct_df.loc[:,EXT_DISEASES]
    shap_disease_per_cell_type[cell_type]= filtered_shap

  0%|          | 0/11 [00:00<?, ?it/s]

In [32]:
symbol2ENSdf = pd.read_pickle(here('03_downstream_analysis/02_gene_universe_definition/results/04_selected_gene_list.pkl'))
symbol2ENSdf.head()

Unnamed: 0_level_0,hgnc_id,symbol,locus_group,HUGO_status
ensembl_gene_id,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1
ENSG00000000003,HGNC:11858,TSPAN6,protein_coding,official
ENSG00000000457,HGNC:19285,SCYL3,protein_coding,official
ENSG00000000938,HGNC:3697,FGR,protein_coding,official
ENSG00000000971,HGNC:4883,CFH,protein_coding,official
ENSG00000001036,HGNC:4008,FUCA2,protein_coding,official


In [33]:
top_shap_per_cell_type = {}
top_shap_values_per_cell_type = {}
for cell_type in CELL_TYPES:
    tmp = shap_disease_per_cell_type[cell_type]
    highest_shap_values = {}
    highest_shaps = {}
    for disease in EXT_DISEASES:
        highest_shap = tmp.loc[:, disease].nlargest(N_GENES).to_frame()
        highest_shap_values[disease] = highest_shap
        highest_shaps[disease] = highest_shap.index
    top_shap_per_cell_type[cell_type] = pd.DataFrame(highest_shaps)
    top_shap_values_per_cell_type[cell_type] = highest_shap_values

In [34]:
selected_genes_per_cell_type = {}
for cell_type in CELL_TYPES:
    tmp = top_shap_per_cell_type[cell_type]
    tmp = tmp.to_numpy().flatten()
    tmp = np.unique(tmp)
    selected_genes_per_cell_type[cell_type] = tmp

In [35]:
shap_gene_set_sizes = {k: len(v) for k,v in selected_genes_per_cell_type.items()}
shap_gene_set_sizes

{'Mono': 114,
 'T_CD4_Naive': 101,
 'T_CD4_NonNaive': 102,
 'T_CD8_Naive': 110,
 'T_CD8_NonNaive': 104,
 'B': 100,
 'Plasma': 107,
 'UTC': 115,
 'ILC': 100,
 'pDC': 112,
 'DC': 118}

In [36]:
for cell_type in CELL_TYPES:
    selection = symbol2ENSdf.reset_index().set_index('symbol').loc[selected_genes_per_cell_type[cell_type]].ensembl_gene_id.values
    np.save(f'gene_subsets_{N_GENES}/{cell_type}_shap',selection)

### Top relevant genes per *studies*

In [37]:
disease_studyID_df = pd.read_pickle(here('01_data_processing/results/01_INFLAMMATION_main_sampleMetadata.pkl'))[['disease','studyID']].drop_duplicates()
EXT_STUDIES = disease_studyID_df.loc[disease_studyID_df.disease.isin(EXT_DISEASES), 'studyID'].astype(str).unique()

In [38]:
EXT_STUDIES

array(['SCGT00', 'SCGT01', 'SCGT02', 'SCGT04', 'Reyes2020', 'Cillo2020',
       'Zhang2023', 'Schafflick2020', 'Terekhova2023', 'Perez2022',
       'Wang2020', 'COMBAT2022', 'Ren2021'], dtype=object)

In [39]:
shap_studies_per_cell_type = dict()

In [40]:
for cell_type in  tqdm(CELL_TYPES):
    shap_ct_df = generate_shap_data(
        cell_type = cell_type,
        shap_stats_path = here(f'03_downstream_analysis/08_gene_importance/results/targetY_studyID/shap/shap_vals/total_studyID_{cell_type}_shap_stats.npz'), 
        adata_path = here(f'03_downstream_analysis/08_gene_importance/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'),
        gene_symbol_df_path = here('03_downstream_analysis/02_gene_universe_definition/results/04_selected_gene_list.pkl'),
        stat = 'mean_abs',
        category_col = 'studyID',
        expressed_gene_cellTypes_path = here('03_downstream_analysis/08_gene_importance/results/genes_expressing_cells.csv')
    )    
    filtered_shap = shap_ct_df.loc[:,EXT_STUDIES]
    shap_studies_per_cell_type[cell_type]= filtered_shap

  0%|          | 0/11 [00:00<?, ?it/s]

In [41]:
def get_gene_set(shap_df, gene_set_size, additional_mask=None):

    mask = shap_df.values > 0
    
    if additional_mask is not None:
        mask &= (additional_mask.values)
        
    masked_df = shap_df.values
    masked_df = np.where(mask, masked_df, np.nan)
    
    gene_rank = (-1*masked_df).argsort(axis=0)
    
    gene_set = set()
    r = 0
    while len(gene_set) < gene_set_size:
    
        top_r_genes = gene_rank[r]
        
        gene_set.update(shap_df.index[top_r_genes])
        
        r+=1

    return gene_set

In [42]:
selected_genes_studyID_per_cell_type = {}
for cell_type in CELL_TYPES:
    tmp = shap_studies_per_cell_type[cell_type]
    selected_genes_studyID_per_cell_type[cell_type] = np.array(get_gene_set(tmp.loc[:, EXT_STUDIES], shap_gene_set_sizes[cell_type]))

In [43]:
for cell_type in CELL_TYPES:
    selection = symbol2ENSdf.reset_index().set_index('symbol').loc[selected_genes_studyID_per_cell_type[cell_type]].ensembl_gene_id.values
    np.save(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID_samesize',selection)

In [44]:
for cell_type in CELL_TYPES:
    subset = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID_samesize.npy', allow_pickle=True)
    orig = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID.npy', allow_pickle=True)
    print(len(subset), len(orig), len(np.intersect1d(subset, orig)))

117 141 117
105 140 105
105 159 104
111 166 109
104 163 96
104 150 98
108 151 108
115 138 115
104 150 104
115 176 115
122 146 122


### Random gene selection

In [45]:
for cell_type in CELL_TYPES:
    adata = ad.read_h5ad(here(f'03_downstream_analysis/08_gene_importance/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'), backed='r')
    symbols_df = pd.read_pickle(here('03_downstream_analysis/02_gene_universe_definition/results/04_selected_gene_list.pkl'))
    symbols_sorted = symbols_df.loc[adata.var_names].symbol.values
    perc_gene_expr = pd.read_csv(here('03_downstream_analysis/08_gene_importance/results/genes_expressing_cells.csv'))
    perc_gene_expr = perc_gene_expr.astype({'symbol': 'str', '% cells': 'float', 'CellType': 'category'})
    well_expressed_symbols = perc_gene_expr.query('`% cells` > 5 & CellType == @cell_type').symbol
    symbol_mask = symbols_sorted.isin(well_expressed_symbols)
    symbols_sorted = symbols_sorted[symbol_mask]

    target_sel_size = len(selected_genes_per_cell_type[cell_type])
    print(symbols_sorted.shape)
    symbols_sorted = symbols_sorted[~symbols_sorted.isin(selected_genes_per_cell_type[cell_type])]
    print(symbols_sorted.shape)

    for seed in range(100):
        rng = np.random.default_rng(seed)
        rand_sel = rng.choice(symbols_sorted, size=target_sel_size, replace=False)
        rand_sel = symbols_df.reset_index().set_index('symbol').loc[rand_sel].ensembl_gene_id.values
        assert len(np.unique(rand_sel)) == target_sel_size
        np.save(f'gene_subsets_{N_GENES}/{cell_type}_{seed}', rand_sel)

(612,)
(498,)
(436,)
(335,)
(492,)
(390,)
(431,)
(321,)
(520,)
(416,)
(466,)
(366,)
(592,)
(485,)
(496,)
(381,)
(506,)
(406,)
(522,)
(410,)
(649,)
(531,)


### Intersection between most relevant studyID and disease genes

Here we are disentangling genes that are relevant to classify a disease from the ones for classifing studies (i.e., batch effect related)

**First, we removing genes with an absolute shap value (aggregated on each study were a disease is present) from the list of genes relevant to classify diseases**

In [46]:
disease_disease_studyID_filt_df = disease_studyID_df.loc[disease_studyID_df.disease.isin(EXT_DISEASES)].groupby('disease', observed=True).agg(lambda x: list(x))

In [47]:
for cell_type in CELL_TYPES:
    
    TOP_genes_disease_celltype_postFilter_symbol = list()

    relevant_disease_only = dict()
    for d, study_list in disease_disease_studyID_filt_df.iterrows():
        
        shap_relevant_studies = shap_studies_per_cell_type[cell_type][study_list.studyID].sum(axis=1)
        relevant_disease_only[d] = shap_relevant_studies <= np.quantile(shap_relevant_studies, 0.25)

    relevant_disease_only = pd.DataFrame.from_dict(relevant_disease_only).loc[:, shap_disease_per_cell_type[cell_type].columns]
    
    geneset = get_gene_set(shap_disease_per_cell_type[cell_type], shap_gene_set_sizes[cell_type], relevant_disease_only)
    
    TOP_genes_disease_postFilter_ENSid = symbol2ENSdf[symbol2ENSdf.symbol.isin(geneset)].index.to_numpy()
    np.save(f'gene_subsets_{N_GENES}/{cell_type}_shap_disease_NOstudy_samesize',TOP_genes_disease_postFilter_ENSid)      

In [48]:
for cell_type in CELL_TYPES:
    subset = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_disease_NOstudy_samesize.npy', allow_pickle=True)
    orig = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_disease_NOstudy.npy', allow_pickle=True)
    print(len(subset), len(orig), len(np.intersect1d(subset, orig)))

119 126 119
101 123 101
103 118 103
113 116 107
109 119 109
103 114 103
108 136 108
119 130 119
102 126 102
115 120 115
121 131 121


**Second, we removing genes with an absolute shap value (aggregated on each disease included in a study) from the list of genes relevant to classify studies**

In [49]:
disease_studyID_disease_filt_df = disease_studyID_df.loc[disease_studyID_df.disease.isin(EXT_DISEASES)].groupby('studyID', observed=True).agg(lambda x: list(x))

In [50]:
for cell_type in CELL_TYPES:
    
    TOP_genes_studyID_celltype_postFilter_symbol = list()

    relevant_study_only = dict()
    for s, disease_list in disease_studyID_disease_filt_df.iterrows():
        
        shap_relevant_disease = shap_disease_per_cell_type[cell_type][disease_list.disease].sum(axis=1)
        relevant_study_only[s] = shap_relevant_disease <= np.quantile(shap_relevant_disease, 0.25)

    relevant_study_only = pd.DataFrame.from_dict(relevant_study_only).loc[:, shap_studies_per_cell_type[cell_type].columns]
    
    geneset = get_gene_set(shap_studies_per_cell_type[cell_type], shap_gene_set_sizes[cell_type], relevant_study_only)
    
    TOP_genes_studyID_postFilter_ENSid = symbol2ENSdf[symbol2ENSdf.symbol.isin(geneset)].index.to_numpy()
    np.save(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID_NOdisease_samesize',TOP_genes_studyID_postFilter_ENSid)

In [51]:
for cell_type in CELL_TYPES:
    subset = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID_NOdisease_samesize.npy', allow_pickle=True)
    orig = np.load(f'gene_subsets_{N_GENES}/{cell_type}_shap_studyID_NOdisease.npy', allow_pickle=True)
    print(len(subset), len(orig), len(np.intersect1d(subset, orig)))

118 162 111
102 144 101
104 96 81
110 80 62
104 83 48
101 108 83
114 142 101
118 121 103
104 114 91
113 122 104
123 159 118
