In [None]:
import numpy as np
import pickle
import scanpy as sc
import pandas as pd
from anndata import AnnData
from tqdm.notebook import tqdm
from torch_geometric.data import Data
import torch

In [None]:
adata = sc.read("data/sciplex_lincs.h5ad")

In [None]:
split_key = 'split'

In [None]:
train_data = adata[adata.obs[split_key] == "train"]
train_data = train_data[train_data.obs['dose'].astype(float)!=0.0]
valid_data = adata[adata.obs[split_key] == "valid"]
valid_data = valid_data[valid_data.obs['dose'].astype(float)!=0.0]
test_data = adata[adata.obs[split_key] == "test"]
test_data = test_data[test_data.obs['dose'].astype(float)!=0.0]

In [None]:
drugs_names = np.array(adata.obs['condition'].values)
drugs_names_unique = set(drugs_names)
drugs_names_unique_sorted = np.array(sorted(drugs_names_unique))

In [None]:
var_names = adata.var_names
de_genes = adata.uns['rank_genes_groups_cov']

In [None]:
def get_pert_idx(pert_category):
    try:
        pert_idx = [np.where(p == drugs_names_unique_sorted)[0][0]
                for p in pert_category.split('*')
                if p != 'control']
    except:
        print(pert_category)
        pert_idx = None
        
    return pert_idx

In [None]:
def create_cell_graph(X, y, de_idx, pert, dose, pert_idx=None):
    feature_mat = torch.Tensor(X).T
    if pert_idx is None:
        pert_idx = [-1]
    return Data(x=feature_mat, pert_idx=pert_idx,
                y=torch.Tensor(y), de_idx=de_idx, dose=dose, pert=pert)

In [None]:
def create_cell_graph_dataset(split_adata, pert_category, adata_all):
    num_de_genes = 50        
    adata_ = split_adata[split_adata.obs['condition'] == pert_category]
    cell_graphs = []
    for celltype in np.unique(adata_.obs['cell_type']):
        adata_celline = adata_[adata_.obs['cell_type'] == celltype]

        Xs = []
        ys = []
        dose_all = []
        if pert_category != 'control':
            pert_idx = get_pert_idx(pert_category)
            cell_drug_dose_comb = adata_celline[0].obs['cov_drug'].values[0]
            bool_de = var_names.isin(
                np.array(de_genes[cell_drug_dose_comb])
            )
            indices = np.where(bool_de)[0]
            if(indices.shape == 49):
                import pdb;pdb.set_trace()
            ctrl_index = adata_all.obs_names.get_indexer(adata_celline.obs['paired_control_index'].values)

            dose_all.extend(adata_celline.obs['dose'].values)
            Xs.extend(adata_all.X[ctrl_index])
            ys.extend(adata_celline.X)
            
            for X, y, dose in zip(Xs, ys, dose_all):
                cell_graphs.append(create_cell_graph(np.array(X).reshape(1, -1),
                                    np.array(y).reshape(1, -1), indices, pert_category, dose, pert_idx))

        else:
            pert_idx = None
            de_idx = [-1] * num_de_genes
            dose_all.extend(adata_.obs['dose'].values)
            Xs.extend(adata_.X)
            ys.extend(adata_.X)
            cell_graphs = []
            for X, y, dose in zip(Xs, ys, dose_all):
                cell_graphs.append(create_cell_graph(np.array(X).reshape(1, -1),
                                    np.array(y).reshape(1, -1), de_idx, pert_category, dose, pert_idx))


    return cell_graphs

In [None]:
train_dataset_processed = {}
for p in tqdm(train_data.obs['condition'].unique()):
    train_dataset_processed[p] = create_cell_graph_dataset(train_data, p, adata)
pickle.dump(train_dataset_processed, open('/data/data_pyg/sciplex_lincs/train_graph.pkl', "wb"))
print("train_dataset Done!")
val_dataset_processed = {}
for p in tqdm(valid_data.obs['condition'].unique()):
    val_dataset_processed[p] = create_cell_graph_dataset(valid_data, p, adata)
pickle.dump(val_dataset_processed, open('/data/data_pyg/sciplex_lincs/val_graph.pkl', "wb"))
print("valid_dataset Done!")
test_dataset_processed = {}
for p in tqdm(test_data.obs['condition'].unique()):
    test_dataset_processed[p] = create_cell_graph_dataset(test_data, p, adata)
pickle.dump(test_dataset_processed, open('/data/data_pyg/data_celltype/K562.pkl', "wb"))
print("test_dataset Done!")
