In [1]:
import scanpy as sc
import pickle
import numpy as np
import pandas as pd
import anndata as ad
import os
import urllib.request

## Generating adamson dataset for biolord

To generate the adamson dataset for model `biolord`, we followed its author's [jupyter notebook](https://github.com/nitzanlab/biolord_reproducibility/blob/main/notebooks/perturbations/adamson/1_perturbations_adamson_preprocessing.ipynb) of preprocessing workflows. Since they uploaded the resulting preprocessed data in [`adamson_single_biolord.h5ad`](https://figshare.com/articles/dataset/perturbseq_adamson_single/22344445) and [`adamson_biolord.h5ad`](https://figshare.com/articles/dataset/perturbseq_adamson/22344214), we can just download them using CLI.

In [2]:
# File paths
file_1_path = "adamson/adamson_single_biolord.h5ad"
file_2_path = "adamson/adamson_biolord.h5ad"

# URLs
file_1_url = "https://figshare.com/ndownloader/files/39756736"
file_2_url = "https://figshare.com/ndownloader/files/39756439"

# Check if file 1 exists, if not, download it
if not os.path.exists(file_1_path):
    print(f"{file_1_path} not found, downloading...")
    urllib.request.urlretrieve(file_1_url, file_1_path)
else:
    print(f"{file_1_path} already exists, skipping download.")

# Check if file 2 exists, if not, download it
if not os.path.exists(file_2_path):
    print(f"{file_2_path} not found, downloading...")
    urllib.request.urlretrieve(file_2_url, file_2_path)
else:
    print(f"{file_2_path} already exists, skipping download.")

adamson/adamson_single_biolord.h5ad not found, downloading...
adamson/adamson_biolord.h5ad not found, downloading...


## Generating norman dataset for biolord

Similarly, to generate the norman dataset for model `biolord`, we followed its author's [jupyter notebook](https://github.com/nitzanlab/biolord_reproducibility/blob/main/notebooks/perturbations/norman/1_perturbations_norman_preprocessing.ipynb) of preprocessing workflows. Since they uploaded the resulting preprocessed data in [`norman2019_single_biolord.h5ad`](https://figshare.com/articles/dataset/pertrubseq_norman_single/22344427) and [`norman2019_biolord.h5ad`](https://figshare.com/articles/dataset/perturbseq_nornan/22344253), we can just download them using CLI.

In [3]:
# File paths
file_3_path = "norman/norman2019_single_biolord.h5ad"
file_4_path = "norman/norman2019_biolord.h5ad"

# URLs
file_3_url = "https://figshare.com/ndownloader/files/39756733"
file_4_url = "https://figshare.com/ndownloader/files/39756463"

# Check if file 3 exists, if not, download it
if not os.path.exists(file_3_path):
    print(f"{file_3_path} not found, downloading...")
    urllib.request.urlretrieve(file_3_url, file_3_path)
else:
    print(f"{file_3_path} already exists, skipping download.")

# Check if file 4 exists, if not, download it
if not os.path.exists(file_4_path):
    print(f"{file_4_path} not found, downloading...")
    urllib.request.urlretrieve(file_4_url, file_4_path)
else:
    print(f"{file_4_path} already exists, skipping download.")

norman/norman2019_single_biolord.h5ad not found, downloading...
norman/norman2019_biolord.h5ad not found, downloading...


## Generating dixit dataset for biolord

In their work of `biolord`, the datasets `dixit` is not considered. Hence, we follow the similar preprocessing workflows and create the `dixit` dataset.

In [4]:
adata = sc.read('../Data_GEARS/dixit/perturb_processed.h5ad')
for seed in range(1,11):
    with open(f"../Data_GEARS/dixit/splits/dixit_simulation_{seed}_0.9.pkl", "rb") as f:
        split_data = pickle.load(f)
        pert2set = {}
        for i,j in split_data.items():
            for x in j:
                pert2set[x] = i
        
        not_in = np.setxor1d(list(adata.obs.condition.unique()), list(pert2set.keys()))
        if len(not_in) > 0:
            for i in not_in:
                adata = adata[adata.obs.condition != i]
                
        subgroup = pickle.load(open(f"../Data_GEARS/dixit/splits/dixit_simulation_{seed}_0.9_subgroup.pkl", "rb"))
        adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
        pert2subgroup = {}
        for i,j in subgroup["test_subgroup"].items():
            for x in j:
                pert2subgroup[x] = i
                
        adata.obs[f"subgroup{seed}"] = adata.obs["condition"].apply(lambda x: pert2subgroup[x] if x in pert2subgroup else 'Train/Val')
        rename = {
            'train': 'train',
             'test': 'ood',
             'val': 'test'
        }
        adata.obs[f'split{seed}'] = adata.obs[f'split{seed}'].apply(lambda x: rename[x])
        
adata.obs["perturbation"] = [cond.split("+")[0] for cond in adata.obs["condition"]]
adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")

go_path = '../Data_GEARS/dixit/go.csv'
gene_path = '../Data_GEARS/essential_all_data_pert_genes.pkl'
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)
df = df[df["source"].isin(gene_list)]

def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()

pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}
adata.uns["pert2neighbor"] = pert2neighbor

pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 

df_perts_expression = pd.DataFrame(adata.X.toarray(), index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(df_perts_expression.set_index("condition").loc[single_perts_condition].values, index=single_pert_val)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

adata_single = ad.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb

for split_seed in range(1,11):
    adata_single.obs[f"split{split_seed}"] = None
    adata_single.obs[f"subgroup{split_seed}"] = "Train/Val"
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories)
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat
        if cat == "ood":
            adata_single.obs.loc[cat_idx ,f"subgroup{split_seed}"] = "unseen_single"
            
adata_single.write("dixit/dixit_single_biolord.h5ad")
adata.write("dixit/dixit_biolord.h5ad")

  adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
  df_perts_expression = df_perts_expression.groupby(["condition"]).mean()


## Generating dixit dataset for Replogle K562

Similarly, we follow the the preprocessing workflows and create the `Replogle K562` dataset.

In [6]:
adata = sc.read('../Data_GEARS/replogle_k562_essential/perturb_processed.h5ad')
for seed in range(1,6):
    with open(f"../Data_GEARS/replogle_k562_essential/splits/replogle_k562_essential_simulation_{seed}_0.75.pkl", "rb") as f:
        split_data = pickle.load(f)
        pert2set = {}
        for i,j in split_data.items():
            for x in j:
                pert2set[x] = i
        
        not_in = np.setxor1d(list(adata.obs.condition.unique()), list(pert2set.keys()))
        if len(not_in) > 0:
            for i in not_in:
                adata = adata[adata.obs.condition != i]
                
        subgroup = pickle.load(open(f"../Data_GEARS/replogle_k562_essential/splits/replogle_k562_essential_simulation_{seed}_0.75_subgroup.pkl", "rb"))
        adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
        pert2subgroup = {}
        for i,j in subgroup["test_subgroup"].items():
            for x in j:
                pert2subgroup[x] = i
                
        adata.obs[f"subgroup{seed}"] = adata.obs["condition"].apply(lambda x: pert2subgroup[x] if x in pert2subgroup else 'Train/Val')
        rename = {
            'train': 'train',
             'test': 'ood',
             'val': 'test'
        }
        adata.obs[f'split{seed}'] = adata.obs[f'split{seed}'].apply(lambda x: rename[x])
        
adata.obs["perturbation"] = [cond.split("+")[0] for cond in adata.obs["condition"]]
adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")

go_path = '../Data_GEARS/go_essential_all/go_essential_all.csv'
gene_path = '../Data_GEARS/essential_all_data_pert_genes.pkl'
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)
df = df[df["source"].isin(gene_list)]

def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()

pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}
adata.uns["pert2neighbor"] = pert2neighbor

pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 

df_perts_expression = pd.DataFrame(adata.X.toarray(), index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(df_perts_expression.set_index("condition").loc[single_perts_condition].values, index=single_pert_val)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

adata_single = ad.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb

for split_seed in range(1,6):
    adata_single.obs[f"split{split_seed}"] = None
    adata_single.obs[f"subgroup{split_seed}"] = "Train/Val"
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories)
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat
        if cat == "ood":
            adata_single.obs.loc[cat_idx ,f"subgroup{split_seed}"] = "unseen_single"
            
adata_single.write("replogle_k562_essential/k562_single_biolord.h5ad")
adata.write("replogle_k562_essential/k562_biolord.h5ad")

  adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
  df_perts_expression = df_perts_expression.groupby(["condition"]).mean()


## Generating dixit dataset for Replogle K562

Similarly, we follow the the preprocessing workflows and create the `Replogle RPE1` dataset.

In [8]:
adata = sc.read('../Data_GEARS/replogle_rpe1_essential/perturb_processed.h5ad')
for seed in range(1,6):
    with open(f"../Data_GEARS/replogle_rpe1_essential/splits/replogle_rpe1_essential_simulation_{seed}_0.75.pkl", "rb") as f:
        split_data = pickle.load(f)
        pert2set = {}
        for i,j in split_data.items():
            for x in j:
                pert2set[x] = i
        
        not_in = np.setxor1d(list(adata.obs.condition.unique()), list(pert2set.keys()))
        if len(not_in) > 0:
            for i in not_in:
                adata = adata[adata.obs.condition != i]
                
        subgroup = pickle.load(open(f"../Data_GEARS/replogle_rpe1_essential/splits/replogle_rpe1_essential_simulation_{seed}_0.75_subgroup.pkl", "rb"))
        adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
        pert2subgroup = {}
        for i,j in subgroup["test_subgroup"].items():
            for x in j:
                pert2subgroup[x] = i
                
        adata.obs[f"subgroup{seed}"] = adata.obs["condition"].apply(lambda x: pert2subgroup[x] if x in pert2subgroup else 'Train/Val')
        rename = {
            'train': 'train',
             'test': 'ood',
             'val': 'test'
        }
        adata.obs[f'split{seed}'] = adata.obs[f'split{seed}'].apply(lambda x: rename[x])
        
adata.obs["perturbation"] = [cond.split("+")[0] for cond in adata.obs["condition"]]
adata.obs["perturbation"] = adata.obs["perturbation"].astype("category")

go_path = '../Data_GEARS/go_essential_all/go_essential_all.csv'
gene_path = '../Data_GEARS/essential_all_data_pert_genes.pkl'
df = pd.read_csv(go_path)
df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
with open(gene_path, 'rb') as f:
    gene_list = pickle.load(f)
df = df[df["source"].isin(gene_list)]

def get_map(pert):
    tmp = pd.DataFrame(np.zeros(len(gene_list)), index=gene_list)
    tmp.loc[df[df.target == pert].source.values, :] = df[df.target == pert].importance.values[:, np.newaxis]
    return tmp.values.flatten()

pert2neighbor =  {i: get_map(i) for i in list(adata.obs["perturbation"].cat.categories)}
adata.uns["pert2neighbor"] = pert2neighbor

pert2neighbor = np.asarray([val for val in adata.uns["pert2neighbor"].values()])
keep_idx = pert2neighbor.sum(0) > 0

name_map = dict(adata.obs[["condition", "condition_name"]].drop_duplicates().values)
ctrl = np.asarray(adata[adata.obs["condition"].isin(["ctrl"])].X.mean(0)).flatten() 

df_perts_expression = pd.DataFrame(adata.X.toarray(), index=adata.obs_names, columns=adata.var_names)
df_perts_expression["condition"] = adata.obs["condition"]
df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
df_perts_expression = df_perts_expression.reset_index()

single_perts_condition = []
single_pert_val = []
double_perts = []
for pert in adata.obs["condition"].cat.categories:
    if len(pert.split("+")) == 1:
        continue
    elif "ctrl" in pert:
        single_perts_condition.append(pert)
        p1, p2 = pert.split("+")
        if p2 == "ctrl":
            single_pert_val.append(p1)
        else:
            single_pert_val.append(p2)
single_perts_condition.append("ctrl")
single_pert_val.append("ctrl")

df_singleperts_expression = pd.DataFrame(df_perts_expression.set_index("condition").loc[single_perts_condition].values, index=single_pert_val)
df_singleperts_emb = np.asarray([adata.uns["pert2neighbor"][p1][keep_idx] for p1 in df_singleperts_expression.index])

df_singleperts_condition = pd.Index(single_perts_condition)
df_single_pert_val = pd.Index(single_pert_val)

adata_single = ad.AnnData(X=df_singleperts_expression.values, var=adata.var.copy(), dtype=df_singleperts_expression.values.dtype)
adata_single.obs_names = df_singleperts_condition
adata_single.obs["condition"] = df_singleperts_condition
adata_single.obs["perts_name"] = df_single_pert_val
adata_single.obsm["perturbation_neighbors"] = df_singleperts_emb

for split_seed in range(1,6):
    adata_single.obs[f"split{split_seed}"] = None
    adata_single.obs[f"subgroup{split_seed}"] = "Train/Val"
    for cat in ["train","test","ood"]:
        cat_idx = adata_single.obs["condition"].isin(adata[adata.obs[f"split{split_seed}"] == cat].obs["condition"].cat.categories)
        adata_single.obs.loc[cat_idx ,f"split{split_seed}"] = cat
        if cat == "ood":
            adata_single.obs.loc[cat_idx ,f"subgroup{split_seed}"] = "unseen_single"
            
adata_single.write("replogle_rpe1_essential/rpe1_single_biolord.h5ad")
adata.write("replogle_rpe1_essential/rpe1_biolord.h5ad")

  adata.obs[f"split{seed}"] = [pert2set[i] for i in adata.obs["condition"].values]
  df = df.groupby('target').apply(lambda x: x.nlargest(20 + 1, ['importance'])).reset_index(drop = True)
  df_perts_expression = df_perts_expression.groupby(["condition"]).mean()
