# IY011 Contrastive Learning Model Training (Part 2)

In [1]:
import os
import subprocess
import glob
import tempfile
from pathlib import Path
import numpy as np
import pandas as pd
import time
# plotting 
import matplotlib.pyplot as plt
from visualisation.plots import plot_mRNA_dist, plot_mRNA_trajectory
# ml
import torch, itertools
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, TensorDataset
from models.siamese_transformer import SiameseTransformer
from training.eval import evaluate_model
from training.train import train_model 

# data handling
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

# Build groups
from utils.data_processing import build_groups
import wandb

%load_ext autoreload
%autoreload 2  

In [2]:
DATA_ROOT = Path("/home/ianyang/stochastic_simulations/experiments/EXP-25-IY011/data")
RESULTS_PATH = DATA_ROOT / "IY011_simulation_parameters_sobol.csv" #  this csv file stores all the simulation parameters used
df_params = pd.read_csv(RESULTS_PATH) 
# TRAJ_PATH = [DATA_ROOT / f"mRNA_trajectories_mu{row['mu_target']:.3f}_cv{row['cv_target']:.3f}_tac{row['t_ac_target']:.3f}.csv" for idx, row in df_params.iterrows()] # the trajectories 
TRAJ_PATH = [DATA_ROOT / df_params['trajectory_filename'].values[i] for i in range(len(df_params))]
TRAJ_NPZ_PATH = [traj_file.with_suffix('.npz') for traj_file in TRAJ_PATH]

# extract meta data
parameter_sets = [{
    'sigma_b': row['sigma_b'],
    'sigma_u': row['sigma_u'],
    'rho': row['rho'],
    'd': row['d'],
    'label': 0
} for idx, row in df_params.iterrows()]
time_points = np.arange(0, 3000, 1.0)
size = 1000

In [3]:
import torch
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from torch.utils.data import Dataset, DataLoader
from utils.data_processing import build_groups  # Assumes your file is accessible

class SiameseGroupDataset(Dataset):
    """
    Wraps the output of build_groups (list of (X, y)) into the 
    (x1, x2, y) format required by the SiameseTransformer.
    """
    def __init__(self, groups):
        self.groups = groups

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        # X has shape (seq_len, 2) because num_traj=2
        # y is scalar (0 or 1)
        X, y = self.groups[idx]
        
        # Split the stacked trajectories into two separate inputs
        # Slicing keeps the last dimension: (seq_len, 1)
        x1 = X[:, 0:1]
        x2 = X[:, 1:2]
        
        return (
            torch.tensor(x1, dtype=torch.float32),
            torch.tensor(x2, dtype=torch.float32),
            torch.tensor(y, dtype=torch.float32).unsqueeze(0) # Target must be (1,) for BCE
        )

def siamese_data_prep(
    all_file_paths, 
    batch_size=64, 
    num_groups_train=1000, 
    num_groups_val=200,
    num_groups_test=200,
    num_traj=2,
    seed=42
):
    """
    Prepares DataLoaders for Siamese training using Instance Discrimination.
    
    1. Splits FILES (not samples) into Train/Val/Test.
    2. Calls build_groups on disjoint file sets to create Positive/Negative pairs.
    3. Fits a Global Scaler on Training data only (preserves mu/cv differences).
    4. Wraps in SiameseGroupDataset.
    """
    
    # 1. Split Files (Simulations) to prevent leakage
    #    We want the model to generalize to *new* simulations, not just new crops of known ones.
    train_files, test_files = train_test_split(all_file_paths, test_size=0.2, random_state=seed)
    train_files, val_files  = train_test_split(train_files, test_size=0.2, random_state=seed)
    
    print(f"Files split: {len(train_files)} Train, {len(val_files)} Val, {len(test_files)} Test")

    # 2. Build Groups (Pairs) using your existing logic
    #    Force num_traj=2 to generate PAIRS (x1, x2)
    print("Building training pairs...")
    train_groups = build_groups(train_files, num_groups=num_groups_train, num_traj=num_traj, seed=seed)
    
    print("Building validation pairs...")
    val_groups   = build_groups(val_files,   num_groups=num_groups_val,   num_traj=num_traj, seed=seed)
    
    print("Building test pairs...")
    test_groups  = build_groups(test_files,  num_groups=num_groups_test,  num_traj=num_traj, seed=seed)

    # 3. Global Scaling (Fit on Train, Transform All)
    #    This preserves relative differences (e.g., low mu vs high mu) while standardizing range.
    scaler = StandardScaler()
    
    # Collect all training data to fit scaler
    # train_groups is [(X, y), ...]. X is (seq_len, 2)
    # We flatten everything to (Total_Samples * 2, 1) to fit the scaler to the 'value' distribution
    all_train_values = []
    for X, y in train_groups:
        all_train_values.append(X) # Append (seq_len, 2)
    
    # Concatenate and reshape to (-1, 1) for scaling
    if len(all_train_values) > 0:
        train_stack = np.vstack(all_train_values)
        # Reshape to (N*seq_len*2, 1) - treating every timepoint as a sample for scaling
        scaler.fit(train_stack.reshape(-1, 1))
        print("Global scaler fitted on training data.")
    
    # Apply Transform helper
    def scale_groups(group_list):
        scaled_list = []
        for X, y in group_list:
            # X is (seq_len, 2). Reshape -> Scale -> Reshape back
            shape = X.shape
            X_scaled = scaler.transform(X.reshape(-1, 1)).reshape(shape)
            scaled_list.append((X_scaled, y))
        return scaled_list

    train_groups = scale_groups(train_groups)
    val_groups   = scale_groups(val_groups)
    test_groups  = scale_groups(test_groups)

    # 4. Wrap in Datasets and Loaders
    train_ds = SiameseGroupDataset(train_groups)
    val_ds   = SiameseGroupDataset(val_groups)
    test_ds  = SiameseGroupDataset(test_groups)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader

In [None]:
# === Dataloader hyperparams & data prep ===
batch_size = 64
num_groups_train=4000  
num_groups_val=800
num_groups_test=800
num_traj=2
train_loader, val_loader, test_loader = siamese_data_prep(
    TRAJ_NPZ_PATH,
    batch_size=batch_size,
    num_groups_train=num_groups_train,  # Generate 100 pairs for training
    num_groups_val=num_groups_val,
    num_groups_test=num_groups_test,
    num_traj=num_traj,
)
# === Dataloader hyperparams & data prep ===

V2: Sequence cropping

In [15]:
class SiameseGroupDataset(Dataset):
    def __init__(self, groups, crop_len=200, training=True):
        """
        groups: List of (X, y)
        crop_len: Length of the time window to slice
        training: If True, crops randomly. If False, takes the center crop (deterministic).
        """
        self.groups = groups
        self.crop_len = crop_len
        self.training = training

    def __len__(self):
        return len(self.groups)

    def __getitem__(self, idx):
        # X has shape (Original_Seq_Len, 2)
        X, y = self.groups[idx]
        seq_len = X.shape[0]
        
        if seq_len > self.crop_len:
            if self.training:
                # Random crop for training
                start = np.random.randint(0, seq_len - self.crop_len)
            else:
                # Center crop for validation/testing (deterministic)
                start = (seq_len - self.crop_len) // 2
            
            X_crop = X[start : start + self.crop_len, :]
        else:
            # If sequence is shorter than crop_len, pad it (or just return as is if acceptable)
            # Simple padding logic:
            pad_len = self.crop_len - seq_len
            # Pad with zeros at the end
            X_crop = np.pad(X, ((0, pad_len), (0, 0)), mode='constant')

        # Split into x1 and x2
        return (
            torch.tensor(X_crop[:, 0:1], dtype=torch.float32),
            torch.tensor(X_crop[:, 1:2], dtype=torch.float32),
            torch.tensor(y, dtype=torch.float32).unsqueeze(0)
        )
        
def siamese_data_prep(all_file_paths, 
                      batch_size=64,     
                      num_groups_train=1000, 
                      num_groups_val=200,
                      num_groups_test=200, 
                      seed=42
                      ):
    # 1. Split Files
    train_files, test_files = train_test_split(all_file_paths, test_size=0.2, random_state=seed)
    train_files, val_files  = train_test_split(train_files, test_size=0.2, random_state=seed)

    # 2. Build Groups (Raw length)
    # Note: We generate MANY pairs now (20,000)
    print("Building pairs...")
    train_groups = build_groups(train_files, num_groups=num_groups_train, num_traj=2, seed=seed)
    val_groups   = build_groups(val_files,   num_groups=num_groups_val, num_traj=2, seed=seed)
    test_groups  = build_groups(test_files,  num_groups=num_groups_test, num_traj=2, seed=seed)

    # 3. Global Scaling (Fit on Train)
    scaler = StandardScaler()
    all_train_values = [X for X, y in train_groups]
    if len(all_train_values) > 0:
        train_stack = np.vstack(all_train_values)
        scaler.fit(train_stack.reshape(-1, 1))
    
    def scale_groups(group_list):
        scaled = []
        for X, y in group_list:
            shape = X.shape
            X_sc = scaler.transform(X.reshape(-1, 1)).reshape(shape)
            scaled.append((X_sc, y))
        return scaled

    train_groups = scale_groups(train_groups)
    val_groups   = scale_groups(val_groups)
    test_groups  = scale_groups(test_groups)

    # 4. Create Datasets with CROPPING
    # Use crop_len=200 (or 300) depending on your signal dynamics
    train_ds = SiameseGroupDataset(train_groups, crop_len=200, training=True)
    val_ds   = SiameseGroupDataset(val_groups,   crop_len=200, training=False)
    test_ds  = SiameseGroupDataset(test_groups,  crop_len=200, training=False)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=4)
    val_loader   = DataLoader(val_ds,   batch_size=batch_size, shuffle=False, num_workers=4)
    test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False, num_workers=4)
    
    return train_loader, val_loader, test_loader

V3: log transform global scaling

In [12]:
# Load from modularised code
from utils.data_loader import SiameseGroupDataset, siamese_data_prep

In [6]:
# === Dataloader hyperparams & data prep ===
batch_size = 64
num_groups_train=20000  
num_groups_val=num_groups_train // 10
num_groups_test=num_groups_train // 10
num_traj=2
train_loader, val_loader, test_loader = siamese_data_prep(
    TRAJ_NPZ_PATH,
    batch_size=batch_size,
    num_groups_train=num_groups_train,
    num_groups_val=num_groups_val,
    num_groups_test=num_groups_test,
)
# === Dataloader hyperparams & data prep ===

Files split: 655 Train, 164 Val, 205 Test
Generating 20000 training pairs...


Building positive groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [00:00<00:00, 24483.25it/s]
Building negative groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 10000/10000 [00:00<00:00, 21469.35it/s]


Generating validation/test pairs...


Building positive groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 25987.34it/s]
Building negative groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 20268.70it/s]




Building positive groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 26093.72it/s]
Building negative groups: 100%|‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà| 1000/1000 [00:00<00:00, 20918.18it/s]


Fitting scaler on Log-Transformed training data...
Applying Log-Scaling to all groups...


In [7]:
# === Dataloader hyperparams & data prep ===
# batch_size = 64
# train_loader, val_loader, test_loader = siamese_data_prep(groups, batch_size)
# === Dataloader hyperparams & data prep ===

In [8]:
X1_b, X2_b, y_b = next(iter(train_loader))
print(X1_b.shape, X2_b.shape, y_b.shape)
# want: [B, T, 1], [B, T, 1], [B]

torch.Size([64, 200, 1]) torch.Size([64, 200, 1]) torch.Size([64, 1])


In [9]:
# === Model hyperparams ===
input_size = 1
num_classes = 2
d_model=64
nhead=4
num_layers=2
dropout=0.001
use_conv1d=False 

# model = TransformerClassifier(
#     input_size=input_size,
#     d_model=d_model,
#     nhead=nhead,
#     num_layers=num_layers,
#     num_classes=num_classes,
#     dropout=dropout, 
#     use_conv1d=use_conv1d 
# )

model = SiameseTransformer(
    input_size=input_size,   # each trajectory is (T,1)
    d_model=d_model,
    nhead=nhead,
    num_layers=num_layers,
    dropout=dropout,
    use_conv1d=use_conv1d,
)
# === Model hyperparams ===

# === Training hyperparams ===
epochs = 100
patience = 10
lr = 1e-2
optimizer = optim.Adam(model.parameters(), lr=lr)

### schedulers ### 
# 1. simple scheduler choice
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5) 

# 2. cosine scheduler with warmup, most commonly used for transformer
# total_steps = epochs * len(train_loader)
# warmup_steps = int(0.1 * total_steps)   # 10% warmup (good default)
# #  (from huggingface)
# from transformers import get_cosine_schedule_with_warmup
# scheduler = get_cosine_schedule_with_warmup(
#     optimizer,
#     num_warmup_steps=warmup_steps,
#     num_training_steps=total_steps,
# ) 

loss_fn = nn.BCEWithLogitsLoss()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
grad_clip = 1.0
save_path = None
verbose = True

model.to(device)
# === Training hyperparams ===

SiameseTransformer(
  (backbone): TransformerClassifier(
    (input_proj): Linear(in_features=1, out_features=64, bias=True)
    (pe): PositionalEncoding()
    (encoder): TransformerEncoder(
      (layers): ModuleList(
        (0-1): 2 x TransformerEncoderLayer(
          (self_attn): MultiheadAttention(
            (out_proj): NonDynamicallyQuantizableLinear(in_features=64, out_features=64, bias=True)
          )
          (linear1): Linear(in_features=64, out_features=256, bias=True)
          (dropout): Dropout(p=0.001, inplace=False)
          (linear2): Linear(in_features=256, out_features=64, bias=True)
          (norm1): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (norm2): LayerNorm((64,), eps=1e-05, elementwise_affine=True)
          (dropout1): Dropout(p=0.001, inplace=False)
          (dropout2): Dropout(p=0.001, inplace=False)
        )
      )
    )
    (dropout): Dropout(p=0.001, inplace=False)
    (head): Linear(in_features=64, out_features=2, bias=True

In [10]:
# === wandb config (required for tracking within train_model) ===
wandb_config = {
    "entity": "grignard-reagent",
    "project": "IY011-contrastive-learning",
    "name": f"siamese_logtransform-feature-diff_num_train_groups_{num_groups_train}", # change this to what you want
    "dataset": DATA_ROOT.name,
    "batch_size": batch_size,
    "input_size": input_size,
    "d_model": d_model,
    "nhead": nhead,
    "num_layers": num_layers,
    "num_classes": num_classes,
    "dropout": dropout,
    "use_conv1d": use_conv1d,
    "epochs": epochs,
    "patience": patience,
    "lr": lr,
    "optimizer": type(optimizer).__name__,
    "scheduler": type(scheduler).__name__,
    "loss_fn": type(loss_fn).__name__,
    "model": type(model).__name__,
    "batch_size": train_loader.batch_size,
    "num_traj_per_group": num_traj,
    "num_groups_train": num_groups_train,
    "num_groups_val": num_groups_val,
    "num_groups_test": num_groups_test,
}
# === wandb config === 

In [11]:
from training.train import train_siamese_model
history = train_siamese_model(
    model,
    train_loader,
    val_loader,
    epochs=epochs,
    patience=patience,
    lr=lr,
    optimizer=optimizer,
    scheduler=scheduler,
    loss_fn=loss_fn,
    device=device,
    grad_clip=grad_clip,
    save_path=save_path,
    verbose=verbose,
    wandb_logging=True, # this enables wandb logging within train_model
    wandb_config=wandb_config, # pass the config dictionary
)

[34m[1mwandb[0m: Currently logged in as: [33mgrignardreagent[0m ([33mgrignard-reagent[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting siamese training...
[Siamese] Epoch [1/100] | train_loss 0.6533 | train_acc 0.6273 | val_loss 0.6009 | val_acc 0.6540
[Siamese] Epoch [2/100] | train_loss 0.5957 | train_acc 0.6818 | val_loss 0.5841 | val_acc 0.7005
[Siamese] Epoch [3/100] | train_loss 0.5613 | train_acc 0.7043 | val_loss 0.5263 | val_acc 0.7305
[Siamese] Epoch [4/100] | train_loss 0.5327 | train_acc 0.7313 | val_loss 0.5339 | val_acc 0.7345
[Siamese] Epoch [5/100] | train_loss 0.5119 | train_acc 0.7441 | val_loss 0.5118 | val_acc 0.7390
[Siamese] Epoch [6/100] | train_loss 0.4489 | train_acc 0.7864 | val_loss 0.4371 | val_acc 0.7900
No improvement (1/10).
[Siamese] Epoch [7/100] | train_loss 0.4273 | train_acc 0.7992 | val_loss 0.4504 | val_acc 0.7785
[Siamese] Epoch [8/100] | train_loss 0.4171 | train_acc 0.8054 | val_loss 0.3807 | val_acc 0.8235
No improvement (1/10).
[Siamese] Epoch [9/100] | train_loss 0.4006 | train_acc 0.8165 | val_loss 0.4123 | val_acc 0.8205
[Siamese] Epoch [10/100] | train_loss 0.364

0,1
epoch,‚ñÅ‚ñÅ‚ñÅ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà
grad/norm,‚ñÉ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñÖ‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñá‚ñà‚ñà‚ñÅ‚ñà‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà
lr,‚ñà‚ñà‚ñà‚ñà‚ñÑ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
train/acc,‚ñÅ‚ñÇ‚ñÉ‚ñÑ‚ñÑ‚ñÖ‚ñÜ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
train/loss,‚ñà‚ñá‚ñÜ‚ñÜ‚ñÖ‚ñÑ‚ñÑ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ
val/acc,‚ñÅ‚ñÇ‚ñÉ‚ñÉ‚ñÉ‚ñÖ‚ñÖ‚ñÜ‚ñÜ‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñá‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà‚ñà
val/loss,‚ñà‚ñà‚ñÜ‚ñá‚ñÜ‚ñÖ‚ñÖ‚ñÉ‚ñÑ‚ñÉ‚ñÉ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÇ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ‚ñÅ

0,1
best_val_acc,0.895
epoch,36.0
grad/norm,1.0
lr,4e-05
train/acc,0.89075
train/loss,0.26534
training_time_sec,268.21944
val/acc,0.8935
val/loss,0.26952


Siamese training complete.


In [None]:
#TODO train test split, scaling up 

