In [1]:
from torch.utils.data import Dataset
from mydataloader import dataset


import torch
import torch.nn as nn
import numpy as np
import datetime
import pytz

import numpy as np
import pandas as pd

In [2]:
from tqdm.auto import tqdm

In [3]:
from torch.utils.data import DataLoader, TensorDataset

In [4]:
torch.set_default_dtype(torch.float64)

In [5]:
GENE_EMBED_DIM = 150
GENE_EXPRESSION_VEC = 15077
BATCH_SIZE = 8

In [6]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.generator_stack = nn.Sequential(
            nn.Linear(input_dim, 10240),
            nn.LeakyReLU(),
            nn.Linear(10240, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 10240),
            nn.BatchNorm1d(10240, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(10240, output_dim),
            nn.LeakyReLU(),
        )

    def forward(self, x, cond):
        a = torch.cat((x, cond), axis=-1)
        #print(a.size()[0])
        return self.generator_stack(a)

In [7]:
class Discriminator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Discriminator, self).__init__()
        self.discriminator_stack = nn.Sequential(
            nn.Linear(input_dim, 10240),
            nn.LeakyReLU(),
            nn.Linear(10240, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 1280),
            nn.BatchNorm1d(1280, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(1280, 512),
            nn.BatchNorm1d(512, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(256, 64),
            nn.BatchNorm1d(64, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(64, output_dim),
            nn.Sigmoid()
        )

    def forward(self, x, cond):
        a = torch.cat((x, cond), axis=-1)
        return self.discriminator_stack(a)

In [8]:
def train(learning_rate, num_epochs, train_loader):

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    device = "cpu"
    
    # Create Models
    G = Generator(GENE_EXPRESSION_VEC+GENE_EMBED_DIM, GENE_EXPRESSION_VEC).to(device)
    D = Discriminator(GENE_EXPRESSION_VEC+GENE_EMBED_DIM, 1).to(device)

    # Define Losses
    lossG = nn.BCELoss()
    lossD = nn.BCELoss()
    
    alpha = 0.5

    # Create Optimizer
    param_list = list(G.parameters()) + list(D.parameters())

    optimizer_all = torch.optim.Adam(param_list, lr=learning_rate)

    # Train Models
    G.train()
    D.train()

    # training loop
    for ep in range(num_epochs):

        # Reset Loss List
        epoch_loss = {'loss_G': [], 'loss_D': []}

        for index, batch_inputs in enumerate(tqdm(train_loader)):

            # unpack train loader
            KO_gene, batch_X, batch_y = batch_inputs
            
            #print(KO_gene.shape, batch_X.shape, batch_y.shape)

            # zero all gradients
            optimizer_all.zero_grad()

            # generator input: embedding of knockout gene, unperturbed gene expression
            fake = G(batch_X.to(device), KO_gene.to(device))

            # get discriminator decision: real or fake
            decision_fake = D(fake, KO_gene.to(device))        
            decision_real = D(batch_y.to(device), KO_gene.to(device))
            
            # Get Losses
            loss_G = lossG(decision_fake, torch.ones_like(decision_fake).to(device))
            #loss_G += .5 * lossG(decision_fake, torch.zeros_like(decision_fake).to(device))
            loss_D = .5 * lossD(decision_real, torch.ones_like(decision_real).to(device))
            loss_D += .5 * lossD(decision_fake, torch.zeros_like(decision_fake).to(device))
            #print(loss_D)

            # Backward the Loss
            loss_total = alpha*loss_G + (1.0 - alpha)*loss_D
            loss_total.backward()

            # Step
            optimizer_all.step()

            # Add Losses to Epoch Dictionary
            epoch_loss["loss_G"].append(loss_G)
            epoch_loss["loss_D"].append(loss_D)
            #print(epoch_loss)
            #print(len(epoch_loss["loss_G"]))

        #Average Loss Every Epoch
        avg_ep_error_G = sum(epoch_loss["loss_G"]) / len(epoch_loss["loss_G"]) #why len
        avg_ep_error_D = sum(epoch_loss["loss_D"]) / len(epoch_loss["loss_G"])

        # Print losses every epoch
        print("######################################################")
        print("Generator_Loss: {}\t at epoch: {}".format(avg_ep_error_G, ep))
        print("Discriminator_Loss: {}\t at epoch: {}".format(avg_ep_error_D, ep))

    return G, D


In [None]:
if __name__ == '__main__':
    
    print("Starting...")
    train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
    G, D  = train(learning_rate = 1e-3, num_epochs = 200, train_loader = train_loader)

    #Save Model
    FILE_MODEL_G = "./saved_models/Model_G.pth"
    FILE_MODEL_D = "./saved_models/Model_D.pth"
    torch.save(G, FILE_MODEL_G)
    torch.save(D, FILE_MODEL_D)

Starting...


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.5930319315812046	 at epoch: 0
Discriminator_Loss: 0.7113545219242756	 at epoch: 0


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.4566224700262467	 at epoch: 1
Discriminator_Loss: 0.7328893322420704	 at epoch: 1


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.370532462790347	 at epoch: 2
Discriminator_Loss: 0.7746943081221627	 at epoch: 2


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.322116846677342	 at epoch: 3
Discriminator_Loss: 0.8072764525854539	 at epoch: 3


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.3130870410821567	 at epoch: 4
Discriminator_Loss: 0.803014229281941	 at epoch: 4


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.31910977881865304	 at epoch: 5
Discriminator_Loss: 0.7884582945976243	 at epoch: 5


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.3167494341613124	 at epoch: 6
Discriminator_Loss: 0.7851859766439001	 at epoch: 6


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.3109621450715182	 at epoch: 7
Discriminator_Loss: 0.7938459731778241	 at epoch: 7


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.29995119417725136	 at epoch: 8
Discriminator_Loss: 0.8060298977568687	 at epoch: 8


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.32485865508771633	 at epoch: 9
Discriminator_Loss: 0.7633189423913557	 at epoch: 9


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.33156089668222505	 at epoch: 10
Discriminator_Loss: 0.7481162006712389	 at epoch: 10


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.3206194749424868	 at epoch: 11
Discriminator_Loss: 0.7607810804383014	 at epoch: 11


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.319517316691912	 at epoch: 12
Discriminator_Loss: 0.7570259831096874	 at epoch: 12


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.31700400779058174	 at epoch: 13
Discriminator_Loss: 0.7681980027913025	 at epoch: 13


  0%|          | 0/8 [00:00<?, ?it/s]

######################################################
Generator_Loss: 0.31792276705185474	 at epoch: 14
Discriminator_Loss: 0.7590479402034609	 at epoch: 14


  0%|          | 0/8 [00:00<?, ?it/s]

In [None]:
!nvidia-smi

In [22]:
15109/1024

14.7548828125

In [None]:
import random
random.randint(0, len(dataset.perturb_seq) - 1)

In [None]:
dataset.perturb_seq - 1

In [21]:
pd.DataFrame(np.load("./gene_embedding/entities_kg_emb_TransE_nepochs_128.npy"))

Unnamed: 0,0,1,2,3,4,5,6,7,8,9,...,140,141,142,143,144,145,146,147,148,149
0,0.084076,0.066017,0.160443,0.045989,0.028019,0.028495,0.109682,0.028926,-0.119988,0.043822,...,0.055338,-0.025267,0.045266,0.107301,-0.083249,-0.053435,0.011428,-0.010621,-0.017017,0.035835
1,-0.097811,0.017677,0.024524,-0.034668,-0.077248,-0.040067,0.077137,-0.008848,-0.055085,0.091817,...,0.134737,0.031279,0.043491,0.094547,0.055724,0.009046,-0.022644,0.055887,0.121261,-0.056232
2,0.000243,0.078149,0.098819,0.069123,0.016560,-0.042909,0.053804,0.057282,-0.166803,-0.040503,...,0.121094,-0.043642,0.075727,0.015954,-0.046719,-0.029619,-0.026320,0.061115,0.157111,-0.075119
3,0.089756,-0.035879,-0.114247,0.033007,-0.089655,0.026436,0.019157,0.057533,-0.145426,-0.102638,...,0.073326,-0.040363,-0.015797,0.125135,-0.042652,0.115123,-0.082012,0.003727,-0.043642,0.080191
4,-0.068236,-0.014179,-0.116757,-0.049554,0.008640,0.076427,0.112634,-0.058167,-0.034452,-0.089467,...,0.086226,0.025905,-0.010350,0.102784,-0.099384,0.077953,0.021640,-0.012799,0.150379,0.069952
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
15705,-0.014159,-0.033720,0.009675,-0.064084,0.121911,-0.018223,-0.102532,-0.102793,0.138172,-0.058916,...,0.126625,-0.090644,0.055296,0.071481,-0.055293,0.029780,0.004009,-0.002020,-0.047531,0.036867
15706,0.062602,-0.054406,0.044205,0.061566,0.124266,-0.021329,-0.055662,-0.036168,0.113291,-0.095599,...,-0.099225,0.077612,0.015890,0.095092,0.010492,-0.018273,-0.075841,0.062877,-0.141602,0.126040
15707,-0.067877,0.197717,0.118632,-0.044108,-0.013304,0.000297,-0.103470,-0.027654,0.076239,0.047753,...,-0.089681,0.068424,0.062545,0.022933,-0.019549,-0.005379,0.068350,0.036818,0.000432,0.065463
15708,0.042693,-0.086445,0.054753,-0.056397,-0.002129,-0.127126,-0.053702,-0.130947,-0.014795,-0.019460,...,-0.148341,-0.055738,-0.000493,0.075588,0.030998,0.025893,-0.027715,-0.035677,-0.006670,-0.093611


In [17]:
pd.read_csv('./data/knockout_list_conditions.csv')[1::].values.flatten()

array(['Tox2', 'Tpt1', 'Tcf7', 'Il12rb1', 'Ikzf3', 'Nr4a3', 'Litaf',
       'Elf1', 'Irf2', 'Arid5b', 'Zeb2', 'Satb1', 'Dvl2', 'Nr4a1',
       'Hif1a', 'Crem', 'Runx2', 'Ctnnb1', 'Tcf3', 'Foxo1', 'Dvl1',
       'Gsk3b', 'Dkk3', 'Hmgb1', 'Dvl3', 'Sox4', 'Fzd1', 'Stat4', 'Nr4a2',
       'Sp100', 'Rela', 'Ldhb', 'Eomes', 'Zfp292', 'Prdm1', 'Atf2',
       'Il12rb2', 'Egr1', 'Id2', 'Lef1', 'Arid4b', 'Fzd6', 'Foxp1', 'Id3',
       'Fzd3', 'Foxm1', 'Nr3c1', 'Irf9', 'Tox', 'Hmgb2', 'Oxnad1',
       'Sp140', 'Sub1', 'Yy1', 'Lrp1', 'Ep300', 'P2rx7', 'Runx3', 'Rad21',
       'Klf2', 'Ezh2', 'Myb', 'Eef2', 'Batf', 'Tbx21', 'Rps6', 'Aqr',
       'Bach2', 'Bhlhe40', 'Ets1', 'Fosb', 'Mafk', 'Stat3'], dtype=object)

In [8]:
import pandas as pd
pd.read_csv('./data/perturbed_gene_expression.csv').head()

KeyboardInterrupt: 