In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torchvision import transforms
from torchvision.utils import save_image
import torchvision.datasets as datasets  
from torch.utils.data import DataLoader
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
from tqdm import tqdm
import pandas as pd

In [5]:
import re
expr = pd.read_csv('/home/ubuntu/spatial.variableGeneExpr.csv',index_col=0)
prop = pd.read_csv('/home/ubuntu/STdeconvolve.K15.pixelProp.csv',index_col=0)
expr.index = list(map(lambda x: re.sub(r'-[1234]-','-',x),expr.index))
prop = prop.loc[map(lambda x: x in set(expr.index),prop.index),:]
print(expr)
print(prop)

                              Kap      Mmp7     Cuzd1     Fabp4      Cd24  \
AAACCGGGTAGGTACC-JY97A   0.000000  0.000000  0.000000  0.000000  0.000000   
AAATCGTGTACCACAA-JY97A   0.000000  0.000000  0.000000  0.000000  0.000000   
AAATGGTCAATGTGCC-JY97A   0.000000  0.000000  0.000000  0.000000  0.000000   
AAATTAACGGGTAGCT-JY97A   0.000000  0.000000  0.000000  0.000000  0.000000   
AACTCAAGTTAATTGC-JY97A   0.000000  0.000000  0.000000  0.000000  0.000000   
...                           ...       ...       ...       ...       ...   
TTGTGGTAGGAGGGAT-JY102A  0.608742  0.349994  0.000000  0.349994  0.349994   
TTGTTAGCAAATTCGA-JY102A  0.795374  0.474764  0.000000  0.565055  1.037767   
TTGTTCAGTGTGCTAC-JY102A  1.396384  0.593778  0.475061  0.184572  0.963797   
TTGTTTCCATACAACT-JY102A  1.255837  0.487164  0.000000  0.608093  0.608093   
TTGTTTGTGTAAATTC-JY102A  0.670376  1.068134  0.670376  0.670376  0.390342   

                               C3  RGD1304870       Cfd    Tspan1      Krt8

In [7]:
label = prop.apply(lambda x: list(x).index(max(x))+1,axis=1)
label

AAACCGGGTAGGTACC-JY97A      7
AAATCGTGTACCACAA-JY97A      7
AAATGGTCAATGTGCC-JY97A     11
AAATTAACGGGTAGCT-JY97A     15
AACTCAAGTTAATTGC-JY97A      4
                           ..
TTGTGGTAGGAGGGAT-JY102A    11
TTGTTAGCAAATTCGA-JY102A    10
TTGTTCAGTGTGCTAC-JY102A     7
TTGTTTCCATACAACT-JY102A     7
TTGTTTGTGTAAATTC-JY102A     8
Length: 2256, dtype: int64

In [28]:
from torch.utils.data import Dataset
import numpy as np
class spatialDataset(Dataset):
    def __init__(self,exprMat,label):
        self.exprMat = exprMat
        self.labels = label
        
    def __len__(self):
        return(len(self.labels))
    
    def __getitem__(self,idx):
        exprVector = np.array(self.exprMat.iloc[idx,:])
        label = self.labels[idx]
        return exprVector,label

In [29]:
dataset = spatialDataset(expr,label)
print(len(dataset))

2256


In [38]:
dataset[1][0].shape

(1804,)

In [39]:
# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)
INPUT_DIM = 1804
Z_DIM = 10
H_DIM = 200
NUM_EPOCHS = 10
BATCH_SIZE = 32
LR_RATE = 3e-4

cpu


In [40]:
train_loader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True)

In [64]:
class VariationalAutoEncoder(nn.Module):
    def __init__(self, input_dim, z_dim, h_dim=200):
        super().__init__()
        # encoder
        print(input_dim)
        print(h_dim)
        self.img_2hid = nn.Linear(input_dim, h_dim)

        # one for mu and one for stds, note how we only output
        # diagonal values of covariance matrix. Here we assume
        # the pixels are conditionally independent 
        self.hid_2mu = nn.Linear(h_dim, z_dim)
        self.hid_2sigma = nn.Linear(h_dim, z_dim)

        # decoder
        self.z_2hid = nn.Linear(z_dim, h_dim)
        self.hid_2img = nn.Linear(h_dim, input_dim)

    def encode(self, x):
        h = F.relu(self.img_2hid(x))
        mu = self.hid_2mu(h)
        sigma = self.hid_2sigma(h)
        return mu, sigma

    def decode(self, z):
        new_h = F.relu(self.z_2hid(z))
        x = torch.sigmoid(self.hid_2img(new_h))
        return x

    def forward(self, x):
        mu, sigma = self.encode(x)
        sigma = torch.exp(sigma)

        # Sample from latent distribution from encoder
        epsilon = torch.randn_like(sigma)
        z_reparametrized = mu + sigma*epsilon

        x = self.decode(z_reparametrized)
        return x, mu, sigma


In [65]:
# Define train function
def train(num_epochs, model, optimizer, loss_fn):
    # Start training
    for epoch in range(num_epochs):
        loop = tqdm(enumerate(train_loader))
        for i, (x, y) in loop:
            # Forward pass
            x = x.to(device).view(-1, INPUT_DIM).float()
            x_reconst, mu, sigma = model(x)

            # loss, formulas from https://www.youtube.com/watch?v=igP03FXZqgo&t=2182s
            reconst_loss = loss_fn(x_reconst, x)
            kl_div = - torch.sum(1 + torch.log(sigma.pow(2)) - mu.pow(2) - sigma.pow(2))

            # Backprop and optimize
            loss = reconst_loss + kl_div
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            loop.set_postfix(loss=loss.item())


In [66]:
# Initialize model, optimizer, loss
model = VariationalAutoEncoder(INPUT_DIM, Z_DIM, H_DIM).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=LR_RATE)
loss_fn = nn.BCELoss(reduction="sum")

1804
200


In [67]:
# Run training
train(NUM_EPOCHS, model, optimizer, loss_fn)

71it [00:01, 43.10it/s, loss=2.85e+3]
71it [00:01, 57.01it/s, loss=-88535.5]
71it [00:01, 44.12it/s, loss=-1.19e+5]
71it [00:01, 51.88it/s, loss=-1.12e+5]
71it [00:01, 54.80it/s, loss=-1.31e+5]
71it [00:01, 45.33it/s, loss=-1.05e+5]
71it [00:01, 48.07it/s, loss=-1.31e+5]
71it [00:01, 53.09it/s, loss=-1.34e+5]
71it [00:01, 57.63it/s, loss=-1.27e+5]
71it [00:01, 55.85it/s, loss=-1.21e+5]
