### VirtCell Stuff

In [1]:
import polars as pl 

In [2]:
# Define our path
pert_counts_path = "vcc_data/pert_counts_Test.csv"

# Read in the csv
pert_counts = pl.read_csv(pert_counts_path)

# Show the dimensions
print(f"Dimensions: {pert_counts.shape}")
pert_counts.head()



Dimensions: (100, 3)


target_gene,n_cells,median_umi_per_cell
str,i64,f64
"""GRB10""",4074,54815.0
"""STXBP1""",3751,55077.0
"""PLAGL2""",3642,54355.0
"""MED15""",3637,50285.0
"""SOX4""",3343,57576.0


In [3]:
gene_names_path = "vcc_data/gene_names.csv"

# Read this in and immediately convert to array
gene_names = pl.read_csv(gene_names_path, has_header=False).to_numpy().flatten()

gene_names

array(['SAMD11', 'NOC2L', 'KLHL17', ..., 'MT-ND5', 'MT-ND6', 'MT-CYB'],
      dtype=object)

### Prepping Model

In [4]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import scanpy as sc
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt
import pyensembl
import time

In [5]:
import src.updatedmodel as m
import src.data_cleaning as dc
import src.position_encoding as pos
import src.pathway_encoding as path
import src.updatetraining as train

In [6]:
adata = sc.read_h5ad('vcc_data/adata_Training.h5ad')
#adata.obs['perturbation_idx'] = np.random.randint(0, CONFIG["n_perturbations"], size=adata.n_obs)
#gene_names = pd.read_csv('vcc_data/gene_names.csv', header = None)
#gene_names = gene_names.iloc[:, 0].tolist()

CONFIG = {
    "n_genes": len(gene_names), #total number of expression genes needed to predict
    "n_perturbations": len(adata.obs['target_gene'].unique().tolist()), #number of unique perturbations
    "n_chromosomes": 25, #chromosome number 23+X+Y

    "perturbation_dim": 256,      # Condition embedding
    "chrom_embedding_dim": 16,     # Learned in-model for chromosome identity
    "locus_fourier_features": 8,   # Number of Fourier frequency pairs (2*F)
    "pathway_dim": 50,             # From pre-trained Autoencoder(based on hallmark MSigDB)
    "gene_identity_dim": 189,       # Main learnable gene embedding

    # Backbone dims
    "d_model": 512,                # Mamba hidden size
    "mamba_layers": 4,
    "n_heads": 8,
    "n_layers": 4,

    # Head
    "prediction_head": "probabilistic", # "linear" | "probabilistic"

    # Training
    "batch_size": 4,
    "learning_rate": 5e-5, # Lowered LR for AdamW stability
    "epochs": 50,
}

POS_DIM = CONFIG["chrom_embedding_dim"] + CONFIG["chrom_embedding_dim"]
GENE_FEAT_DIM = CONFIG["gene_identity_dim"] + CONFIG["pathway_dim"] + POS_DIM

ensembl_data = pyensembl.EnsemblRelease(109)
ensembl_data.download()
ensembl_data.index()

unique_genes = adata.var.index.tolist()
perturbation_to_idx_map = {name: i for i, name in enumerate(unique_genes)}
adata.obs['perturbation_idx'] = adata.obs['target_gene'].map(perturbation_to_idx_map)
adata.obs['perturbation_idx'] = adata.obs['perturbation_idx'].fillna(CONFIG['n_genes']).astype(int)
control_adata = dc.get_control_data(adata)
chr_idx, locus_norm, locus_fourier = pos.precompute_positional_indices(ensembl_data, gene_names, CONFIG)

pathway_feats = np.load('pathway_matrix.npy')

INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.cdna.all.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.ncrna.fa.gz.pickle
INFO:pyensembl.sequence_data:Loaded sequence dictionary from /hpc/home/zy231/.cache/pyensembl/GRCh38/ensembl109/Homo_sapiens.GRCh38.pep.all.fa.gz.pickle



--- Preparing positional indices using pyensembl ---


Fetching gene positions: 100%|██████████| 18080/18080 [02:22<00:00, 126.75it/s]


SUCCESS: Generated positional tensors with shapes:
Chromosome Indices (chr_idx): torch.Size([18080])
Normalized Locus (locus_norm): torch.Size([18080, 1])
Locus Fourier Features (locus_fourier): torch.Size([18080, 16])





In [7]:
dataset = m.PerturbationDataset(adata,CONFIG)
test_loader = DataLoader(dataset, shuffle=True)

Dataset using log-normalized data from adata.X for linear head.


In [8]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


print("Creating a new, untrained model instance...")
model = m.TranscriptomePredictor(CONFIG, GENE_FEAT_DIM, pathway_feats, chr_idx, locus_fourier)
model.to(device)

Using device: cuda
Creating a new, untrained model instance...


TranscriptomePredictor(
  (gene_id_emb): Embedding(18080, 189)
  (chr_emb): Embedding(25, 16)
  (locus_mlp): Sequential(
    (0): Linear(in_features=16, out_features=16, bias=True)
    (1): GELU(approximate='none')
  )
  (cond_proj): Linear(in_features=271, out_features=271, bias=True)
  (input_proj): Linear(in_features=271, out_features=512, bias=True)
  (input_norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  (backbone): BiMamba(
    (mamba_fwd): Mamba2(
      (in_proj): Linear(in_features=512, out_features=2320, bias=False)
      (conv1d): Conv1d(1280, 1280, kernel_size=(4,), stride=(1,), padding=(3,), groups=1280)
      (act): SiLU()
      (norm): RMSNorm()
      (out_proj): Linear(in_features=1024, out_features=512, bias=False)
    )
    (mamba_bwd): Mamba2(
      (in_proj): Linear(in_features=512, out_features=2320, bias=False)
      (conv1d): Conv1d(1280, 1280, kernel_size=(4,), stride=(1,), padding=(3,), groups=1280)
      (act): SiLU()
      (norm): RMSNorm()
    

In [9]:
# --- 3. LOAD THE SAVED WEIGHTS (THE .pth FILE) ---
model_path = "update_weights(2).pth"
print(f"Loading saved weights from {model_path}...")

state_dict = torch.load(model_path, map_location=device)
model.load_state_dict(state_dict)


model.eval()

print("\n--- ✅ Model loaded successfully! ---")

Loading saved weights from update_weights(2).pth...

--- ✅ Model loaded successfully! ---


### Get Results

In [14]:

import gc

# --- 1. Define a manageable batch size ---
# This is the "fix". We will only process this many cells at a time.
EVAL_BATCH_SIZE = 4096 

all_predictions = []
all_target_names = []

# --- 2. Setup Model ---
model.eval()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)


with torch.no_grad():
    pbar = tqdm(pert_counts.rows(named=True), total=len(pert_counts), desc="Generating predictions")
    
    # --- This loop is fine ---
    for row in pbar:
        pert_name = row['target_gene']
        num_cells = row['n_cells']
        pert_idx = perturbation_to_idx_map[pert_name]
        
        # -----------------------------------------------------------------
        # --- FIX: Inner loop to process in small batches ---
        # This new loop breaks the large 'num_cells' into small chunks
        # -----------------------------------------------------------------
        for start_idx in range(0, num_cells, EVAL_BATCH_SIZE):
            
            # Calculate the size of this specific (and small) batch
            end_idx = min(start_idx + EVAL_BATCH_SIZE, num_cells)
            current_batch_size = end_idx - start_idx
            
            # Create a *small* input tensor (e.g., [4096])
            input_tensor = torch.tensor([pert_idx] * current_batch_size, dtype=torch.long).to(device)
            
            # Run the model *only* on the small batch
            # This will NOT try to allocate 74.37 GiB
            with torch.amp.autocast('cuda'):
                if CONFIG["prediction_head"] == "linear":
                    predictions_batch = model(input_tensor)
                else:
                    mean_log, _ = model(input_tensor)
                    predictions_batch = torch.exp(mean_log) 

            # Move the small result to the CPU, freeing VRAM
            all_predictions.append(predictions_batch.cpu().numpy().squeeze(-1))
            all_target_names.extend([pert_name] * current_batch_size)
            
        # --- End of the new inner loop ---


# --- 5. Create the final AnnData object on the CPU ---
print("\nCreating final AnnData submission file...")

final_X = np.concatenate(all_predictions, axis=0)
final_X = final_X.astype(np.float32) 

final_obs = pd.DataFrame({'target_gene': all_target_names})
final_obs.index = final_obs.index.astype(str) 

final_var = pd.DataFrame(index=gene_names)
final_var.index.name = "gene_name" 

submission_adata = ad.AnnData(X=final_X, obs=final_obs, var=final_var)

# 6. Save the file
submission_adata.write_h5ad("my_submission.h5ad", compression="gzip")

print(f"\n✅ Submission file created: 'my_submission.h5ad'")
print(submission_adata)

Generating predictions:   0%|          | 0/100 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 74.37 GiB. GPU 0 has a total capacity of 23.55 GiB of which 23.13 GiB is free. Including non-PyTorch memory, this process has 420.00 MiB memory in use. Of the allocated memory 143.45 MiB is allocated by PyTorch, and 24.55 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [13]:
pert_counts_path = "vcc_data/pert_counts_Test.csv"
pert_counts = pl.read_csv(pert_counts_path)
gene_names_path = "vcc_data/gene_names.csv"
gene_names = pl.read_csv(gene_names_path, has_header=False).to_numpy().flatten()


all_predictions = []
all_target_names = []


with torch.no_grad():
    pbar = tqdm(pert_counts.rows(named=True), total=len(pert_counts), desc="Generating predictions")
    
    for row in pbar:
        pert_name = row['target_gene']
        num_cells = row['n_cells']
        
        # Find the correct integer ID for your model
        pert_idx = mapping[pert_name]
        
        # Create a batch of inputs (e.g., 200 copies of ID 1)
        input_tensor = torch.tensor([pert_idx] * num_cells, dtype=torch.long).to(device)
        

        with torch.amp.autocast('cuda'):
            if CONFIG["prediction_head"] == "linear":
                predictions = model(input_tensor)
            else:
                # For submission, we just want the "best guess" (the mean)
                mean_log, _ = model(input_tensor)
                predictions = torch.exp(mean_log) # Get mean from log-mean

        all_predictions.append(predictions.cpu().numpy().squeeze())
        all_target_names.extend([pert_name] * num_cells)

# --- 5. Create the final AnnData object ---

print("\nCreating final AnnData submission file...")

# Concatenate all predicted expression matrices
final_X = np.concatenate(all_predictions, axis=0)
final_X = final_X.astype(np.float32) # Rule 5: Must be float32

# Create the .obs dataframe (this is the correct structure)
final_obs = pd.DataFrame({'target_gene': all_target_names})
final_obs.index = final_obs.index.astype(str) # Make index strings

# Create the .var dataframe (must match training gene list)
final_var = pd.DataFrame(index=gene_names)
final_var.index.name = "gene_name" # Good practice

# Build the submission anndata
submission_adata = ad.AnnData(X=final_X, obs=final_obs, var=final_var)

# 9. Save the file
submission_adata.write_h5ad("my_submission.h5ad", compression="gzip")

print(f"\n✅ Submission file created: 'my_submission.h5ad'")
print(submission_adata)

Generating predictions:   0%|          | 0/100 [00:00<?, ?it/s]

NameError: name 'mapping' is not defined