In [1]:
import scanpy as sc
import numpy as np
import random

In [2]:
random.seed(42)

# Scenario 1

# Read data

In [3]:
adata = sc.read_h5ad(
    '../../preprocessed_datasets/norman.h5ad'
)

all_perts = np.unique(adata.obs['condition']).tolist()

# Read double perts

In [4]:
with open(
    '../../preprocessed_datasets/norman_double_perts.txt',
    'r'
    ) as f:
        double_perts = f.readlines()

double_perts = [db.replace('\n', '') for db in double_perts]
n_perts = len(double_perts)
print('Number of double perts', n_perts)

Number of double perts 128


# Create train_val_ood splits

In [5]:
split_dict_1 = {}

for i_e, ood_index in enumerate(range(0, n_perts, 10)):
    ood_labels = double_perts[ood_index: ood_index+10]

    train = [x for x in all_perts if not x in ood_labels]
    double_perts_train = [x for x in double_perts if not x in ood_labels]
    val_size = int(0.1 * len(double_perts_train))

    val = random.sample(double_perts_train, val_size)
    train = [x for x in train if not x in val]

    split_dict_1[i_e] = {
        'train': train,
        'val': val,
        'ood': ood_labels
    }

    print(i_e, len(ood_labels))

0 10
1 10
2 10
3 10
4 10
5 10
6 10
7 10
8 10
9 10
10 10
11 10
12 8


# Save

In [6]:
import json

In [7]:
with open('../../preprocessed_datasets/norman_splits_1.json', 'w') as f:
    json.dump(split_dict_1, f)

In [8]:
# Sanity check that all of them are doubly

# Scenario 2

In [9]:
conditions = np.unique(adata.obs['condition']).tolist()

conds = [
    'CBL+CNN1', 'DUSP9+ETS2', 'DUSP9+MAPK1', 
    'ETS2+MAPK1', 'ETS2+CEBPE', 'CNN1+MAPK1', 
    'CEBPB+CEBPA', 'CBL+PTPN12', 'AHR+FEV', 
]

In [10]:
split_dict_2 = {}

for i_e, cond in enumerate(conds):
    ood1, ood2 = cond.split('+')
    comb1 = [x for x in conditions if f'{ood1}+' in x or f'+{ood1}' in x]
    comb2 = [x for x in conditions if f'{ood2}+' in x or f'+{ood2}' in x]
    ood_labels = comb1 + comb2

    train = [x for x in all_perts if not x in ood_labels]
    double_perts_train = [x for x in double_perts if not x in ood_labels]
    val_size = int(0.1 * len(double_perts_train))

    val = random.sample(double_perts_train, val_size)
    train = [x for x in train if not x in val]

    split_dict_2[i_e] = {
        'train': train,
        'val': val,
        'ood': ood_labels
    }

    print(i_e, len(ood_labels))

0 12
1 15
2 14
3 17
4 20
5 14
6 13
7 18
8 6


# Save

In [11]:
with open('../../preprocessed_datasets/norman_splits_2.json', 'w') as f:
    json.dump(split_dict_2, f)