In [2]:
import numpy as np
import seaborn as sns
from pyprojroot.here import here
import re
import os
import pandas as pd
import matplotlib.pyplot as plt
import anndata as ad
from tqdm.auto import tqdm

In [8]:
CELL_TYPES=["Mono", "T_CD4_Naive", "T_CD4_NonNaive", "T_CD8_Naive", "T_CD8_NonNaive", "B", "Plasma", "UTC", "ILC", "pDC", "DC"]

In [2]:
def load_shap_statistics(
    cell_type: str = '',
    interactions: bool = False, 
    run_name: str = 'run1',
    dir_path: str = '',
):

    if interactions:
        shap_type = 'shap_int'
    else:
        shap_type = 'shap'

    fname = here(f"{dir_path}/total_{run_name}_{cell_type}_{shap_type}_stats.npz")

    shap_stats = np.load(fname)

    return (
        shap_stats['mean_raw'],
        shap_stats['var_raw'],
        shap_stats['mean_abs'],
        shap_stats['var_abs'])

In [6]:
def load_sorted_shap_values_fnames(
    cell_type: str = '',
    run_name: str = 'run1'):

    dirpath = here(f"03_downstreamAnalyses/05_SHAP/results/04_shap/shap_vals")
    fname_regex = f'^{run_name}_{cell_type}_shap_values_' + r'(\d+)'
    results_batches = [
        (fname, int(re.search(fname_regex, fname).group(1)))
        for fname in os.listdir(dirpath)
        if re.search(fname_regex, fname)]

    sorted_files = sorted(results_batches, key=lambda x: x[1])

    sorted_filenames = [os.path.join(dirpath, filename) for filename, _ in sorted_files]

    return sorted_filenames

In [152]:
def define_cell_subset(cell_type, frac=0.02, max_cells=60000):
    adata = ad.read_h5ad(here(f'03_downstreamAnalyses/05_SHAP/data/{cell_type}_adataMerged_SPECTRAgenes.log1p.h5ad'), backed='r')
    total_cells = 60000 + 1
    frac = 1
    while total_cells > max_cells:
        cells_subset = (
            adata.obs
            .reset_index()
            .groupby('sampleID', observed=True)
            .sample(frac=frac, random_state=42)
            .index)
        total_cells = cells_subset.shape[0]
        frac -= 0.025
        frac = round(frac, 3)
    cells_subset = np.sort(cells_subset)
    return cells_subset

In [161]:
def extract_shap_cell_subset(cell_type):
    cells_subset = np.load(here(f'03_downstreamAnalyses/05_SHAP/results/04_shap/shap_stripplot/{cell_type}_idxs.npy'))
    start_idx = 0
    output_shap = []
    for fname in load_sorted_shap_values_fnames(cell_type):
        batch_of_shap = np.load(fname)['shap_values']
        n_samples = batch_of_shap.shape[0]
        idxs = np.intersect1d(np.arange(n_samples)+start_idx,cells_subset)
        assert (idxs == np.sort(idxs)).all()
        output_shap.append(batch_of_shap[idxs-start_idx])
        start_idx += n_samples
    output_shap = np.concatenate(output_shap)
    assert output_shap.shape[0] == cells_subset.shape[0]
    np.savez(
        here(f'03_downstreamAnalyses/05_SHAP/results/04_shap/shap_stripplot/{cell_type}_shap_subset'), 
        shap_values=output_shap)
    return