In [3]:
from torch.utils.data import Dataset
from mydataloader import dataset

In [48]:
import torch
import torch.nn as nn
import numpy as np
import datetime
import pytz
import pickle

from torch.utils.data import DataLoader, TensorDataset

import numpy as np
import pandas as pd

torch.set_default_dtype(torch.float64)

In [31]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        self.generator_stack = nn.Sequential(
            nn.Linear(input_dim, 10240),
            nn.LeakyReLU(),
            nn.Linear(10240, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 2560),
            nn.BatchNorm1d(2560, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(2560, 5120),
            nn.BatchNorm1d(5120, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(5120, 10240),
            nn.BatchNorm1d(10240, momentum=.9),
            nn.LeakyReLU(),
            nn.Linear(10240, output_dim),
            nn.LeakyReLU(),
        )

    def forward(self, x, cond):
        a = torch.cat((x, cond), axis=-1)
        #print(a.size()[0])
        return self.generator_stack(a)

In [72]:
GENE_EMBED_DIM = 150
GENE_EXPRESSION_VEC = 15077
BATCH_SIZE = 10

input_dim = GENE_EXPRESSION_VEC+GENE_EMBED_DIM
output_dim = GENE_EXPRESSION_VEC

In [75]:
def get_data(train_loader):
   
    for index, batch_inputs in enumerate(train_loader):
        KO_gene, batch_X, batch_y = batch_inputs
        
    return KO_gene, batch_X, batch_y
            

In [88]:
train_loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
KO_gene, batch_X, batch_y = get_data(train_loader)

In [77]:
print(KO_gene.shape)
print(batch_X.shape)
print(batch_y.shape)

torch.Size([10, 150])
torch.Size([10, 15077])
torch.Size([10, 15077])


In [89]:

# Load the saved model

#TODO: change this path to the well trained model, this is just a test with 50 epochs
model = torch.load("./saved_models/test/Model_G_50.pth") 


model.eval()




Generator(
  (generator_stack): Sequential(
    (0): Linear(in_features=15227, out_features=10240, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=10240, out_features=5120, bias=True)
    (3): BatchNorm1d(5120, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (4): LeakyReLU(negative_slope=0.01)
    (5): Linear(in_features=5120, out_features=2560, bias=True)
    (6): BatchNorm1d(2560, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (7): LeakyReLU(negative_slope=0.01)
    (8): Linear(in_features=2560, out_features=2560, bias=True)
    (9): BatchNorm1d(2560, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (10): LeakyReLU(negative_slope=0.01)
    (11): Linear(in_features=2560, out_features=5120, bias=True)
    (12): BatchNorm1d(5120, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
    (13): LeakyReLU(negative_slope=0.01)
    (14): Linear(in_features=5120, out_features=10240, bias=True)

In [90]:
# Generate some data
#inputs: KO_gene = KO gene embedding, batch_X = unperturbed cell expression
z = batch_X
cond = KO_gene
generated_data = model(z, cond)

In [91]:
generated_data.shape

torch.Size([10, 15077])

### Prepare generated data for input into SVM
Note that the SVM was trained on filtered data (low variance genes were filtered out). Thus, the input to the SVM needs to have the same features as what it was trained on.

In [92]:
#get the gene names for the 15077 genes
columns = pd.read_csv('./data/unperturbed_gene_expression.csv', nrows = 10).columns

#get gene names for what the SVM expects
SVM_filt = pd.read_csv('./data/unperturbed_filtered.csv').columns

#turn the generated transcripts into a dataframe for filtering in the next step
cell = pd.DataFrame(generated_data.detach().numpy(), columns = columns)

#grab just the genes that the SVM uses to predict t cell state
SVM_input = cell[SVM_filt] 

In [93]:
loaded_model = pickle.load(open('../saved_models/svc_model_unperturbed.sav', 'rb'))
preds = loaded_model.predict(SVM_input)
#print('Input is a(n) ' + str(preds[0]) + ' cell')

In [94]:
preds

array(['effector', 'effector', 'effector', 'effector', 'effector',
       'effector', 'effector', 'effector', 'effector', 'effector'],
      dtype=object)