In [31]:
"""
DTOF Deep Learning Pipeline

This module implements a complete, modular DL pipeline for predicting
the optical properties (μa, μs') of homogeneous tissue models from time resolved 
DTOF signals. The system is designed for reproducible experimentation and supports highly flexible preprocessing and neural architecture configurations 

Applicability 

This framework is intended for studies involving: 
    * Monte Carlo simulated DTOFs (homogeneous or multilayer)
    * inversion of DTOFs to estimate μa, μs'
    * benchmarking the effects of preprocessing choices such as SG filters 
    * deep learning architecture optimisation 
    * systematic comparison of multi-channel DTOF representaitons 

Supported Model Variants 

This pipeline allows dynamic selection of the number of input channels. 3 commonly used configurations are supported: 

    (1) Single channel DTOF:
        * DTOF cropped to 0-5 ns 
        * Standardised after smoothing 
        * Input shape: (1, T)

    (2) 3 channel temporal bin model: 
        * Early, mid, and late temporal masks (0 - 0.5 ns, 0.5 - 4 ns, 4 -5 ns)
        * Each bin multiplied with the DTOF to form 3 channels 
        * Input shape: (3, T)
    
    (3) 4 channel hybrid model:
        * Channel 1: Full DTOF (0 - 0.5 ns)
        * Channel 2-4: Early, mid, late temporal bins
        * Input shape: (4, T)

Channel configuration is selected from the CONFIG block and handled automatically by the DTOFDataset 

Pipeline Components 

This module consists of: 

1. CONFIG dictionary: 
    Centralises all experiment-level parameters including: 
        - preprocessing settings (SG frame / order, clipping, cropping)
        - number of input channels 
        - temporal mask definitions 
        - CNN architecture (layer widths, output dimensions)
        - learning rate, batch size, epochs 

2. DTOFDataset: 
    Implements the preprocessing chain: 
        - CSV loading 
        - Savitzky-Golay smoothing 
        - negative-value clipping 
        - mean / std standardisation 
        - crop to 0 - 5 ns
        - dynamic construction of 1, 3, or 4 input channels 
        - returns (signal, [mua, mus])

3. Net (CNN architecture): 
    * Input channels match CONFIG["in_channels"]
    * 3 Conv blocks with BatchNorm, ReLU, and pooling 
    * Automatic flatten dimension calculatino 
    * Fully connected regression head configurable via CONFIG 
    * Output dimension fixed at 2 (μa, μs′)

4. Training loop: 
    * GPU/CPU detection 
    * Train/validation loss computation 
    * best-model checkpoint saving 
    * optional printing of sample predictions 
    * automatic plotting of train/val loss curves (PNG)

5. Evaluation and logging: 
    Provides: 
        - MAE and RMSE for μa, μs′
        - prediction vs.ground truth arrays 
        - JSON logging of each run's configuration and metrics 
        - loss curve visualisation per run 
Dependencies 

Required: 
    * torch 
    * numpy
    * pandas 
    * scipy (Savitzky-Golay filtering)
    * matplotlib (loss curve plotting)

Recommended: 
    * seaborn (optional visualisation enhancements)

Inputs & Outputs 

Inputs: 
    * DTOFs_Homo_raw.csv - raw DTOF data with time in column 0
    * DTOFs_Homo_labels.csv - extracted (μa, μs') labels 

Outputs: 
    * training and validation loss curves .png
    * MAE / RMSE metrics 
    * prediction arrays for downstream plotting 
    * JSON log for all runs

"""

'\nDTOF Deep Learning Pipeline\n\nThis module implements a complete, modular DL pipeline for predicting\nthe optical properties (μa, μs\') of homogeneous tissue models from time resolved \nDTOF signals. The system is designed for reproducible experimentation and supports highly flexible preprocessing and neural architecture configurations \n\nApplicability \n\nThis framework is intended for studies involving: \n    * Monte Carlo simulated DTOFs (homogeneous or multilayer)\n    * inversion of DTOFs to estimate μa, μs\'\n    * benchmarking the effects of preprocessing choices such as SG filters \n    * deep learning architecture optimisation \n    * systematic comparison of multi-channel DTOF representaitons \n\nSupported Model Variants \n\nThis pipeline allows dynamic selection of the number of input channels. 3 commonly used configurations are supported: \n\n    (1) Single channel DTOF:\n        * DTOF cropped to 0-5 ns \n        * Standardised after smoothing \n        * Input shape: 

In [32]:
import os
import json
import re
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
from scipy.optimize import curve_fit

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Dict, Any, Iterable

In [33]:
# 1. Single configuration: initial starting point

CONFIG = {
    "csv_path": "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_raw.csv", 
    "label_csv_path": "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_labels.csv", 

    # Preprocessing 
    "sg_window": 21, 
    "sg_order": 1, 
    "eps": 1e-8, 
    "crop_t_max": 5.0,      # crop DTOF to 0-5ns

    # Channels: "single", "early_mid_late", "hybrid_4ch"
    "channel_mode": "single",

    # Data split + loader
    "train_frac": 0.8,
    "batch_size": 32,

    # Model
    "in_channels": 1,            # will be overwritten based on dataset.C
    "output_dim": 2,             # mua + mus
    "hidden_fc": [128, 64],

    # Training
    "lr": 1e-3,
    "epochs": 20,

    # Outputs / logging
    "run_name": "exp_single_w21_o1",
    "save_dir": "JSON logs",
    "log_path": "JSON logs/dtof_runs_log.json",
}

In [34]:
# 2. Label extraction helper

def extract_labels_from_dtof_csv(csv_path: str, label_csv_path: str):
    """
    Extracts (mua, mus) labels from DTOF column headers in the raw CSV.

    Expects column names: "mua: 0.005  mus: 2.0"
    Writes a CSV of labels with columns ["mua", "mus"].
    """

    df = pd.read_csv(csv_path)
    dtof_columns = df.columns[1:]  # skip time column
    labels = []

    for col in dtof_columns:
            col_clean = str(col).strip()
            match = re.search(r"mua:\s*([0-9.]+)\s+mus:\s*([0-9.]+)", col_clean)
            if not match:
                raise ValueError(f"Could not parse mua/mus from column '{col}'")
            mua_val = float(match.group(1))
            mus_val = float(match.group(2))
            labels.append((mua_val, mus_val))   
    
    label_df = pd.DataFrame(labels, columns=["mua", "mus"])
    label_df.to_csv(label_csv_path, index=False)
    print(f"[INFO] Saved labels to {label_csv_path} (N={len(label_df)})")

In [35]:
# 3. DTOFDataset: preprocessing + channel construction 

class DTOFDataset(Dataset):
    """
    DTOF dataset with preprocessing and flexible channel configurations.

    Preprocessing:
        - load CSV
        - crop time axis to [0, crop_t_max]
        - Savitzky-Golay smoothing
        - negative-value clipping
        - per-trace standardisation
        - channel construction:
            * "single"        -> 1 channel, full DTOF
            * "early_mid_late"-> 3 channels (early/mid/late masks)
            * "hybrid_4ch"    -> 4 channels (full + 3 masks)

    Returns:
        signal: (C, T) tensor (C = n. of channels)
        label:  (2,) tensor [mua, mus]
    """

    def __init__(self, csv_path: str, labels: np.ndarray, cfg: dict):
        super().__init__()
        self.cfg = cfg

        df = pd.read_csv(csv_path)

        # Time and DTOFs
        time_full = df.iloc[:, 0].values               # (T_full,)
        dtof_full = df.iloc[:, 1:].values.T            # (N, T_full)
        N, T_full = dtof_full.shape

        # Crop time axis
        t_mask = (time_full >= 0.0) & (time_full <= cfg["crop_t_max"])
        time = time_full[t_mask]                       # (T,)
        dtof = dtof_full[:, t_mask]                    # (N, T)

        # Savitzky–Golay smoothing
        dtof_smooth = savgol_filter(
            dtof,
            cfg["sg_window"],
            cfg["sg_order"],
            axis=1
        )

        # Clip negatives and standardise
        eps = cfg["eps"]
        dtof_smooth[dtof_smooth < 0] = eps

        mean = dtof_smooth.mean(axis=1, keepdims=True)
        std = dtof_smooth.std(axis=1, keepdims=True)
        #dtof_std = (dtof_smooth - mean) / (std + eps)  # (N, T)

        # Build channels
        channels = self.build_channels(time, dtof_smooth, cfg["channel_mode"])
        # channels: (N, C, T)

        self.signals = torch.tensor(channels, dtype=torch.float32)  # (N,C,T)
        self.labels = torch.tensor(labels, dtype=torch.float32)     # (N,2)

        self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Construct channels based on the chosen mode:
            "single"         -> 1 channel, full DTOF
            "early_mid_late" -> 3 masked channels
            "hybrid_4ch"     -> 1 full + 3 masked = 4 channels
        """
        N, T = dtof.shape

        if mode == "single":
            # (N, 1, T)
            return dtof[:, None, :]

        # Define early/mid/late masks within cropped time
        early = ((t >= 0.0) & (t < 0.5)).astype(float)
        mid = ((t >= 0.5) & (t < 4.0)).astype(float)
        late = ((t >= 4.0) & (t <= self.cfg["crop_t_max"])).astype(float)
        masks = np.stack([early, mid, late], axis=0)  # (3, T)

        if mode == "early_mid_late":
            # Multiply each DTOF by masks to get 3 channels
            out = dtof[:, None, :] * masks[None, :, :]  # (N,3,T)
            return out

        if mode == "hybrid_4ch":
            # Channel 1: full DTOF
            ch_full = dtof[:, None, :]                   # (N,1,T)
            ch_bins = dtof[:, None, :] * masks[None, :, :]  # (N,3,T)
            return np.concatenate([ch_full, ch_bins], axis=1)  # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        return self.signals[idx], self.labels[idx]   

In [36]:
# 4. CNN Model (Net) with flexible in_channels and FC head 

class Net(nn.Module):
    """
    1D CNN for DTOF-based regression to (μa, μs').

    - Supports variable input channels (C) from CONFIG["in_channels"]
    - Uses three conv blocks with BatchNorm, ReLU, MaxPool
    - Automatically computes flatten dimension
    - FC layers controlled by CONFIG["hidden_fc"]
    """

    def __init__(self, cfg: dict, input_length: int):

        super().__init__()

        C = cfg["in_channels"]

        self.conv1 = nn.Sequential(
            nn.Conv1d(C, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(32, 32, kernel_size=5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(32, 16, kernel_size=3, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Determine flatten dim automatically
        with torch.no_grad():
            dummy = torch.zeros(1, C, input_length)  # (1,C,T)
            feat = self._forward_features(dummy)
            flatten_dim = feat.shape[1]  # (1, flatten_dim)

        # Build FC head from cfg["hidden_fc"]
        fc_layers = []
        last = flatten_dim
        for h in cfg["hidden_fc"]:
            fc_layers += [nn.Linear(last, h), nn.ReLU()]
            last = h

        fc_layers.append(nn.Linear(last, cfg["output_dim"]))
        self.fc = nn.Sequential(*fc_layers)

    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.flatten(1)  # (batch, -1)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._forward_features(x)
        x = self.fc(x)
        return x

In [37]:
# 5. Training loop with loss tracking, prediction inspection & plotting 

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    cfg: dict,
    device: torch.device
):
    """
    Train the model with tracking of train/val MSE loss curves, saving best model,
    and basic prediction inspection on the first validation batch.
    """
    num_epochs = cfg["epochs"]
    save_dir   = cfg["save_dir"]
    run_name   = cfg["run_name"]

    os.makedirs(save_dir, exist_ok=True)

    train_losses = []  # MSE(train) per epoch
    val_losses   = []  # MSE(val) per epoch
    best_val     = float("inf")
    best_path    = os.path.join(save_dir, f"{run_name}_best.pth")

    for epoch in range(num_epochs):
        # ----- TRAIN -----
        model.train()
        running_train = 0.0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels  = labels.to(device).float()

            optimizer.zero_grad()
            preds = model(signals)            # (B, 2)
            loss  = loss_fn(preds, labels)    # this is MSE
            loss.backward()
            optimizer.step()

            running_train += loss.item()

        epoch_train = running_train / len(train_loader)
        train_losses.append(epoch_train)

        # ----- VALIDATION -----
        model.eval()
        running_val = 0.0

        with torch.no_grad():
            for batch_idx, (signals, labels) in enumerate(val_loader):
                signals = signals.to(device)
                labels  = labels.to(device).float()

                preds = model(signals)
                loss  = loss_fn(preds, labels)  # MSE on val
                running_val += loss.item()

                if epoch == 0 and batch_idx == 0:
                    print("\n[VAL SAMPLE] First batch predictions vs labels:")
                    print("Preds (μa, μs'):\n", preds[:5].cpu())
                    print("Labels (μa, μs'):\n", labels[:5].cpu())
                    print("Abs error:\n", (preds[:5] - labels[:5]).abs().cpu())
                    print("---------------------------------------------------")

        epoch_val = running_val / len(val_loader)
        val_losses.append(epoch_val)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train MSE: {epoch_train:.4f} | Val MSE: {epoch_val:.4f}")

        if epoch_val < best_val:
            best_val = epoch_val
            torch.save(model.state_dict(), best_path)

    # ----- Plot MSE vs epoch (train + val) -----
    fig_path = os.path.join(save_dir, f"{run_name}_loss_curves.png")
    plt.figure()
    plt.plot(train_losses, label="Train MSE")
    plt.plot(val_losses,   label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("MSE loss")
    plt.title(f"MSE vs. Epoch: {run_name}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(fig_path, dpi=150)
    plt.close()

    results = {
        "train_losses": train_losses,
        "val_losses":   val_losses,
        "best_val":     best_val,
        "best_path":    best_path,
        "loss_plot":    fig_path,
    }
    return results



In [38]:
# 6. Model Evaluation and visualisation of performance (MAE / RMSE, percentage error)
class ModelEvaluator:
    """
    Evaluate a trained model on a DataLoader and compute MAE/RMSE for (μa, μs'),
    plus percentage error plots vs actual values.
    """

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader: DataLoader, cfg: dict):
        all_preds  = []
        all_labels = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels  = labels.to(self.device).float()

                preds = self.model(signals)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

        all_preds  = torch.cat(all_preds,  dim=0)  # (N,2)
        all_labels = torch.cat(all_labels, dim=0)  # (N,2)

        # Basic metrics: MAE and RMSE
        abs_err = torch.abs(all_preds - all_labels)
        sq_err  = (all_preds - all_labels) ** 2

        mae  = abs_err.mean(dim=0)             # (2,)
        rmse = torch.sqrt(sq_err.mean(dim=0))  # (2,)

        # ---------- Percentage error vs Actual ----------
        preds_np  = all_preds.numpy()
        labels_np = all_labels.numpy()

        # Avoid division by very small numbers
        eps = 1e-8
        denom = np.maximum(np.abs(labels_np), eps)  # (N,2)

        pct_error = 100.0 * (preds_np - labels_np) / denom  # signed %
        abs_pct_error = np.abs(pct_error)                   # absolute %

        # Scatter plots: Actual vs % error for μa and μs′
        save_dir = cfg["save_dir"]
        run_name = cfg["run_name"]
        os.makedirs(save_dir, exist_ok=True)

        # error vs true μa plot
        fig_mua = os.path.join(save_dir, f"{run_name}_pct_error_mua.png")
        x = labels_np[:, 0]            # true μa
        y = abs_pct_error[:, 0]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 1.0, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·exp(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μa exponential fit failed:", e)

        plt.xlabel("True μa")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μa: {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mua, dpi=150)
        plt.close()

        # error vs true μs′ plot
        fig_mus = os.path.join(save_dir, f"{run_name}_pct_error_mus.png")
        x = labels_np[:, 1]            # true μs'
        y = abs_pct_error[:, 1]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 0.5, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·e^(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μs' exponential fit failed:", e)

        plt.xlabel("True μs'")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μs': {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mus, dpi=150)
        plt.close()

        metrics = {
            "MAE": mae.numpy(),          # [MAE_mua, MAE_mus]
            "RMSE": rmse.numpy(),        # [RMSE_mua, RMSE_mus]
            "preds": preds_np,
            "labels": labels_np,
            "pct_error": pct_error,      # signed %
            "abs_pct_error": abs_pct_error,
            "pct_error_plots": {
                "mua": fig_mua,
                "mus": fig_mus,
            }
        }
        return metrics


In [39]:
# 7. Logging of runs (config + metrics) to JSON 

def log_run(cfg, results, log_path):
    """
    Append a run entry to a JSON log file, including config and metrics.
    """
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    # Handle both old ('best_path') and new ('model_path') keys safely
    model_path = results.get("model_path", results.get("best_path", None))

    entry = {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "run_name": cfg["run_name"],
        "sg_window": cfg["sg_window"],
        "sg_order": cfg["sg_order"],
        "channel_mode": cfg["channel_mode"],
        "crop_t_max": cfg["crop_t_max"],
        "hidden_fc": cfg["hidden_fc"],
        "lr": cfg["lr"],
        "epochs": cfg["epochs"],
        "batch_size": cfg["batch_size"],
        "best_val": results["best_val"],
        "loss_plot": results["loss_plot"],
        "model_path": model_path,
    }

    # to store MAE / RMSE 
    if "val_metrics" in results:
        entry["MAE"] = results["val_metrics"]["MAE"].tolist()
        entry["RMSE"] = results["val_metrics"]["RMSE"].tolist()
        entry["pct_error_plots"] = results["val_metrics"]["pct_error_plots"]

    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            data = json.load(f)
        data.append(entry)
    else:
        data = [entry]

    with open(log_path, "w") as f:
        json.dump(data, f, indent=2)




In [40]:
def get_in_channels(mode: str) -> int:
    """
    Returns number of input channels for the CNN depending on the preprocessing mode.

    mode options:
        "single"          -> 1 channel  (raw DTOF only)
        "early_mid_late"  -> 3 channels (masked temporal bins)
        "hybrid_4ch"      -> 4 channels (raw + 3 temporal bins)
    """
    if mode == "single":
        return 1
    elif mode == "early_mid_late":
        return 3
    elif mode == "hybrid_4ch":
        return 4
    else:
        raise ValueError(f"Unknown channel_mode: {mode}")


In [41]:
# 8. Single experiment runner

def run_experiment(cfg):

    # Work on a copy of the cfg to avoid mutating the single CONFIG instantiation to allow for debugging / individual run visualisation 
    cfg = dict(cfg)

    # 1. Build dataset + loaders
    labels = pd.read_csv(cfg["label_csv_path"]).values # shape (N, 2)
    dataset = DTOFDataset(
        csv_path=cfg["csv_path"],
        labels=labels, # shape (N,2)
        cfg = cfg, 
    )

    n_train = int(0.8 * len(dataset))
    n_val = len(dataset) - n_train
    train_dataset, val_dataset = random_split(dataset, [n_train, n_val])

    train_loader = DataLoader(train_dataset, batch_size=cfg["batch_size"], shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=cfg["batch_size"], shuffle=False)

    # 2. Build model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Derive in_channels from channel_mode and store in cfg for Net (the copy)
    in_channels = get_in_channels(cfg["channel_mode"])
    cfg["in_channels"] = in_channels

    model = Net(
        cfg = cfg, 
        input_length = dataset.T, 
    ).to(device)

    # Loss + optimizer
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])

    # 3. Training Loop
    best_val = float("inf")       
    save_path = None                 

    train_losses = []
    val_losses = []

    for epoch in range(cfg["epochs"]):
        model.train()
        running_loss = 0.0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()
            preds = model(signals)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        train_loss = running_loss / len(train_loader)
        train_losses.append(train_loss)

        # ---- Validation ----
        model.eval()
        val_running = 0.0
        with torch.no_grad():
            for signals, labels in val_loader:
                signals = signals.to(device)
                labels = labels.to(device)

                preds = model(signals)
                val_running += loss_fn(preds, labels).item()

        val_loss = val_running / len(val_loader)
        val_losses.append(val_loss)

        # ---- Save best model ----
        if val_loss < best_val:
            best_val = val_loss
            save_path = f"Best paths/{cfg['run_name']}_best.pth"   
            torch.save(model.state_dict(), save_path)

    # ---- Plot curves for this run ----
    fig_path = f"Model evaluation figs/{cfg['run_name']}_loss_curves.png"
    plt.figure()
    plt.plot(train_losses, label="Train MSE")
    plt.plot(val_losses, label="Val MSE")
    plt.xlabel("Epoch")
    plt.ylabel("MSE Loss")
    plt.title(f"Loss Curves: {cfg['run_name']}")
    plt.legend()
    plt.grid(True)
    plt.tight_layout()
    plt.savefig(fig_path, dpi=150)
    plt.close()

    # ---- Evaluate best model on validation set with ModelEvaluator ----
    # Rebuild best model and load best weights
    best_model = Net(
        cfg=cfg,
        input_length=dataset.T,
    ).to(device)

    if save_path is not None:
        best_model.load_state_dict(torch.load(save_path, map_location=device))

    evaluator = ModelEvaluator(model=best_model, device=device)

    eval_cfg = {
        "run_name": cfg["run_name"],
        "save_dir": cfg.get("eval_save_dir", "Model evaluation figs"),
    }

    val_metrics = evaluator.evaluate(val_loader, cfg=eval_cfg)

    # 4. Return results (including metrics)
    return {
        "run_name": cfg["run_name"],
        "best_val": best_val,         # validation MSE at best epoch
        "model_path": save_path,      # path to best model weights
        "cfg": cfg,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "loss_plot": fig_path,
        "val_metrics": val_metrics,   # MAE, RMSE, % errors, plot paths
    }


In [42]:
# 9. Simple grid search over a few configurations 

def grid_search():
    base_cfg = CONFIG

    windows = [11, 21, 31, 41]
    orders  = [1,  2,  3,  4]
    modes   = ["single", "early_mid_late", "hybrid_4ch"]
    lrs     = [1e-3, 3e-4]

    results = []

    for w in windows:
        for o in orders:
            for mode in modes:
                for lr in lrs:
                    cfg = dict(base_cfg)
                    cfg["sg_window"]   = w
                    cfg["sg_order"]    = o
                    cfg["channel_mode"] = mode
                    cfg["lr"]          = lr
                    cfg["run_name"]    = f"{mode}_w{w}_o{o}_lr{lr}"

                    print(f"\n=== Running {cfg['run_name']} ===")
                    res = run_experiment(cfg)

                    # Attach best_val directly at top-level for convenience (already done in res)
                    results.append(res)

                    # log run to JSON (now includes val_metrics)
                    log_run(cfg, res, cfg["log_path"])

    # Sort by validation loss (MSE). NOTE: key name is 'best_val', not 'val_loss'.
    results_sorted = sorted(results, key=lambda x: x["best_val"])

    print("\n==================== GRID SEARCH SUMMARY ====================")
    for r in results_sorted[:10]:  # print top 10
        best_val = r["best_val"]
        mae      = r["val_metrics"]["MAE"]   # [MAE_mua, MAE_mus]
        rmse     = r["val_metrics"]["RMSE"]  # [RMSE_mua, RMSE_mus]

        print(
            f"{r['run_name']} | "
            f"val_MSE={best_val:.6f} | "
            f"MAE=(μa={mae[0]:.4f}, μs'={mae[1]:.4f}) | "
            f"RMSE=(μa={rmse[0]:.4f}, μs'={rmse[1]:.4f})"
        )

    best = results_sorted[0]
    print("\nBEST MODEL (by val MSE):")
    print(best["run_name"])
    print(f"  val_MSE = {best['best_val']:.6f}")
    print(f"  model_path = {best['model_path']}")
    print(f"  loss_plot  = {best['loss_plot']}")
    print(f"  MAE  = {best['val_metrics']['MAE']}")
    print(f"  RMSE = {best['val_metrics']['RMSE']}")

    return results_sorted


In [43]:
# Single experiment run 
if __name__ == "__main__":
    run_info = run_experiment(CONFIG)
    print(run_info["run_name"], run_info["best_val"])

[WARN] μs' exponential fit failed: Optimal parameters not found: Number of calls to function has reached maxfev = 5000.
exp_single_w21_o1 0.04363725893199444


In [44]:
# Run grid search 
if __name__ == "__main__": 
    results = grid_search()

    # print the best configuration 
    best = min(results, key=lambda r: r["best_val"])
    print("\nBEST MODEL FROM GRID SEARCH:")
    print(best["run_name"], best["best_val"])


=== Running single_w11_o1_lr0.001 ===

=== Running single_w11_o1_lr0.0003 ===

=== Running early_mid_late_w11_o1_lr0.001 ===

=== Running early_mid_late_w11_o1_lr0.0003 ===

=== Running hybrid_4ch_w11_o1_lr0.001 ===

=== Running hybrid_4ch_w11_o1_lr0.0003 ===

=== Running single_w11_o2_lr0.001 ===

=== Running single_w11_o2_lr0.0003 ===

=== Running early_mid_late_w11_o2_lr0.001 ===

=== Running early_mid_late_w11_o2_lr0.0003 ===

=== Running hybrid_4ch_w11_o2_lr0.001 ===

=== Running hybrid_4ch_w11_o2_lr0.0003 ===

=== Running single_w11_o3_lr0.001 ===

=== Running single_w11_o3_lr0.0003 ===

=== Running early_mid_late_w11_o3_lr0.001 ===

=== Running early_mid_late_w11_o3_lr0.0003 ===

=== Running hybrid_4ch_w11_o3_lr0.001 ===

=== Running hybrid_4ch_w11_o3_lr0.0003 ===

=== Running single_w11_o4_lr0.001 ===

=== Running single_w11_o4_lr0.0003 ===

=== Running early_mid_late_w11_o4_lr0.001 ===

=== Running early_mid_late_w11_o4_lr0.0003 ===

=== Running hybrid_4ch_w11_o4_lr0.001 ===


Overall top 5 runs (across all channel modes):
  - run_name: early_mid_late_w11_o4_lr0.001
    channel_mode: early_mid_late
    sg_window: 11, sg_order: 4
    lr: 0.001, epochs: 20, batch_size: 32
    best_val (MSE): 0.020709
    MAE: (n/a)
    RMSE: (n/a)
    model_path: Best paths/early_mid_late_w11_o4_lr0.001_best.pth
    loss_plot: Model evaluation figs/early_mid_late_w11_o4_lr0.001_loss_curves.png

  - run_name: single_w31_o3_lr0.0003
    channel_mode: single
    sg_window: 31, sg_order: 3
    lr: 0.0003, epochs: 20, batch_size: 32
    best_val (MSE): 0.023788
    MAE: (n/a)
    RMSE: (n/a)
    model_path: Best paths/single_w31_o3_lr0.0003_best.pth
    loss_plot: Model evaluation figs/single_w31_o3_lr0.0003_loss_curves.png

  - run_name: hybrid_4ch_w31_o3_lr0.0003
    channel_mode: hybrid_4ch
    sg_window: 31, sg_order: 3
    lr: 0.0003, epochs: 20, batch_size: 32
    best_val (MSE): 0.026983
    MAE: (μa=0.0304, μs'=0.1914)
    RMSE: (μa=0.0388, μs'=0.2283)
    model_path: Best 