In [1]:
import torch
import torch.nn as nn
import numpy as np
import datetime
import pytz
import pickle
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd

torch.set_default_dtype(torch.float64)

In [2]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.generator_stack = nn.Sequential(
            nn.Linear(input_dim, 10240),
            nn.LeakyReLU(),
            nn.Linear(10240, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 10240),
            nn.BatchNorm1d(10240, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(10240, output_dim),
            nn.LeakyReLU(),
        )

    def forward(self, x, cond):
        a = torch.cat((x, cond), axis=-1)
        #print(a.size()[0])
        return self.generator_stack(a)

In [3]:
GENE_EMBED_DIM = 150
GENE_EXPRESSION_VEC = 15077
BATCH_SIZE = 1

input_dim = GENE_EXPRESSION_VEC+GENE_EMBED_DIM
output_dim = GENE_EXPRESSION_VEC

In [4]:
#def get_data(train_loader):
   
    #for index, batch_inputs in enumerate(train_loader):
     #   KO_gene, batch_X, batch_y = batch_inputs
        
    #return KO_gene, batch_X, batch_y


#train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
#KO_gene, batch_X, batch_y = get_data(train_loader)


Generate data:
- For every gene in the 15077 list:
     - Get the gene embedding
     - choose 1000 random unperturbed cells
- Feed into the generator (unperturbed cells, KO_gene embedding)
- Return 1000 transcripts of perturbed cells
- Run through the classifier

In [7]:
def dataloader_mini(KO_gene, gene_emb, gene_emb_label, unperturbed_expr_df, n_cells):
   
    #get embedding of KO gene
    #emb_idx is the index of the gene embeddings for that particular KO gene
    emb_idx = gene_emb_label[KO_gene.upper()]  #an index
    #Embeddings of the KO gene:
    KO_gene_emb = np.array(gene_emb.iloc[emb_idx]) #an embedding
    
    #select n random unperturbed cells
    
    rand_index = random.sample(range(0, len(unperturbed_expr_df) - 1), n_cells)
    unpert_input = np.array(unperturbed_expr_df.iloc[rand_index])
    
    return KO_gene, unpert_input, KO_gene_emb

In [8]:
def generate_data(unpert_input, KO_gene_emb, model):
    model.eval()
    z = unpert_input
    cond = KO_gene_emb 
    generated_data = model(z, cond)
    return generated_data

In [9]:
def classify_state_SVM(generated_data):
    
    #get the gene names for the 15077 genes
    columns = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 0).columns
    #get gene names for what the SVM expects
    SVM_filt = pd.read_csv('./data/unperturbed_filtered.csv', nrows = 0).columns
    #turn the generated transcripts into a dataframe for filtering in the next step
    cell = pd.DataFrame(generated_data.detach().numpy(), columns = columns)
    #grab just the genes that the SVM uses to predict t cell state
    SVM_input = cell[SVM_filt] 
    
    #load the svm
    loaded_model = pickle.load(open('../saved_models/svc_model_unperturbed.sav', 'rb'))
    preds = loaded_model.predict(SVM_input)
    
    return preds

In [10]:
gene_list = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 0).columns.values
gene_emb = pd.read_csv('./data/gene_embeddings.csv')
gene_emb_label = np.load('./gene_embedding/entity2labeldict_TransE_nepochs_128.npy', allow_pickle = True).item()
unperturbed_expr_df = pd.read_csv('./data/unperturbed_gene_expression.csv')

In [11]:
n_cells = 100

In [21]:
KO_gene_list = []
KO_gene_list_found = []
pred_list = []

model = torch.load("./saved_models/Model_G_300.pth") 

for gene in gene_list[80:88]:
    
    #load the data and KO gene
    print(gene)
    pred = None
    KO_gene_list.append(gene)
    try:
        KO_gene, unpert_input, KO_gene_emb = dataloader_mini(KO_gene = gene, gene_emb = gene_emb, gene_emb_label = gene_emb_label, unperturbed_expr_df = unperturbed_expr_df, n_cells = n_cells)
        KO_gene_list_found.append(gene)
        #unpert_input_list.append(unpert_input)
        #KO_gene_emb_list.append(KO_gene_emb)
    
        #generate data using GAN
        #batch size comes first
        KO_gene_tensor = torch.tensor(np.tile(KO_gene_emb, (n_cells, 1)))
        #print(unpert_input.shape)
        generated_data = generate_data(torch.tensor(unpert_input), KO_gene_tensor, model)
    
        #run the svm classifier
        pred = classify_state_SVM(generated_data)
    except:
        continue
    pred_list.append(pred)
    
    

D430040D24Rik
Fam178b
Cox5b
Actr1b
4933424G06Rik
Gm33533
Zap70
Tmem131


In [24]:
KO_gene_list #TODO: how to combine genes not found and genes found in embeddings.

['D430040D24Rik',
 'Fam178b',
 'Cox5b',
 'Actr1b',
 '4933424G06Rik',
 'Gm33533',
 'Zap70',
 'Tmem131']

In [22]:
pd.DataFrame(pred_list, index = KO_gene_list_found)
#TODO: get cell state proportions (sum(t_cell_state)/n_cells)
#TODO: produce csv in correct format

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,90,91,92,93,94,95,96,97,98,99
Fam178b,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector,...,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector
Actr1b,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector,...,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector
Zap70,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector,...,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector
Tmem131,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector,...,effector,effector,effector,effector,effector,effector,effector,effector,effector,effector


### SCRATCH WORK: Prepare generated data for input into SVM
Note that the SVM was trained on filtered data (low variance genes were filtered out). Thus, the input to the SVM needs to have the same features as what it was trained on.

In [17]:
#get the gene names for the 15077 genes
columns = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 10).columns

#get gene names for what the SVM expects
SVM_filt = pd.read_csv('./data/unperturbed_filtered.csv').columns

#turn the generated transcripts into a dataframe for filtering in the next step
cell = pd.DataFrame(generated_data.detach().numpy(), columns = columns)

#grab just the genes that the SVM uses to predict t cell state
SVM_input = cell[SVM_filt] 

In [18]:
loaded_model = pickle.load(open('../saved_models/svc_model_unperturbed.sav', 'rb'))
preds = loaded_model.predict(SVM_input)
#print('Input is a(n) ' + str(preds[0]) + ' cell')

In [19]:
preds

array(['effector', 'effector', 'effector', 'effector', 'effector',
       'effector', 'effector', 'effector', 'effector', 'effector'],
      dtype=object)