In [1]:
import os
os.chdir("..")

In [2]:
import torch
from utils.dataset_highvars import get_loader
from dataclasses import dataclass
import anndata as ad
from torch.utils.data import Dataset
import scanpy as sc
import torch.nn as nn
import pytorch_lightning as pl
import numpy as np

  __import__('pkg_resources').declare_namespace(__name__)


# Dataset Loading

In [3]:
@dataclass
class Config:
    log_dir = "logs"
    name = "AttentionDiffuse"
    batch_size = 64 # Genes processed at once
    version = 1
    epochs = 30
    lr = 1e-3
    num_workers = 15
    num_samples = 700
    target_gene_dim = 128
cfg = Config()


In [4]:
dataRoot = "data/vcc_data"
tr_adata_path = f"{dataRoot}/adata_Training.h5ad"
adata = ad.read_h5ad(tr_adata_path)
sc.pp.normalize_total(adata)
sc.pp.log1p(adata)
#sc.pp.highly_variable_genes(adata)
#maskidx = (~adata.var.index.str.startswith("MT-")) & adata.var.highly_variable
#adata = adata[:,maskidx]

In [5]:
def get_positional_encoding_vector(pos, d_model):
    """
    Calculates the positional encoding vector for a single position.
    
    Args:
        pos (int): The position index (e.g., 0, 1, 2...).
        d_model (int): The dimension of the embedding (e.g., 512).
                       Must be an even number for this implementation.

    Returns:
        np.ndarray: A 1D array of shape (d_model,)
    """
    
    # Ensure d_model is even for simplicity in pairing sin/cos
    if d_model % 2 != 0:
        raise ValueError("d_model must be an even number.")

    # 1. Create an array for all dimension indices: [0, 1, 2, ..., d_model-1]
    d_indices = np.arange(d_model)

    # 2. Calculate the 'i' term for the denominator: [0, 0, 1, 1, 2, 2, ...]
    # This is the 'i' from the formula
    i = d_indices // 2

    # 3. Calculate the denominator (the "timescale" term)
    # 10000^(2i / d_model)
    denominator = np.power(10000, (2 * i) / d_model)

    # 4. Calculate the angle for every dimension: pos / denominator
    angles = pos / denominator
    
    # 5. Create the final vector
    pe_vector = np.zeros(d_model)

    # 6. Apply sin to all even indices
    pe_vector[0::2] = np.sin(angles[0::2])

    # 7. Apply cos to all odd indices
    pe_vector[1::2] = np.cos(angles[1::2])

    return torch.tensor(pe_vector).float()

In [6]:
currentBatch = adata.obs.iloc[10].target_gene
mask = adata.obs.batch == currentBatch
mask.to_numpy()

array([False, False, False, ..., False, False, False])

In [7]:
class GeneExpressionDataset(Dataset):
    def __init__(self, adata,geneMapping,seqLength = 32):
        mask = (adata.obs.target_gene != "non-targeting")
        mask2 = (adata.obs.target_gene == "non-targeting")
        
        self.adata_X = adata[mask].X
        self.adata_obs = adata[mask].obs
        
        self.adata_batchX = adata[mask2].X
        self.adata_obs_batch = adata[mask2].obs
        self.seqLength = seqLength
        self.geneMapping = geneMapping

    def __len__(self):
        return self.adata_X.shape[0]

    def __getitem__(self, idx):
        # Get labels to convert to numbers
        currentBatch = self.adata_obs.iloc[idx].batch
        gene = self.adata_obs.iloc[idx].target_gene


        gene_expression = self.adata_X[idx].toarray().squeeze()
        gene_expressionseq = gene_expression.reshape(self.seqLength,-1)
        
        mask = self.adata_obs_batch["batch"] == currentBatch
        valid_indices = mask.to_numpy()
        random_index = np.random.choice(valid_indices)
        clean_expression = self.adata_batchX[random_index].toarray().squeeze()
        
        geneidx = self.geneMapping[gene]
        
        return gene_expressionseq, clean_expression.reshape(self.seqLength,-1), np.expand_dims(get_positional_encoding_vector(geneidx,self.seqLength),axis = -1)

In [8]:
genemap = {b:i for i,b in enumerate(adata.obs.target_gene.unique())}


In [9]:
dataset = GeneExpressionDataset(adata,genemap)
trainDataset,valDataset = torch.utils.data.random_split(dataset,[int(0.8*len(dataset)),len(dataset)-int(0.8*len(dataset))])
#implement k-fold later
train_loader = torch.utils.data.DataLoader(trainDataset, batch_size=cfg.batch_size, shuffle=True, num_workers=cfg.num_workers)
val_loader = torch.utils.data.DataLoader(valDataset, batch_size=cfg.batch_size, shuffle=False, num_workers=cfg.num_workers)

In [10]:
genes, batchinfo, posencode = next(iter(train_loader))

In [11]:
genes.shape, batchinfo.shape, posencode.shape

(torch.Size([64, 32, 565]), torch.Size([64, 32, 565]), torch.Size([64, 32, 1]))

# Design Architecture

In [30]:
from diffusers import DDPMScheduler
noise_scheduler = DDPMScheduler(num_train_timesteps=5000)


In [31]:
bs = genes.shape[0]
timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bs,))

In [32]:
genes.shape

torch.Size([64, 32, 565])

In [39]:
import math
def timestep_embedding(t, dim):
    """
    t: (batch,)
    dim: embedding dimension
    """
    half = dim // 2
    freqs = torch.exp(
        torch.arange(half, dtype=torch.float32) * -(math.log(10000) / half)
    ).to(t.device)
    args = t[:, None].float() * freqs[None]
    emb = torch.cat([torch.sin(args), torch.cos(args)], dim=-1)
    return emb 

In [67]:

class Uattention(nn.Module):
    def __init__(self,inDims,outDims,condDIm = 1):
        super().__init__()
        self.qlayer = nn.Linear(inDims,outDims)
        self.klayer = nn.Linear(condDIm,outDims)
        self.vlayer = nn.Linear(condDIm,outDims)
        self.mha = nn.MultiheadAttention(embed_dim = outDims,num_heads=8,batch_first=True) # Might be scaled
        self.norm = nn.LayerNorm(outDims)
        self.time_proj = nn.Sequential(
            nn.Linear(32, 64),
            nn.SiLU(),
            nn.Linear(64, inDims)
        )
    
    def forward(self,x,conditioning,timesteps):
        time_emb = timestep_embedding(timesteps,32)
        x = x + self.time_proj(time_emb).unsqueeze(1)
        q = self.qlayer(x)
        k = self.klayer(conditioning)
        v = self.vlayer(conditioning)
        attn,_ = self.mha(q,k,v)
        attn = self.norm(attn)
        return attn
    
class AttentionEncoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.down1 = Uattention(inDims = 565,outDims=256)
        self.down2 = Uattention(inDims = 256,outDims=128)
        self.down3 = Uattention(inDims = 128,outDims=64)
    
    def forward(self,x, conditioning,timesteps):
        
        state1 = self.down1(x,conditioning,timesteps)
        state2 = self.down2(state1,conditioning,timesteps)
        state3 = self.down3(state2,conditioning,timesteps)
        return  state1,state2, state3
    
class AttentionDecoder(nn.Module):
    def __init__(self):
        super().__init__()
        self.up3 = Uattention(inDims=64*2, outDims=128)
        self.up2 = Uattention(inDims=128*2, outDims=256)
        self.up1 = Uattention(inDims=256*2, outDims=512) # Final dimension matches input
        self.final_layer = nn.Linear(512,565)
        
    def forward(self, x, state1,stat2,state3,conditioning,timesteps):
        x = torch.concat([x,state3],dim=-1)
        x = self.up3(x,conditioning,timesteps)
        x = torch.concat([x,stat2],dim=-1)
        x = self.up2(x,conditioning,timesteps)
        x = torch.concat([x,state1],dim=-1)
        x = self.up1(x,conditioning,timesteps)
        x = self.final_layer(x)
        return x
    
class AttentionDiffusion(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.noise_scheduler = DDPMScheduler(num_train_timesteps=5000)
        self.encoder = AttentionEncoder()
        self.latentBlock = nn.Sequential(nn.Linear(64,128),nn.SiLU(),nn.Linear(128,64))
        self.decoder = AttentionDecoder()
        self.criterion = nn.MSELoss() 
        
    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters())
    
    def shared_step(self,batch):
        target, inputdata,condition = batch
        noise = torch.randn(inputdata.shape)
        bs = inputdata.shape[0]
        timesteps = torch.randint(0, self.noise_scheduler.config.num_train_timesteps, (bs,), dtype=torch.int64)
        noisy_input = self.noise_scheduler.add_noise(inputdata, noise, timesteps)
        stat1,stat2,state3 = self.encoder(noisy_input,conditioning=condition,timesteps=timesteps)
        x = self.latentBlock(state3)
        genes = self.decoder(x,stat1,stat2,state3,conditioning=condition,timesteps=timesteps)
        return genes, target   
    
    def training_step(self, batch,batch_idx):
        genes, target = self.shared_step(batch)
        loss = self.criterion(genes, target)
        self.log("train/loss", loss,prog_bar=True)
        return loss
    
    def validation_step(self, batch,batch_idx):
        genes,target = self.shared_step(batch)
        loss = self.criterion(genes, target)
        self.log("val/loss", loss,prog_bar=True)
        return loss

In [68]:
model = AttentionDiffusion()

In [69]:
trainer = pl.Trainer(max_epochs=cfg.epochs)
trainer.fit(model,train_loader,val_loader)


GPU available: True, used: False
TPU available: None, using: 0 TPU cores

  | Name        | Type             | Params
-------------------------------------------------
0 | encoder     | AttentionEncoder | 602 K 
1 | latentBlock | Sequential       | 16.6 K
2 | decoder     | AttentionDecoder | 2.1 M 
3 | criterion   | MSELoss          | 0     
-------------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params


Epoch 0:  12%|█▏        | 350/2862 [00:52<06:17,  6.65it/s, loss=0.884, v_num=15, val/loss=1.9, train/loss=0.887]


1