In [1]:
RUN_NAME: str = "studyID"
CELL_TYPE: str = 'DC'
BUCKET_DIRPATH: str = ""
OUTPUT: str = "shap_vals_and_stats"
target_y = 'studyID'

In [2]:
# Parameters
CELL_TYPE = "Plasma"


In [3]:
N_GENES = 935
N_CLASSES = 16


In [4]:
if RUN_NAME != "":
    RUN_NAME = RUN_NAME + "_"

In [5]:
import shap
import os
import joblib
import anndata as ad
import numpy as np
#from pyprojroot.here import here
from tqdm.auto import trange, tqdm
from more_itertools import roundrobin
from numba import njit, prange
import sparse as sp
from sklearn.metrics import f1_score

from pyprojroot import here

In [6]:
adata = ad.read_h5ad(
    here(f'03_downstream_analysis/08_gene_importance/data/{CELL_TYPE}_adataMerged_SPECTRAgenes.log1p.h5ad'),
)

In [7]:
def custom_f1_score(y_true, y_pred):
    return -f1_score(y_true, y_pred.argmax(1), average='weighted')

In [8]:
xgb = joblib.load(here(f'03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/03_xgboost/best_model/{RUN_NAME}{CELL_TYPE}_xgb.json'))
xgb

In [9]:
@njit()
def find_nonzero(m):
    coords = np.nonzero(m)
    return coords
    
def to_sparse(m):
    coords = find_nonzero(m)
    data = m[coords]
    return sp.COO(coords, m[coords], shape=m.shape)

In [10]:
class Weldford:
    def __init__(self, shape):
        self.count = 0
        self.mean = np.zeros(shape)
        self.M2 = np.zeros(shape)

    @staticmethod
    @njit(parallel=True)
    def _update(arr, count, mean, M2):
        count += 1
        delta = arr - mean
        mean += delta / count
        delta2 = arr - mean
        M2 += delta * delta2
        return count, mean, M2
    
    def update(self, arr):
        self.count, self.mean, self.M2 = self._update(arr, self.count, self.mean, self.M2)

    def update_all(self, arr):
        self.count, self.mean, self.M2 = self._update_all(arr, self.count, self.mean, self.M2)
        return self
    
    @staticmethod
    @njit(parallel=True)
    def _update_all(arr, count, mean, M2):
        for idx in range(arr.shape[0]):
            a = arr[idx]
            count += 1
            delta = a - mean
            mean += delta / count
            delta2 = a - mean
            M2 += delta * delta2
        return count, mean, M2

    def finalize(self):
        if self.count < 2:
            return np.nan
        else:
            mean, variance = self.mean, self.M2 / self.count
            return mean, variance

In [11]:
def compute_shap_values_with_stats(xgb, X, outer_batch_size = int(10000), inner_batch_size = 25, target_y=''):
    explainer = shap.explainers.TreeExplainer(xgb, feature_perturbation='tree_path_dependant')
    
    stream_stats = {}
    for shap_type in ['shap_values', 'shap_int_values']:
        for kind in ['abs', 'raw', 'raw_sum', 'raw_sqsum']:
            shape = (N_GENES, N_CLASSES) if shap_type == 'shap_values' else (N_GENES, N_GENES, N_CLASSES)
            stream_stats[f"{shap_type}_{kind}"] = Weldford(shape=shape)

    for batch_idx, oidx in enumerate(tqdm(np.arange(X.shape[0], step=outer_batch_size))):

        outer_batch_size_fix = min(outer_batch_size, X.shape[0] - oidx)
        obatch_shap_vals = []
        for iidx in tqdm(np.arange(outer_batch_size_fix, step=inner_batch_size)):

            shap_type = 'shap_int_values'
            batch_shap_int_vals = explainer.shap_interaction_values(X[oidx+iidx:oidx+iidx+inner_batch_size])
            
            stream_stats[f'{shap_type}_raw'].update_all(batch_shap_int_vals)
            abs_batch_shap_int_vals = np.abs(batch_shap_int_vals)
            stream_stats[f'{shap_type}_abs'].update_all(abs_batch_shap_int_vals)
            
            shap_type = 'shap_values'
            batch_shap_vals = batch_shap_int_vals.sum(1)
            stream_stats[f'{shap_type}_raw'].update_all(batch_shap_vals)
            abs_batch_shap_vals = np.abs(batch_shap_vals)
            stream_stats[f'{shap_type}_abs'].update_all(abs_batch_shap_vals)

            obatch_shap_vals.append(batch_shap_vals)

        # Concatenate sparse
        obatch_shap_vals = np.concatenate(obatch_shap_vals)

        np.savez_compressed(
            here(f'03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/shap/shap_vals/fix_{RUN_NAME}{CELL_TYPE}_shap_values_{batch_idx}'), 
            shap_values=obatch_shap_vals)

        shap_type = 'shap_int_values'
        mean_raw, var_raw = stream_stats[f'{shap_type}_raw'].finalize()
        mean_abs, var_abs = stream_stats[f'{shap_type}_abs'].finalize()
        np.savez_compressed(
            here(f'03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/shap/shap_vals/fix_{RUN_NAME}{CELL_TYPE}_{shap_type}_stats_{batch_idx}'), 
            mean_raw=mean_raw, var_raw=var_raw, mean_abs=mean_abs, var_abs=var_abs)
        
        shap_type = 'shap_values'
        mean_raw, var_raw = stream_stats[f'{shap_type}_raw'].finalize()
        mean_abs, var_abs = stream_stats[f'{shap_type}_abs'].finalize()
        np.savez_compressed(
            here(f'03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/shap/shap_vals/fix_{RUN_NAME}{CELL_TYPE}_{shap_type}_stats_{batch_idx}'), 
            mean_raw=mean_raw, var_raw=var_raw, mean_abs=mean_abs, var_abs=var_abs)

In [12]:
def compute_and_save_shap_interaction_values(xgb, X, sorted_idxs, outer_batch_size = int(500), inner_batch_size = 25, n_instances: int = 10000, target_y=target_y):
    explainer = shap.explainers.TreeExplainer(xgb, feature_perturbation='tree_path_dependant')
    for obatch_idx in trange(int(n_instances // outer_batch_size)):
        # Compute outer idx
        oidx = int(outer_batch_size * obatch_idx)

        shap_coo = []
        for iidx in range(0, outer_batch_size, inner_batch_size):
            selected_idxs = sorted_idxs[oidx+iidx:oidx+iidx+inner_batch_size]
            shap_coo.append(to_sparse(explainer.shap_interaction_values(X[selected_idxs])))

        # Concatenate sparse
        shap_coo = sp.concatenate(shap_coo)

        # Save batch
        sp.save_npz(here(f'03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/shap/{RUN_NAME}{CELL_TYPE}_shap_int_{obatch_idx}.npz'), shap_coo)

        print(f"BATCH {obatch_idx} DONE")

In [13]:
!pwd

/scratch_isilon/groups/singlecell/shared/projects/Inflammation-PBMCs-Atlas/03_downstream_analysis/08_gene_importance


In [14]:
os.makedirs(here(f"03_downstream_analysis/08_gene_importance/results/targetY_{target_y}/shap/shap_vals/"), exist_ok=True)

if OUTPUT=="shap_vals_and_stats":
    compute_shap_values_with_stats(xgb, adata.X, target_y=target_y)
elif OUTPUT=="shap_int":
    patient_roundrobin = list(roundrobin(*adata.obs.groupby('sampleID').indices.values()))
    np.save(here(f'03_downstream_analysis/08_gene_importance/results/{target_y}/shap/{RUN_NAME}{CELL_TYPE}_patient_roundrobin.npy'), patient_roundrobin)
    compute_and_save_shap_interaction_values(xgb, adata.X, sorted_idxs=patient_roundrobin, target_y=target_y)  
else:
    raise ValueError()

  0%|          | 0/2 [00:00<?, ?it/s]

  0%|          | 0/400 [00:00<?, ?it/s]

  0%|          | 0/275 [00:00<?, ?it/s]