# **Virtual Cell Challenge - Transformer NH**

### Set up
---

In [1]:
# Import relevant packages
import pandas as pd
import os 
import numpy as np
import torch 
import scanpy as sc
import seaborn as sns
import matplotlib.pyplot as plt
import torch.nn as nn
import anndata as ad
import seaborn as sns

In [2]:
# Set working directory 
os.chdir("/home/____/Documents/VirtualCell")
os.getcwd()

'/home/____/Documents/VirtualCell'

### Preparing Training Data
---

#### Obtain VCC training data

In [3]:
# Read VCC data
VCC_Training = sc.read_h5ad("adata_Training.h5ad")

In [4]:
# Log1p normalization
sc.pp.log1p(VCC_Training)

In [5]:
# Get Control cells
Control_Cells = VCC_Training[VCC_Training.obs['target_gene'] == 'non-targeting'].copy()

## **Phase 1**

### Organize Inputs for Transformer
---

#### Create vocab for GeneIDs and get tensor

In [6]:
# Get list of genes
gene_list = list(VCC_Training.var_names)

# Create gene to ID mapping
gene_to_id = {gene: i for i, gene in enumerate(gene_list)}

# Add a pad token 
gene_to_id['<pad>'] = len(gene_to_id)

# Create tensor
gene_ids = [gene_to_id[gene] for gene in gene_list]
gene_ids_tensor = torch.tensor(gene_ids, dtype=torch.long)

#### Create UMI tensor

In [7]:
# Get UMI counts for each cell in control data
umi_counts_control = np.array(Control_Cells.X.sum(axis=1)).flatten()

# Log1p normalization
log1p_umi_control = np.log1p(umi_counts_control)

# Binning
n_umi_bins = 32
umi_bins = pd.cut(log1p_umi_control, bins=n_umi_bins, labels=False, include_lowest=True)
binned_umi_tensor = torch.tensor(umi_bins, dtype=torch.long)

#### Create Perturbation ID Input

In [8]:
# Create perturb vocab and add 'non-targeting' token (control)
perturb_vocab = gene_to_id.copy()
perturb_vocab['non-targeting'] = len(perturb_vocab)

vocab_size = len(perturb_vocab)

# Create tensor of non targeting ID from vocab
non_targeting_id = perturb_vocab['non-targeting']
num_control_cells = Control_Cells.n_obs
perturb_ids_tensor = torch.full((num_control_cells,), fill_value=non_targeting_id, dtype=torch.long)

#### Create Expression Bins 

In [9]:
n_expression_bins = 51 # Number of bins for expression values

# Get min and max of expression data to create edges for binning
expression_data_binning = Control_Cells.X.toarray()
non_zero_mask = expression_data_binning > 0
non_zero_values = expression_data_binning[non_zero_mask]
min_val, max_val = np.min(non_zero_values), np.max(non_zero_values)

# Create bins
bins = np.linspace(min_val, max_val, num=n_expression_bins)
binned_data = np.zeros(expression_data_binning.shape, dtype=np.int32)
binned_data[non_zero_mask] = np.digitize(non_zero_values, bins=bins)
Control_Cells.layers['binned'] = binned_data

# Pad token for masking
pad_token_id = n_expression_bins

#### Construct Dataset

In [10]:
from torch.utils.data import Dataset, DataLoader, random_split

# Create a class for the dataset
class VCC_Training_Dataset_Phase1(Dataset):
    def __init__(self, adata, gene_ids_tensor, binned_umi_tensor, perturb_ids_tensor, mask_prob=0.15, pad_token_id = n_expression_bins):
        self.X_binned = adata.layers['binned']
        self.X_true = adata.X.toarray()
        self.gene_ids_tensor = gene_ids_tensor
        self.binned_umi_tensor = binned_umi_tensor
        self.perturb_ids_tensor = perturb_ids_tensor
        self.mask_prob = mask_prob
        self.pad_token_id = pad_token_id

    def __len__(self):
        return self.X_binned.shape[0]
    
    def __getitem__(self, idx):
        binned_expr = torch.tensor(self.X_binned[idx, :], dtype = torch.long)
        true_expr = torch.tensor(self.X_true[idx, :], dtype = torch.float)

        masked_binned_expr = binned_expr.clone()
        mask = torch.rand(binned_expr.shape) < self.mask_prob
        masked_binned_expr[mask] = self.pad_token_id

        return {"gene_ids": self.gene_ids_tensor,
                "binned_expression_masked": masked_binned_expr,
                "true_expression": true_expr,
                "binned_umi": self.binned_umi_tensor[idx],
                "perturb_id": self.perturb_ids_tensor[idx],
                "mask": mask
                }

# Instantiate the Phase 1 dataset
Phase1_Dataset = VCC_Training_Dataset_Phase1(Control_Cells, gene_ids_tensor, binned_umi_tensor, perturb_ids_tensor)

# Split data
train_size = int(0.8 * len(Phase1_Dataset))
val_size = len(Phase1_Dataset) - train_size
train_dataset, val_dataset = random_split(Phase1_Dataset, [train_size, val_size])

# Create DataLoaders
train_loader = DataLoader(train_dataset, batch_size = 1, shuffle = True)
val_loader = DataLoader(val_dataset, batch_size = 1, shuffle = False)

### Constructing Transformer Architecture
---

#### Build Model

In [11]:
# Transformer Class
class VCC_Transformer(nn.Module):
    def __init__(self, vocab_size, n_expression_bins, n_umi_bins, embedding_dim, n_heads, n_layers, pad_token_id):
        super().__init__()

        # Embedding Layers
        self.shared_embedding = nn.Embedding(vocab_size, embedding_dim) 
        self.expression_embedding = nn.Embedding(n_expression_bins + 1, embedding_dim, padding_idx=pad_token_id)
        self.umi_embedding = nn.Embedding(n_umi_bins, embedding_dim)
        self.perturb_transform = nn.Linear(embedding_dim, embedding_dim)
        self.embedding_norm = nn.LayerNorm(embedding_dim)
        self.embedding_dropout = nn.Dropout(0.1)


        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(d_model = embedding_dim, nhead = n_heads, batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Output Prediction
        self.prediction_head = nn.Linear(embedding_dim, 1)

    def forward(self, batch):
        gene_ids = batch['gene_ids']
        binned_expression = batch['binned_expression_masked']
        umi_bin = batch['binned_umi']
        perturb_id = batch['perturb_id']

        # Create Embeddings
        gene_embed = self.shared_embedding(gene_ids)
        expr_embed = self.expression_embedding(binned_expression)
        sequence_embed = gene_embed + expr_embed
        sequence_embed = self.embedding_norm(sequence_embed)
        sequence_embed = self.embedding_dropout(sequence_embed)
        umi_embed = self.umi_embedding(umi_bin).unsqueeze(1)
        initial_p_embed = self.shared_embedding(perturb_id).unsqueeze(1)
        final_p_embed = self.perturb_transform(initial_p_embed)

        # Assemble and Process the Sequence
        full_sequence = torch.cat([umi_embed, final_p_embed, sequence_embed], dim=1)

        # Pass Sequence through Tranformer
        transformer_output = self.transformer_encoder(full_sequence)

        # Keep only Gene ouputs
        gene_outputs = transformer_output[:, 2:, :] 

        # Pass through prediction head
        predictions = self.prediction_head(gene_outputs)

        return predictions.squeeze(-1)

#### Instantiate Model

In [12]:
EMBED_DIM  = 256
N_HEADS = 8
N_LAYERS = 6

# Instantiate Model
model = VCC_Transformer(vocab_size=vocab_size,n_expression_bins=n_expression_bins, n_umi_bins=n_umi_bins, embedding_dim=EMBED_DIM, n_heads=N_HEADS, n_layers=N_LAYERS, pad_token_id=pad_token_id)

### Running Model
---

#### Setup

In [13]:
import torch.optim as optim
from tqdm import tqdm 

# Move model to GPU
device = torch.device("cuda")
model.to(device)

# Setup loss and optimizer
criterion = nn.MSELoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-4)
scaler = torch.cuda.amp.GradScaler()

# Number of epochs
num_epochs = 20

#### Training and Validation Passes

In [None]:
for epoch in range(num_epochs):

    # --- Training Phase

    # Setup
    model.train()
    total_train_loss = 0

    # Training batch loop
    train_pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Train]")
    for batch in train_pbar:
        batch = {k: v.to(device) for k, v in batch.items()}

        # Forward Pass
        with torch.cuda.amp.autocast():
            predictions = model(batch)
            true_expr = batch['true_expression']
            mask = batch['mask']

            loss = criterion(predictions[mask], true_expr[mask])

        # Backward Pass
        optimizer.zero_grad()
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_train_loss += loss.item()
        train_pbar.set_postfix({'loss': loss.item()})

    avg_train_loss = total_train_loss / len(train_loader)

    # --- Validation Phase

    # Setup
    model.eval()
    total_val_loss = 0

    val_pbar = tqdm(val_loader, desc=f"Epoch {epoch+1}/{num_epochs} [Val]")

    with torch.no_grad():
        for batch in val_pbar:
            batch = {k: v.to(device) for k, v in batch.items()}
            
            with torch.cuda.amp.autocast(enabled=torch.cuda.is_available()):
                predictions = model(batch)
                true_expr = batch['true_expression']
                mask = batch['mask']
                
                loss = criterion(predictions[mask], true_expr[mask])

            total_val_loss += loss.item()
            val_pbar.set_postfix({'loss': loss.item()})

    avg_val_loss = total_val_loss / len(val_loader)

    print(f"Epoch {epoch+1}/{num_epochs} Train Loss: {avg_train_loss:.6f}, Val Loss: {avg_val_loss:.6f}")