In [51]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import json
from pathlib import Path
from tqdm import tqdm
from scipy.stats import pearsonr
import random
import matplotlib.pyplot as plt
from dataclasses import dataclass
import math
from collections import defaultdict
from sklearn.metrics import r2_score, mean_squared_error

import biojepa_ac_model as model

## BioJEPA Model load

In [28]:
torch.manual_seed(1337)
random.seed(1337)

In [29]:
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    # elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    #     device = 'mps'
    print(f'using {device}')
    return device

DEVICE = get_device()

using cpu


In [30]:
BATCH_SIZE = 32
n_embd = 8
n_pathways = 1024
training_file_chunk = 25000
pretraining_file_chunk = 50000
n_heads = 2
n_layers = 2

In [31]:
data_dir = Path('/Users/djemec/data/jepa/v0_2')
train_dir = data_dir / 'training'
pretrain_dir = data_dir / 'pretraining'
mask_path = data_dir / 'binary_pathway_mask.npy'
checkpoint_dir = data_dir / 'checkpoint'
pert_dir = data_dir / 'pert_embd'
pert_embd_path = pert_dir / 'action_embeddings_esm2.npy'

In [32]:
print('Loading Pathway Mask...')
binary_mask = np.load(mask_path)
n_genes, n_pathways = binary_mask.shape
print(f'Mask Loaded: {n_genes} Genes -> {n_pathways} Pathways')

print('Loading Action Embedding ...')
pert_embd = np.load(pert_embd_path)
print(f'Bank Loaded. Shape: {pert_embd.shape}')

Loading Pathway Mask...
Mask Loaded: 5000 Genes -> 1024 Pathways
Loading Action Embedding ...
Bank Loaded. Shape: (1087, 320)


In [63]:
config = model.BioJepaConfig(
    mask_matrix=binary_mask, 
    num_genes=n_genes,
    num_pathways=n_pathways,
    embed_dim=n_embd,
    n_layer=n_layers,
    heads=n_heads,
    n_pre_layer = n_layers
)
model = model.BioJepa(config, pert_embd=pert_embd).to(DEVICE)

**Load Checkpoint**

In [64]:
checkpoint_path = checkpoint_dir / 'bio_jepa_ckpt_31769_final.pt'
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

keys = model.load_state_dict(checkpoint['model'])
keys

<All keys matched successfully>

**Freeze Model**

In [65]:
model.eval()
for param in model.parameters():
    param.requires_grad = False

## Build Decoder

In [36]:
@dataclass
class BenchmarkDecoderConfig:
    embed_dim: int = 384
    num_pathways: int = 1024
    num_genes: int = 4096
    
class BenchmarkDecoder(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        # Step 1: Collapse the embedding dimension (384 -> 1)
        # This asks: "How active is this pathway overall?"
        self.pool = nn.Linear(config.embed_dim, 1) 
        
        # Step 2: Decode Pathway Activity -> Gene Expression
        # This learns the specific contribution of each pathway to each gene
        self.decode = nn.Linear(config.num_pathways, config.num_genes) 

        self.apply(self._init_weights)
        
    def _init_weights(self, module):
        if isinstance(module, nn.Linear):
            torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
            if module.bias is not None: 
                torch.nn.init.zeros_(module.bias)
        
    def forward(self, latents):
        # latents: [Batch, 1024, 384]
        
        # 1. Calculate Pathway Scores
        # [B, 1024, 384] -> [B, 1024, 1] -> [B, 1024]
        scores = self.pool(latents).squeeze(-1)
        
        # 2. Project to Genes
        # [B, 1024] @ [1024, 2000] -> [B, 2000]
        gene_preds = self.decode(scores)
        
        return gene_preds

### Data Loader

In [37]:
def load_shard(filename):
    print(f'loading {filename}') # Optional: reduce noise
    with np.load(filename) as data:
        # Load all arrays into memory
        # We convert to correct types immediately to save hassle later
        control_x = data['control'].astype(np.float32)
        control_tot = data['control_total'].astype(np.float32)
        case_x = data['case'].astype(np.float32)
        case_tot = data['case_total'].astype(np.float32)
        action_ids = data['action_ids'].astype(np.int64)
        
    return control_x, control_tot, case_x, case_tot, action_ids

class DataLoaderLite:
    def __init__(self, B, split, device):
        self.B = B
        self.split = split
        self.device = device
        
        # 1. Find Shards
        data_root = train_dir / f'{split}'
        shards = list(data_root.glob('*.npz'))

        self.total_files = len(shards)
        self.shards = sorted(shards)

        assert len(shards) > 0, f'no shards found for split {split}'
        print(f'found {len(shards)} shards for split {split}')
        
        self.reset()

    def reset(self):
        # Create a randomized queue of shards
        self.remaining_shards = list(self.shards)
        random.shuffle(self.remaining_shards)
        
        self.current_shard_idx = -1
        self.load_next_shard()

    def load_next_shard(self):
        self.current_shard_idx += 1
        
        # If we ran out of shards, reset (Epoch done)
        if self.current_shard_idx >= len(self.remaining_shards):
            self.reset() # This resets shard_idx to -1 and reshuffles
            return 

        # Load the file
        filename = self.remaining_shards[self.current_shard_idx]
        self.data_tuple = load_shard(filename)
        
        # Shuffle the items INSIDE the shard
        # This is critical so we don't just memorize the sorted order of the shard
        n_samples = len(self.data_tuple[0])
        self.perm = np.random.permutation(n_samples)
        self.current_position = 0
        self.total_samples_in_shard = n_samples

    def next_batch(self):
        B = self.B
        
        # Check if we have enough data left in current shard
        if self.current_position + B > self.total_samples_in_shard:
            self.load_next_shard()
            # Recursively call to get batch from the new shard
            return self.next_batch()
            
        # Get indices for this batch
        indices = self.perm[self.current_position : self.current_position + B]
        self.current_position += B
        
        # Slice data using the shuffled indices
        # data_tuple structure: (xc, xct, xt, xtt, aid)
        batch_xc  = torch.from_numpy(self.data_tuple[0][indices]).to(self.device)
        batch_xct = torch.from_numpy(self.data_tuple[1][indices]).to(self.device)
        batch_xt  = torch.from_numpy(self.data_tuple[2][indices]).to(self.device)
        batch_xtt = torch.from_numpy(self.data_tuple[3][indices]).to(self.device)
        batch_aid = torch.from_numpy(self.data_tuple[4][indices]).to(self.device)
        
        return batch_xc, batch_xct, batch_xt, batch_xtt, batch_aid

**Data Loader**

In [38]:
train_loader = DataLoaderLite(B=BATCH_SIZE, split='train', device=DEVICE)
val_loader = DataLoaderLite(B=BATCH_SIZE, split='val', device=DEVICE)


found 5 shards for split train
loading /Users/djemec/data/jepa/v0_2/training/train/shard_k562e_train_0001.npz
found 1 shards for split val
loading /Users/djemec/data/jepa/v0_2/training/val/shard_k562e_val_0000.npz


## Training Decoder

### Training Config/Setup

In [39]:
lr_decoder = 1e-2
epochs = 5

**Initialize Decoder** 

In [40]:
config = BenchmarkDecoderConfig(
    embed_dim= n_embd,
    num_pathways= n_pathways,
    num_genes= n_genes
)

decoder = BenchmarkDecoder(config).to(DEVICE)

## Load Checkpoint

In [41]:
decoder_checkpoint_path = checkpoint_dir / 'biojepa_decoder_ckpt_15884_final.pt'
decode_checkpoint = torch.load(decoder_checkpoint_path, map_location=DEVICE)

d_keys = decoder.load_state_dict(decode_checkpoint['model'])
d_keys

<All keys matched successfully>

**Optimizer**

In [42]:
optimizer = torch.optim.AdamW(decoder.parameters(), lr=lr_decoder)

**Training Lenght**

In [43]:
train_total_examples = 101682
val_total_examples = 11044
test_total_examples = 38829

In [44]:
steps_per_epoch = train_total_examples // BATCH_SIZE
max_steps = epochs * steps_per_epoch
steps_per_epoch, max_steps

(3177, 15885)

**Scheduler**

In [None]:
scheduler = torch.optim.lr_scheduler.OneCycleLR(
    optimizer, max_lr=lr_decoder, total_steps=max_steps, pct_start=0.05
)

### Training Loop

In [None]:
lossi = []
step = 0
total_epoch_loss = 0

In [None]:
for step in range(max_steps):

    last_step = (step == max_steps - 1)

    # once in a while evaluate our validation set loss
    if step % 100 == 0 or last_step:
        model.eval()
        with torch.no_grad():
            val_loss_accum = 0.0
            val_loss_steps = 25
            for i in range(val_loss_steps):
                cont_x, cont_tot, case_x, case_tot, act_id = val_loader.next_batch()

                # run BioJEPA
                with torch.no_grad():
                    z_context = model.student(cont_x, cont_tot)
                    z_pred = model.predictor(z_context, act_id)

                # run new decoder
                pred_delta = decoder(z_pred) - decoder(z_context)
                real_delta = case_x - cont_x

                loss = F.mse_loss(pred_delta, real_delta)
                loss = loss / val_loss_steps
                val_loss_accum += loss.detach()

        print(f'val loss: {val_loss_accum.item():.4f}')


    # periodically save checkpoint
    if step > 0 and  (step+1) % steps_per_epoch ==0 and not last_step:
        # Save Checkpoint
        torch.save({
            'model': decoder.state_dict(),
            'optimizer': optimizer.state_dict(),
            'step': step
        }, checkpoint_dir / f'biojepa_decoder_ckpt_{step}.pt')

    # actual training
    decoder.train
    cont_x, cont_tot, case_x, case_tot, act_id = train_loader.next_batch()

    # run frozen BioJEPA
    with torch.no_grad():
        z_context = model.student(cont_x, cont_tot)
        z_pred = model.predictor(z_context, act_id)

    # run decoder
    pred_delta = decoder(z_pred) - decoder(z_context)
    real_delta = case_x - cont_x
    real_delta = case_x - cont_x

    # loss
    pred_delta = decoder(z_pred) - decoder(z_context)
    real_delta = case_x - cont_x

    optimizer.zero_grad()

    loss = F.mse_loss(pred_delta, real_delta)
    loss.backward()

    optimizer.step()
    scheduler.step()

    # loss caching
    lossi.append(loss.item())
    total_epoch_loss += loss.item()

    if step % 25 == 0:
        print(f"Step {step} | Loss: {loss.item():.5f} | LR: {scheduler.get_last_lr()[0]:.2e}")
    
    if step > 0 and (step+1) % steps_per_epoch == 0:   
        avg_loss = total_epoch_loss / steps_per_epoch
        print(f"=== Step {step} Done. Avg Loss: {avg_loss:.5f} ===")
        total_epoch_loss = 0
    
    if last_step:
        # Save Checkpoint
        torch.save({
            'model': decoder.state_dict(),
            'optimizer': optimizer.state_dict(),
            'step': step
        }, checkpoint_dir / f'biojepa_decoder_ckpt_{step}_final.pt')

    step += 1

**Training Loss Plot**

In [None]:
plt.plot(lossi)
plt.yscale('log')

## Trained Decoder Validation

In [45]:
import warnings
from scipy.stats import ConstantInputWarning

In [46]:
decoder.eval()

BenchmarkDecoder(
  (pool): Linear(in_features=8, out_features=1, bias=True)
  (decode): Linear(in_features=1024, out_features=5000, bias=True)
)

In [47]:
val_steps_per_epoch = val_total_examples // BATCH_SIZE
test_steps_per_epoch = test_total_examples // BATCH_SIZE

In [48]:
bulk_preds = defaultdict(list)      
bulk_reals = defaultdict(list)     
bulk_reals_delta = defaultdict(list)

val_r2_all = []
val_r2_top50 = []
val_correlations = []
val_mses = []

In [49]:
for step in tqdm(range(val_steps_per_epoch), desc='Val Dataset'):
    
    # Custom Loader Call
    cont_x, cont_tot, case_x, case_tot, act_id = val_loader.next_batch()
    B, N = cont_x.shape
    
    with torch.no_grad():
        z_context = model.student(cont_x, cont_tot)
        z_pred = model.predictor(z_context, act_id)
        
        pred_delta = decoder(z_pred) - decoder(z_context)
        real_delta = case_x - cont_x

        pred_absolute = cont_x + pred_delta

    pred_delta_np = pred_delta.cpu().numpy()
    real_delta_np = real_delta.cpu().numpy()

    pred_abs_np = pred_absolute.cpu().numpy()
    real_abs_np = case_x.cpu().numpy()
    
    act_id_np = act_id.cpu().numpy().flatten()
        
    # Per-Sample Metrics
    for i in range(B):
        p_delta = pred_delta_np[i]
        t_delta = real_delta_np[i]
        pid = act_id_np[i] 

        val_mses.append(np.mean((p_delta - t_delta)**2))

        top_20_idx = np.argsort(np.abs(t_delta))[-20:]
        p_top = p_delta[top_20_idx]
        t_top = t_delta[top_20_idx]
        
        if np.std(p_top) > 1e-9 and np.std(t_top) > 1e-9:
            corr, _ = pearsonr(p_top, t_top)
            val_correlations.append(0.0 if np.isnan(corr) else corr)
        else:
            val_correlations.append(0.0)

        bulk_preds[pid].append(pred_abs_np[i])
        bulk_reals[pid].append(real_abs_np[i])
        bulk_reals_delta[pid].append(real_delta_np[i])

Val Dataset: 100%|████████████████████████████████████████████████████████████| 345/345 [01:31<00:00,  3.78it/s]


In [52]:
for pid in bulk_preds:
    p_mean = np.mean(np.stack(bulk_preds[pid]), axis=0)
    t_mean = np.mean(np.stack(bulk_reals[pid]), axis=0)
    t_mean_delta = np.mean(np.stack(bulk_reals_delta[pid]), axis=0)

    if np.std(t_mean) > 1e-9:
        val_r2_all.append(r2_score(t_mean, p_mean))
    
    top_50_idx = np.argsort(np.abs(t_mean_delta))[-50:] 
    
    val_r2_top50.append(r2_score(t_mean[top_50_idx], p_mean[top_50_idx]))

In [53]:
val_mean_mse = np.mean(val_mses)
val_mean_corr = np.mean(val_correlations)
val_mean_r2_all = np.mean(val_r2_all)
val_median_r2_all = np.median(val_r2_all)
val_mean_r2_top50 = np.mean(val_r2_top50)
val_median_r2_top50 = np.median(val_r2_top50)

In [54]:
print(f'Global MSE: {val_mean_mse:.4f}')
print(f'Top-20 Pearson R: {val_mean_corr:.4f}')
print(f'R^2 (All Genes): Mean: {val_mean_r2_all:.4f}, Median: {val_median_r2_all:.4f}')
print(f'R^2 (Top 50 DEGs): Mean: {val_mean_r2_top50:.4f}, Median: {val_median_r2_top50:.4f}')


Global MSE: 0.7929
Top-20 Pearson R: 0.5565
R^2 (All Genes): Mean: 0.9500, Median: 0.9554
R^2 (Top 50 DEGs): Mean: -0.1257, Median: 0.1082


## Trained Decoder Evaluation

In [66]:
bulk_preds = defaultdict(list)
bulk_reals = defaultdict(list)
bulk_reals_delta = defaultdict(list)

test_r2_all = []
test_r2_top50 = []
test_correlations = []
test_mses = []

In [67]:
test_loader = DataLoaderLite(B=BATCH_SIZE, split='test', device=DEVICE)
test_steps_per_epoch = test_total_examples // BATCH_SIZE


found 2 shards for split test
loading /Users/djemec/data/jepa/v0_2/training/test/shard_k562e_test_0000.npz


In [68]:
for step in tqdm(range(test_steps_per_epoch), desc='Benchmarking'):
    
    # Custom Loader Call
    cont_x, cont_tot, case_x, case_tot, act_id = test_loader.next_batch()
    B, N = cont_x.shape
    
    with torch.no_grad():
        z_context = model.student(cont_x, cont_tot)
        z_pred = model.predictor(z_context, act_id)
        
        pred_delta = decoder(z_pred) - decoder(z_context)
        real_delta = case_x - cont_x

        pred_absolute = cont_x + pred_delta

    pred_delta_np = pred_delta.cpu().numpy()
    real_delta_np = real_delta.cpu().numpy()

    pred_abs_np = pred_absolute.cpu().numpy()
    real_abs_np = case_x.cpu().numpy()
    
    act_id_np = act_id.cpu().numpy().flatten()
        
    # Per-Sample Metrics
    for i in range(B):
        p_delta = pred_delta_np[i]
        t_delta = real_delta_np[i]
        pid = act_id_np[i] 

        test_mses.append(np.mean((p_delta - t_delta)**2))

        top_20_idx = np.argsort(np.abs(t_delta))[-20:]
        p_top = p_delta[top_20_idx]
        t_top = t_delta[top_20_idx]
        
        if np.std(p_top) > 1e-9 and np.std(t_top) > 1e-9:
            corr, _ = pearsonr(p_top, t_top)
            test_correlations.append(0.0 if np.isnan(corr) else corr)
        else:
            test_correlations.append(0.0)

        bulk_preds[pid].append(pred_abs_np[i])
        bulk_reals[pid].append(real_abs_np[i])
        bulk_reals_delta[pid].append(real_delta_np[i])

Benchmarking:  64%|█████████████████████████████████████▎                    | 781/1213 [03:34<02:09,  3.33it/s]

loading /Users/djemec/data/jepa/v0_2/training/test/shard_k562e_test_0001.npz


Benchmarking: 100%|█████████████████████████████████████████████████████████| 1213/1213 [05:37<00:00,  3.59it/s]


In [69]:
for pid in bulk_preds:
    p_mean = np.mean(np.stack(bulk_preds[pid]), axis=0)
    t_mean = np.mean(np.stack(bulk_reals[pid]), axis=0)
    t_mean_delta = np.mean(np.stack(bulk_reals_delta[pid]), axis=0)

    if np.std(t_mean) > 1e-9:
        test_r2_all.append(r2_score(t_mean, p_mean))
    
    top_50_idx = np.argsort(np.abs(t_mean_delta))[-50:] 
    
    test_r2_top50.append(r2_score(t_mean[top_50_idx], p_mean[top_50_idx]))

In [70]:
test_mean_mse = np.mean(test_mses)
test_mean_corr = np.mean(test_correlations)
test_mean_r2_all = np.mean(test_r2_all)
test_median_r2_all = np.median(test_r2_all)
test_mean_r2_top50 = np.mean(test_r2_top50)
test_median_r2_top50 = np.median(test_r2_top50)

In [72]:
print(f'Global MSE: {test_mean_mse:.4f}')
print(f'Top-20 Pearson R: {test_mean_corr:.4f}')
print(f'R^2 (All Genes): Mean: {test_mean_r2_all:.4f}, Median: {test_median_r2_all:.4f}')
print(f'R^2 (Top 50 DEGs): Mean: {test_mean_r2_top50:.4f}, Median: {test_median_r2_top50:.4f}')

Global MSE: 0.7904
Top-20 Pearson R: 0.5544
R^2 (All Genes): Mean: 0.9498, Median: 0.9584
R^2 (Top 50 DEGs): Mean: -0.1000, Median: 0.1886


In [73]:
diff = np.abs(pred_absolute - cont_x).mean()
print(f'Average Predicted Shift magnitude: {diff:.4f}')

Average Predicted Shift magnitude: 0.2665


  diff = np.abs(pred_absolute - cont_x).mean()
