In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import copy
import math
import numpy as np
import glob
import os
import json
from pathlib import Path
import matplotlib.pyplot as plt
import random
from dataclasses import dataclass
import itertools
import biojepa_ac_model as model


## Model

In [2]:
torch.manual_seed(1337)
random.seed(1337)

In [3]:
def get_device():
    device = 'cpu'
    if torch.cuda.is_available():
        torch.cuda.manual_seed(1337)
        device = 'cuda'
    # elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
    #     device = 'mps'
    print(f'using {device}')
    return device

DEVICE = get_device()

using cpu


In [4]:
BATCH_SIZE = 32
n_embd = 8
n_pathways = 1024
LR = 1e-3
EPOCHS = 2
tok_file_chunk_size = 10000

## Data Loading and Model instantiation

In [5]:
data_dir = Path('/Users/djemec/data/jepa')
val_dir = data_dir / 'tokenized' / 'val'
mask_path = data_dir / 'binary_pathway_mask.npy'
metadata_path = data_dir / 'perturbation_map.json'
checkpoint_dir = data_dir / 'checkpoint'
gene_names_path = data_dir / 'gene_names.json'

In [6]:
print('Loading Pathway Mask...')
binary_mask = np.load(mask_path)
N_GENES, N_PATHWAYS = binary_mask.shape
print(f'Mask Loaded: {N_GENES} Genes -> {N_PATHWAYS} Pathways')

with open(metadata_path, 'r') as f:
    pert_map = json.load(f)
id_to_pert = {v: k for k, v in pert_map.items()}
print(f'Loaded {len(id_to_pert.keys())} perturbations')

with open(data_dir / 'gene_names.json', "r") as f:
    gene_names = json.load(f)
print(f'Loaded {len(gene_names)} genes')

Loading Pathway Mask...
Mask Loaded: 4096 Genes -> 1024 Pathways
Loaded 2058 perturbations
Loaded 4096 genes


In [7]:
config = model.BioJepaConfig(
    mask_matrix=binary_mask, 
    num_genes=N_GENES,
    num_pathways=N_PATHWAYS,
    embed_dim=n_embd,
    heads=1
)
model = model.BioJepa(config).to(DEVICE)

## Checkpoint Loading

In [16]:
# 3. Load Checkpoint
# Your script saves a dict: {'model': ..., 'optimizer': ..., 'step': ...}
checkpoint_path = checkpoint_dir / 'bio_jepa_ckpt_4epoch_final.pt'
checkpoint = torch.load(checkpoint_path, map_location=DEVICE)

keys = model.load_state_dict(checkpoint['model'])
keys

<All keys matched successfully>

## Validation

In [17]:
val_file = sorted(val_dir.glob('*.npz'))
val_file

[PosixPath('/Users/djemec/data/jepa/tokenized/val/shard_0000.npz')]

In [18]:
with np.load(val_file[0]) as data:
    control_x_all = data['control']        # [2000]
    control_total_all = data['control_total']  # Scalar
    case_x_all = data['case']         # [2000]
    case_total_all = data['case_total']  # Scalar
    action_ids_all = data['action_ids']     # Scalar

In [19]:
examples_to_validate = 25
validation = {}
total_examples = len(action_ids_all)

**Pathway Weights**
As part of our validation, we need the learned pathway weights.  We could use the binary mask but this would negate any learning on the pathways our model did.  We'll detach the pathway weights before starting our eval as it remains a constant. 

In [20]:
model.eval()

learned_gene_network_matrix = model.student.pathway_weights.detach().T.abs()
learned_gene_network_matrix

tensor([[0.0469, 0.0414, 0.0600,  ..., 0.0205, 0.0459, 0.0481],
        [0.0129, 0.0094, 0.0609,  ..., 0.0827, 0.0652, 0.0017],
        [0.0200, 0.0410, 0.0319,  ..., 0.0629, 0.0007, 0.0003],
        ...,
        [0.0233, 0.0226, 0.0096,  ..., 0.0169, 0.0588, 0.0128],
        [0.0552, 0.0323, 0.0386,  ..., 0.0380, 0.0168, 0.0265],
        [0.0449, 0.0719, 0.0238,  ..., 0.0867, 0.0212, 0.0016]])

In [21]:
for i in range(examples_to_validate):
    # pick a random index out 
    idx = np.random.randint(total_examples)

    # extract inputs for the index
    control_x = torch.tensor(control_x_all[idx]).float().unsqueeze(0).to(DEVICE)
    control_tot = torch.tensor(control_total_all[idx]).float().unsqueeze(0).to(DEVICE)
    case_x = torch.tensor(case_x_all[idx]).float().unsqueeze(0).to(DEVICE)
    case_tot = torch.tensor(case_total_all[idx]).float().unsqueeze(0).to(DEVICE)
    action_id = torch.tensor([action_ids_all[idx]]).long().to(DEVICE)

    # get perturbation name
    pert_name = id_to_pert[action_id.item()]

    # Run model, both teacher and predictor
    with torch.no_grad():
        z_case = model.teacher(case_x, case_tot)       
        z_context = model.student(control_x, control_tot)
        z_control = model.teacher(control_x, control_tot) 
        z_predicted = model.predictor(z_context, action_id)

        l1_loss = model(control_x, control_tot, case_x, case_tot, action_id)
    
    delta_latent = (z_predicted - z_control).squeeze(0)

    #gene projection
    pathway_magnitudes = delta_latent.norm(dim=1) 
    gene_impact_scores = torch.mv(learned_gene_network_matrix, pathway_magnitudes)

    # comparing top 5% expressed genes for overlap
    top_k = int(0.05*len(gene_names))
    pred_values, pred_indices = torch.topk(gene_impact_scores, top_k)
    
    real_gene_delta = (case_x - control_x).abs().squeeze(0) 
    real_values, real_indices = torch.topk(real_gene_delta, top_k)

    # if i == 0:
    #     print(top_k)
    #     print(f'predicted_genes {[gene_names[i.item()] for i in pred_indices]}')
    #     print(f'actual_genes {[gene_names[i.item()] for i in real_indices]}')

    overlap = len(set(pred_indices).intersection(set(real_indices)))
    precision = overlap/top_k

    validation[idx] = {
        'perturbation': pert_name,
        'l1_loss': l1_loss,
        'overlap': overlap,
        'precision': precision,
    }
    

In [22]:
validation

{8646: {'perturbation': 'TSR1',
  'l1_loss': tensor(0.0204),
  'overlap': 0,
  'precision': 0.0},
 8771: {'perturbation': 'MTPAP',
  'l1_loss': tensor(0.0132),
  'overlap': 0,
  'precision': 0.0},
 5414: {'perturbation': 'NUP205',
  'l1_loss': tensor(0.0154),
  'overlap': 0,
  'precision': 0.0},
 2231: {'perturbation': 'GLMN',
  'l1_loss': tensor(0.0187),
  'overlap': 0,
  'precision': 0.0},
 4378: {'perturbation': 'FNBP4',
  'l1_loss': tensor(0.0397),
  'overlap': 0,
  'precision': 0.0},
 7923: {'perturbation': 'RPL3',
  'l1_loss': tensor(0.0184),
  'overlap': 0,
  'precision': 0.0},
 9781: {'perturbation': 'DRAP1',
  'l1_loss': tensor(0.0164),
  'overlap': 0,
  'precision': 0.0},
 1489: {'perturbation': 'COG3',
  'l1_loss': tensor(0.0192),
  'overlap': 0,
  'precision': 0.0},
 3364: {'perturbation': 'MED22',
  'l1_loss': tensor(0.0364),
  'overlap': 0,
  'precision': 0.0},
 7746: {'perturbation': 'HGS',
  'l1_loss': tensor(0.0108),
  'overlap': 0,
  'precision': 0.0},
 5850: {'pertur