# Bio-JEPA AC Explainer (NEEDS UPDATE)


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import copy
import math
import numpy as np
import glob
import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import random
from dataclasses import dataclass

In [2]:
torch.manual_seed(1337)
random.seed(1337)

In [3]:
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    print(f'using {device}')
    return device

DEVICE = get_device()

using cpu


In [4]:
data_dir = Path('/Users/djemec/data/jepa')
tok_dir = data_dir / 'tokenized'
mask_path = data_dir / 'binary_pathway_mask.npy'
metadata_path = data_dir / 'perturbation_map.json'
checkpoint_dir = data_dir / 'checkpoint'

## Text Prep/Tokenization

We previously loaded and saved our data into a tokenized files.  For a full explainer, look at the data_prep_explainer.ipynb notebook.   In this step we'll pull the data from the files to use for training.  First we'll start by getting all the files in the tokenized directory. 

In [6]:
split = 'train'
data_root = tok_dir / f'{split}'
shards = list(data_root.glob('*.npz'))
print(f'Found {len(shards)} shards.')

Found 29 shards.


Now that we have the files, we need just a single file. Because of how we tokanized our files, we do not have to load them in any particular order so we'll select a random file from the list.  In training, we'd want to remove this from the file list to ensure we go through the full epoch before restarting.   

In [8]:
random_file_idx = np.random.randint(len(shards))
file_path = shards[random_file_idx]
random_file_idx, file_path

(4, PosixPath('/Users/djemec/data/jepa/tokenized/train/shard_0005.npz'))

# Modeling

## Data Loading

To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to be held in memory all at once in real practice, we will read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 

Recall that we saved the data in the following structure:

```
{
    'control':[array of gene expression information],
    'control_total':[array of total expression counts],
    'case':[array of gene expression information],
    'case_total':[array of total expressions],
    'action_ids':[array of ids of the perturbation]
}
```

Because of this, taking the same index out of each array gives us the control, case and perturbation informaiton. For training we'll do more than 1 cell at a time.  We'll use a variable `BATCH_SIZE` to control how much data to train on at a time.  In this case we'll start with 2. During training we'll want a increase our batch size.  


In [22]:
BATCH_SIZE = 2
with np.load(file_path) as data:
    n_rows = data['action_ids'].shape[0]
    row_idx = np.random.randint(n_rows, size=BATCH_SIZE)
    
    control_x = torch.from_numpy(data['control'][row_idx].astype(np.float32)).to(DEVICE)
    control_tot = torch.from_numpy(data['control_total'][row_idx].astype(np.float32)).to(DEVICE)
    case_x = torch.from_numpy(data['case'][row_idx].astype(np.float32)).to(DEVICE)
    case_tot = torch.from_numpy(data['case_total'][row_idx].astype(np.float32)).to(DEVICE)
    action_ids = torch.from_numpy(data['action_ids'][row_idx].astype(np.int64)).to(DEVICE)


row_idx, control_x, control_tot, case_x, case_tot, action_ids

(array([2816, 4658]),
 tensor([[0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.0000],
         [0.0000, 0.0000, 0.0000,  ..., 0.0000, 0.0000, 0.4874]]),
 tensor([9.8303, 9.6755]),
 tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([8.9352, 8.2506]),
 tensor([1901,   55]))

## Forward pass

![network_impact_math](../resources/biojepa_v1.png)

Our current architecture has 3 models:  

1. **The Student** - starts from the control cell and learns how to create a latent embedding of the cell state. 
2. **The Predictor** - learns how to adjust the student's embedded cell state based on different perturbations to hallucinate the future diseased state in the latent space shared with the teacher. 
3. **The Teacher** - starts from the actual perturbed cell and creates a latent space embedding for the pertured state to grade the predictors output against.


We'll start by building the Pre-Norm Transformer Encoder block that are the basis for both the Teacher and Student models. 

### Teacher/Student Setup

**Multi-Headed Attention**

In [23]:
class BioMultiHeadAttention(nn.Module):
    # mirrors nn.MultiheadAttention(dim, heads, batch_first=True) 
    def __init__(self, config):
        super().__init__()
        self.config = config

        assert config.embed_dim % config.heads == 0
        
        self.head_dim = config.embed_dim // config.heads
        self.heads = config.heads
        
        # Projections
        self.q_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.k_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.v_proj = nn.Linear(config.embed_dim, config.embed_dim)
        
        self.c_proj = nn.Linear(config.embed_dim, config.embed_dim)

    def forward(self, x):
        B, T, C = x.size() # Batch, Seq, Embed Dim
        
        # 1. Project
        q = self.q_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(x).view(B, T, self.heads, self.head_dim).transpose(1, 2)

        # 2. Standard Scaled Dot Product Attention (Permutation Invariant)
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        
        # 5. Reassemble
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        
        return y

**MLP**

In [24]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.c_fc = nn.Linear(config.embed_dim, int(config.mlp_ratio * config.embed_dim))
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(int(config.mlp_ratio * config.embed_dim), config.embed_dim)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

**Hidden Transformer Block**

In [25]:
class CellStateBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.ln_1 = nn.LayerNorm(config.embed_dim)
        self.attn = BioMultiHeadAttention(config)
        self.ln_2 = nn.LayerNorm(config.embed_dim)
        self.mlp = MLP(config)

    def forward(self, x):
        # 1. Attention 
        x = x + self.attn(self.ln_1(x))

        # 2. MLP
        x = x + self.mlp(self.ln_2(x))
        return x

**Cell State Encoder**

In [26]:
@dataclass
class CellStateEncoderConfig:
    num_genes: int = 4096
    num_pathways: int = 1024 
    n_layer: int = 24 
    heads: int = 12
    embed_dim: int = 768
    mlp_ratio: float = 4.0 # Changed to float for precision
    mask_matrix: np.ndarray = None 

class CellStateEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # Learnable Network, initialized based on known pathwasy 1 == connection
        # wrapping as a "parameter" allows it to be learned
        assert config.mask_matrix is not None, 'Must provide binary_pathway_mask!'
        init_weights = torch.tensor(config.mask_matrix).float().T 
        self.pathway_weights = nn.Parameter(init_weights)
        
        # Learnable Gene Embeddings [num_genes, Dim]
        self.gene_embeddings = nn.Parameter(torch.randn(config.num_genes, config.embed_dim) * 0.02)
        
        # Context Injector
        self.total_count_proj = nn.Linear(1, config.embed_dim)

        # Transfomer
        self.blocks = nn.ModuleList([CellStateBlock(config) for _ in range(config.n_layer)])
        self.ln_f = nn.LayerNorm(config.embed_dim)

        # Initiation 
        self.apply(self._init_weights)


    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: 
                torch.nn.init.zeros_(module.bias)
        
    def forward(self, x_genes, x_total_ct):
        # 1. Project Genes
        x_genes = x_genes.unsqueeze(-1) 
        gene_repr = x_genes * self.gene_embeddings.unsqueeze(0)

        # 2. Gene Embeddings @ pathway weights
        x_pathway = self.pathway_weights @ gene_repr

        # 3. Context Injection
        x_total_ct = x_total_ct.unsqueeze(-1)
        x_total_ct = self.total_count_proj(x_total_ct)
        x_total_ct = x_total_ct.unsqueeze(1)
        x = x_pathway + x_total_ct

        # 4. Set Transformer
        for block in self.blocks:
            x = block(x)
        
        # 5. Layer Norm
        x = self.ln_f(x)

        return x

### Create The Student And Teacher Models

In [27]:
num_genes=4096
num_pathways=1024
n_layer=1
heads=4
embed_dim= 8 
mlp_ratio= 4

**Binary Mask**
Our student/teacher encoder models take in known gene network encodings to initiate the network embedding.  During tokenization we saved a one-hot encoded network.  We'll load this in as part of our initiation 

In [28]:
mask_matrix = np.load(mask_path)
N_GENES, N_PATHWAYS = mask_matrix.shape
print(f'Mask Loaded: {N_GENES} Genes -> {N_PATHWAYS} Pathways')

Mask Loaded: 4096 Genes -> 1024 Pathways


**Student Model** Now we'll create our hyperparameters and initiate our student model.  

In [29]:
enc_conf = CellStateEncoderConfig(
    num_genes=num_genes,
    num_pathways=num_pathways,
    n_layer=n_layer,
    heads=heads,
    embed_dim=embed_dim,
    mlp_ratio=mlp_ratio,
    mask_matrix=mask_matrix
)
student = CellStateEncoder(enc_conf) 
student

CellStateEncoder(
  (total_count_proj): Linear(in_features=1, out_features=8, bias=True)
  (blocks): ModuleList(
    (0): CellStateBlock(
      (ln_1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (attn): BioMultiHeadAttention(
        (q_proj): Linear(in_features=8, out_features=8, bias=True)
        (k_proj): Linear(in_features=8, out_features=8, bias=True)
        (v_proj): Linear(in_features=8, out_features=8, bias=True)
        (c_proj): Linear(in_features=8, out_features=8, bias=True)
      )
      (ln_2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (c_fc): Linear(in_features=8, out_features=32, bias=True)
        (gelu): GELU(approximate='tanh')
        (c_proj): Linear(in_features=32, out_features=8, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
)

**Teacher** Now let's make a copy for the teacher. Since the teacher does not require updates, we can update the model to not require gradient updates. 

In [30]:
teacher = copy.deepcopy(student)
teacher

CellStateEncoder(
  (total_count_proj): Linear(in_features=1, out_features=8, bias=True)
  (blocks): ModuleList(
    (0): CellStateBlock(
      (ln_1): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (attn): BioMultiHeadAttention(
        (q_proj): Linear(in_features=8, out_features=8, bias=True)
        (k_proj): Linear(in_features=8, out_features=8, bias=True)
        (v_proj): Linear(in_features=8, out_features=8, bias=True)
        (c_proj): Linear(in_features=8, out_features=8, bias=True)
      )
      (ln_2): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (c_fc): Linear(in_features=8, out_features=32, bias=True)
        (gelu): GELU(approximate='tanh')
        (c_proj): Linear(in_features=32, out_features=8, bias=True)
      )
    )
  )
  (ln_f): LayerNorm((8,), eps=1e-05, elementwise_affine=True)
)

In [31]:
for p in teacher.parameters():
    p.requires_grad = False

### Forward Pass: Teacher
Teacher encodes the Target (Case). We'll make sure to run it without gradient accumulation. You may wonder at this point how the teacher is updated. The teacher's update is actually based on a downweighted version of the student's gradietns.  

This teacher creates an embedding of the case cell for which the student + predictor try to match

In [32]:
with torch.no_grad():
    target_latents = teacher(case_x, case_tot)

### Forward Pass: Student
Student encodes the Context (Control), With the student we need to accumulate gradients since we need to be able to train the model. 
> TO DO: Add masking here if you want extra difficulty, but the task Control->Treated is enough for now.

In [34]:
context_latents = student(control_x, control_tot)
context_latents

tensor([[[-0.3264,  0.8426, -0.6142,  ..., -0.7733,  1.5373, -1.7331],
         [-0.4907,  0.5296, -0.3648,  ..., -0.8971,  1.5316, -1.6865],
         [-0.0982, -0.3098, -1.2579,  ..., -0.0050,  1.7568, -1.2275],
         ...,
         [-1.0808, -0.4855,  0.0171,  ...,  0.3515,  0.9369, -0.3291],
         [-0.6165,  0.7703, -0.0679,  ..., -0.3916,  1.8222, -1.8078],
         [-0.6518, -0.0421, -0.7640,  ..., -0.4835,  1.5784, -1.4388]],

        [[-0.2929,  0.9335, -0.3322,  ..., -0.6493,  1.7300, -1.8462],
         [-0.3199,  0.5634, -0.3114,  ..., -0.8531,  1.7809, -1.7569],
         [-0.1883, -0.3070, -1.4775,  ..., -0.1232,  1.9010, -0.8828],
         ...,
         [-0.8064,  0.1704,  0.8994,  ..., -0.6685,  0.8774, -0.9867],
         [-0.3169,  0.3114,  0.0949,  ..., -0.8969,  1.7992, -1.7457],
         [-0.3654, -0.2884, -0.8097,  ..., -0.6777,  1.8492, -1.2775]]],
       grad_fn=<NativeLayerNormBackward0>)

## Predictor Setup
We need to create a predictor that tries to guess Target (from the teacher) given Context + Action


**Adaptive Layer Normalization AdaLN**

In [35]:
class AdaLN(nn.Module):
    '''
    Adaptive Layer Norm for conditioning the predictor on action embeddings.
    The action vector regresses the Scale (gamma) and Shift (beta) of the normalization.
    '''
    def __init__(self, embed_dim, action_embed_dim):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.action_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(action_embed_dim, 2 * embed_dim)
        )
        # Initialize to identity (gamma=0, beta=0 originally, effectively gamma=1 after logic)
        # Zero-init the last layer so the action starts as a "no-op" (identity)
        nn.init.zeros_(self.action_mlp[1].weight)
        nn.init.zeros_(self.action_mlp[1].bias)

    def forward(self, x, action_emb):
        # x: [Batch, Seq, Dim]
        # action_emb: [Batch, action_embed_dim]
        
        # Project action to style: [B, 2*D] -> [B, 1, 2*D]
        style = self.action_mlp(action_emb).unsqueeze(1) 
        gamma, beta = style.chunk(2, dim=-1)
        
        # Apply affine transformation based on action
        return self.norm(x) * (1 + gamma) + beta

**Predictor Block**

In [36]:
class PredictorBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        # 1. Conditioning (AdaLN) replaces standard LayerNorm
        self.ada_ln1 = AdaLN(config.embed_dim, config.action_embed_dim)
        
        # 2. Attention (Using the shared BioMultiHeadAttention)
        self.attn = BioMultiHeadAttention(config)
        
        # 3. Conditioning (AdaLN) for the MLP block
        self.ada_ln2 = AdaLN(config.embed_dim, config.action_embed_dim)
        
        # 4. MLP (Using the shared MLP)
        self.mlp = MLP(config)

    def forward(self, x, action_emb):
        # 1. AdaLN -> Attention  -> Residual
        x_norm = self.ada_ln1(x, action_emb)
        x = x + self.attn(x_norm)
        
        # 2. AdaLN -> MLP -> Residual
        x_norm = self.ada_ln2(x, action_emb)
        x = x + self.mlp(x_norm)
        
        return x

**Main Predictor Model**

In [37]:
@dataclass
class ACPredictorConfig:
    num_pathways: int = 1024
    n_layer: int = 6 
    heads: int = 4
    embed_dim: int = 384
    action_embed_dim: int = 256 
    mlp_ratio: float = 4.0
    max_perturb: int = 2058 ## eventually try to get to a 2**N power

class ACPredictor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Action Embedding (Discrete ID -> Vector)
        self.action_embed = nn.Embedding(config.max_perturb, config.action_embed_dim)
        
        # Learnable Queries ("Mask Tokens") for the future state
        # One query vector per pathway position
        self.mask_queries = nn.Parameter(torch.randn(1, config.num_pathways, config.embed_dim) * 0.02)
        
        self.blocks = nn.ModuleList([
            PredictorBlock(config) for _ in range(config.n_layer)
        ])
        
        self.final_norm = AdaLN(config.embed_dim, config.action_embed_dim)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        if isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)

    def forward(self, context_latents, action_ids):
        """
        context_latents: [Batch, N, Dim] (From Student Encoder)
        action_ids: [Batch] (Ints)
        """
        B, N, D = context_latents.shape
        
        # 1. Embed Action
        action_emb = self.action_embed(action_ids) # [B, action_embed_dim]
        
        # 2. Construct Input: [Context, Mask_Queries]
        # We concatenate the learned queries to the context. 
        # The predictor will attend to the context to update the queries.
        queries = self.mask_queries.repeat(B, 1, 1) # [B, N, D]
        sequence = torch.cat([context_latents, queries], dim=1) # [B, 2N, D]     
        
        # 3. Pass through AdaLN Blocks
        for block in self.blocks:
            sequence = block(sequence, action_emb)
            
        sequence = self.final_norm(sequence, action_emb)
        
        # 4. Return only the predicted part (The Queries corresponding to N..2N)
        predictions = sequence[:, N:, :] 
        return predictions

### Create the Predictor Model

In [40]:
pred_conf = ACPredictorConfig(
    num_pathways=num_pathways,
    n_layer=n_layer,
    heads=heads,
    embed_dim=embed_dim,
    action_embed_dim=embed_dim,
    mlp_ratio=mlp_ratio
)

In [41]:
predictor = ACPredictor(pred_conf)
predictor

ACPredictor(
  (action_embed): Embedding(2058, 8)
  (blocks): ModuleList(
    (0): PredictorBlock(
      (ada_ln1): AdaLN(
        (norm): LayerNorm((8,), eps=1e-05, elementwise_affine=False)
        (action_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=8, out_features=16, bias=True)
        )
      )
      (attn): BioMultiHeadAttention(
        (q_proj): Linear(in_features=8, out_features=8, bias=True)
        (k_proj): Linear(in_features=8, out_features=8, bias=True)
        (v_proj): Linear(in_features=8, out_features=8, bias=True)
        (c_proj): Linear(in_features=8, out_features=8, bias=True)
      )
      (ada_ln2): AdaLN(
        (norm): LayerNorm((8,), eps=1e-05, elementwise_affine=False)
        (action_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=8, out_features=16, bias=True)
        )
      )
      (mlp): MLP(
        (c_fc): Linear(in_features=8, out_features=32, bias=True)
        (gelu): GELU(approximate='tanh')

### Forward Pass: Action Predictor
The Action predictor tries to guess Target encoding given Context + Action.  The context is the ouput of the Student and the action is from our tokenized data.  The goal here is tha the output attempts to match the teacher. 

In [43]:
action_ids

tensor([1901,   55])

In [45]:
predicted_latents = predictor(context_latents, action_ids)
predicted_latents

tensor([[[ 0.0078,  1.6829, -1.1367,  ..., -1.0413,  0.1867,  1.3981],
         [ 0.1166,  1.6274, -1.1412,  ..., -1.1031,  0.1151,  1.4478],
         [ 0.2033,  1.6339, -1.0711,  ..., -1.2691,  0.2119,  1.3527],
         ...,
         [ 0.1031,  1.3992, -1.1596,  ..., -1.2680,  0.4441,  1.4944],
         [ 0.2364,  1.3741, -1.1901,  ..., -1.3242,  0.2741,  1.5036],
         [ 0.1904,  1.5286, -1.2348,  ..., -1.1161,  0.2004,  1.4497]],

        [[-0.0266,  1.6780, -1.1313,  ..., -1.0113,  0.1162,  1.4369],
         [ 0.0838,  1.6175, -1.1394,  ..., -1.0765,  0.0527,  1.4874],
         [ 0.1740,  1.6284, -1.0700,  ..., -1.2384,  0.1352,  1.4009],
         ...,
         [ 0.0648,  1.3932, -1.1603,  ..., -1.2432,  0.3783,  1.5407],
         [ 0.1993,  1.3728, -1.1908,  ..., -1.2878,  0.1942,  1.5496],
         [ 0.1578,  1.5301, -1.2293,  ..., -1.0842,  0.1256,  1.4852]]],
       grad_fn=<SliceBackward0>)

# Loss Calculation

We'll use L1 Loss (Mean Absolute Error) which measures the average absolute difference between predicted and actual values, making it robust to outliers.  We'll call this latent loss since we'll  use an L1 loss function within the latent space of a model to ensure the compressed representation (latent) accurately reflects the original data. 

In [46]:
loss = F.l1_loss(predicted_latents, target_latents)
loss

tensor(1.1376, grad_fn=<MeanBackward0>)

# Back Propogation

In [47]:
loss.backward()

## Gradient Clipping

In [48]:
torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)

tensor(1.4932)

## Update the Teacher

Recall that our teacher was created from a copy of our student. This means the weights are initiated randomly. We also didn't include any backprop for our teacher as since our loss is calculated using the action predictor and the student model.  To update the teacher we'll use a mechanism called  **Exponential Moving Average (EMA)**

In simple terms, this mechanism forces the Teacher to evolve very slowly, becoming a stable, wiser version of the Student.

Here is the breakdown of the mechanics. Mathematically, it calculates:

$Teacher_{new} = (m * Teacher_{old})+((1−m)*Student_{current})$

With a typical momentum ($m$) of 0.996, the code tells the Teacher: *"Keep 99.6% of your current weights, and mix in just 0.4% of what the Student learned in this last step."*

The reason we do this is Stability. The Student network is updating rapidly and noisily using Gradient Descent ($Param←Param−LearningRate×Gradient$). If we simply copied the Student to the Teacher (`Teacher = Student`), the target we are trying to predict would move erratically every single step, making it impossible to learn (like trying to hit a bullseye that is vibrating wildly).

By updating the Teacher this way, it effectively represents the **average** of the Student's weights over the last several thousand steps. This smooths out the noise and provides a steady, high-quality target for the Student to predict, preventing "Model Collapse" (where the model cheats by outputting zeros for everything).

TO:DO  we could also replace the teacher with a pretrained model and use this as an update mechanism. 

In [49]:
m=0.996
for param_s, param_t in zip(student.parameters(), teacher.parameters()):
    param_t.data.mul_(m).add_((1 - m) * param_s.data)

## Inference
Now we've done a walkthrough of the forward pass of the model let's see how we can run an inference test since this is an encoder model. We'll reload our tokenized data and perturbation map so that we can convert our pertubation to text and sample from our dataset. 

In [51]:
val_dir = tok_dir / 'val'
gene_names_path = data_dir / 'gene_names.json'
gene_names_path, metadata_path, val_dir

(PosixPath('/Users/djemec/data/jepa/gene_names.json'),
 PosixPath('/Users/djemec/data/jepa/perturbation_map.json'),
 PosixPath('/Users/djemec/data/jepa/tokenized/val'))

In [68]:
# Load Perturbation Map (ID -> Name)
with open(metadata_path, "r") as f:
    pert_map = json.load(f)
id_to_pert = {v: k for k, v in pert_map.items()}

# Load Gene Names
with open(data_dir / 'gene_names.json', "r") as f:
    gene_names = json.load(f)

**Sample Data**
We'll build a modification of our data loader which pulls a random file and a random perturbed cell example from it.  

In [69]:
def get_random_test_pair(shard_dir):
    '''Grab a single real pair from a random shard with Context'''
    files = sorted(shard_dir.glob('*.npz'))
        
    file_path = files[np.random.randint(len(files))]
    
    with np.load(file_path) as data:
        idx = np.random.randint(data['action_ids'].shape[0])
        
        # Extract Items (Now 5 items instead of 3)
        c_raw = data['control'][idx]         # [2000]
        ct_raw = data['control_total'][idx]  # Scalar
        case_raw = data['case'][idx]         # [2000]
        caset_raw = data['case_total'][idx]  # Scalar
        act_id = data['action_ids'][idx]     # Scalar
        
    # Convert to Tensor & Add Batch Dim [1, ...]
    xc = torch.tensor(c_raw).float().unsqueeze(0).to(DEVICE)
    xct = torch.tensor(ct_raw).float().unsqueeze(0).to(DEVICE)
    xt = torch.tensor(case_raw).float().unsqueeze(0).to(DEVICE)
    xtt = torch.tensor(caset_raw).float().unsqueeze(0).to(DEVICE)
    aid = torch.tensor([act_id]).long().to(DEVICE)
    
    return xc, xct, xt, xtt, aid

In [77]:
control_x, control_tot, case_x, case_tot, action_id = get_random_test_pair(val_dir)
pert_name = id_to_pert[action_id.item()]
pert_name

'ACD'

In [79]:
control_x.shape

torch.Size([1, 4096])

**Get Baseline**

To start, we'll take our `teacher` model to create our baseline expected embedding for both our control and case examples.  These create our "expected" values 

In [71]:
with torch.no_grad():
    z_control = teacher(control_x, control_tot)       # Where the cell started
    z_case = teacher(case_x, case_tot)     # Where the cell actually went

**Get Predicted**

Now we need to run our predicted model. This requires first running our control (baseline cell) through our `student`, then use the output with our `predictor` along with the perturbation to see how good our predictor is. 

In [82]:
# 3. Run The Physics Engine (Predictor)
# Student encodes context -> Predictor adds Action -> Output
with torch.no_grad():
    z_context = student(control_x, control_tot)
    z_predicted = predictor(z_context, action_id) # Where the model thinks it went

### Metrics 
Now that we have our predicted latents and expected latents we can evaluate our prediction.  What we'll do is evaluate the change in the real control space, use the geneXnetwork matrix, and derive the most perturbed genes based on the action. 

*Note that currently our model is untrained so this will be BAD*


First we'll start by calculating how much we saw the latent space change based on the perturbation. This means subtracting the predictor output from the teacher's control space. We use the teacher as the predictor output and teacher outputs should live in the same latent space so this delta should signal how much of a change the action caused.  

In [100]:
# Calculate Latent Movement
delta_latent = (z_predicted - z_control).squeeze(0) # [1024, 384]
delta_latent

tensor([[ 3.5257e-01,  8.6591e-01, -8.3257e-01,  ..., -7.6716e-02,
         -1.4727e+00,  3.1437e+00],
        [ 3.0687e-01,  1.3582e+00, -8.7432e-01,  ...,  1.2150e-01,
         -1.6244e+00,  3.0174e+00],
        [ 4.4986e-01,  2.1124e+00,  9.8932e-04,  ..., -8.9967e-01,
         -1.9176e+00,  2.4622e+00],
        ...,
        [ 1.4239e+00,  1.9808e+00, -1.7845e+00,  ..., -1.0847e+00,
         -1.1714e+00,  1.8653e+00],
        [ 7.5283e-01,  7.1341e-01, -1.0429e+00,  ..., -6.0725e-01,
         -1.6332e+00,  3.2135e+00],
        [ 5.4587e-01,  1.6708e+00, -7.4546e-01,  ..., -5.5125e-02,
         -1.6411e+00,  2.8595e+00]])

**Pathway Energies**

Next we want to score every gene based on how much of a shift the parent pathway had.  Our current latent space is the `[pathway,embedding]` dimension. we'll normalize our embeddings down to a single value per pathway, then use the learned weights of the `[gene,pathway]` to project the pathway impacts back out to the gene. This results in `Gene_Score = Sum(Mask_gp * Magnitude(Pathway_p))`  With this, we can determine things like if a gene is in 5 pathways that all moved violently, that gene is highly implicated. 

*Note that this is a very rudamentary way of validating.  full validation would require a trained head*

First we'll normalize our latents down to a single value per network. 

In [102]:
pathway_magnitudes = delta_latent.norm(dim=1) 
pathway_magnitudes

tensor([3.9492, 4.1390, 4.2119,  ..., 4.1820, 4.0623, 4.2419])

**Network X Gene**

Now we'll extract our pathway weights that are the learned gene to network mapping. We take the absolute value since a strong negative weight implies just as much biological "impact" for our testing as a positive value does.  

In [103]:
# 1. Access the Learnable Weights
weights = student.pathway_weights.detach()

# 2. Transpose back to [Genes, Pathways] for the projection
routing_matrix = weights.T.abs() # Shape: [2000, 1024]
routing_matrix

tensor([[0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        ...,
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.],
        [0., 0., 0.,  ..., 0., 0., 0.]])

**Matrix-vector product**
Now we're ready to performs the matrix-vector product to expand our network impacts back to the genes in them. This will output a single vector for each gene based on the different networks it's been tied to. 

In [104]:
gene_impact_scores = torch.mv(routing_matrix, pathway_magnitudes)
gene_impact_scores

tensor([  0.0000,   7.0570, 126.2627,  ...,   0.0000,   0.0000,   0.0000])

### Validate top K genes impacted
Now that we have our gene values, we need a way to compare this to our teacher.  In this case we'll take top K values and convert the index to a gene to signfiy it's been impacted. We'll then do the same on the teacher space and see if we got an overlap. 

Let's start by extracting the top gene impact scores and indexes

In [106]:
top_k = 100
pred_values, pred_indices = torch.topk(gene_impact_scores, top_k)

Now we'll convert those indeces into gene names. 

In [108]:
print('Model Predicted')
for val, idx in zip(pred_values, pred_indices):
    g_name = gene_names[idx.item()]
    print(f'{g_name:<10} (Score: {val.item():.4f})')

Model Predicted
CASP3      (Score: 1732.7908)
VEGFA      (Score: 1525.5117)
AURKB      (Score: 1229.8809)
SLC7A1     (Score: 1091.5343)
MYC        (Score: 1081.3484)
ICAM1      (Score: 969.7882)
SOD1       (Score: 918.0603)
LDHA       (Score: 914.0148)
NR3C1      (Score: 906.3495)
CYCS       (Score: 887.7327)
PRDX2      (Score: 885.9540)
JUN        (Score: 880.2817)
MCL1       (Score: 874.4275)
CDT1       (Score: 858.1977)
BCL2L1     (Score: 840.5076)
NFKB1      (Score: 834.9631)
NFE2L2     (Score: 817.7412)
STAT3      (Score: 815.0062)
PCNA       (Score: 799.4529)
CASP8      (Score: 792.2278)
NFKBIA     (Score: 763.6768)
MAPK1      (Score: 737.3122)
DDIT3      (Score: 728.9926)
POLD4      (Score: 722.5864)
DECR1      (Score: 715.1577)
TNFRSF10B  (Score: 684.7576)
EGR1       (Score: 674.3992)
PMAIP1     (Score: 665.2002)
HSPA1A     (Score: 664.4025)
GADD45A    (Score: 662.5065)
TAOK1      (Score: 662.2225)
SAT1       (Score: 662.1367)
MTR        (Score: 652.7802)
GDF15      (Score: 624

**Ground Truth Comparison**
Now we need to compare this list to our ground truth.  For the ground truth we can just compare our raw input variables `case_x, control_x` and pick the top K genes. 

Let's start by calculating the difference in expression

In [109]:
real_gene_delta = (case_x - control_x).abs().squeeze(0) 
real_gene_delta

tensor([0.0000, 0.5572, 0.0000,  ..., 0.5572, 0.0000, 0.0000])

Now we'll again extract the top K genes. 

In [111]:
real_values, real_indices = torch.topk(real_gene_delta, top_k)

**Comparison of predicted and real**
And now we'll print out the genes and, if they match, add a checkmark. 

In [112]:
print(f'Ground Truth')
for val, idx in zip(real_values, real_indices):
    g_name = gene_names[idx.item()]
    
    # Check if our model also listed this gene in its top 10
    match_icon = '✅' if idx in pred_indices else '❌'

    print(f'{match_icon} {g_name:<10} (Delta: {val.item():.4f})')

Ground Truth
❌ PFDN2      (Delta: 2.1350)
❌ CMTM6      (Delta: 2.1350)
❌ HBG1       (Delta: 2.0921)
❌ CKB        (Delta: 2.0921)
❌ KIF2C      (Delta: 2.0427)
❌ UBC        (Delta: 2.0188)
❌ HLTF       (Delta: 1.9583)
❌ DNAJA1     (Delta: 1.9583)
❌ ARF1       (Delta: 1.9410)
❌ KDM5B      (Delta: 1.9410)
❌ RAP1B      (Delta: 1.8593)
❌ BTG2       (Delta: 1.8278)
❌ ARL6IP1    (Delta: 1.8278)
❌ TPP2       (Delta: 1.8038)
❌ AC079466.1 (Delta: 1.8038)
❌ SPDL1      (Delta: 1.8038)
❌ CBLL1      (Delta: 1.8038)
❌ NRDC       (Delta: 1.8038)
❌ POP7       (Delta: 1.8038)
❌ HDLBP      (Delta: 1.8038)
❌ ACTG1      (Delta: 1.7579)
❌ HDGF       (Delta: 1.7001)
❌ ETF1       (Delta: 1.7001)
❌ TRIB2      (Delta: 1.7001)
❌ PTPN1      (Delta: 1.7001)
❌ SDF4       (Delta: 1.7001)
❌ PSAP       (Delta: 1.7001)
❌ NCL        (Delta: 1.6382)
❌ NDUFB6     (Delta: 1.6210)
❌ ECI1       (Delta: 1.6210)
❌ LMO2       (Delta: 1.6210)
❌ NDUFC1     (Delta: 1.6210)
❌ LRPPRC     (Delta: 1.6210)
❌ RHOBTB3    (Delta: 1.6210)
❌

**Precition Calculation**
Now let's calcualte the precision which is a simple correct / top K  

In [113]:
# --- 6. Precision@10 Metric ---
# How many of the Top 10 real genes were in the Top 10 predicted genes?
# (Set intersection)
set_pred = set(pred_indices.tolist())
set_real = set(real_indices.tolist())
overlap = len(set_pred.intersection(set_real))

print(f'Precision @ {top_k} | {overlap/top_k:.4f}')

Precision @ 100 | 0.0200


## Conclusion
As you can see, untrained we have quite a lot of error.  Also, we're juggling 3 models and our predictor relies on a chain of two models meaning that overall we'll need a lot of training to start learning the cell physics. 

Also, looking at how this works we can see that we have a limitation on our training data currently as we're learning genes and networks. 