In [3]:
import os

import decoupler as dc
import numba
import pandas as pd
import scanpy as sc

numba.set_num_threads(8)

In [99]:
# adata_file = "data/scPerturbData/TianKampmann2021_CRISPRi.h5ad"
adata_file = "data/scPerturbData/AdamsonWeissman2016_GSM2406681_10X010.h5ad"
adata = sc.read_h5ad(adata_file)
adata

AnnData object with n_obs × n_vars = 65337 × 32738
    obs: 'perturbation', 'read count', 'UMI count', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'ncounts', 'ngenes', 'percent_mito', 'percent_ribo', 'nperts'
    var: 'ensembl_id', 'ncounts', 'ncells'

In [100]:
adata.obs

Unnamed: 0_level_0,perturbation,read count,UMI count,tissue_type,cell_line,cancer,disease,perturbation_type,celltype,organism,ncounts,ngenes,percent_mito,percent_ribo,nperts
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1
AAACATACAAGATG,63(mod)_pBA580,282.0,8.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,8866.0,2914,4.917663,21.306112,2
AAACATACACCTAG,OST4_pDS353,331.0,7.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,13785.0,3818,4.468626,19.492201,2
AAACATACTTCCCG,SEC61A1_pDS031,285.0,10.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,7569.0,2616,5.060113,23.199894,2
AAACATTGAAACAG,EIF2B4_pDS491,1036.0,30.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,13834.0,3488,5.052769,28.733555,2
AAACATTGCAGCTA,SRPR_pDS482,863.0,25.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,15507.0,3620,4.514091,26.729864,2
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
TTTGCATGCTTTAC,STT3A_pDS011,476.0,17.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,14524.0,3356,5.996971,22.679703,2
TTTGCATGGAGGAC,ARHGAP22_pDS458,539.0,19.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,11685.0,2961,4.612751,26.983313,2
TTTGCATGTAGAGA,63(mod)_pBA580,647.0,35.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,16610.0,3473,7.242625,26.207104,2
TTTGCATGTCAAGC,KCTD16_pDS096,98.0,4.0,cell_line,K562,True,chronic myelogenous leukemia,CRISPR,lymphoblasts,human,14473.0,3431,7.296345,20.755890,2


In [101]:
print(adata.var)

                   ensembl_id  ncounts  ncells
gene_symbol                                   
MIR1302-10    ENSG00000243485     11.0      11
FAM138A       ENSG00000237613      0.0       0
OR4F5         ENSG00000186092      0.0       0
RP11-34P13.7  ENSG00000238009      0.0       0
RP11-34P13.8  ENSG00000239945     43.0      43
...                       ...      ...     ...
AC145205.1    ENSG00000215635      0.0       0
BAGE5         ENSG00000268590      0.0       0
CU459201.1    ENSG00000251180      0.0       0
AC002321.2    ENSG00000215616      0.0       0
AC002321.1    ENSG00000215611      0.0       0

[32738 rows x 3 columns]


In [102]:
adata.obs.describe()

Unnamed: 0,read count,UMI count,ncounts,ngenes,percent_mito,percent_ribo,nperts
count,62724.0,62724.0,65337.0,65337.0,65337.0,65337.0,65337.0
mean,638.174734,24.956954,15915.297852,3639.789813,5.327863,24.905811,1.918469
std,567.115265,22.359175,6349.257812,803.38918,2.475392,3.947483,0.393537
min,1.0,1.0,3179.0,462.0,0.020105,2.047059,0.0
25%,305.0,12.0,11669.0,3148.0,4.290833,22.444645,2.0
50%,541.0,21.0,15355.0,3690.0,5.203821,24.729574,2.0
75%,858.0,32.0,19572.0,4192.0,6.187406,27.206667,2.0
max,46452.0,1557.0,67075.0,6859.0,82.174286,56.376625,2.0


In [10]:
# identify controls (heuristic: explicit "control" in perturbation or guide names or nperts==0)
ctrl_mask = (adata.obs['perturbation'].str.contains('control', case=False, na=False)) | \
            (adata.obs['guide_id'].str.contains('non-target', case=False, na=False)) | \
            (adata.obs['nperts'] == 0)
adata.obs['is_control'] = ctrl_mask

# create a per-cell label: target gene name or "control"
adata.obs['ptb_label'] = adata.obs['perturbation'].astype(str).replace('nan', 'unknown')
adata.obs.loc[adata.obs['is_control'], 'ptb_label'] = 'control'

# Option: simplify labels to only TFs you care about
# tf_list = [...]  # your TF list
# adata.obs['ptb_label'] = adata.obs['ptb_label'].where(adata.obs['ptb_label'].isin(tf_list), other='other')

# Save labels for evaluation
labels = adata.obs['ptb_label']


In [44]:
print("labels.shape: ", labels.shape)
labels

labels.shape:  (32300,)


AAACCCAAGGGTTGCA    CREBBP
AAACCCAAGGTAGCCA    SH3RF1
AAACCCACAAACTAGA      TAF1
AAACCCACAGAATTCC     UBTD2
AAACCCAGTAGCGTTT    FRMD4A
                     ...  
TTTGTTGCAGGTCTCG    FAM57B
TTTGTTGCAGTTGCGC     XRCC1
TTTGTTGGTCCCGGTA     RPL14
TTTGTTGGTTGATGTC     PMPCA
TTTGTTGTCTTACTGT       HTT
Name: ptb_label, Length: 32300, dtype: object

### Perturbation data filtering, cleanup for faster processing

In [None]:
adata_files = [
    "data/scPerturbData/TianKampmann2021_CRISPRi.h5ad",  # (32300,  33538)   After Filtering   (31407,  21903)
    "data/scPerturbData/TianKampmann2021_CRISPRa.h5ad",  # (21193,  33538)   After Filtering   (19758,  20404)
    "data/scPerturbData/ReplogleWeissman2022_rpe1.h5ad",  # (247914, 8749 )   After Filtering   (247914, 8749 )
    "data/scPerturbData/NormanWeissman2019_filtered.h5ad",  # (111445, 33694)   After Filtering   (111391, 20265)
    "data/scPerturbData/FrangiehIzar2021_RNA.h5ad",  # (218331, 23712)   After Filtering   (218027, 23593)
    ### "data/scPerturbData/ReplogleWeissman2022_K562_essential.h5ad",  # (310385, 8563) Too large to process here
    "data/scPerturbData/ShifrutMarson2018_kt_filtered.h5ad",  # () After Filtering (52236, 33694)
    "data/scPerturbData/AdamsonWeissman2016_GSM2406675_10X001.h5ad",  # (5768, 35635) After Filtering (3999, 12487)
    "data/scPerturbData/AdamsonWeissman2016_GSM2406681_10X010.h5ad",  # (65337, 32738) After Filtering (59171, 18276)
]

for adata_file in adata_files:
    print("Processing file: ", adata_file)
    adata = sc.read_h5ad(adata_file)

    print("Original Shape: ", adata.shape)
    n_cells, n_genes = adata.shape

    # Filter low-quality genes and cells
    sc.pp.filter_cells(adata, min_genes=500)
    sc.pp.filter_genes(adata, min_cells=10)
    print("Shape after filtering: ", adata.shape)

    # Mitochondrial/QC filter
    adata.var['mt'] = adata.var_names.str.startswith('MT-')  # assuming human data; use 'mt-' for mouse
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
    adata = adata[adata.obs.pct_counts_mt < 20, :]  # filter cells with >20% mitochondrial genes
    print("Final Shape ", adata.shape)

    # Normalize and log-transform
    sc.pp.normalize_total(adata, target_sum=1e4)
    sc.pp.log1p(adata)
    print("After Normalization and Log1p: ", adata.shape)

    # Export adata after filtering
    output_filtered_file = adata_file.replace('.h5ad', '_kt_filtered.h5ad')
    adata.write_h5ad(output_filtered_file)
    print(f"Filtered data saved to {output_filtered_file}")

    print("*" * 50 + "\n")


Processing file:  data/scPerturbData/ReplogleWeissman2022_K562_essential.h5ad
Original Shape:  (310385, 8563)


## Run KALE on multiple datasets

In [4]:
# Datasets which has MYC perturbations
adata_files = [
    "data/scPerturbData/FrangiehIzar2021_RNA.h5ad",
    "data/scPerturbData/ReplogleWeissman2022_K562_essential.h5ad",
    "data/scPerturbData/ReplogleWeissman2022_rpe1.h5ad",
]

for adata_file in adata_files:
    if not os.path.exists(adata_file):
        print(f"\nFile {adata_file} does not exist. Skipping...\n")
        continue

    print("Running KALE for test case: ", adata_file)

    file_name = adata_file.split('/')[-1].replace('.h5ad', '')

    # run script
    print("Running KALE for test case without weights without ignore_zeros")
    !uv run src/kale.py --gene_exp_file {adata_file} --prior_file data/causal_priors.tsv --output_file data/_kale_scores_unweighted_{file_name}.tsv --pvalue_output_file data/_kale_pvalues_unweighted_{file_name}.tsv --ignore_zeros False --cores 8 --method rank_of_ranks --min_targets 0 --weighted False

    print("Running KALE for test case without weights ignoring zeros")
    !uv run src/kale.py --gene_exp_file {adata_file} --prior_file data/causal_priors.tsv --output_file data/_kale_scores_unweighted_ignore_zeros_{file_name}.tsv --pvalue_output_file data/_kale_pvalues_unweighted_ignore_zeros_{file_name}.tsv --ignore_zeros True --cores 8 --method rank_of_ranks --min_targets 1 --weighted False


    print("Running KALE for test case with weights without ignore_zeros")
    !uv run src/kale.py --gene_exp_file {adata_file} --prior_file data/causal_priors.tsv --output_file data/_kale_scores_weighted_{file_name}.tsv --pvalue_output_file data/_kale_pvalues_weighted_{file_name}.tsv --ignore_zeros False --cores 8 --method rank_of_ranks --min_targets 1 --weighted True --weighted_power_factor 1


    print("Running KALE for test case with weights ignoring zeros")
    !uv run src/kale.py --gene_exp_file {adata_file} --prior_file data/causal_priors.tsv --output_file data/_kale_scores_weighted_ignore_zeros_{file_name}.tsv --pvalue_output_file data/_kale_pvalues_weighted_ignore_zeros_{file_name}.tsv --ignore_zeros True --cores 8 --method rank_of_ranks --min_targets 1 --weighted True --weighted_power_factor 1

Running KALE for test case:  data/scPerturbData/AdamsonWeissman2016_GSM2406675_10X001_kt_filtered.h5ad
Running KALE for test case without weights without ignore_zeros
Preprocessing gene expression data...
Using raw gene expression as input for per-cell ranking...
Starting TF activity using 8 cores.
Running in parallel with CORES_USED=8.
Processing cells in parallel:   0%|                    | 0/3999 [00:00<?, ?it/s][Parallel(n_jobs=8)]: Using backend LokyBackend with 8 concurrent workers.
Processing cells in parallel:   1%|           | 32/3999 [00:01<01:53, 34.97it/s][Parallel(n_jobs=8)]: Done   34 out of 3999 | elapsed:    1.1s
Processing cells in parallel:  45%|███▌    | 1808/3999 [00:03<00:03, 650.34it/s][Parallel(n_jobs=8)]: Done 1680 out of 3999 | elapsed:    3.8s
Processing cells in parallel: 100%|████████| 3999/3999 [00:07<00:00, 556.32it/s]
[Parallel(n_jobs=8)]: Done 3999 out of 3999 | elapsed:    7.5s finished

Aggregating results...
kale completed
Kale TF activity

## Run Viper, MLM, ULM, 

In [5]:
# Datasets which has MYC perturbations
adata_files = [
    "data/scPerturbData/FrangiehIzar2021_RNA.h5ad",
    "data/scPerturbData/ReplogleWeissman2022_K562_essential.h5ad",
    "data/scPerturbData/ReplogleWeissman2022_rpe1.h5ad",
]

net_file = "data/causal_priors.tsv"
effect_map = {"upregulates-expression": 1, "downregulates-expression": -1}
net = pd.read_csv(
    net_file,
    sep="\t",
    names=["source", "weight", "target"],
    usecols=[0, 1, 2],
    converters={"weight": effect_map.get}
)[["source", "target", "weight"]]

# Run Decoupler Methods
# methods_to_run = ["viper", "mlm", "ulm"]
methods_to_run = ["viper"]

for adata_file in adata_files:
    print(f"Processing {adata_file}...")
    adata = sc.read_h5ad(adata_file)

    # Filter low-quality genes and cells
    sc.pp.filter_cells(adata, min_genes=500)
    sc.pp.filter_genes(adata, min_cells=10)
    print("Shape after filtering: ", adata.shape)

    # Mitochondrial/QC filter
    adata.var['mt'] = adata.var_names.str.startswith('MT-')  # assuming human data; use 'mt-' for mouse
    sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)
    adata = adata[adata.obs.pct_counts_mt < 20, :]  # filter cells with >20% mitochondrial genes
    print("Final Shape ", adata.shape)

    print("Normalizing and Log1p...")
    # 1. Normalize (Corrects sequencing depth)
    sc.pp.normalize_total(adata, target_sum=1e4)
    # 2. Log transform (Stabilizes variance / makes data Gaussian-like)
    sc.pp.log1p(adata)
    # 3. Scale (Optional for ULM/MLM, Highly Recommended for VIPER)
    sc.pp.scale(adata, max_value=10)

    dc.mt.decouple(adata, net, tmin=0, methods=methods_to_run)

    for method in methods_to_run:
        method_scores = adata.obsm[f"score_{method}"]
        method_scores_df = pd.DataFrame(
            method_scores,
            index=adata.obs_names,
            columns=[col for col in adata.var_names if col in net['source'].values]
        )
        method_scores_df.to_csv(f"data/_{method}_scores_{adata_file.split('/')[-1].replace('.h5ad', '')}.tsv", sep="\t")
        print(f"Saved {method} scores to data/_{method}_scores_{adata_file.split('/')[-1].replace('.h5ad', '')}.tsv")

Processing data/scPerturbData/AdamsonWeissman2016_GSM2406675_10X001_kt_filtered.h5ad...
Normalizing and Log1p...


  return dispatch(args[0].__class__)(*args, **kw)


Saved viper scores to data/_viper_scores_AdamsonWeissman2016_GSM2406675_10X001_kt_filtered.tsv


In [6]:
viper_scores = pd.read_csv("data/_viper_scores_AdamsonWeissman2016_GSM2406675_10X001_kt_filtered.tsv", sep="\t",
                           index_col=0)
# viper_scores = viper_scores.dropna(axis=1, how='all')
viper_scores = viper_scores.reindex(sorted(viper_scores.columns), axis=1)
viper_scores.head()

Unnamed: 0_level_0,A1BG,AATF,ABCA3,ABL1,ACACA,ACADSB,ACAT2,ACLY,ACTL6A,ADCY7,...,ZNF385A,ZNF420,ZNF444,ZNF467,ZNF521,ZNF638,ZNF76,ZNFX1,ZNRD1,ZYX
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACATACACCGAT,,-1.617859,-0.744422,-0.328945,-0.550234,0.078765,-0.550234,-1.178748,0.909472,-0.0498,...,0.673482,,,,,,-1.098185,-0.00562,0.715378,-0.628818
AAACATACAGAGAT,,1.312137,0.631046,-0.338817,-0.538363,0.089844,-0.538363,-1.153887,-2.116118,-0.038749,...,0.691471,,,,,,-1.102796,-0.014653,0.734434,-0.619064
AAACATACGTTGAC,,-0.679692,0.804137,1.475622,-0.635682,0.009033,-0.635682,-1.320915,-0.596403,-0.121333,...,0.593733,,,,,,-1.115631,0.065684,0.633717,1.847311
AAACCGTGCAGCTA,,2.461895,-0.817146,-0.616701,-0.585374,0.050001,-0.585374,0.662877,0.122142,-0.078161,...,0.639863,,,,,,-1.105961,0.023487,0.681056,-0.667702
AAACCGTGGAACTC,,0.187228,-0.363359,0.206227,-0.370617,0.260234,-0.370617,-0.902204,2.456166,0.127604,...,0.920516,,,,,,-1.138586,-0.18389,0.971382,-0.449006


In [7]:
kale_scores = pd.read_csv("data/_kale_scores_unweighted_AdamsonWeissman2016_GSM2406675_10X001_kt_filtered.tsv",
                          sep="\t", index_col=0)
# kale_scores = kale_scores.dropna(axis=1, how='all')
kale_scores = kale_scores.reindex(sorted(kale_scores.columns), axis=1)
kale_scores.head()

Unnamed: 0_level_0,A2M,AATF,ABCA1,ABCA3,ABCB1,ABCG1,ABCG5,ABCG8,ABL1,ACACA,...,ZNF24,ZNF300,ZNF350,ZNF382,ZNF383,ZNF385A,ZNF76,ZNFX1,ZNRD1,ZYX
cell_barcode,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
AAACATACACCGAT,-1.831116,-1.509827,0.034511,-0.811318,1.106163,-0.895033,-0.895033,-0.895033,-0.216546,-0.719286,...,-0.784743,-1.645247,-0.137264,1.743215,0.510729,0.940569,-1.724501,-0.026916,1.030724,-0.830515
AAACATACAGAGAT,1.887936,0.668283,2.119007,0.6768,-0.408522,1.210134,1.210134,1.210134,-0.137432,-0.686871,...,-0.782145,-1.291756,-0.137973,1.198376,0.486617,0.969748,-1.726453,-0.043292,1.060918,-0.799975
AAACATACGTTGAC,1.664323,0.473649,-0.002047,0.654262,-0.9435,-1.03738,-1.03738,-1.03738,0.959447,-0.811102,...,1.092437,0.462682,-0.13567,1.703406,0.596243,0.835394,-1.712455,0.03307,0.922208,2.131584
AAACCGTGCAGCTA,1.777362,3.766162,-2.117577,-0.877021,-0.85284,-0.942169,-0.942169,-0.942169,-0.285025,-0.742611,...,-0.780199,0.531486,-2.927052,-0.475886,0.529959,0.901276,-1.720245,-0.005994,0.989651,-0.859558
AAACCGTGGAACTC,-1.60769,0.269451,0.156968,-0.331397,-0.108344,-0.529982,-0.529982,-0.529982,0.175601,-0.465175,...,-0.810883,0.219799,-0.151256,0.982213,0.276617,1.29145,-1.794669,-0.230872,1.392857,-0.573505


In [8]:
intersec = set(kale_scores.columns).intersection(set(viper_scores.columns))
print(f"Number of common TFs between KALE and VIPER: {len(intersec)}")

Number of common TFs between KALE and VIPER: 1070


In [9]:
# remove rows if net["target"] not in adata.var_names
print("net shape before filtering: ", net.shape)
net = net[net['target'].isin(adata.var_names)]
print("net shape after filtering: ", net.shape)
# net_grouped_targets = filtered_net.groupby('source')['target'].apply(list)
# net_grouped_targets = net_grouped_targets.reset_index()
# net_grouped_targets.columns = ['source', 'targets']
# net_grouped_targets = net_grouped_targets.set_index('source')
# net_grouped_targets

net shape before filtering:  (11839, 3)
net shape after filtering:  (6943, 3)


In [10]:
net_grouped = net.groupby('source').size().sort_values(ascending=False)
net_grouped = net_grouped.reset_index()
net_grouped.columns = ['source', 'target_count']
net_grouped = net_grouped.set_index('source')
print("Kale and network intersection:", len(set(kale_scores.columns).intersection(set(net_grouped.index))))
print("Viper and network intersection:", len(set(viper_scores.columns).intersection(set(net_grouped.index))))

Kale and network intersection: 1572
Viper and network intersection: 1070


In [11]:
all_genes = set(adata.var.index.tolist())
print("Total genes in dataset:", len(all_genes))
print("Total TFs in network:", len(net_grouped))
print("Total TFs in KALE scores:", len(kale_scores.columns))
print("Total TFs in VIPER scores:", len(viper_scores.columns))
print("Intersection of total genes(in datasets) with network TFs: ",
      len(all_genes.intersection(set(net_grouped.index.tolist()))))
print("Intersection of total genes(in datasets) with with KALE TFs: ",
      len(all_genes.intersection(set(kale_scores.columns.tolist()))))
print("Intersection of total genes(in datasets) with with VIPER TFs: ",
      len(all_genes.intersection(set(viper_scores.columns.tolist()))))

Total genes in dataset: 12487
Total TFs in network: 1572
Total TFs in KALE scores: 1572
Total TFs in VIPER scores: 1291
Intersection of total genes(in datasets) with network TFs:  1070
Intersection of total genes(in datasets) with with KALE TFs:  1070
Intersection of total genes(in datasets) with with VIPER TFs:  1291


In [12]:
# group net by making targets list
net_grouped_targets = net.groupby('source')['target'].apply(list)
net_grouped_targets = net_grouped_targets.reset_index()
net_grouped_targets.columns = ['source', 'targets']
net_grouped_targets = net_grouped_targets.set_index('source')
net_grouped_targets.head()

Unnamed: 0_level_0,targets
source,Unnamed: 1_level_1
A2M,[STAT3]
AATF,"[BAX, CTNNB1, MYC]"
ABCA1,"[NR1H2, NR1H3, SREBF2]"
ABCA3,"[SREBF1, GATA6, NFATC3, CEBPA]"
ABCB1,"[SP1, RELB, RELA, TCF7L2, FOXO1, EGR1, CEBPB, ..."
