# Bio-JEPA AC Explainer


In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import copy
import math
import numpy as np
import glob
import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import random
from dataclasses import dataclass

In [2]:
torch.manual_seed(1337)
random.seed(1337)

In [3]:
data_dir = Path('/Users/djemec/data/jepa')
shard_dir = data_dir / 'tokenized'
metadata_path = data_dir / 'perturbation_map.json'
checkpoint_dir = data_dir / 'checkpoint'

## Text Prep/Tokenization

We previously loaded and saved our data into a tokenized files.  For a full explainer, look at the data_prep_explainer.ipynb notebook.   In this step we'll pull the data from the files to use for training.  First we'll start by getting all the files in the tokenized directory. 

In [4]:
data_files = sorted(shard_dir.glob('*.npz'))
print(f'Found {len(data_files)} shards.')

Found 30 shards.


Now that we have the files, we need just a single file. Because of how we tokanized our files, we do not have to load them in any particular order so we'll select a random file from the list.  In training, we'd want to remove this from the file list to ensure we go through the full epoch before restarting.   

In [5]:
random_file_idx = np.random.randint(len(data_files))
file_path = data_files[random_file_idx]
random_file_idx, file_path

(15, PosixPath('/Users/djemec/data/jepa/tokenized/shard_0015.npz'))

# Modeling

## Data Loading


To start, we need to get enough data to run the forward and backward passes.  Since our total dataset is likely too big to be held in memory all at once in real practice, we will read just enough file information into memory so that we can run the passes, leaving memory and compute to be used on the passes instead of static data holding. 

Recall that we saved the data in the following structure:

```
{
    'control':[array of network expression information],
    'case':[array of network expression information],
    'action_ids':[array of ids of the perturbation]
}
```

Because of this, taking the same index out of each array gives us our action_id, pertrubed data, and control data. For training we'll do more than 1 cell at a time.  We'll use a variable `BATCH_SIZE` to control how much data to train on at a time.  In this case we'll start with 2. During training we'll want a increase our batch size.  


In [6]:
BATCH_SIZE = 2
with np.load(file_path) as data:
    # Keys: 'control', 'treated', 'action_ids'
    n_rows = data['action_ids'].shape[0]
    row_idx = np.random.randint(n_rows, size=BATCH_SIZE)
    
    control_raw = data['control'][row_idx]
    case_raw = data['case'][row_idx]
    act_id = data['action_ids'][row_idx]

x_control = torch.tensor(control_raw)
x_case = torch.tensor(case_raw)
action_id = torch.tensor(act_id, dtype=torch.long)

row_idx, action_id, x_control, x_case

(array([1391, 2491]),
 tensor([1045,  802]),
 tensor([[0.7370, 0.4793, 0.5609,  ..., 0.4051, 0.2653, 0.4966],
         [0.8382, 0.6693, 0.5777,  ..., 0.4235, 0.3463, 0.5699]]),
 tensor([[0.6636, 0.5720, 0.5646,  ..., 0.2728, 0.2980, 0.5147],
         [0.7203, 0.6128, 0.6233,  ..., 0.3658, 0.2489, 0.5796]]))

## Forward pass

![network_impact_math](../resources/biojepa_v1.png)

Our current architecture has 3 models:  

1. **The Student** - starts from the control cell and learns how to create a latent embedding of the cell state. 
2. **The Predictor** - learns how to adjust the student's embedded cell state based on different perturbations to hallucinate the future diseased state in the latent space shared with the teacher. 
3. **The Teacher** - starts from the actual perturbed cell and creates a latent space embedding for the pertured state to grade the predictors output against.


We'll start by building the Pre-Norm Transformer Encoder block with Rotary Positional Embeddings (RoPE) that are the basis for both the Teacher and Student models. 

### Teacher/Student Setup
**RoPE**

In [7]:
class RotaryEmbedding(nn.Module):
    def __init__(self, dim, max_seq_len=2048):
        super().__init__()
        # Ensure dim is the head dimension, not the full embedding dimension
        inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
        t = torch.arange(max_seq_len).type_as(inv_freq)
        freqs = torch.outer(t, inv_freq)
        emb = torch.cat((freqs, freqs), dim=-1)
        self.register_buffer('emb', emb)

    def forward(self, x):
        # x: [Batch, Seq, Dim]
        # We only care about the sequence length here
        n = x.shape[1]
        # Returns [Seq, Head_Dim]
        return self.emb[:n, :].cos(), self.emb[:n, :].sin()

def apply_rotary_pos_emb(q, k, cos, sin):
    # q, k: [Batch, Heads, Seq, Head_Dim]
    # cos, sin: [Seq, Head_Dim] -> reshape to [1, 1, Seq, Head_Dim]
    
    cos = cos.unsqueeze(0).unsqueeze(0)
    sin = sin.unsqueeze(0).unsqueeze(0)

    # Standard RoPE rotation logic
    # split last dim into half
    q_d = q.shape[-1] // 2
    k_d = k.shape[-1] // 2
    
    q1, q2 = q[..., :q_d], q[..., q_d:]
    k1, k2 = k[..., :k_d], k[..., k_d:]
    
    q_rotated = torch.cat((-q2, q1), dim=-1)
    k_rotated = torch.cat((-k2, k1), dim=-1)
    
    q_out = (q * cos) + (q_rotated * sin)
    k_out = (k * cos) + (k_rotated * sin)
    
    return q_out, k_out

**Multi-Headed Attention**

In [8]:
class BioMultiHeadAttention(nn.Module):
    # mirrors nn.MultiheadAttention(dim, heads, batch_first=True) 
    def __init__(self, config):
        super().__init__()
        assert config.embed_dim % config.heads == 0
        
        self.head_dim = config.embed_dim // config.heads
        self.heads = config.heads
        self.embed_dim = config.embed_dim
        
        # Projections
        self.q_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.k_proj = nn.Linear(config.embed_dim, config.embed_dim)
        self.v_proj = nn.Linear(config.embed_dim, config.embed_dim)
        
        self.c_proj = nn.Linear(config.embed_dim, config.embed_dim)

    def forward(self, x, cos, sin):
        B, T, C = x.size() # Batch, Seq, Embed Dim
        
        # 1. Project
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # 2. Reshape for multi-head attention
        # (B, T, nh, hs) -> (B, nh, T, hs)
        q = q.view(B, T, self.heads, self.head_dim).transpose(1, 2)
        k = k.view(B, T, self.heads, self.head_dim).transpose(1, 2)
        v = v.view(B, T, self.heads, self.head_dim).transpose(1, 2)

        # 3. Apply RoPE to Q and K (Rotary is applied per head)
        q, k = apply_rotary_pos_emb(q, k, cos, sin)

        # 4. Attention
        # is_causal=False because this is a bidirectional encoder
        y = F.scaled_dot_product_attention(q, k, v, is_causal=False)
        
        # 5. Reassemble
        y = y.transpose(1, 2).contiguous().view(B, T, C)
        y = self.c_proj(y)
        
        return y

**MLP**

In [9]:
class MLP(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.c_fc = nn.Linear(config.embed_dim, int(config.mlp_ratio * config.embed_dim))
        self.gelu = nn.GELU(approximate='tanh')
        self.c_proj = nn.Linear(int(config.mlp_ratio * config.embed_dim), config.embed_dim)

    def forward(self, x):
        x = self.c_fc(x)
        x = self.gelu(x)
        x = self.c_proj(x)
        return x

**Hidden Transformer Block**

In [10]:
class PathwayBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.embed_dim)
        self.attn = BioMultiHeadAttention(config)
        self.ln_2 = nn.LayerNorm(config.embed_dim)
        self.mlp = MLP(config)

    def forward(self, x, cos, sin):
        # 1. Attention with RoPE
        x_norm = self.ln_1(x)
        # Pass cos/sin into attn to be applied to Q/K
        attn_out = self.attn(x_norm, cos, sin)
        x = x + attn_out

        # 2. MLP
        x = x + self.mlp(self.ln_2(x))
        return x

**Main Pathway Model**

In [11]:
@dataclass
class PathwayEncoderConfig:
    num_pathways: int = 1024 
    n_layer: int = 24 
    heads: int = 12
    embed_dim: int = 768
    mlp_ratio: float = 4.0 # Changed to float for precision

class PathwayEncoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config

        self.transformer = nn.ModuleDict(dict(
            # input projects to embedding
            input_proj = nn.Linear(1, config.embed_dim),
            
            # RoPE needs the HEAD dimension, not the full embedding dimension
            rope = RotaryEmbedding(config.embed_dim // config.heads),
            
            # transformer block
            blocks = nn.ModuleList([PathwayBlock(config) for _ in range(config.n_layer)]),
            
            # final layer norm
            ln_f = nn.LayerNorm(config.embed_dim) 
        ))

        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
        
    def forward(self, x):
        # x: [Batch, Num_Pathways]
        x = x.unsqueeze(-1) # [B, N, 1]
        x = self.transformer.input_proj(x) # [B, N, Dim]
        
        # Generate RoPE cache
        # cos, sin are [Seq, Head_Dim]
        cos, sin = self.transformer.rope(x)
        cos, sin = cos.to(x.device), sin.to(x.device)
        
        for block in self.transformer.blocks:
            x = block(x, cos, sin)
            
        x = self.transformer.ln_f(x)

        return x

### Create The Student And Teacher Models

In [12]:
n_embd = 4
n_heads = 2
n_pathways = 1024
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [13]:
student = PathwayEncoder(PathwayEncoderConfig(n_layer= 1, heads=n_heads, embed_dim= n_embd))
student

PathwayEncoder(
  (transformer): ModuleDict(
    (input_proj): Linear(in_features=1, out_features=4, bias=True)
    (rope): RotaryEmbedding()
    (blocks): ModuleList(
      (0): PathwayBlock(
        (ln_1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
        (attn): BioMultiHeadAttention(
          (q_proj): Linear(in_features=4, out_features=4, bias=True)
          (k_proj): Linear(in_features=4, out_features=4, bias=True)
          (v_proj): Linear(in_features=4, out_features=4, bias=True)
          (c_proj): Linear(in_features=4, out_features=4, bias=True)
        )
        (ln_2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=4, out_features=16, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=16, out_features=4, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  )
)

**Teacher** Now let's make a copy for the teacher. Since the teacher does not require updates, we can update the model to not require gradient updates. 

In [14]:
teacher = copy.deepcopy(student)
teacher

PathwayEncoder(
  (transformer): ModuleDict(
    (input_proj): Linear(in_features=1, out_features=4, bias=True)
    (rope): RotaryEmbedding()
    (blocks): ModuleList(
      (0): PathwayBlock(
        (ln_1): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
        (attn): BioMultiHeadAttention(
          (q_proj): Linear(in_features=4, out_features=4, bias=True)
          (k_proj): Linear(in_features=4, out_features=4, bias=True)
          (v_proj): Linear(in_features=4, out_features=4, bias=True)
          (c_proj): Linear(in_features=4, out_features=4, bias=True)
        )
        (ln_2): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
        (mlp): MLP(
          (c_fc): Linear(in_features=4, out_features=16, bias=True)
          (gelu): GELU(approximate='tanh')
          (c_proj): Linear(in_features=16, out_features=4, bias=True)
        )
      )
    )
    (ln_f): LayerNorm((4,), eps=1e-05, elementwise_affine=True)
  )
)

In [15]:
for p in teacher.parameters():
    p.requires_grad = False

### Forward Pass: Teacher
Teacher encodes the Target (Case). We'll make sure to run it without gradient accumulation.  This teacher creates an embedding of the case cell

In [16]:
with torch.no_grad():
    target_latents = teacher(x_case)
target_latents

tensor([[[ 0.7027,  0.8507, -1.1759, -0.3774],
         [ 0.6592,  0.8491, -1.0908, -0.4175],
         [ 0.6550,  0.8488, -1.0831, -0.4207],
         ...,
         [ 0.3765,  0.7773, -0.6491, -0.5047],
         [ 0.4124,  0.7902, -0.6988, -0.5038],
         [ 0.6244,  0.8451, -1.0278, -0.4418]],

        [[ 0.7226,  0.8514, -1.2200, -0.3541],
         [ 0.6788,  0.8525, -1.1302, -0.4011],
         [ 0.6838,  0.8526, -1.1400, -0.3964],
         ...,
         [ 0.4938,  0.8193, -0.8186, -0.4945],
         [ 0.3378,  0.7663, -0.5985, -0.5057],
         [ 0.6618,  0.8516, -1.0978, -0.4156]]])

### Forward Pass: Student
Student encodes the Context (Control), With the student we need to accumulate gradients since we need to be able to train the model. 
> TO DO: Add masking here if you want extra difficulty, but the task Control->Treated is enough for now.

In [17]:
context_latents = student(x_control)
context_latents

tensor([[[ 0.7284,  0.8506, -1.2322, -0.3467],
         [ 0.5979,  0.8431, -0.9838, -0.4573],
         [ 0.6517,  0.8504, -1.0786, -0.4235],
         ...,
         [ 0.5346,  0.8296, -0.8807, -0.4835],
         [ 0.3636,  0.7756, -0.6330, -0.5063],
         [ 0.6106,  0.8452, -1.0054, -0.4504]],

        [[ 0.7536,  0.8499, -1.2958, -0.3077],
         [ 0.7005,  0.8572, -1.1785, -0.3792],
         [ 0.6573,  0.8565, -1.0942, -0.4197],
         ...,
         [ 0.5476,  0.8399, -0.9058, -0.4816],
         [ 0.4675,  0.8191, -0.7837, -0.5030],
         [ 0.6530,  0.8562, -1.0861, -0.4231]]],
       grad_fn=<NativeLayerNormBackward0>)

## Predictor Setup
We need to create a predictor that tries to guess Target (from the teacher) given Context + Action


**Adaptive Layer Normalization AdaLN**

In [18]:
class AdaLN(nn.Module):
    '''
    Adaptive Layer Norm for conditioning the predictor on action embeddings.
    The action vector regresses the Scale (gamma) and Shift (beta) of the normalization.
    '''
    def __init__(self, embed_dim, action_embed_dim):
        super().__init__()
        self.norm = nn.LayerNorm(embed_dim, elementwise_affine=False)
        self.action_mlp = nn.Sequential(
            nn.SiLU(),
            nn.Linear(action_embed_dim, 2 * embed_dim)
        )
        # Initialize to identity (gamma=0, beta=0 originally, effectively gamma=1 after logic)
        # Zero-init the last layer so the action starts as a "no-op" (identity)
        nn.init.zeros_(self.action_mlp[1].weight)
        nn.init.zeros_(self.action_mlp[1].bias)

    def forward(self, x, action_emb):
        # x: [Batch, Seq, Dim]
        # action_emb: [Batch, action_embed_dim]
        
        # Project action to style: [B, 2*D] -> [B, 1, 2*D]
        style = self.action_mlp(action_emb).unsqueeze(1) 
        gamma, beta = style.chunk(2, dim=-1)
        
        # Apply affine transformation based on action
        return self.norm(x) * (1 + gamma) + beta

**Predictor Block**

In [19]:
class PredictorBlock(nn.Module):
    def __init__(self, config):
        super().__init__()
        # 1. Conditioning (AdaLN) replaces standard LayerNorm
        self.ada_ln1 = AdaLN(config.embed_dim, config.action_embed_dim)
        
        # 2. Attention (Using the shared BioMultiHeadAttention)
        self.attn = BioMultiHeadAttention(config)
        
        # 3. Conditioning (AdaLN) for the MLP block
        self.ada_ln2 = AdaLN(config.embed_dim, config.action_embed_dim)
        
        # 4. MLP (Using the shared MLP)
        self.mlp = MLP(config)

    def forward(self, x, action_emb, cos, sin):
        # 1. AdaLN -> Attention (with internal RoPE) -> Residual
        x_norm = self.ada_ln1(x, action_emb)
        
        # Note: BioMultiHeadAttention handles q/k/v projection and apply_rotary_pos_emb internally
        attn_out = self.attn(x_norm, cos, sin)
        x = x + attn_out
        
        # 2. AdaLN -> MLP -> Residual
        x_norm = self.ada_ln2(x, action_emb)
        x = x + self.mlp(x_norm)
        
        return x

**Main Predictor Model**

In [20]:
@dataclass
class ACPredictorConfig:
    num_pathways: int = 1024
    n_layer: int = 6 
    heads: int = 4
    embed_dim: int = 384
    action_embed_dim: int=256 
    mlp_ratio: float = 4.0
    max_perturb: int= 2058 ## eventually try to get to a 2**N power

class ACPredictor(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        
        # Action Embedding (Discrete ID -> Vector)
        self.action_embed = nn.Embedding(config.max_perturb, config.action_embed_dim)
        
        # Learnable Queries ("Mask Tokens") for the future state
        # One query vector per pathway position
        self.mask_queries = nn.Parameter(torch.randn(1, config.num_pathways, config.embed_dim) * 0.02)
        
        # RoPE: initialized with HEAD dimension (dim // heads)
        head_dim = config.embed_dim // config.heads
        self.rope = RotaryEmbedding(head_dim)
        
        self.blocks = nn.ModuleList([
            PredictorBlock(config) for _ in range(config.n_layer)
        ])
        
        self.final_norm = AdaLN(config.embed_dim, config.action_embed_dim)
        
        self.apply(self._init_weights)

    def _init_weights(self, module):
        std = 0.02
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)
            if module.bias is not None:
                torch.nn.init.zeros_(module.bias)
        elif isinstance(module, nn.Embedding):
            torch.nn.init.normal_(module.weight, mean=0.0, std=std)

    def forward(self, context_latents, action_ids):
        """
        context_latents: [Batch, N, Dim] (From Student Encoder)
        action_ids: [Batch] (Ints)
        """
        B, N, D = context_latents.shape
        
        # 1. Embed Action
        action_emb = self.action_embed(action_ids) # [B, action_embed_dim]
        
        # 2. Construct Input: [Context, Mask_Queries]
        # We concatenate the learned queries to the context. 
        # The predictor will attend to the context to update the queries.
        queries = self.mask_queries.repeat(B, 1, 1) # [B, N, D]
        sequence = torch.cat([context_latents, queries], dim=1) # [B, 2N, D]
        
        # 3. Generate RoPE for the full sequence (2N)
        # cos, sin are [2N, Head_Dim]
        cos, sin = self.rope(sequence)
        cos, sin = cos.to(sequence.device), sin.to(sequence.device)
        
        # 4. Pass through AdaLN Blocks
        for block in self.blocks:
            sequence = block(sequence, action_emb, cos, sin)
            
        sequence = self.final_norm(sequence, action_emb)
        
        # 5. Return only the predicted part (The Queries corresponding to N..2N)
        predictions = sequence[:, N:, :] 
        return predictions

### Create the Predictor Model

In [21]:
predictor = ACPredictor(ACPredictorConfig(n_layer=1, heads=n_heads, embed_dim=n_embd, action_embed_dim=n_embd))
predictor

ACPredictor(
  (action_embed): Embedding(2058, 4)
  (rope): RotaryEmbedding()
  (blocks): ModuleList(
    (0): PredictorBlock(
      (ada_ln1): AdaLN(
        (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=False)
        (action_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=4, out_features=8, bias=True)
        )
      )
      (attn): BioMultiHeadAttention(
        (q_proj): Linear(in_features=4, out_features=4, bias=True)
        (k_proj): Linear(in_features=4, out_features=4, bias=True)
        (v_proj): Linear(in_features=4, out_features=4, bias=True)
        (c_proj): Linear(in_features=4, out_features=4, bias=True)
      )
      (ada_ln2): AdaLN(
        (norm): LayerNorm((4,), eps=1e-05, elementwise_affine=False)
        (action_mlp): Sequential(
          (0): SiLU()
          (1): Linear(in_features=4, out_features=8, bias=True)
        )
      )
      (mlp): MLP(
        (c_fc): Linear(in_features=4, out_features=16, bias=True)
        (gelu)

### Forward Pass: Action Predictor
The Action predictor tries to guess Target encoding given Context + Action.  The context is the ouput of the Student and the action is from our tokenized data.  The goal here is tha the output attempts to match the teacher. 

In [22]:
action_id

tensor([1045,  802])

In [23]:
predicted_latents = predictor(context_latents, action_id)
predicted_latents

tensor([[[-1.3139, -0.4016,  0.3994,  1.3179],
         [ 1.5855, -0.7184, -0.8062, -0.0598],
         [-0.3741,  1.6877, -0.8541, -0.4585],
         ...,
         [ 0.4311,  1.2780, -1.1620, -0.5462],
         [-1.5899, -0.0922,  0.7758,  0.9081],
         [ 1.3576, -0.6968, -1.1525,  0.4929]],

        [[-1.3131, -0.4038,  0.4005,  1.3167],
         [ 1.5844, -0.7207, -0.8036, -0.0603],
         [-0.3742,  1.6872, -0.8518, -0.4595],
         ...,
         [ 0.4309,  1.2770, -1.1594, -0.5471],
         [-1.5892, -0.0939,  0.7765,  0.9069],
         [ 1.3567, -0.6989, -1.1497,  0.4921]]], grad_fn=<SliceBackward0>)

# Loss Calculation

We'll use L1 Loss (Mean Absolute Error) which measures the average absolute difference between predicted and actual values, making it robust to outliers.  We'll call this latent loss since we'll  use an L1 loss function within the latent space of a model to ensure the compressed representation (latent) accurately reflects the original data. 

In [24]:
loss = F.l1_loss(predicted_latents, target_latents)
loss

tensor(1.0196, grad_fn=<MeanBackward0>)

# Back Propogation

In [25]:
loss.backward()

## Gradient Clipping

In [26]:
torch.nn.utils.clip_grad_norm_(predictor.parameters(), 1.0)
torch.nn.utils.clip_grad_norm_(student.parameters(), 1.0)

tensor(5.5633)

## Update the Teacher

Recall that our teacher was created from a copy of our student. This means the weights are initiated randomly. We also didn't include any backprop for our teacher as since our loss is calculated using the action predictor and the student model.  To update the teacher we'll use a mechanism called  **Exponential Moving Average (EMA)**

In simple terms, this mechanism forces the Teacher to evolve very slowly, becoming a stable, wiser version of the Student.

Here is the breakdown of the mechanics. Mathematically, it calculates:

$Teacher_{new} = (m * Teacher_{old})+((1−m)*Student_{current})$

With a typical momentum ($m$) of 0.996, the code tells the Teacher: *"Keep 99.6% of your current weights, and mix in just 0.4% of what the Student learned in this last step."*

The reason we do this is Stability. The Student network is updating rapidly and noisily using Gradient Descent ($Param←Param−LearningRate×Gradient$). If we simply copied the Student to the Teacher (`Teacher = Student`), the target we are trying to predict would move erratically every single step, making it impossible to learn (like trying to hit a bullseye that is vibrating wildly).

By updating the Teacher this way, it effectively represents the **average** of the Student's weights over the last several thousand steps. This smooths out the noise and provides a steady, high-quality target for the Student to predict, preventing "Model Collapse" (where the model cheats by outputting zeros for everything).

TO:DO  we could also replace the teacher with a pretrained model and use this as an update mechanism. 

In [27]:
m=0.996
for param_s, param_t in zip(student.parameters(), teacher.parameters()):
    param_t.data.mul_(m).add_((1 - m) * param_s.data)

## Inference
Now we've done a walkthrough of the forward pass of the model let's see how we can run an inference test since this is an encoder model. We'll reload our tokenized data and perturbation map so that we can convert our pertubation to text and sample from our dataset. 

In [28]:
metadata_path, shard_dir

(PosixPath('/Users/djemec/data/jepa/perturbation_map.json'),
 PosixPath('/Users/djemec/data/jepa/tokenized'))

In [29]:
# 1. Load Map (ID -> Name)
with open(metadata_path, "r") as f:
    pert_map = json.load(f)
# Invert map to: ID -> Name
id_to_name = {v: k for k, v in pert_map.items()}

**Sample Data**
We'll build a modification of our data loader which pulls a random file and a random perturbed cell example from it.  

In [30]:
def get_random_test_pair(shard_dir):
    '''Grab a single real pair from a random shard'''
    files = sorted(shard_dir.glob('*.npz'))
    file_path = files[np.random.randint(len(files))]
    
    with np.load(file_path) as data:
        idx = np.random.randint(data['action_ids'].shape[0])
        
        # Extract raw uint32
        control_raw = data['control'][idx]
        case_raw = data['case'][idx]
        act_id = data['action_ids'][idx]
        
    # Dequantize
    x_control = torch.tensor(control_raw.astype(np.float32)).unsqueeze(0).to(DEVICE) # [1, 1024]
    x_case = torch.tensor(case_raw.astype(np.float32)).unsqueeze(0).to(DEVICE)
    action_id = torch.tensor([act_id], dtype=torch.long).to(DEVICE)
    
    return x_control, x_case, action_id

In [31]:
x_control, x_case, action_id = get_random_test_pair(shard_dir)
pert_name = id_to_name[action_id.item()]
pert_name

'MTBP'

**Get Baseline**

To start, we'll take our `teacher` model to create our baseline expected embedding for both our control and case examples.  These create our "expected" values 

In [32]:
# 2. Get Baselines (Teacher View)
# We need the Teacher to tell us where the "Control" and "Real Treated" 
# sit in the abstract latent space.
with torch.no_grad():
    z_control = teacher(x_control)       # Where the cell started
    z_case = teacher(x_case)     # Where the cell actually went

**Get Predicted**

Now we need to run our predicted model. This requires first running our control (baseline cell) through our `student`, then use the output with our `predictor` along with the perturbation to see how good our predictor is. 

In [33]:
# 3. Run The Physics Engine (Predictor)
# Student encodes context -> Predictor adds Action -> Output
with torch.no_grad():
    z_context = student(x_control)
    z_predicted = predictor(z_context, action_id) # Where the model thinks it went

### Metrics

Now we'll calculate 3 metrics:
1. **Baseline Drift** - First we have to understand how much our teacher expects the cell state to change.  If the teacher is not expecting a change then we'd want to know that and the value would be `0`.  
2. **Prediction Error** - Second we'll now check to see how far our predictor is from our expected case based on the teacher.  This is the value we'd like to minimize as best as possible. 
3. **Simulation Magnitude** - Third we'll see how much the predictor is different from the teacher based control. This value should ideally be as close to the drift value as possible since the predictor's output should share the latent space with the teacher

With these values, we can understand a few things: If the error is larger than the drift, we can infer that the model is still worse than just predicting no change. Also, by comparing drift and simulated magnitude move we can understand how close directionally our predictor is to the teacher's impact. If that is close we can see that we're improving our model.  

In [34]:
# Metric 1: Baseline Drift (How much did the drug actually change the cell?)
# If this is 0, the drug did nothing, so prediction is trivial.
drift = F.l1_loss(z_control, z_case).item()
drift

0.03951822593808174

In [35]:
# Metric 2: Prediction Error (How close is our guess to the real result?)
error = F.l1_loss(z_predicted, z_case).item()
error

1.0133721828460693

In [36]:
# Metric 3: Simulation Magnitude (How much did our model decide to move the cell?)
sim_move = F.l1_loss(z_control,z_predicted).item()
sim_move

1.0050748586654663

## Conclusion
As you can see, untrained we have quite a lot of error.  Also, we're juggling 3 models and our predictor relies on a chain of two models meaning that overall we'll need a lot of training to start learning the cell physics. 

Also, looking at how this works we can see that we have a limitation on our training data currently as we're collapsing full networks down to a single number. 