# IY011 Contrastive Learning: Model Training
Randomly pick pairs of samples from the dataset, randomly assign labels to each, and train a model to distinguish them

In [1]:
import os
import subprocess
import glob
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
import time
# plotting 
import matplotlib.pyplot as plt
from visualisation.plots import plot_mRNA_dist, plot_mRNA_trajectory
# simulation
from simulation.julia_simulate_telegraph_model import simulate_telegraph_model
# ml
import torch, itertools
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, TensorDataset
from classifiers.transformer_classifier import transformer_classifier
from models.simple_transformer import SimpleTransformer
from models.transformer import TransformerClassifier, train_model, evaluate_model
# data handling
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from utils.load_data import load_and_split_data
from utils.data_processing import add_binary_labels
from utils.standardise_time_series import standardise_time_series
from utils.steady_state import find_steady_state

# Build groups
from utils.data_processing import build_groups

import wandb
%load_ext autoreload
%autoreload 2  

In [2]:
DATA_ROOT = Path("/home/ianyang/stochastic_simulations/experiments/EXP-25-IY011/data")
RESULTS_PATH = DATA_ROOT / "IY011_simulation_parameters_sobol.csv" #  this csv file stores all the simulation parameters used
df_params = pd.read_csv(RESULTS_PATH) 
# TRAJ_PATH = [DATA_ROOT / f"mRNA_trajectories_mu{row['mu_target']:.3f}_cv{row['cv_target']:.3f}_tac{row['t_ac_target']:.3f}.csv" for idx, row in df_params.iterrows()] # the trajectories 
TRAJ_PATH = [DATA_ROOT / df_params['trajectory_filename'].values[i] for i in range(len(df_params))]
TRAJ_NPZ_PATH = [traj_file.with_suffix('.npz') for traj_file in TRAJ_PATH]

# extract meta data
parameter_sets = [{
    'sigma_b': row['sigma_b'],
    'sigma_u': row['sigma_u'],
    'rho': row['rho'],
    'd': row['d'],
    'label': 0
} for idx, row in df_params.iterrows()]
time_points = np.arange(0, 3000, 1.0)
size = 1000

In [36]:
num_traj = 500
NUM_GROUPS = 6  # a pair: 1 pos, 1 neg
groups = build_groups(TRAJ_NPZ_PATH, num_groups=NUM_GROUPS, num_traj=num_traj) # list of tuples (X, y)

Building positive groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00, 237.45it/s]
Building negative groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 3/3 [00:00<00:00, 223.68it/s]


In [37]:
groups

[(array([[ 448.,  640., 3033., ...,  791., 2449., 4989.],
         [ 429.,  615., 2939., ...,  768., 2360., 4821.],
         [ 417.,  598., 2843., ...,  741., 2291., 4654.],
         ...,
         [1646., 1872.,   48., ..., 3108., 1965., 3051.],
         [1595., 1812.,   46., ..., 3001., 1910., 2960.],
         [1544., 1752.,   46., ..., 2912., 1851., 2869.]],
        shape=(1811, 500), dtype=float32),
  1),
 (array([[ 9576., 11189., 17969., ...,  2469.,     0.,  1050.],
         [ 9465., 11053., 17789., ...,  2264.,     0.,   935.],
         [ 9365., 10931., 17589., ...,  2060.,     0.,   853.],
         ...,
         [ 4240.,  1386.,  5725., ...,   253.,  1576., 31987.],
         [ 4199.,  1373.,  5661., ...,   218.,  1433., 28943.],
         [ 4164.,  1352.,  5615., ...,   209.,  1306., 26245.]],
        shape=(1811, 500), dtype=float32),
  0),
 (array([[2254.,   96., 1108., ..., 4067., 2857.,  924.],
         [2229.,   95., 1098., ..., 4028., 2835.,  915.],
         [2207.,   91., 

## Data Prep

In [38]:
def data_prep(groups, NUM_GROUPS):
    # Stacked groups -> individual trajectory samples
    X_samples = []
    y_samples = []
    for Xg, yg in groups:          # Xg shape (seq_len, K)
        L, K = Xg.shape
        for k in range(K):
            X_samples.append(Xg[:, k:k+1])  # (seq_len, 1)
            y_samples.append(yg)            # or some other per-trajectory label
    X_samples = np.stack(X_samples, 0)      # (N_samples, seq_len, 1)
    y_samples = np.array(y_samples)
    print(f'X_samples shape: {X_samples.shape}, y_samples shape: {y_samples.shape}')

    # with the stacked samples
    X_train, X_test, y_train, y_test = train_test_split(
        X_samples, y_samples, test_size=0.2, random_state=42, stratify=y_samples
    )
    X_train, X_val,  y_train, y_val  = train_test_split(
        X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
    )

    print("Data preparation:")
    print(f"  Train groups: {len(y_train)}, Val groups: {len(y_val)}, Test groups: {len(y_test)}")
    # === Standardise features (across time*batch, per-channel) ===
    scaler = StandardScaler()

    # Reshape 3D data to 2D for scaling
    original_shape_train = X_train.shape
    original_shape_val = X_val.shape
    original_shape_test = X_test.shape

    # Reshape to 2D: (batch * seq_len, features)
    X_train_2d = X_train.reshape(-1, X_train.shape[-1])
    X_val_2d = X_val.reshape(-1, X_val.shape[-1])
    X_test_2d = X_test.reshape(-1, X_test.shape[-1])

    # Scale the data
    X_train_2d = scaler.fit_transform(X_train_2d)
    X_val_2d = scaler.transform(X_val_2d)
    X_test_2d = scaler.transform(X_test_2d)

    # Reshape back to 3D
    X_train = X_train_2d.reshape(original_shape_train)
    X_val = X_val_2d.reshape(original_shape_val)
    X_test = X_test_2d.reshape(original_shape_test)

    print("X_train shape:", X_train.shape)
    print("X_val shape:", X_val.shape)
    print("X_test shape:", X_test.shape)
    
    # Torch loaders
    batch_size = NUM_GROUPS

    # === Convert to tensors and loaders ===
    X_train_t = torch.tensor(X_train, dtype=torch.float32)
    y_train_t = torch.tensor(y_train, dtype=torch.long)
    X_val_t   = torch.tensor(X_val,   dtype=torch.float32)
    y_val_t   = torch.tensor(y_val,   dtype=torch.long)
    X_test_t  = torch.tensor(X_test,  dtype=torch.float32)
    y_test_t  = torch.tensor(y_test,  dtype=torch.long)

    train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True)
    val_loader   = DataLoader(TensorDataset(X_val_t,   y_val_t),   batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
    test_loader  = DataLoader(TensorDataset(X_test_t,  y_test_t),  batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

    # check the data loaders
    for X_batch, y_batch in train_loader:
        print(X_batch.shape, y_batch.shape)
        break 
    
    return train_loader, val_loader, test_loader
    
train_loader, val_loader, test_loader = data_prep(groups, NUM_GROUPS)

X_samples shape: (3000, 1811, 1), y_samples shape: (3000,)
Data preparation:
  Train groups: 1920, Val groups: 480, Test groups: 600
X_train shape: (1920, 1811, 1)
X_val shape: (480, 1811, 1)
X_test shape: (600, 1811, 1)
torch.Size([6, 1811, 1]) torch.Size([6])


## Transformer Model Eval
 Start a new wandb run to track this script.

In [39]:
import wandb
import torch.optim as optim
import time

# === Model hyperparams ===
input_size = 1
num_classes = 2
d_model=64
nhead=4
num_layers=2
dropout=0.001
use_conv1d=False 

model = TransformerClassifier(
    input_size=input_size,
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    num_classes=num_classes,
    dropout=dropout, 
    use_conv1d=use_conv1d 
)
# === Model hyperparams ===

# === Training hyperparams ===
epochs = 50
patience = 10
lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)
loss_fn = nn.BCEWithLogitsLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grad_clip = 1.0
save_path = None
verbose = True

model.to(device)
# === Training hyperparams ===

TransformerClassifier(
  (input_proj): Linear(in_features=1, out_features=64, bias=True)
  (pe): PositionalEncoding()
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-1): 2 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
        )
        (linear1): Linear(in_features=64, out_features=256, bias=True)
        (dropout): Dropout(p=0.001, inplace=False)
        (linear2): Linear(in_features=256, out_features=64, bias=True)
        (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.001, inplace=False)
        (dropout2): Dropout(p=0.001, inplace=False)
      )
    )
  )
  (dropout): Dropout(p=0.001, inplace=False)
  (head): Linear(in_features=64, out_features=2, bias=True)
)

In [40]:
# === Init wandb run ===
wandb_config = {
    "dataset": DATA_ROOT.name,
    "input_size": input_size,
    "d_model": d_model,
    "nhead": nhead,
    "num_layers": num_layers,
    "num_classes": num_classes,
    "dropout": dropout,
    "use_conv1d": use_conv1d,
    "epochs": epochs,
    "patience": patience,
    "lr": lr,
    "optimizer": type(optimizer).__name__,
    "loss_fn": type(loss_fn).__name__,
    "model": type(model).__name__,
    "batch_size": getattr(train_loader, "batch_size", None),
    "num_traj_per_group": num_traj,
    "num_groups": NUM_GROUPS
}
run = wandb.init(entity="grignard-reagent",
                 project="IY011-contrastive-learning",
                 name=f"groups_{NUM_GROUPS}_traj_{num_traj}_random_neg_split",
                 config=wandb_config)
# === Init wandb run ===

In [41]:
if verbose:
    print("Starting training...")

best_val_acc = -1.0
epochs_no_improve = 0
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": []}
start_time = time.time()

for epoch in range(epochs):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for X_batch, y_batch in train_loader:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        optimizer.zero_grad()
        outputs = model(X_batch)

        # adjust targets if BCE-type loss
        y_batch_mod = y_batch
        if isinstance(loss_fn, (nn.BCEWithLogitsLoss, nn.BCELoss)):
            y_batch_mod = y_batch_mod.float().unsqueeze(1) if y_batch_mod.dim() == 1 else y_batch_mod.float()
            if outputs.dim() == 2 and outputs.size(1) == 2 and y_batch_mod.size(1) == 1:
                outputs = outputs[:, 1].unsqueeze(1)

        loss = loss_fn(outputs, y_batch_mod)
        loss.backward()
        if grad_clip:
            nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
        optimizer.step()

        total_loss += loss.item() * X_batch.size(0)

        # compute accuracy
        if isinstance(loss_fn, (nn.BCEWithLogitsLoss, nn.BCELoss)):
            probs = torch.sigmoid(outputs).view(-1)
            preds = (probs > 0.5).long()
            tgt = y_batch.view(-1).long()
        else:
            preds = outputs.argmax(1)
            tgt = y_batch
        correct += (preds == tgt).sum().item()
        total += tgt.size(0)

    train_loss = total_loss / len(train_loader.dataset)
    train_acc = correct / total
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)

    # Validation
    val_loss, val_acc = (None, None)
    if val_loader is not None:
        val_loss, val_acc = evaluate_model(model, val_loader, loss_fn=loss_fn, device=device, verbose=False)

        # Early stopping
        if val_acc is not None and val_acc > best_val_acc:
            best_val_acc = val_acc
            epochs_no_improve = 0
            if save_path:
                torch.save(model.state_dict(), save_path)
                print(f"‚úÖ Model saved at {save_path} (Best Val Acc: {best_val_acc:.4f})")
                # also upload to wandb
                try:
                    wandb.save(save_path)
                except Exception:
                    pass
        else:
            epochs_no_improve += 1
            if verbose:
                print(f"No improvement ({epochs_no_improve}/{patience}).")

        if epochs_no_improve >= patience:
            print("üõë Early stopping.")
            break

    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_acc)

    # Log metrics to wandb each epoch
    log_dict = {
        "epoch": epoch + 1,
        "train/loss": train_loss,
        "train/acc": train_acc,
    }
    if val_loss is not None:
        log_dict.update({"val/loss": val_loss, "val/acc": val_acc})
    # also log current LR
    try:
        current_lr = optimizer.param_groups[0]["lr"]
        log_dict["lr"] = current_lr
    except Exception:
        pass

    # optional: log gradient norm (approx)
    total_grad_norm = 0.0
    for p in model.parameters():
        if p.grad is not None:
            total_grad_norm += p.grad.data.norm(2).item() ** 2
    total_grad_norm = total_grad_norm ** 0.5 if total_grad_norm > 0 else 0.0
    log_dict["grad/norm"] = total_grad_norm

    run.log(log_dict)

    if verbose:
        msg = f"Epoch [{epoch+1}/{epochs}] | train_loss {train_loss:.4f} | train_acc {train_acc:.4f}"
        if val_loader is not None:
            msg += f" | val_loss {val_loss:.4f} | val_acc {val_acc:.4f}"
        print(msg)

# final save + finish wandb run
elapsed = time.time() - start_time
run.summary["training_time_sec"] = elapsed
run.summary["best_val_acc"] = best_val_acc
run.finish()

print("Training complete.")

Starting training...


Epoch [1/50] | train_loss 0.2729 | train_acc 0.9578 | val_loss 1.5516 | val_acc 0.9104
Epoch [2/50] | train_loss 0.2439 | train_acc 0.9714 | val_loss 0.2718 | val_acc 0.9667
Epoch [3/50] | train_loss 0.1548 | train_acc 0.9755 | val_loss 0.1564 | val_acc 0.9750
No improvement (1/10).
Epoch [4/50] | train_loss 0.1901 | train_acc 0.9651 | val_loss 0.1385 | val_acc 0.9667
No improvement (2/10).
Epoch [5/50] | train_loss 0.2637 | train_acc 0.9734 | val_loss 0.2000 | val_acc 0.9708
No improvement (3/10).
Epoch [6/50] | train_loss 0.1432 | train_acc 0.9760 | val_loss 0.2723 | val_acc 0.9688
No improvement (4/10).
Epoch [7/50] | train_loss 0.1896 | train_acc 0.9714 | val_loss 0.2721 | val_acc 0.9729
No improvement (5/10).
Epoch [8/50] | train_loss 0.1550 | train_acc 0.9797 | val_loss 0.2406 | val_acc 0.9729
No improvement (6/10).
Epoch [9/50] | train_loss 0.1239 | train_acc 0.9802 | val_loss 0.1635 | val_acc 0.9708
No improvement (7/10).
Epoch [10/50] | train_loss 0.1294 | train_acc 0.9786 | v

0,1
epoch,‚ñÅ‚ñÇ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÜ‚ñá‚ñá‚ñà
grad/norm,‚ñÅ‚ñà‚ñà‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÇ
lr,‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/acc,‚ñÅ‚ñÖ‚ñá‚ñÉ‚ñÜ‚ñá‚ñÖ‚ñà‚ñà‚ñà‚ñà‚ñá
train/loss,‚ñà‚ñá‚ñÇ‚ñÑ‚ñà‚ñÇ‚ñÑ‚ñÇ‚ñÅ‚ñÅ‚ñÇ‚ñà
val/acc,‚ñÅ‚ñá‚ñà‚ñá‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñÖ‚ñà
val/loss,‚ñà‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
best_val_acc,0.975
epoch,12.0
grad/norm,0.08168
lr,0.01
train/acc,0.97656
train/loss,0.27348
training_time_sec,156.78121
val/acc,0.97083
val/loss,0.18197


Training complete.


## SVM Model Benchmark


In [42]:
# Stacked groups -> individual trajectory samples
X_samples = []
y_samples = []
for Xg, yg in groups:          # Xg shape (seq_len, K)
    L, K = Xg.shape
    for k in range(K):
        X_samples.append(Xg[:, k:k+1])  # (seq_len, 1)
        y_samples.append(yg)            # or some other per-trajectory label
X_samples = np.stack(X_samples, 0)      # (N_samples, seq_len, 1)
y_samples = np.array(y_samples)
print(f'X_samples shape: {X_samples.shape}, y_samples shape: {y_samples.shape}')

X_samples shape: (3000, 1811, 1), y_samples shape: (3000,)


In [43]:
# # Train/val/test with stratify on group label
X_train, X_test, y_train, y_test = train_test_split(
    X_samples, y_samples, test_size=0.2, random_state=42, stratify=y_samples
)
X_train, X_val,  y_train, y_val  = train_test_split(
    X_train, y_train, test_size=0.2, random_state=42, stratify=y_train
)

print("Data preparation:")
print(f"  Train groups: {len(y_train)}, Val groups: {len(y_val)}, Test groups: {len(y_test)}")


Data preparation:
  Train groups: 1920, Val groups: 480, Test groups: 600


In [44]:
# === Standardise features (across time*batch, per-channel) ===
scaler = StandardScaler()

# Reshape 3D data to 2D for scaling
original_shape_train = X_train.shape
original_shape_val = X_val.shape
original_shape_test = X_test.shape

# Reshape to 2D: (batch * seq_len, features)
X_train_2d = X_train.reshape(-1, X_train.shape[-1])
X_val_2d = X_val.reshape(-1, X_val.shape[-1])
X_test_2d = X_test.reshape(-1, X_test.shape[-1])

# Scale the data
X_train_2d = scaler.fit_transform(X_train_2d)
X_val_2d = scaler.transform(X_val_2d)
X_test_2d = scaler.transform(X_test_2d)

# Reshape back to 3D
X_train = X_train_2d.reshape(original_shape_train)
X_val = X_val_2d.reshape(original_shape_val)
X_test = X_test_2d.reshape(original_shape_test)

print("X_train shape:", X_train.shape)
print("X_val shape:", X_val.shape)
print("X_test shape:", X_test.shape)

X_train shape: (1920, 1811, 1)
X_val shape: (480, 1811, 1)
X_test shape: (600, 1811, 1)


In [45]:
# Torch loaders
batch_size = 64

# === Convert to tensors and loaders ===
X_train_t = torch.tensor(X_train, dtype=torch.float32)
y_train_t = torch.tensor(y_train, dtype=torch.long)
X_val_t   = torch.tensor(X_val,   dtype=torch.float32)
y_val_t   = torch.tensor(y_val,   dtype=torch.long)
X_test_t  = torch.tensor(X_test,  dtype=torch.float32)
y_test_t  = torch.tensor(y_test,  dtype=torch.long)

train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True,  num_workers=4, pin_memory=True)
val_loader   = DataLoader(TensorDataset(X_val_t,   y_val_t),   batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
test_loader  = DataLoader(TensorDataset(X_test_t,  y_test_t),  batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# check the data loaders
for X_batch, y_batch in train_loader:
    print(X_batch.shape, y_batch.shape)
    break 

torch.Size([64, 1811, 1]) torch.Size([64])


In [46]:
from classifiers.svm_classifier import svm_classifier

# Flatten the time series data for SVM (reshape from (n_samples, seq_len, features) to (n_samples, seq_len * features))
X_train_svm = X_train.reshape(X_train.shape[0], -1)
X_test_svm = X_test.reshape(X_test.shape[0], -1)

svm_accuracy = svm_classifier(
    X_train_svm,
    X_test_svm,
    y_train,
    y_test,
)

=== SVM (RBF Kernel) Classification Accuracy: 0.98 ===
