In [7]:
import scanpy as sc
import random
import os

from scipy.sparse import csr_matrix

## Pull CPA version of Norman19
From this tutorial notebook: https://cpa-tools.readthedocs.io/en/latest/tutorials/Norman.html

In [None]:
data_cache_dir = 'perturbench_data/' ## Set this to your desired data cache directory
data_path = data_cache_dir + 'norman19_cpa_hvg_normalized.h5ad'

try:
    adata = sc.read(data_path)

except FileNotFoundError:
    import gdown
    gdown.download(
        'https://drive.google.com/uc?export=download&id=109G9MmL-8-uh7OSjnENeZ5vFbo62kI7j',
        output=data_path
    )
    adata = sc.read(data_path)

adata

## Generate CPA Norman19 split

In [3]:
adata.obs['condition'] = adata.obs['cond_harm'].copy().replace({'ctrl': 'control'})
adata.obs['condition'].value_counts()

  adata.obs['condition'] = adata.obs['cond_harm'].copy().replace({'ctrl': 'control'})


condition
control          11855
KLF1              1960
BAK1              1457
CEBPE             1233
CEBPE+RUNX1T1     1219
                 ...  
FOSB+CEBPB          71
CBL+UBASH3A         64
CEBPB+CEBPA         64
JUN+CEBPB           59
JUN+CEBPA           54
Name: count, Length: 235, dtype: int64

In [4]:
adata.obs['cell_type'] = adata.obs['cell_line'].copy().replace({
    'K562': 'k562'
})
adata.obs['cell_type'].value_counts()

  adata.obs['cell_type'] = adata.obs['cell_line'].copy().replace({


cell_type
k562    111122
Name: count, dtype: int64

In [5]:
split_cols = [x for x in adata.obs.columns if 'split' in x]
split_cols

['split_hardest',
 'split_1',
 'split_2',
 'split_3',
 'split_4',
 'split_5',
 'split_6']

In [8]:
out_dir = f'{data_cache_dir}/norman19_cpa_hvg_normalized_splits/'

if not os.path.exists(out_dir):
    os.makedirs(out_dir)

In [None]:
for split_col in split_cols:
    split = adata.obs[split_col].replace({
        'valid': 'val',
        'ood': 'test',
    }).copy()
    
    unique_train_perts = set()
    for perts in adata[split == 'train'].obs['condition'].unique():
        for pert in perts.split('+'):
            unique_train_perts.add(pert)
    
    perts_remove = []
    for perts in adata[split == 'val'].obs['condition'].unique():
        for pert in perts.split('+'):
            if pert not in unique_train_perts:
                perts_remove.append(perts)
    
    for perts in adata[split == 'test'].obs['condition'].unique():
        for pert in perts.split('+'):
            if pert not in unique_train_perts:
                perts_remove.append(perts)
    
    print(split_col, len(perts_remove))
    perts_remove_idx = adata[adata.obs.condition.isin(perts_remove)].obs_names
    split.loc[perts_remove_idx] = None
    
    n_train = adata[split == 'train'].obs.condition.nunique()
    n_val = adata[split == 'val'].obs.condition.nunique()
    n_test = adata[split == 'test'].obs.condition.nunique()
    print(f'{split_col}: {n_train} train, {n_val} val, {n_test} test')
        
    train_control_cells = adata[
        (adata.obs[split_col] == 'train') &
        (adata.obs.condition == 'control')
    ].obs.index.tolist()
    test_control_cells = random.sample(train_control_cells, 1500)
    split.loc[test_control_cells] = 'test'
    
    out_path = f'{out_dir}/{split_col}.csv'
    split.to_csv(out_path, header=False)
    print(f'Saved to {out_path}')
    
    

In [10]:
adata.X = csr_matrix(adata.X)
adata.write_h5ad(f'{data_cache_dir}/norman19_cpa_hvg_normalized_curated.h5ad')