### NOTE: must set up GEARS via this repo to run this notebook: https://github.com/yhr91/GEARS_misc

In [15]:
import argparse
import sys
sys.path.append('../../GEARS_misc/')
import pickle

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt 
import seaborn as sns
from scipy.stats import pearsonr

from gears import PertData, GEARS

results_folder = './results/'
dataset = 'parpi'
data_path = './'
model = 'gears'
device = 'cuda'
seed = 1

%load_ext autoreload
%autoreload 2

def load_data(seed):
    pert_data = PertData(data_path, gi_go=True) # specific saved folder
    # pert_data = PertData(data_path, gene_path='../../data/gears_reproduce/essential_norman.pkl')
    pert_data.load(data_path= data_path+dataset) # load the processed data, the path is saved folder + dataset_name
    # pert_data.load(data_name = 'norman') # load the processed data, the path is saved folder + dataset_name
    pert_data.prepare_split(split = 'simulation', seed = seed)
    pert_data.get_dataloader(batch_size = 32, test_batch_size = 32)
    
    adata_df = pert_data.adata.to_df()
    adata_df['condition'] = pert_data.adata.obs['condition']
    mean_df = adata_df.groupby('condition').mean()
    ctrl_mean = mean_df.loc['ctrl']
    
    return pert_data, ctrl_mean, mean_df

def load_model(pert_data):
    gears_model = GEARS(pert_data, device = 'cuda', 
                weight_bias_track = False, 
                proj_name = 'gears', 
                exp_name = 'gears',
                gi_predict = True)
    gears_model.model_initialize(hidden_size = 64)
    
    return gears_model

    

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [16]:
pert_data, ctrl_mean, mean_df = load_data(seed)
gears_model = load_model(pert_data)

Found local copy...


These perturbations are not in the GO graph and is thus not able to make prediction for...
['AARS2+ABCB1' 'AARS2+ADPRM' 'AARS2+AGO4' ... 'ZRANB3+non-targeting'
 'ctrl+non-targeting' 'non-targeting+non-targeting']
Local copy of pyg dataset is detected. Loading...
Done!
Local copy of split is detected. Loading...
Simulation split test composition:
combo_seen0:7501
combo_seen1:44401
combo_seen2:16601
unseen_single:122
Done!
Creating dataloaders....
Done!


In [17]:
### turn the model on 
model = gears_model.model
model.cell_fitness_pred = True

In [18]:
adata = gears_model.adata
ctrl_adata = adata[adata.obs['condition'] == 'ctrl']

In [19]:
len(gears_model.pert_list)

9853

In [20]:
all_pert_list_options = gears_model.pert_list

In [21]:
import pickle
with open('../../data/parpi/GI_gears_data_parpi.pkl', 'rb') as f:
    fitness_data = pickle.load(f)

In [22]:
fitness_data.keys()

dict_keys(['good_phen', 'fitness_mapper_gene', 'good_genes'])

In [23]:
unique_pert_genes = fitness_data['good_phen'].index.values

In [24]:
from itertools import combinations
pert_list = [i[0] + '+' + i[1] for i in list(combinations(fitness_data['good_phen'].index.values, 2))]

In [25]:
def convert(i):
    if '_' in i and i.split('_')[0] == i.split('_')[1]:
        return i.split('_')[0]
    #elif 'neg' in i:
    #    print(i)
    else:
        return i

In [26]:
import pandas as pd
import numpy as np
import pickle

from sklearn.preprocessing import PolynomialFeatures
from sklearn.linear_model import LinearRegression
from sklearn.pipeline import make_pipeline

class GI_expt():
    # TODO get rid of hold out genes
    def __init__(self, raw=False, pick_genes = None, set_train = None):
        # sample: matrix sampling percentage
        # seed: sampling seed,
        # sampled_data: sampled indices from upper traingle matrix
        # test: matrix values to be predicted
        # train: observed matrix values used for training, 
        # delta: deviation from expectation as inferred using 
        #        a regression model trained on sampled_data (corresponds to sampled_data)
        # delta_tot: delta values computed for all matrix values
        # y_tot: Full true matrix (raw fitness values)
        # raws: Sampled raw fitness values (corresponds to sampled_data)
        # transformer: regression model used to determine expected fitness values
        

        # with open('/dfs/project/perturb-gnn/datasets/Horlbeck2018/GI_data_horlbeck_jurkat.pkl', 'rb') as f:
        # with open('../../data/horlbeck/jurkat/GI_data_horlbeck_jurkat.pkl', 'rb') as f:
        with open('../../data/parpi/GI_gears_data_parpi.pkl', 'rb') as f:
            self.data = pickle.load(f)
            
        self.data['fitness_mapper_gene'] = {convert(i): j for i,j in self.data['fitness_mapper_gene'].items()}
        self.raw = raw
        self.pick_genes = pick_genes
        self.set_train = set_train

    def pre_process(self, sample, seed):
        itr_data = {}
        itr_data['sample'] = sample
        itr_data['seed'] = seed

        # Randomly sample from the fitness matrix
        if self.set_train is not None:
            itr_data['sampled_data'] = self.set_train
        elif self.pick_genes is None:
            itr_data['sampled_data'] = \
               upper_triangle(self.data['good_phen'], k=0).sample(
               frac=float(sample) / 100., replace=False, random_state=seed).index
        else:
            itr_data['sampled_data'] = \
                upper_triangle(self.data['good_phen'].iloc[self.pick_genes,
                                                    self.pick_genes]).index

        # Set up masked delta matrix for performing matrix completion
        itr_data['delta'], itr_data['test'], \
        itr_data['train'], itr_data['delta_tot'], \
        itr_data['X_tot'], itr_data['y_tot'], itr_data['transformer'], \
        itr_data['raws'] = \
            get_masked_delta_matrix(self.data['good_phen'],
                                    itr_data['sampled_data'],
                                    self.data['fitness_mapper_gene'],
                                    self.data['good_genes'])

        return itr_data


def get_masked_delta_matrix(data, sampling, phen_mapper, good_genes):
    masked_data, mask = get_masked_data(data, sampling)

    upper_masked_data = upper_triangle(masked_data)
    predicted = upper_masked_data[upper_masked_data == 0].index
    given = upper_masked_data[upper_masked_data != 0].index

    delta_tot, X, X_tot, y, y_tot, transformer, raws = get_deltas(data, masked_data, phen_mapper, good_genes)
    delta, _ = get_masked_data(delta_tot, sampling)
    raws, _ = get_masked_data(raws, sampling)

    return delta, predicted, given, delta_tot, X_tot, y_tot, transformer, raws

def get_deltas(source_data, sampled_data, phen_mapper, good_genes):
    # this is the routine that fits a quadratic model to observed fitness measurements and returns
    # the deviations from the expectation given by that model (the "deltas")
    y = sampled_data.stack()
    y = y[y != 0]

    X = np.concatenate([y.index.get_level_values(0).map(lambda x: phen_mapper[x]).values[:, np.newaxis],
                        y.index.get_level_values(1).map(lambda x: phen_mapper[x]).values[:, np.newaxis]], axis=1)

    y_tot = source_data.loc[good_genes, good_genes].stack()
    X_tot = np.concatenate([y_tot.index.get_level_values(0).map(lambda x: phen_mapper[x]).values[:, np.newaxis],
                            y_tot.index.get_level_values(1).map(lambda x: phen_mapper[x]).values[:, np.newaxis]],
                           axis=1)

    transformer = PolynomialFeatures()
    model = make_pipeline(PolynomialFeatures(), LinearRegression())
    model.fit(X, y)

    delta = pd.Series(y_tot - model.predict(X_tot), index=y_tot.index)
    raws = pd.Series(y_tot, index=y_tot.index).unstack()
    delta = delta.unstack()

    return delta, X, X_tot, y, y_tot, model, raws

def upper_triangle(M, k=1):
    """ Copyright (C) 2019  Thomas Norman
    Return the upper triangular part of a matrix in stacked format (i.e. as a vector)
    """
    keep = np.triu(np.ones(M.shape), k=k).astype('bool').reshape(M.size)
    return M.stack(dropna=False).loc[keep]

def get_masked_data(df, ind, mean_normalize=False):
    masked_data = df.copy().values
    mask = pd.DataFrame(0, index=df.index, columns=df.columns)

    for gene1, gene2 in ind:
        mask.loc[gene1, gene2] = 1

        # assume DataFrame is symmetric
    mask = mask + mask.T
    mask = (mask != 0).values
    masked_data[~mask] = 0
    masked_data_df = pd.DataFrame(masked_data, index=df.index, columns=df.columns)

    # whether to center the observed entries such that the overall mean is 0
    if mean_normalize:
        masked_data_df = masked_data_df.stack()
        offset = masked_data_df[masked_data_df != 0]
        offset = offset.mean()
        masked_data_df[masked_data_df != 0] = masked_data_df[masked_data_df != 0] - offset
        print(offset)
        masked_data_df = masked_data_df.unstack()

    return masked_data_df, mask

In [27]:
pert_data.adata.obs

Unnamed: 0,condition,dose_val,control,condition_name
0,AARS2+AARS2,1+1,1,K562_AARS2+AARS2_1+1
1,AARS2+AATF,1+1,1,K562_AARS2+AATF_1+1
3,AARS2+ABL1,1+1,1,K562_AARS2+ABL1_1+1
7,AARS2+ALDOA,1+1,1,K562_AARS2+ALDOA_1+1
8,AARS2+ALG1,1+1,1,K562_AARS2+ALG1_1+1
...,...,...,...,...
1482105,ctrl,1+1,1,K562_ctrl_1+1
1482106,ctrl,1+1,1,K562_ctrl_1+1
1482107,ctrl,1+1,1,K562_ctrl_1+1
1482108,ctrl,1+1,1,K562_ctrl_1+1


In [28]:
fitness_data['fitness_mapper_gene']

{'AARS2_AARS2': -0.21110421783207559,
 'AARS2_AATF': -0.19928091580549664,
 'AARS2_ABCB1': -0.16457389164352698,
 'AARS2_ABL1': -0.1728702643023251,
 'AARS2_ADPRM': -0.2157649750682031,
 'AARS2_AGO4': -0.1627904080603406,
 'AARS2_AKT3': -0.15613644766089865,
 'AARS2_ALDOA': -0.2145913171668197,
 'AARS2_ALG1': -0.16394647804314388,
 'AARS2_ANAPC2': -0.17062152880479992,
 'AARS2_ARID1A': -0.1307264842487982,
 'AARS2_ASCC3': -0.2715723757679105,
 'AARS2_ATM': -0.18045198983752875,
 'AARS2_ATMIN': -0.15580137515579637,
 'AARS2_ATP5J2': -0.18731769433154413,
 'AARS2_ATP6V0C': -0.13793162861618397,
 'AARS2_ATP6V1A': -0.2019625842612026,
 'AARS2_ATR': -0.2456256377087611,
 'AARS2_ATRIP': -0.17353456892355446,
 'AARS2_ATXN10': -0.324133404589846,
 'AARS2_AUNIP': -0.14997010186888565,
 'AARS2_AURKA': -0.19192602248204443,
 'AARS2_BABAM1': -0.16960604328682916,
 'AARS2_BAP1': -0.19862459790163253,
 'AARS2_BARD1': -0.26663343060967093,
 'AARS2_BIN1': -0.16043326971789168,
 'AARS2_BLM': -0.2017621

In [15]:
# Run prediction
from torch_geometric.data import DataLoader

import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from torch.optim.lr_scheduler import StepLR
from copy import deepcopy

def evaluate(model, loader):
    model.eval()
    y_all = []
    loss_all = []
    pert_all = []
    y_true_all = []
    for step, batch in enumerate(tqdm(loader)):
        batch = batch.to(device)
        out, y_pred = model(batch)
        loss = loss_fct(batch.y.reshape(-1), y_pred.reshape(-1))
        loss_all.append(loss.item())
        y_all.append(y_pred.detach().cpu().numpy())
        pert_all.append(batch.pert)
        y_true_all.append(batch.y)
    return y_all, loss_all, pert_all, y_true_all

def get_dataloader(set2cond, batch_size, test_batch_size = None):
    if test_batch_size is None:
        test_batch_size = batch_size

    cell_graphs = {}
    

    splits = ['train','val','test']
    for i in splits:
        cell_graphs[i] = []
        for p in set2cond[i]:
            if p in all_cell_graphs:
                cell_graphs[i].extend(all_cell_graphs[p])

    # Set up dataloaders
    train_loader = DataLoader(cell_graphs['train'],
                        batch_size=batch_size, shuffle=True, drop_last = True)
    val_loader = DataLoader(cell_graphs['val'],
                        batch_size=batch_size, shuffle=True)

    test_loader = DataLoader(cell_graphs['test'],
                    batch_size=batch_size, shuffle=False)
    return {'train_loader': train_loader,
                        'val_loader': val_loader,
                        'test_loader': test_loader}

n_train = 148240
n_samples = [100, 500, 1000, 2000, 3000, 4000, 5000, 6000]
# for seed in [123, 847, 618, 748, 808]:
for seed in [618, 748, 808]:
    for n_sample in n_samples:
        # for label in ['delta_tot', 'raw_y']:
        for label in ['raw_y']:
            sample_frac = float(n_sample) / float(n_train) * 100

            GI = GI_expt(pick_genes=None)
            itr_data = GI.pre_process(sample=sample_frac, seed=seed)

            itr_data['raw_y'] = fitness_data['good_phen']
            y = itr_data[label].values

            gene2idx = dict(zip(itr_data[label].index.values, range(len(itr_data[label].index.values))))

            pert_list = [i[0] + '+' + i[1] for i in list(combinations(itr_data[label].index.values, 2))]
            pert_list += [i + '+ctrl' for i in unique_pert_genes]

            y_list = []
            for i in pert_list:
                if i.split('+')[1]!= 'ctrl':
                    y_list.append(y[gene2idx[i.split('+')[0]], gene2idx[i.split('+')[1]]])
                else:
                    y_list.append(y[gene2idx[i.split('+')[0]], gene2idx[i.split('+')[0]]])

            pert_list2y_list = dict(zip(pert_list, y_list))

            train = [i[0] + '+' + i[1] for i in itr_data['train']]
            test = [i[0] + '+' + i[1] for i in itr_data['test']]

            from sklearn.model_selection import train_test_split
            train, val, _, _ = train_test_split(train, [-1] * len(train), test_size=0.1, random_state=42)

            ## add single back
            train = train + [i + '+ctrl' for i in unique_pert_genes]

            set2cond = {
                'train': train,
                'val': val,
                'test': test
            }

            from torch_geometric.data import Data
            import torch
            from tqdm import tqdm
            num_samples = 20
            all_cell_graphs = {}
            pert_na = []
            for pert in tqdm(pert_list):
                pert_name = pert
                pert = pert.split('+')
                pert = [i for i in pert if i!='ctrl']

                # Get the indices (and signs) of applied perturbation
                try:
                    pert_idx = [np.where(p == np.array(all_pert_list_options))[0][0] for p in pert]
                    Xs = ctrl_adata[np.random.randint(0, len(ctrl_adata), num_samples), :].X.toarray()
                    # Create cell graphs
                    all_cell_graphs[pert_name] = [Data(x=torch.Tensor(X).T, pert_idx = pert_idx, pert=pert, y=pert_list2y_list[pert_name]) for X in Xs]
                except:
                    pert_na.append(pert_name)

            loaders = get_dataloader(set2cond, batch_size = 32)
            train_loader = loaders['train_loader']
            val_loader = loaders['val_loader']
            test_loader = loaders['test_loader']

            optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay = 5e-4)
            scheduler = StepLR(optimizer, step_size=1, gamma=0.5)

            loss_fct = F.mse_loss
            earlystop_validation_metric = 'mse'
            binary_output = False

            earlystop_direction = 'descend'
            min_val = np.inf

            best_model = deepcopy(model).to(device)

            loss_history = {
                'loss': []
            }

            print('Start Training...')
            for epoch in range(1):
                model.train()

                for step, batch in enumerate(tqdm(train_loader)):
                    optimizer.zero_grad()
                    batch = batch.to(device)
                    out, y_pred = model(batch)
                    loss = loss_fct(batch.y.float().reshape(-1), y_pred.reshape(-1))
                    loss.backward()
                    optimizer.step()
                    loss_history['loss'].append(loss.item())

                    if (step % 200 == 0) and (step >= 200):
                        log = "Epoch {} Step {} Train Loss: {:.4f}" 
                        print(log.format(epoch + 1, step + 1, loss.item()))
                #_, val_loss
                # val_y_all, val_loss_all, val_pert_all, val_y_true_all = evaluate(model, val_loader)
                # if np.mean(val_loss_all) < min_val:
                #     best_model = deepcopy(model)
                #     min_val = np.mean(val_loss_all)
                best_model = deepcopy(model)
            test_y_all, test_loss_all, test_pert_all, test_y_true_all = evaluate(best_model, test_loader)
            print(np.mean(test_loss_all))


            pert_test_list = [i[0] + '+' +  i[1] if len(i) == 2 else i[0] +'+ctrl' for i in [j for i in test_pert_all for j in i]]
            test_y_list = np.concatenate([i for i in [j for i in test_y_all for j in i]])
            test_y_true_list = np.concatenate([i.detach().cpu().numpy() for i in test_y_true_all])
            df_pred = pd.DataFrame((pert_test_list, test_y_true_list, test_y_list)).T
            pert2y_true = dict(df_pred.groupby([0])[1].agg(np.mean))
            pert2y_pred = dict(df_pred.groupby([0])[2].agg(np.mean))

            pred = pd.DataFrame(pert2y_true, index = [0]).T.reset_index().rename(columns = {0: 'truth', 'index': 'test_pert'})
            pred['pred'] = pred.test_pert.apply(lambda x: pert2y_pred[x])
            pred.to_csv(results_folder + label + '_pred_no_pretrain_seed' + str(seed) + '_nsample' + str(n_sample) + 'parpi_gi.csv', index = False)

100%|██████████| 148240/148240 [16:08<00:00, 153.09it/s] 


Start Training...


  4%|▍         | 206/4906 [00:06<01:46, 44.28it/s]

Epoch 1 Step 201 Train Loss: 0.5348


  8%|▊         | 406/4906 [00:10<01:41, 44.19it/s]

Epoch 1 Step 401 Train Loss: 0.3379


 12%|█▏        | 606/4906 [00:15<01:35, 45.15it/s]

Epoch 1 Step 601 Train Loss: 0.1730


 16%|█▋        | 806/4906 [00:19<01:31, 44.75it/s]

Epoch 1 Step 801 Train Loss: 0.0821


 21%|██        | 1006/4906 [00:24<01:26, 45.02it/s]

Epoch 1 Step 1001 Train Loss: 0.0401


 25%|██▍       | 1206/4906 [00:28<01:22, 44.88it/s]

Epoch 1 Step 1201 Train Loss: 0.0159


 29%|██▊       | 1406/4906 [00:32<01:18, 44.69it/s]

Epoch 1 Step 1401 Train Loss: 0.0105


 33%|███▎      | 1606/4906 [00:37<01:14, 44.58it/s]

Epoch 1 Step 1601 Train Loss: 0.0028


 37%|███▋      | 1806/4906 [00:42<01:10, 44.19it/s]

Epoch 1 Step 1801 Train Loss: 0.0033


 41%|████      | 2006/4906 [00:46<01:06, 43.47it/s]

Epoch 1 Step 2001 Train Loss: 0.0048


 45%|████▍     | 2206/4906 [00:51<01:02, 43.49it/s]

Epoch 1 Step 2201 Train Loss: 0.0017


 49%|████▉     | 2406/4906 [00:55<00:57, 43.67it/s]

Epoch 1 Step 2401 Train Loss: 0.0018


 53%|█████▎    | 2606/4906 [01:00<00:52, 44.15it/s]

Epoch 1 Step 2601 Train Loss: 0.0017


 57%|█████▋    | 2806/4906 [01:04<00:47, 44.28it/s]

Epoch 1 Step 2801 Train Loss: 0.0016


 61%|██████▏   | 3006/4906 [01:09<00:42, 44.43it/s]

Epoch 1 Step 3001 Train Loss: 0.0020


 65%|██████▌   | 3206/4906 [01:13<00:38, 44.16it/s]

Epoch 1 Step 3201 Train Loss: 0.0021


 69%|██████▉   | 3406/4906 [01:18<00:33, 44.64it/s]

Epoch 1 Step 3401 Train Loss: 0.0025


 74%|███████▎  | 3606/4906 [01:22<00:29, 44.01it/s]

Epoch 1 Step 3601 Train Loss: 0.0012


 78%|███████▊  | 3806/4906 [01:27<00:25, 43.75it/s]

Epoch 1 Step 3801 Train Loss: 0.0027


 82%|████████▏ | 4006/4906 [01:32<00:20, 43.78it/s]

Epoch 1 Step 4001 Train Loss: 0.0025


 86%|████████▌ | 4206/4906 [01:36<00:16, 43.37it/s]

Epoch 1 Step 4201 Train Loss: 0.0013


 90%|████████▉ | 4406/4906 [01:41<00:11, 43.57it/s]

Epoch 1 Step 4401 Train Loss: 0.0019


 94%|█████████▍| 4606/4906 [01:45<00:06, 43.88it/s]

Epoch 1 Step 4601 Train Loss: 0.0018


 98%|█████████▊| 4806/4906 [01:50<00:02, 43.60it/s]

Epoch 1 Step 4801 Train Loss: 0.0016


100%|██████████| 4906/4906 [01:52<00:00, 43.56it/s]
100%|██████████| 364525/364525 [1:06:03<00:00, 91.97it/s]


0.0015829641314039332


 89%|████████▉ | 131804/148240 [13:59<01:44, 157.59it/s]  

In [16]:
r_2 = pearsonr(pred['truth'], pred['pred'])[0]**2
r_2

0.7759029964347355