In [1]:
import scanpy as sc
import numpy as np
from tqdm import tqdm

import sys
sys.path.append('../')

import preprocessing_tools as pt

# Read adata

In [2]:
data_path = '../../original_datasets/Norman/NormanWeissman2019_filtered.h5ad'
adata = sc.read_h5ad(data_path)

In [3]:
adata

AnnData object with n_obs × n_vars = 111445 × 33694
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_line', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'perturbation', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo'
    var: 'ensemble_id', 'ncounts', 'ncells'

# Rename columns

In [4]:
adata.obs.rename(columns = {
                'nCount_RNA': 'ncounts',
                'nFeature_RNA': 'ngenes',
                'percent.mt': 'percent_mito',
                'cell_line': 'cell_type',
                'perturbation': 'condition',
            }, inplace=True
                )
adata.obs['condition'] = adata.obs['condition'].str.replace('_', '+')
adata.obs['condition'] = adata.obs['condition'].astype('category')
adata.obs['condition'] = adata.obs['condition'].str.replace('control', 'ctrl')
adata.obs['cond_harm'] = adata.obs['condition'].copy()

# Unique perturbations

In [5]:
unique_single_perts = set()
unique_conditions = np.unique(adata.obs['condition']).tolist()
unique_conditions.remove('ctrl')
                              
for pert in unique_conditions:
    if '+' in pert:
        unique_single_perts.add(pert.split('+')[0])
        unique_single_perts.add(pert.split('+')[1])
    else:
        unique_single_perts.add(pert)

# Remove perturbations that are not in gene names or cannot be added to GEARS GO graph

In [6]:
perts_absent_in_genes = [x for x in unique_single_perts if not x in adata.var_names]
conditions_to_remove = set()

for u_c in unique_conditions:
    for a_p in perts_absent_in_genes:
        if a_p in u_c:
            conditions_to_remove.add(u_c)

GO_absent = {'RHOXF2BB', 'LYL1+IER5L', 'IER5L', 'KIAA1804', 'RHOXF2BB+ZBTB25', 'RHOXF2BB+SET'}
conditions_to_remove.update(GO_absent)

In [7]:
len(unique_conditions)

236

In [8]:
adata = adata[~adata.obs['condition'].isin(conditions_to_remove)]

# Updated condition

In [9]:
unique_conditions = np.unique(adata.obs['condition']).tolist()
unique_conditions.remove('ctrl')

double_perturbations = [x for x in unique_conditions if '+' in x]
single_perturbations = [x for x in unique_conditions if not '+' in x]

print(f'Number of single perturbations {len(single_perturbations)}. Number of double perts: {len(double_perturbations)}')

Number of single perturbations 101. Number of double perts: 128


In [10]:
len(unique_conditions)

229

# Standard filtering

In [11]:
sc.pp.filter_cells(adata, min_counts=100)
sc.pp.filter_genes(adata, min_counts=5)
sc.pp.calculate_qc_metrics(adata, inplace=True)

  adata.obs["n_counts"] = number


# Save counts before normalization

In [12]:
adata.X.max(), adata.X.min()

(3718.0, 0.0)

In [13]:
adata.layers["counts"] = adata.X.copy()

In [14]:
adata.X.max(), adata.X.min()

(3718.0, 0.0)

# Normalization

In [15]:
sc.pp.normalize_total(
    adata, 
    target_sum=1e4,
    exclude_highly_expressed=True
)
sc.pp.log1p(adata)

In [16]:
# Median before HVG selection
adata_subset_full = adata[adata.obs['condition'].isin(single_perturbations + ['ctrl'])]
_sums_full = adata_subset_full.X.toarray().sum(axis=1, keepdims=True)
data_median_full = np.median(_sums_full)

adata.uns['single_perts_median_full'] = data_median_full

In [17]:
adata.uns['single_perts_median_full']

3191.5535

# Compute top 5K HVG

In [18]:
sc.pp.highly_variable_genes(
    adata,
    n_top_genes=5000,
    subset=False
    )

In [19]:
adata

AnnData object with n_obs × n_vars = 108619 × 21591
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'condition', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'cond_harm', 'n_counts', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes'
    var: 'ensemble_id', 'ncounts', 'ncells', 'n_counts', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'single_perts_median_full', 'hvg'
    layers: 'counts'

# Compute DEGs

## Ctrl versus pert

In [20]:
import gc

In [21]:
rank_genes_groups = {}
for cond in tqdm(unique_conditions):
    degs = pt.compute_degs(
        adata, 
        cov_key='cell_type',
        cond_key='condition',
        stim_name=cond,
        control_name='ctrl',
        condition_names = ['ctrl', cond],
        method='wilcoxon'
        )
    rank_genes_groups[cond] = degs['K562']
    gc.collect()

100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 229/229 [2:05:22<00:00, 32.85s/it]


# ['pertA+B'] versus ['pertA', 'pertB']

In [22]:
rank_genes_groups_double_vs_single = {}
print('Double perts DEGs')
for double_pert in tqdm(double_perturbations):
    pert1, pert2 = double_pert.split('+')
    adata_subset = adata[adata.obs['condition'].isin([double_pert, pert1, pert2])]
    
    degs_combo = pt.compute_degs(
        adata_subset,
        cov_key='cell_type',
        cond_key='condition',
        stim_name=double_pert, 
        control_name='',
        synergy=True,
        condition_names = [double_pert, pert1, pert2],
        method='wilcoxon'
        )
    rank_genes_groups_double_vs_single[double_pert] = degs_combo['K562']
    gc.collect()

Double perts DEGs


100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 128/128 [04:31<00:00,  2.12s/it]


# Select genes to keep

In [23]:
deg_genes = set()

for genes in rank_genes_groups.values():
    deg_genes.update(genes[:20])

for genes in rank_genes_groups_double_vs_single.values():
    deg_genes.update(genes[:20])

In [24]:
len(deg_genes), len(single_perturbations)

(892, 101)

In [25]:
highly_var_genes = adata.var_names[adata.var["highly_variable"]].tolist()
genes_to_keep = list(set(single_perturbations).union(highly_var_genes).union(deg_genes))

In [26]:
len(genes_to_keep)

5446

# Filter: Keep highly var and top 20 DEGs for each pert

In [27]:
adata = adata[:, genes_to_keep]

In [28]:
adata.uns['rank_genes_groups'] = rank_genes_groups
adata.uns['rank_genes_groups_combo_specific'] = rank_genes_groups_double_vs_single

  adata.uns['rank_genes_groups'] = rank_genes_groups


# Subset DEGs

In [29]:
for degs_type in ['', '_combo_specific']:
    for pert_key in adata.uns[f'rank_genes_groups{degs_type}'].keys():
        adata.uns[f'rank_genes_groups{degs_type}'][pert_key] = [x for x in adata.uns[f'rank_genes_groups{degs_type}'][pert_key] if x in adata.var_names]

In [30]:
adata.layers['raw_norm_1e4'] = adata.X.copy()

# Set X to counts

In [31]:
adata.X = adata.layers['counts'].copy()

In [32]:
adata.layers['raw_norm_1e4'].max(), adata.layers['raw_norm_1e4'].min()

(9.3838415, 0.0)

In [33]:
adata.X.max(), adata.X.min()

(3718.0, 0.0)

# Save median of single perturbations and control cells

In [34]:
len(single_perturbations)

101

In [35]:
adata_subset = adata[adata.obs['condition'].isin(single_perturbations + ['ctrl'])]
_sums = adata_subset.X.toarray().sum(axis=1, keepdims=True)
data_median = np.median(_sums)

adata.uns['single_perts_median'] = data_median

# Add sc cell ids

In [36]:
adata.obs['sc_cell_ids'] = list(range(adata.shape[0]))

In [37]:
#adata.write_h5ad(f'../../preprocessed_datasets/Norman.h5ad')

# Separate perturbation col for scDisentangle into pertA, pertB

In [38]:
def align_perturbation_codes(adata, pert1_col='perturbation1', pert2_col='perturbation2'):
    
    # Get all unique categories from both columns
    all_categories = sorted(set(adata.obs[pert1_col].unique()) | set(adata.obs[pert2_col].unique()))
    
    # Convert both columns to categorical with the same categories
    adata.obs[pert1_col] = adata.obs[pert1_col].astype('category').cat.set_categories(all_categories)
    adata.obs[pert2_col] = adata.obs[pert2_col].astype('category').cat.set_categories(all_categories)
    
    return adata

In [39]:
perts = adata.obs['condition'].tolist()
pert1 = [x.split('+')[0] if len(x.split('+')) > 1 else x for x in perts]
pert2 = [x.split('+')[1] if len(x.split('+')) > 1 else 'NOPERT' for x in perts]

adata.obs['perturbation1'] = pert1
adata.obs['perturbation2'] = pert2

adata = align_perturbation_codes(adata)

# Save adata

In [40]:
adata.write_h5ad(f'../../preprocessed_datasets/norman.h5ad')

# Write double perts to txt file

In [41]:
double_perturbations = [x for x in unique_conditions if '+' in x]

In [43]:
import random

In [44]:
rng = random.Random(42)

rng.shuffle(double_perturbations)

with open(
    '../../preprocessed_datasets/norman_double_perts.txt',
    'w'
    ) as f:
        for d_p in double_perturbations:
            f.write(d_p + '\n')

In [45]:
len(double_perturbations)

128

In [46]:
with open(
    '../../preprocessed_datasets/norman_double_perts.txt',
    'r'
    ) as f:
        double_perts = f.readlines()

double_perts = [db.replace('\n', '') for db in double_perts]
n_perts = len(double_perts)
print('Number of double perts', n_perts)

Number of double perts 128


# Sanity check

In [47]:
codes1 = adata.obs['perturbation1'].values.codes
codes2 = adata.obs['perturbation2'].values.codes

unique_perts = np.unique(adata.obs['perturbation1'])
for p in unique_perts:
    # Find the code in the original full column for perturbation1
    mask1 = adata.obs['perturbation1'] == p
    c1 = codes1[mask1][0]
    
    # Check if perturbation exists in perturbation2
    if p in adata.obs['perturbation2'].values:
        mask2 = adata.obs['perturbation2'] == p
        c2 = codes2[mask2][0]
    else:
        c2 = None
        
    print(p, c1, c2)
    if c2 != None:
        assert c1 == c2

AHR 0 None
ARID1A 1 None
ARRDC3 2 None
ATL1 3 None
BAK1 4 4
BCL2L11 5 None
BCORL1 6 None
BPGM 7 None
CBFA2T3 8 8
CBL 9 None
CDKN1A 10 10
CDKN1B 11 11
CDKN1C 12 None
CEBPA 13 13
CEBPB 14 14
CEBPE 15 15
CELF2 16 None
CITED1 17 None
CKS1B 18 None
CLDN6 19 19
CNN1 20 20
CNNM4 21 None
COL1A1 22 None
COL2A1 23 23
CSRNP1 24 None
DLX2 25 25
DUSP9 26 None
EGR1 27 None
ELMSAN1 28 28
ETS2 29 29
FEV 30 30
FOSB 31 None
FOXA1 32 32
FOXA3 33 None
FOXF1 34 34
FOXL2 35 35
FOXO4 36 None
GLB1L2 37 None
HES7 38 None
HK2 39 None
HNF4A 40 None
HOXA13 41 None
HOXB9 42 42
HOXC13 43 43
IGDCC3 44 44
IKZF3 45 45
IRF1 46 None
ISL2 47 47
JUN 48 None
KIF18B 49 None
KIF2C 50 50
KLF1 51 51
KMT2A 52 None
LHX1 53 None
LYL1 54 None
MAML2 55 None
MAP2K3 56 None
MAP2K6 57 57
MAP4K3 58 None
MAP4K5 59 None
MAP7D1 60 60
MAPK1 61 61
MEIS1 62 62
MIDN 63 None
NCL 64 None
NIT1 65 None
OSR2 67 67
PLK4 68 None
POU3F2 69 None
PRDM1 70 None
PRTG 71 71
PTPN1 72 None
PTPN12 73 73
PTPN13 74 None
PTPN9 75 75
RHOXF2 76 None
RREB1 77 None

In [48]:
adata.X.max(), adata.X.min()

(3718.0, 0.0)

In [49]:
adata

AnnData object with n_obs × n_vars = 108619 × 5446
    obs: 'guide_id', 'read_count', 'UMI_count', 'coverage', 'gemgroup', 'good_coverage', 'number_of_cells', 'tissue_type', 'cell_type', 'cancer', 'disease', 'perturbation_type', 'celltype', 'organism', 'condition', 'nperts', 'ngenes', 'ncounts', 'percent_mito', 'percent_ribo', 'cond_harm', 'n_counts', 'n_genes_by_counts', 'log1p_n_genes_by_counts', 'total_counts', 'log1p_total_counts', 'pct_counts_in_top_50_genes', 'pct_counts_in_top_100_genes', 'pct_counts_in_top_200_genes', 'pct_counts_in_top_500_genes', 'sc_cell_ids', 'perturbation1', 'perturbation2'
    var: 'ensemble_id', 'ncounts', 'ncells', 'n_counts', 'n_cells_by_counts', 'mean_counts', 'log1p_mean_counts', 'pct_dropout_by_counts', 'total_counts', 'log1p_total_counts', 'highly_variable', 'means', 'dispersions', 'dispersions_norm'
    uns: 'log1p', 'single_perts_median_full', 'hvg', 'rank_genes_groups', 'rank_genes_groups_combo_specific', 'single_perts_median'
    layers: 'count