In [1]:
import os
os.chdir("..")

In [2]:
import torch
from utils.dataset_highvars import get_loader
from dataclasses import dataclass
import anndata as ad
from torch.utils.data import Dataset
import scanpy as sc
import torch.nn as nn
import pytorch_lightning as pl

  __import__('pkg_resources').declare_namespace(__name__)


# Dataset Loading

In [3]:
@dataclass
class Config:
    log_dir = "logs"
    name = "Highly_Var"
    batch_size = 64 # Genes processed at once
    version = 1
    epochs = 30
    lr = 1e-3
    num_workers = 15
    num_samples = 700
    target_gene_dim = 128
cfg = Config()


In [4]:
dataRoot = "data/vcc_data"
tr_adata_path = f"{dataRoot}/adata_Training.h5ad"
adata = ad.read_h5ad(tr_adata_path)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
#sc.pp.highly_variable_genes(adata)
#maskidx = (~adata.var.index.str.startswith("MT-")) & adata.var.highly_variable
#adata = adata[:,maskidx]

In [5]:
class GeneExpressionDataset(Dataset):
    def __init__(self, adata,batchMapping,seqLength = 32):
        self.adata_batchX = adata.X
        self.batch = adata.obs.batch.to_numpy()
        self.batchmap = batchMapping
        self.seqLength = seqLength

    def __len__(self):
        return self.adata_batchX.shape[0]

    def __getitem__(self, idx):
        gene_expression = self.adata_batchX[idx].toarray().squeeze()
        gene_expressionseq = gene_expression.reshape(self.seqLength,-1)
        origin =  self.batchmap[self.batch[idx]]
        return gene_expressionseq, origin

In [6]:
batchMapping = {b:i for i,b in enumerate(adata.obs.batch.unique())}


In [7]:
dataset = GeneExpressionDataset(adata[adata.obs.target_gene == "non-targeting"],batchMapping)
trainDataset,valDataset = torch.utils.data.random_split(dataset,[int(0.8*len(dataset)),len(dataset)-int(0.8*len(dataset))])
#implement k-fold later
train_loader = torch.utils.data.DataLoader(trainDataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_loader = torch.utils.data.DataLoader(valDataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

In [None]:
genes, batchinfo = next(iter(train_loader))

# Design Architecture

In [None]:
qLayer = nn.Linear(565,64)
kLayer = nn.Linear(565,64)
vLayer = nn.Linear(565,64)
q = qLayer(genes)
k = kLayer(genes)
v = vLayer(genes)
multiAttention = nn.MultiheadAttention(embed_dim=64,num_heads = 8)
latent, grad= multiAttention(q,k,v)


In [16]:
class UpAttention(nn.Module):
    def __init__(self, inDims, outDims):
        super().__init__()
        self.expand_layer = nn.Linear(inDims, outDims)
        self.mha = nn.MultiheadAttention(embed_dim=outDims, num_heads=8, batch_first=True)
        self.norm = nn.LayerNorm(outDims)

    def forward(self, x):
        x = self.expand_layer(x)
        x, _ = self.mha(x, x, x)
        x = self.norm(x)
        return x

class AttentionDecoder(nn.Module):
    def __init__(self, latentDim, seq_len=32):
        super().__init__()
        self.seq_len = seq_len
        
        # 1. Inverse of the Encoder's projection
        # We map z back to (Seq_Len * Smallest_Feature_Dim)
        # Based on your encoder, the smallest dim was 64.
        self.linear_map = nn.Linear(latentDim, 64 * seq_len)
        
        # 2. Inverse of the DownAttention layers
        self.up3 = UpAttention(inDims=64, outDims=128)
        self.up2 = UpAttention(inDims=128, outDims=256)
        self.up1 = UpAttention(inDims=256, outDims=565) # Final dimension matches input
        
    def forward(self, z):
        # z: [Batch, latentDim]
        
        # 1. Project and Unflatten
        x = self.linear_map(z)
        # Reshape to [Batch, Seq_Len, 64]
        x = x.view(-1, self.seq_len, 64) 
        
        # 2. Expand features back up
        x = self.up3(x)
        x = self.up2(x)
        x = self.up1(x)
        
        return x

In [None]:
class MyLoss(nn.Module):
    def __init__(self):
        super().__init__()
        self.mse = nn.MSELoss()
    
    def forward(self,recon_x,x,mu,logvar):
        recon_loss = self.mse(recon_x,x)
        kld_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kld_loss


In [None]:

class DownAttention(nn.Module):
    def __init__(self,inDims,outDims):
        super().__init__()
        self.qlayer = nn.Linear(inDims,outDims)
        self.klayer = nn.Linear(inDims,outDims)
        self.vlayer = nn.Linear(inDims,outDims)
        self.mha = nn.MultiheadAttention(embed_dim = outDims,num_heads=8,batch_first=True) # Might be scaled
        self.norm = nn.LayerNorm(outDims)
    def forward(self,x):
        q = self.qlayer(x)
        k = self.klayer(x)
        v = self.vlayer(x)
        attn,_ = self.mha(q,k,v)
        attn = self.norm(attn)
        return attn
    
class AttentionEncoder(nn.Module):
    def __init__(self,latentDim):
        super().__init__()
        self.down1 = DownAttention(inDims = 565,outDims=256)
        self.down2 = DownAttention(inDims = 256,outDims=128)
        self.down3 = DownAttention(inDims = 128,outDims=64)
        self.flat = nn.Flatten()
        self.proj = nn.Linear(64*32,latentDim*2)
    
    def forward(self,x):
        x = self.down1(x)
        x = self.down2(x)
        x = self.down3(x)
        x = self.proj(self.flat(x))
        return torch.chunk(x,chunks = 2,dim = 1)
    
class AttentionDecoder(nn.Module):
    def __init__(self,latentDim,numseqs):
        super().__init__()
        self.numseqs = numseqs
        self.unflatten = nn.Linear(latentDim,numseqs*64)
        self.up1 = DownAttention(inDims = 64,outDims = 128)
        self.up2 = DownAttention(inDims = 128,outDims = 256)
        self.up3 = DownAttention(inDims = 256,outDims = 512)
        self.final = nn.Linear(512,565)

    def forward(self,x):
        x = self.unflatten(x)
        x = x.view(-1,self.numseqs,64)
        x = self.up1(x)
        x = self.up2(x)
        x = self.up3(x)
        return self.final(x)

class AttentionVae(pl.LightningModule):
    def __init__(self,latentDim):
        super().__init__()
        self.encoder = AttentionEncoder(latentDim)
        self.decoder = AttentionDecoder(latentDim,32)
        self.criterion = MyLoss() 
    
    def reparametrise(self,mu,logvar,mode:str="train"):
        if mode  == "train":
            std = torch.exp(0.5*logvar)
            eps = torch.randn_like(std)
            return eps.mul(std).add_(mu)
        else:
            return mu
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    def shared_step(self,batch,mode):
        genes, _ = batch
        mu,logvar = self.encoder(genes)
        z = self.reparametrise(mu,logvar,mode)
        x = self.decoder(z)
        return genes,x,mu,logvar
    
    def training_step(self, batch,batch_idx):
        genes, x,mu,logvar = self.shared_step(batch,mode="train")
        loss = self.criterion(x, genes, mu, logvar)
        self.log("train/loss", loss,prog_bar=True)
        return loss
    
    def validation_step(self, batch,batch_idx):
        genes, x,mu,logvar = self.shared_step(batch,mode="val")
        loss = self.criterion(x, genes, mu, logvar)
        self.log("val/loss", loss,prog_bar=True)
        return loss

In [18]:
model = AttentionVae(latentDim=512)

In [19]:
trainer = pl.Trainer(max_epochs=cfg.epochs)
trainer.fit(model,train_loader,val_loader)


GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name      | Type             | Params
-----------------------------------------------
0 | encoder   | AttentionEncoder | 3.0 M 
1 | decoder   | AttentionDecoder | 3.2 M 
2 | criterion | MyLoss           | 0     
-----------------------------------------------
6.2 M     Trainable params
0         Non-trainable params
6.2 M     Total params


Epoch 2:  29%|██▊       | 171/598 [00:20<00:50,  8.53it/s, loss=0.921, v_num=2, val/loss=0.915, train/loss=0.922]      



Epoch 2:  29%|██▊       | 171/598 [00:21<00:53,  8.05it/s, loss=0.921, v_num=2, val/loss=0.915, train/loss=0.922]


1