In [18]:
"""
DTOF Deep Learning Pipeline (DL_Full_Pipeline)

This module implements a complete, modular deep learning pipeline for predicting
the optical properties (μa, μs′) of homogeneous tissue models from time-resolved
DTOF (Distribution of Time-of-Flight) signals.

The pipeline is designed for reproducible experimentation and supports flexible
preprocessing, multi-channel DTOF representations, log-space regression, and
systematic neural architecture comparisons (e.g. baseline vs dilated CNNs).

-------------------------------------------------------------------------------
Applicability
-------------------------------------------------------------------------------
This framework is intended for studies involving:
    • Monte Carlo simulated DTOFs (homogeneous tissue models)
    • Inversion of DTOFs to estimate optical properties (μa, μs′)
    • Benchmarking preprocessing choices (e.g. Savitzky–Golay filtering)
    • Deep learning architecture optimisation (kernel size, dilation, channels)
    • Comparison of single-, multi-, and hybrid-channel DTOF representations
    • Analysis of long-range temporal sensitivity using dilated convolutions

-------------------------------------------------------------------------------
Supported Input Format
-------------------------------------------------------------------------------
Inputs are provided as MATLAB v7.3 (.mat, HDF5) files containing:
    • t : time axis (T,)
    • X : DTOF signals (N, T)
    • y : labels (N, 2) = [μa, μs′]

Data are loaded using h5py and converted to NumPy/PyTorch tensors with
appropriate MATLAB → NumPy orientation correction.

-------------------------------------------------------------------------------
Supported Model Variants
-------------------------------------------------------------------------------
The pipeline supports dynamic selection of DTOF channel representations.
Channel configuration is selected from the configuration dictionary and handled
automatically by DTOFDataset.

    (1) Single-channel DTOF:
        • Cropped to a fixed temporal window (e.g. 0–5 ns)
        • Savitzky–Golay smoothed and clipped
        • Input shape: (1, T)

    (2) Three-channel temporal bin model:
        • Early, mid, late temporal masks
          (e.g. 0–0.5 ns, 0.5–4 ns, 4–5 ns)
        • Each mask multiplied with the DTOF
        • Input shape: (3, T)

    (3) Four-channel hybrid model:
        • Channel 1: full DTOF
        • Channels 2–4: early / mid / late temporal bins
        • Input shape: (4, T)

-------------------------------------------------------------------------------
Pipeline Components
-------------------------------------------------------------------------------
1. Configuration Dictionary
   Centralises all experiment-level parameters, including:
        • preprocessing settings (SG window/order, clipping, cropping)
        • channel representation mode
        • CNN architecture parameters (kernel sizes, dilation, channels)
        • optimisation settings (learning rate, batch size, epochs)
        • early stopping and checkpointing controls

2. DTOFDataset
   Implements the full preprocessing chain:
        • MATLAB v7.3 (.mat) loading via h5py
        • time-axis cropping
        • Savitzky–Golay smoothing
        • negative-value clipping
        • optional per-trace normalisation
        • dynamic construction of 1, 3, or 4 input channels
        • returns (signal, label) pairs where label = [μa, μs′]

3. CNN Model (Net)
        • 1D convolutional architecture
        • optional use of dilated convolutions for enlarged receptive fields
        • batch normalisation, ReLU activations, and pooling
        • automatic flatten-dimension inference
        • regression head outputs log(μa), log(μs′)

4. Training Loop
        • log-space MSE optimisation
        • linear-space RMSE and mean absolute percentage error reporting
        • GPU/CPU device handling
        • early stopping based on validation loss
        • best-model checkpoint saving
        • automatic plotting of training and validation loss curves (PNG)

5. Evaluation Utilities
        • conversion from log-space predictions to linear units
        • MAE, RMSE, and percentage error computation
        • prediction vs. ground-truth arrays for downstream analysis

-------------------------------------------------------------------------------
Outputs
-------------------------------------------------------------------------------
The pipeline produces:
    • trained model checkpoints (.pt / .pth)
    • training and validation loss curves (.png)
    • RMSE / MAE / percentage error metrics for μa and μs′
    • prediction arrays for further visualisation or analysis

-------------------------------------------------------------------------------
Dependencies
-------------------------------------------------------------------------------
Required:
    • torch
    • numpy
    • scipy (Savitzky–Golay filtering)
    • h5py (MATLAB v7.3 data loading)
    • matplotlib (loss curve plotting)

Optional:
    • seaborn (enhanced visualisation)

-------------------------------------------------------------------------------
Notes
-------------------------------------------------------------------------------
• Models are trained in log-space for numerical stability and scale balancing.
• Metrics are always reported in original (linear) physical units.
• The pipeline is designed for controlled benchmarking of architectural and
  preprocessing choices in DTOF inversion tasks.
"""

'\nDTOF Deep Learning Pipeline (DL_Full_Pipeline)\n\nThis module implements a complete, modular deep learning pipeline for predicting\nthe optical properties (μa, μs′) of homogeneous tissue models from time-resolved\nDTOF (Distribution of Time-of-Flight) signals.\n\nThe pipeline is designed for reproducible experimentation and supports flexible\npreprocessing, multi-channel DTOF representations, log-space regression, and\nsystematic neural architecture comparisons (e.g. baseline vs dilated CNNs).\n\n-------------------------------------------------------------------------------\nApplicability\n-------------------------------------------------------------------------------\nThis framework is intended for studies involving:\n    • Monte Carlo simulated DTOFs (homogeneous tissue models)\n    • Inversion of DTOFs to estimate optical properties (μa, μs′)\n    • Benchmarking preprocessing choices (e.g. Savitzky–Golay filtering)\n    • Deep learning architecture optimisation (kernel size, dil

In [None]:
import os
import json
import re
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
from scipy.optimize import curve_fit

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Dict, Any, Iterable

import h5py

In [11]:
CONFIG = {
    # Input (.mat struct containing t, X, y)
    "mat_path": "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat",

    # Field names inside the .mat struct (safe + explicit)
    # If your .mat loads as {"data": <struct>}, then mat_struct_key="data"
    # If t/X/y are top-level, mat_struct_key can be None.
    "mat_struct_key": None,     # or e.g. "dtof" / "data"
    "t_key": "t",
    "X_key": "X",
    "y_key": "y",

    # Preprocessing
    "sg_window": 21,
    "sg_order": 1,
    "eps": 1e-8,
    "crop_t_max": 5.0,          # crop DTOF to 0–5 ns using t

    # Channels: "single", "early_mid_late", "hybrid_4ch"
    "channel_mode": "single",

    # Data split + loader
    "train_frac": 0.8,
    "batch_size": 32,

    # Model
    "in_channels": 1,           # will be overwritten based on dataset.C
    "output_dim": 2,            # mua + mus'
    "hidden_fc": [128, 64],

    # Training
    "lr": 1e-3,
    "epochs": 20,

    # Outputs / logging
    "run_name": "exp_single_w21_o1",
    "save_dir": "JSON logs",
    "log_path": "JSON logs/dtof_runs_log.json",
}

In [15]:
# 2. DTOFDataset: preprocessing + channel construction 

class DTOFDataset(Dataset):
    """
    DTOF dataset with preprocessing and flexible channel configurations.

    Input format (MATLAB .mat):
        - t: (T_full,) time axis
        - X: (N, T_full) DTOF traces
        - y: (N, 2) labels: y[:,0]=mua, y[:,1]=mus'

    Preprocessing:
        - load .mat
        - crop time axis to [0, crop_t_max]
        - Savitzky-Golay smoothing
        - negative-value clipping
        - per-trace standardisation (currently computed, not applied, matching your code)
        - channel construction:
            * "single"        -> 1 channel, full DTOF
            * "early_mid_late"-> 3 channels (early/mid/late masks)
            * "hybrid_4ch"    -> 4 channels (full + 3 masks)

    Returns:
        signal: (C, T) tensor
        label:  (2,) tensor [mua, mus']
    """

    def __init__(self, mat_path: str, cfg: dict):
        super().__init__()
        self.cfg = cfg

        # --- Load .mat ---
        with h5py.File(mat_path, "r") as f:
            # NOTE: MATLAB stores arrays column-major and often transposed
            t = np.array(f[cfg.get("t_key", "t")]).squeeze()
            X = np.array(f[cfg.get("X_key", "X")])
            y = np.array(f[cfg.get("y_key", "y")])
        
        # MATLAB → NumPy orientation fixes
        if X.shape[0] == t.shape[0]:
            X = X.T   # (T, N) → (N, T)

        if y.shape[0] == 2:
            y = y.T   # (2, N) → (N, 2)

        # --- Crop time axis ---
        t_mask = (t >= 0.0) & (t <= cfg["crop_t_max"])
        time = t[t_mask]            # (T,)
        dtof = X[:, t_mask]         # (N, T)
        labels = y                  # (N, 2)

        # --- Savitzky–Golay smoothing ---
        dtof_smooth = savgol_filter(
            dtof,
            cfg["sg_window"],
            cfg["sg_order"],
            axis=1
        )

        # --- Clip negatives and standardise ---
        eps = cfg["eps"]
        dtof_smooth[dtof_smooth < 0] = eps

        mean = dtof_smooth.mean(axis=1, keepdims=True)
        std = dtof_smooth.std(axis=1, keepdims=True)
        # dtof_std = (dtof_smooth - mean) / (std + eps)  # (N, T)

        # --- Build channels ---
        channels = self.build_channels(time, dtof_smooth, cfg["channel_mode"])
        # channels: (N, C, T)

        self.signals = torch.tensor(channels, dtype=torch.float32)  # (N, C, T)
        self.labels = torch.tensor(labels, dtype=torch.float32)     # (N, 2)

        self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Construct channels based on the chosen mode:
            "single"         -> 1 channel, full DTOF
            "early_mid_late" -> 3 masked channels
            "hybrid_4ch"     -> 1 full + 3 masked = 4 channels
        """
        N, T = dtof.shape

        if mode == "single":
            return dtof[:, None, :]  # (N,1,T)

        early = ((t >= 0.0) & (t < 0.5)).astype(float)
        mid  = ((t >= 0.5) & (t < 4.0)).astype(float)
        late = ((t >= 4.0) & (t <= self.cfg["crop_t_max"])).astype(float)
        masks = np.stack([early, mid, late], axis=0)  # (3,T)

        if mode == "early_mid_late":
            return dtof[:, None, :] * masks[None, :, :]  # (N,3,T)

        if mode == "hybrid_4ch":
            ch_full = dtof[:, None, :]                    # (N,1,T)
            ch_bins = dtof[:, None, :] * masks[None, :, :]  # (N,3,T)
            return np.concatenate([ch_full, ch_bins], axis=1)  # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        return self.signals[idx], self.labels[idx] 

In [4]:
# 3. CNN Model (Net) with flexible in_channels and FC head 

class Net(nn.Module):
    """
    1D CNN for DTOF-based regression to (μa, μs').

    - Supports variable input channels (C) from CONFIG["in_channels"]
    - Uses three conv blocks with BatchNorm, ReLU, MaxPool
    - Automatically computes flatten dimension
    - FC layers controlled by CONFIG["hidden_fc"]
    """

    def __init__(self, cfg: dict, input_length: int):

        super().__init__()

        C = cfg["in_channels"]

        self.conv1 = nn.Sequential(
            nn.Conv1d(C, 32, kernel_size=3, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(32, 32, kernel_size=5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(32, 16, kernel_size=7, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Determine flatten dim automatically
        with torch.no_grad():
            dummy = torch.zeros(1, C, input_length)  # (1,C,T)
            feat = self._forward_features(dummy)
            flatten_dim = feat.shape[1]  # (1, flatten_dim)

        # Build FC head from cfg["hidden_fc"]
        fc_layers = []
        last = flatten_dim
        for h in cfg["hidden_fc"]:
            fc_layers += [nn.Linear(last, h), nn.ReLU()]
            last = h

        fc_layers.append(nn.Linear(last, cfg["output_dim"]))
        self.fc = nn.Sequential(*fc_layers)

    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.flatten(1)  # (batch, -1)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._forward_features(x)
        x = self.fc(x)
        return x

In [5]:
# 4. Training loop with loss tracking, prediction inspection & plotting 

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    cfg: dict,
    device: torch.device
):
    num_epochs = cfg["epochs"]
    save_dir   = cfg["save_dir"]
    run_name   = cfg["run_name"]

    os.makedirs(save_dir, exist_ok=True)

    train_losses = []
    val_losses   = []
    best_val     = float("inf")
    best_path    = os.path.join(save_dir, f"{run_name}_best.pth")

    for epoch in range(num_epochs):
        # ----- TRAIN -----
        model.train()
        running_train = 0.0
        n_train = 0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels  = labels.to(device).float()
            B = labels.size(0)

            optimizer.zero_grad()
            preds = model(signals)
            loss  = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            running_train += loss.item() * B
            n_train += B

        epoch_train = running_train / max(1, n_train)
        train_losses.append(epoch_train)

        # ----- VALIDATION -----
        model.eval()
        running_val = 0.0
        n_val = 0

        with torch.no_grad():
            for batch_idx, (signals, labels) in enumerate(val_loader):
                signals = signals.to(device)
                labels  = labels.to(device).float()
                B = labels.size(0)

                preds = model(signals)
                loss  = loss_fn(preds, labels)
                running_val += loss.item() * B
                n_val += B

                if epoch == 0 and batch_idx == 0:
                    print("\n[VAL SAMPLE] First batch predictions vs labels:")
                    print("Preds (μa, μs'):\n", preds[:5].cpu())
                    print("Labels (μa, μs'):\n", labels[:5].cpu())
                    print("Abs error:\n", (preds[:5] - labels[:5]).abs().cpu())
                    print("---------------------------------------------------")

        epoch_val = running_val / max(1, n_val)
        val_losses.append(epoch_val)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train MSE: {epoch_train:.4f} | Val MSE: {epoch_val:.4f}")

        if epoch_val < best_val:
            best_val = epoch_val
            torch.save(model.state_dict(), best_path)

    # ----- Plot -----
    fig_path = os.path.join(save_dir, f"{run_name}_loss_curves.png")
    plt.figure()
    plt.plot(train_losses, label="Train MSE")
    plt.plot(val_losses,   label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("MSE loss")
    plt.title(f"MSE vs. Epoch: {run_name}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(fig_path, dpi=150)
    plt.close()

    return {
        "train_losses": train_losses,
        "val_losses":   val_losses,
        "best_val":     best_val,
        "best_path":    best_path,
        "loss_plot":    fig_path,
    }


In [6]:
# 5. Model Evaluation and visualisation of performance (MAE / RMSE, percentage error)
class ModelEvaluator:
    """
    Evaluate a trained model on a DataLoader and compute MAE/RMSE for (μa, μs'),
    plus percentage error plots vs actual values.
    """

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader: DataLoader, cfg: dict):
        all_preds  = []
        all_labels = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels  = labels.to(self.device).float()

                preds = self.model(signals)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

        all_preds  = torch.cat(all_preds,  dim=0)  # (N,2)
        all_labels = torch.cat(all_labels, dim=0)  # (N,2)

        # Basic metrics: MAE and RMSE
        abs_err = torch.abs(all_preds - all_labels)
        sq_err  = (all_preds - all_labels) ** 2

        mae  = abs_err.mean(dim=0)             # (2,)
        rmse = torch.sqrt(sq_err.mean(dim=0))  # (2,)

        # ---------- Percentage error vs Actual ----------
        preds_np  = all_preds.detach().numpy()
        labels_np = all_labels.detach().numpy()

        # Avoid division by very small numbers
        eps = cfg.get("eps", 1e-12)
        denom = np.maximum(np.abs(labels_np), eps)  # (N,2)

        pct_error = 100.0 * (preds_np - labels_np) / denom  # signed %
        abs_pct_error = np.abs(pct_error)                   # absolute %

        # Scatter plots: Actual vs % error for μa and μs′
        save_dir = cfg["save_dir"]
        run_name = cfg["run_name"]
        os.makedirs(save_dir, exist_ok=True)

        # error vs true μa plot
        fig_mua = os.path.join(save_dir, f"{run_name}_pct_error_mua.png")
        x = labels_np[:, 0]            # true μa
        y = abs_pct_error[:, 0]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 1.0, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·exp(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μa exponential fit failed:", e)

        plt.xlabel("True μa")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μa: {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mua, dpi=150)
        plt.close()

        # error vs true μs′ plot
        fig_mus = os.path.join(save_dir, f"{run_name}_pct_error_mus.png")
        x = labels_np[:, 1]            # true μs'
        y = abs_pct_error[:, 1]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 0.5, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·e^(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μs' exponential fit failed:", e)

        plt.xlabel("True μs'")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μs': {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mus, dpi=150)
        plt.close()

        metrics = {
            "MAE": mae.numpy(),          # [MAE_mua, MAE_mus]
            "RMSE": rmse.numpy(),        # [RMSE_mua, RMSE_mus]
            "preds": preds_np,
            "labels": labels_np,
            "pct_error": pct_error,      # signed %
            "abs_pct_error": abs_pct_error,
            "pct_error_plots": {
                "mua": fig_mua,
                "mus": fig_mus,
            }
        }
        return metrics


In [7]:
# 6. Logging of runs (config + metrics) to JSON 

def log_run(cfg, results, log_path):
    """
    Append a run entry to a JSON log file, including config and metrics.
    """
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    # Handle both old ('best_path') and new ('model_path') keys safely
    model_path = results.get("model_path", results.get("best_path", None))

    entry = {
    "timestamp": datetime.now().isoformat(timespec="seconds"),
    "run_name": cfg["run_name"],

    # ---- DATA SOURCE ----
    "data_type": "mat",
    "mat_path": cfg.get("mat_path", None),
    "mat_struct_key": cfg.get("mat_struct_key", None),
    "t_key": cfg.get("t_key", "t"),
    "X_key": cfg.get("X_key", "X"),
    "y_key": cfg.get("y_key", "y"),

    # ---- PREPROCESSING / MODEL ----
    "sg_window": cfg["sg_window"],
    "sg_order": cfg["sg_order"],
    "channel_mode": cfg["channel_mode"],
    "crop_t_max": cfg["crop_t_max"],
    "hidden_fc": cfg["hidden_fc"],
    "lr": cfg["lr"],
    "epochs": cfg["epochs"],
    "batch_size": cfg["batch_size"],
    "best_val": results["best_val"],
    "loss_plot": results["loss_plot"],
    "model_path": model_path,
}

    # to store MAE / RMSE 
    if "val_metrics" in results:
        entry["MAE"] = results["val_metrics"]["MAE"].tolist()
        entry["RMSE"] = results["val_metrics"]["RMSE"].tolist()
        entry["pct_error_plots"] = results["val_metrics"]["pct_error_plots"]

    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            data = json.load(f)
        data.append(entry)
    else:
        data = [entry]

    with open(log_path, "w") as f:
        json.dump(data, f, indent=2)




In [40]:
def get_in_channels(mode: str) -> int:
    """
    Returns number of input channels for the CNN depending on the preprocessing mode.

    mode options:
        "single"          -> 1 channel  (raw DTOF only)
        "early_mid_late"  -> 3 channels (masked temporal bins)
        "hybrid_4ch"      -> 4 channels (raw + 3 temporal bins)
    """
    if mode == "single":
        return 1
    elif mode == "early_mid_late":
        return 3
    elif mode == "hybrid_4ch":
        return 4
    else:
        raise ValueError(f"Unknown channel_mode: {mode}")


In [8]:
# 7. Single experiment runner

def run_experiment(cfg):

    cfg = dict(cfg)

    # 1. Build dataset + loaders  (MAT-based)
    dataset = DTOFDataset(
        mat_path=cfg["mat_path"],
        cfg=cfg,
    )

    n_train = int(cfg["train_frac"] * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg["batch_size"], shuffle=False)

    # 2. Build model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Derive in_channels from the dataset (most reliable)
    cfg["in_channels"] = dataset.C

    model = Net(
        cfg=cfg,
        input_length=dataset.T,
    ).to(device)

    # Loss + optimizer
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])

    # 3. Training Loop
    best_val = float("inf")
    save_path = None

    train_losses = []
    val_losses = []

    os.makedirs("Best paths", exist_ok=True)
    os.makedirs("Model evaluation figs", exist_ok=True)

    for epoch in range(cfg["epochs"]):
        model.train()
        running_loss = 0.0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels = labels.to(device).float()

            optimizer.zero_grad()
            preds = model(signals)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # ---- Validation ----
        model.eval()
        val_running = 0.0
        with torch.no_grad():
            for signals, labels in val_loader:
                signals = signals.to(device)
                labels = labels.to(device).float()

                preds = model(signals)
                val_running += loss_fn(preds, labels).item()

        val_loss = val_running / len(val_loader)
        val_losses.append(val_loss)

        # ---- Save best model ----
        if val_loss < best_val:
            best_val = val_loss
            save_path = f"Best paths/{cfg['run_name']}_best.pth"
            torch.save(model.state_dict(), save_path)

    # ---- Plot curves for this run ----
    fig_path = f"Model evaluation figs/{cfg['run_name']}_loss_curves.png"
    plt.figure()
    plt.plot(train_losses, label="Train MSE")
    plt.plot(val_losses, label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.title(f"Loss Curves: {cfg['run_name']}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(fig_path, dpi=150)
    plt.close()

    # ---- Evaluate best model on validation set with ModelEvaluator ----
    best_model = Net(cfg=cfg, input_length=dataset.T).to(device)

    if save_path is not None:
        best_model.load_state_dict(torch.load(save_path, map_location=device))

    evaluator = ModelEvaluator(model=best_model, device=device)

    eval_cfg = {
        "run_name": cfg["run_name"],
        "save_dir": cfg.get("eval_save_dir", "Model evaluation figs"),
        "eps": cfg.get("eps", 1e-8),
    }

    val_metrics = evaluator.evaluate(val_loader, cfg=eval_cfg)

    return {
        "run_name": cfg["run_name"],
        "best_val": best_val,
        "model_path": save_path,
        "cfg": cfg,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "loss_plot": fig_path,
        "val_metrics": val_metrics,
    }

In [9]:
# 8. Grid search over a few configurations 

def grid_search():
    base_cfg = CONFIG

    windows = [11, 21, 31, 41]
    orders  = [1,  2,  3,  4]
    modes   = ["single", "early_mid_late", "hybrid_4ch"]
    lrs     = [1e-3, 3e-4]

    results = []

    for w in windows:
        for o in orders:
            if w <= o:
                continue  # SG requires window_length > polyorder
            for mode in modes:
                for lr in lrs:
                    cfg = dict(base_cfg)
                    cfg["sg_window"]    = w
                    cfg["sg_order"]     = o
                    cfg["channel_mode"] = mode
                    cfg["lr"]           = lr

                    lr_tag = f"{lr:.0e}"
                    cfg["run_name"] = f"{mode}_w{w}_o{o}_lr{lr_tag}"

                    print(f"\n=== Running {cfg['run_name']} ===")
                    res = run_experiment(cfg)

                    results.append(res)

                    # log run to JSON (should now include mat_path in log_run)
                    log_run(cfg, res, cfg["log_path"])

    results_sorted = sorted(results, key=lambda x: x["best_val"])

    print("\n==================== GRID SEARCH SUMMARY ====================")
    for r in results_sorted[:10]:
        best_val = r["best_val"]
        mae      = r["val_metrics"]["MAE"]
        rmse     = r["val_metrics"]["RMSE"]

        print(
            f"{r['run_name']} | "
            f"val_MSE={best_val:.6f} | "
            f"MAE=(μa={mae[0]:.4f}, μs'={mae[1]:.4f}) | "
            f"RMSE=(μa={rmse[0]:.4f}, μs'={rmse[1]:.4f})"
        )

    best = results_sorted[0]
    print("\nBEST MODEL (by val MSE):")
    print(best["run_name"])
    print(f"  val_MSE = {best['best_val']:.6f}")
    print(f"  model_path = {best['model_path']}")
    print(f"  loss_plot  = {best['loss_plot']}")
    print(f"  MAE  = {best['val_metrics']['MAE']}")
    print(f"  RMSE = {best['val_metrics']['RMSE']}")

    return results_sorted


In [16]:
# Single experiment run 
if __name__ == "__main__":
    run_info = run_experiment(CONFIG)
    print(run_info["run_name"], run_info["best_val"])

exp_single_w21_o1 12.582507729530334


In [17]:
# Run grid search 
if __name__ == "__main__": 
    results = grid_search()

    # print the best configuration 
    best = min(results, key=lambda r: r["best_val"])
    print("\nBEST MODEL FROM GRID SEARCH:")
    print(best["run_name"], best["best_val"])


=== Running single_w11_o1_lr1e-03 ===

=== Running single_w11_o1_lr3e-04 ===

=== Running early_mid_late_w11_o1_lr1e-03 ===

=== Running early_mid_late_w11_o1_lr3e-04 ===

=== Running hybrid_4ch_w11_o1_lr1e-03 ===

=== Running hybrid_4ch_w11_o1_lr3e-04 ===

=== Running single_w11_o2_lr1e-03 ===

=== Running single_w11_o2_lr3e-04 ===

=== Running early_mid_late_w11_o2_lr1e-03 ===

=== Running early_mid_late_w11_o2_lr3e-04 ===

=== Running hybrid_4ch_w11_o2_lr1e-03 ===
[WARN] μs' exponential fit failed: Optimal parameters not found: Number of calls to function has reached maxfev = 5000.

=== Running hybrid_4ch_w11_o2_lr3e-04 ===

=== Running single_w11_o3_lr1e-03 ===

=== Running single_w11_o3_lr3e-04 ===

=== Running early_mid_late_w11_o3_lr1e-03 ===

=== Running early_mid_late_w11_o3_lr3e-04 ===

=== Running hybrid_4ch_w11_o3_lr1e-03 ===

=== Running hybrid_4ch_w11_o3_lr3e-04 ===

=== Running single_w11_o4_lr1e-03 ===

=== Running single_w11_o4_lr3e-04 ===

=== Running early_mid_late_