In [1]:
import torch
import torch.nn as nn
import numpy as np
import datetime
import pytz
import pickle
import random

from torch.utils.data import Dataset
from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd

torch.set_default_dtype(torch.float64)

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.generator_stack = nn.Sequential(
            nn.Linear(input_dim, 10240),
            nn.LeakyReLU(),
            nn.Linear(10240, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 10240),
            nn.BatchNorm1d(10240, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(10240, output_dim),
            nn.LeakyReLU(),
        )

    def forward(self, x, cond):
        a = torch.cat((x, cond), axis=-1)
        #print(a.size()[0])
        return self.generator_stack(a)

In [3]:
GENE_EMBED_DIM = 150
GENE_EXPRESSION_VEC = 15077
BATCH_SIZE = 10

input_dim = GENE_EXPRESSION_VEC+GENE_EMBED_DIM
output_dim = GENE_EXPRESSION_VEC

In [4]:
#def get_data(train_loader):
   
    #for index, batch_inputs in enumerate(train_loader):
     #   KO_gene, batch_X, batch_y = batch_inputs
        
    #return KO_gene, batch_X, batch_y


#train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
#KO_gene, batch_X, batch_y = get_data(train_loader)


Generate data:
- For every gene in the 15077 list:
     - Get the gene embedding
     - choose 1000 random unperturbed cells
- Feed into the generator (unperturbed cells, KO_gene embedding)
- Return 1000 transcripts of perturbed cells
- Run through the classifier

In [5]:
def dataloader_mini(KO_gene, gene_emb, gene_emb_label, unperturbed_expr_df, n_cells):
   
    #get embedding of KO gene
    #emb_idx is the index of the gene embeddings for that particular KO gene
    emb_idx = gene_emb_label[KO_gene.upper()]  #an index
    #Embeddings of the KO gene:
    KO_gene_emb = np.array(gene_emb.iloc[emb_idx]) #an embedding
    
    #select n random unperturbed cells
    
    rand_index = random.sample(range(0, len(unperturbed_expr_df) - 1), n_cells)
    unpert_input = np.array(unperturbed_expr_df.iloc[rand_index])
    
    return KO_gene, unpert_input, KO_gene_emb

In [6]:
def generate_data(unpert_input, KO_gene_emb, model):
    model.eval()
    z = unpert_input
    cond = KO_gene_emb 
    generated_data = model(z, cond)
    return generated_data

In [7]:
def classify_state_SVM(generated_data):
    
    #get the gene names for the 15077 genes
    columns = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 0).columns
    #get gene names for what the SVM expects
    SVM_filt = pd.read_csv('./data/unperturbed_filtered.csv', nrows = 0).columns
    #turn the generated transcripts into a dataframe for filtering in the next step
    cell = pd.DataFrame(generated_data.detach().numpy(), columns = columns)
    #grab just the genes that the SVM uses to predict t cell state
    SVM_input = cell[SVM_filt] 
    
    #load the svm
    loaded_model = pickle.load(open('./saved_models/svc_model_unperturbed.sav', 'rb'))
    preds = loaded_model.predict(SVM_input)
    
    return preds

In [8]:
gene_list = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 0).columns.values
gene_emb = pd.read_csv('./data/gene_embeddings.csv')
gene_emb_label = np.load('./gene_embedding/entity2labeldict_TransE_nepochs_128.npy', allow_pickle = True).item()
unperturbed_expr_df = pd.read_csv('./data/unperturbed_gene_expression.csv')

In [9]:
n_cells = 100

In [93]:
KO_gene_list = []
KO_gene_list_found = []
KO_gene_list_not_found = []
pred_list = []

model = torch.load("./saved_models/Model_G_300.pth") 

for gene in gene_list:
    
    #load the data and KO gene
    pred = None
    KO_gene_list.append(gene)
    if gene.upper() in gene_emb_label:
        KO_gene, unpert_input, KO_gene_emb = dataloader_mini(KO_gene = gene, gene_emb = gene_emb, gene_emb_label = gene_emb_label, unperturbed_expr_df = unperturbed_expr_df, n_cells = n_cells)
        #unpert_input_list.append(unpert_input)
        #KO_gene_emb_list.append(KO_gene_emb)
    
        #generate data using GAN
        #batch size comes first
        KO_gene_tensor = torch.tensor(np.tile(KO_gene_emb, (n_cells, 1)))
        #print(unpert_input.shape)
        generated_data = generate_data(torch.tensor(unpert_input), KO_gene_tensor, model)
    
        #run the svm classifier
        pred = classify_state_SVM(generated_data)
    else:
        KO_gene_list_not_found.append(gene)
        continue
    KO_gene_list_found.append(gene)
    pred_list.append(pred)
    
    

In [94]:
KO_gene_list_found #TODO: how to combine genes not found and genes found in embeddings.

['Mrpl15',
 'Lypla1',
 'Tcea1',
 'Atp6v1h',
 'Rb1cc1',
 'Pcmtd1',
 'Rrs1',
 'Adhfe1',
 'Mybl1',
 'Vcpip1',
 'Sgk3',
 'Mcmdc2',
 'Cops5',
 'Cspp1',
 'Arfgef1',
 'Slco5a1',
 'Ncoa2',
 'Tram1',
 'Lactb2',
 'Eya1',
 'Msc',
 'Trpa1',
 'Terf1',
 'Rpl7',
 'Rdh10',
 'Stau2',
 'Ube2w',
 'Eloc',
 'Tmem70',
 'Ly96',
 'Mcm3',
 'Paqr8',
 'Tram2',
 'Tmem14a',
 'Kcnq5',
 'Ogfrl1',
 'B3gat2',
 'Smap1',
 'Sdhaf4',
 'Fam135a',
 'Lmbrd1',
 'Phf3',
 'Ptp4a1',
 'Prim2',
 'Rab23',
 'Bag2',
 'Bend6',
 'Dst',
 'Ccdc115',
 'Imp4',
 'Ptpn18',
 'Fam168b',
 'Plekhb2',
 'Hs6st1',
 'Uggt1',
 'Neurl3',
 'Arid5a',
 'Kansl3',
 'Lman2l',
 'Cnnm4',
 'Cnnm3',
 'Ankrd23',
 'Ankrd39',
 'Sema4c',
 'Fam178b',
 'Actr1b',
 'Zap70',
 'Tmem131',
 'Vwa3b',
 'Inpp4a',
 'Coa5',
 'Unc50',
 'Mgat4a',
 'Tsga10',
 'Lipt1',
 'Mitd1',
 'Mrpl30',
 'Txndc9',
 'Eif5b',
 'Rev1',
 'Aff3',
 'Chst10',
 'Pdcl3',
 'Tbc1d8',
 'Cnot11',
 'Rnf149',
 'Creg2',
 'Map4k4',
 'Il1r2',
 'Il1rl2',
 'Il1rl1',
 'Il18r1',
 'Il18rap',
 'Mfsd9',
 'Mrps9',
 'Gpr4

In [95]:
KO_gene_list_not_found

['4732440D04Rik',
 'Gm26901',
 '1700034P13Rik',
 'Snhg6',
 'Gm9947',
 '6720483E21Rik',
 'Ptp4a1.1',
 'Zfp451',
 'Gm37233',
 'Gm28306',
 'Gm28417',
 '4930568A12Rik',
 'Arhgef4',
 'Gm38336',
 'Gm33280',
 '4930403P22Rik',
 'D430040D24Rik',
 'Cox5b',
 '4933424G06Rik',
 'Gm33533',
 'Gm38115',
 '2010300C02Rik',
 '4930556I23Rik',
 'Rpl31',
 'Gm15832',
 'Gm16894',
 '8430432A02Rik',
 'AI597479',
 'Gm8251',
 'Tex30',
 'Poglut2',
 'Bivm',
 'Ercc5',
 'Gm17767',
 'Gm31812',
 'Gm29670',
 '1700019D03Rik',
 '4930444A19Rik',
 'Hspe1',
 '4930558J18Rik',
 '1700066M21Rik',
 'Sgo2a',
 'Gm15834',
 'Fam126b',
 'G730003C15Rik',
 'Gm11579',
 'Ino80dos',
 'Gm20342',
 '2810408I11Rik',
 'Rpe',
 'Gm29113',
 'Gm29114',
 'Gm29112',
 'Gm28112',
 'Gm29358',
 'Zfp142',
 'A630095N17Rik',
 'Tmem198',
 'Utp14b',
 'C430014B12Rik',
 'Gm28942',
 'A030005L19Rik',
 'C130026I21Rik',
 'A530032D15Rik',
 'Gm10553',
 'A530040E14Rik',
 'Sp110',
 'Sp140',
 'Gm10552',
 'Gm17017',
 'A630001G21Rik',
 '2810459M11Rik',
 'Gm16341',
 'C1300

In [96]:
def value_or_zero(data, state):
    if state in data:
        return data[state]
    else:
        return 0.0

def tcell_state_proportions(data): 
    cell_states = ['progenitor', 'effector', 'terminal exhausted', 'cycling', 'other']
    proportions = []
    for gene in data.index:
        value_counts = data.loc[gene].value_counts()
        total = value_counts.sum()
        state_proportions = value_counts/total

        gene_proportions = {}
        for state in cell_states:
            gene_proportions[state] = value_or_zero(state_proportions, state)
        proportions.append(gene_proportions)
    return pd.DataFrame(proportions, index = data.index)


In [97]:
# handle genes that aren't in our embedding
for not_found in KO_gene_list_not_found:
    pred_list.append(['other'] * n_cells)

output = pd.DataFrame(pred_list, index = (KO_gene_list_found + KO_gene_list_not_found))

for gene in output.index:
    value_counts = output.loc[gene].value_counts()
    total = value_counts.sum()
    if total == 0:
        total = 1

proportions = tcell_state_proportions(output )
proportions.to_csv("part_c_output.csv", header=False)

In [98]:
def enough_cycling(data):
  enough_cycling = []
  for gene in data.index:
    if value_or_zero(data.loc[gene], 'cycling') >= 0.05:
      enough_cycling.append(1)
    else:
      enough_cycling.append(0)
    
  return enough_cycling

def l1_loss_part_a(row):
    return abs(value_or_zero(row, 'cycling') - 0) + \
      abs(value_or_zero(row, 'terminal exhausted') - 0) + \
        abs(value_or_zero(row, 'effector') - 0) + \
          abs(value_or_zero(row, 'other') - 0) + \
            abs(value_or_zero(row, 'progenitor') - 1) 

def loss_part_b(row):
  return (value_or_zero(row, 'progenitor') / 0.0675) + \
    (value_or_zero(row, 'effector') / 0.2097) - \
      (value_or_zero(row, 'terminal exhausted') / 0.3134) + \
        (value_or_zero(row, 'cycling') / 0.3921)

In [99]:
proportions['enough_cycling'] = enough_cycling(proportions)

In [100]:
proportions.sort_index(key=lambda gene: l1_loss_part_a(proportions.loc[gene]))[['progenitor', 'enough_cycling']].to_csv("part_a_output.csv", header=False)

In [101]:
proportions['objective_b'] = [loss_part_b(proportions.loc[gene]) for gene in proportions.index]
proportions.sort_index(ascending=False, key=lambda gene: loss_part_b(proportions.loc[gene]))[['objective_b', 'enough_cycling']].to_csv("part_b_output.csv", header=False)