In [2]:
import os 
import re
import random 
import torch 
from torch import nn 
import numpy as np 
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader,Subset
from scipy.signal import savgol_filter 
import h5py
random.seed(0)
import math

In [3]:
# Define Savitzky Golay filter parameters 
order = 1
frame_length = 21
eps = 1e-8

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
class DTOFDataset(Dataset):
    """
        DTOF dataset for MATLAB v7.3 (.mat, HDF5) files.

        Required datasets inside .mat:
            X : (Nt, N) or (N, Nt)  reflectance DTOFs
            y : (2, N)  or (N, 2)   [mua, mus']
            t : (Nt,)              time vector in seconds (~1e-12 resolution)

        Preprocessing:
            - convert t from seconds -> ns
            - crop [0, crop_t_max] ns
            - Savitzky–Golay smoothing
            - clip small values to eps
            - input representation:
                * "raw"      -> use reflectance (smoothed+clipped)
                * "log"      -> use log(reflectance)
                * "raw_log"  -> concatenate raw and log along channel dimension
            - channel construction:
                * "single"         -> 1 channel (full)
                * "early_mid_late" -> 3 channels (early/mid/late masks)
                * "hybrid_4ch"     -> 4 channels (full + early/mid/late masks)

        Returns:
            signal: (C, T) float32 tensor
            label:  (2,) float32 tensor [mua, mus']  (raw labels for now)
        """

    def __init__(self, mat_path: str, cfg: dict):
        super().__init__()
        self.cfg = cfg

        # ---------- load HDF5 (.mat v7.3) ----------
        with h5py.File(mat_path, "r") as f:
            X = np.array(f["X"], dtype=np.float32)
            y = np.array(f["y"], dtype=np.float32)
            t = np.array(f["t"], dtype=np.float32).squeeze()

        # ---------- normalise shapes ----------
        # X -> (N, Nt)
        if X.shape[0] == t.shape[0]:
            X = X.T

        # y -> (N, 2)
        if y.shape[0] == 2:
            y = y.T

        if X.ndim != 2:
            raise ValueError(f"Expected X to be 2D, got {X.shape}")
        if y.ndim != 2 or y.shape[1] != 2:
            raise ValueError(f"Expected y to be (N,2), got {y.shape}")
        if t.ndim != 1:
            raise ValueError(f"Expected t to be (Nt,), got {t.shape}")
        if X.shape[1] != t.shape[0]:
            raise ValueError(f"X and t mismatch: X Nt={X.shape[1]} vs t Nt={t.shape[0]}")

        # ---------- time: seconds -> ns ----------
        t_ns = t * 1e9

        # ---------- crop ----------
        crop_t_max = float(cfg["crop_t_max"])  # ns
        t_mask = (t_ns >= 0.0) & (t_ns <= crop_t_max)
        if not np.any(t_mask):
            raise ValueError(
                f"Cropping removed all points. "
                f"t_ns range=[{t_ns.min():.3g}, {t_ns.max():.3g}] ns, crop_t_max={crop_t_max}"
            )

        t_ns = t_ns[t_mask]              # (T,)
        dtof = X[:, t_mask]              # (N,T)

        N, T = dtof.shape

        # ---------- Savitzky–Golay ----------
        sg_window = int(cfg["sg_window"])
        sg_order = int(cfg["sg_order"])

        # enforce validity (odd, > order, <= T)
        if sg_window % 2 == 0:
            sg_window += 1
        if sg_window <= sg_order:
            sg_window = sg_order + 2
            if sg_window % 2 == 0:
                sg_window += 1
        if sg_window > T:
            sg_window = T if (T % 2 == 1) else (T - 1)

        if sg_window >= 3:
            dtof = savgol_filter(dtof, sg_window, sg_order, axis=1)

        # ---------- clip ----------
        eps = float(cfg.get("eps", 1e-12))
        dtof[dtof < eps] = eps

        # ---------- choose representation ----------
        input_rep = cfg.get("input_rep", "log")  # "raw" | "log" | "raw_log"

        dtof_raw = dtof.astype(np.float32)
        dtof_log = np.log(dtof_raw).astype(np.float32)

        # ---------- build channels ----------
        if input_rep == "raw":
            channels = self.build_channels(t_ns, dtof_raw, cfg["channel_mode"])  # (N,C,T)

        elif input_rep == "log":
            channels = self.build_channels(t_ns, dtof_log, cfg["channel_mode"])  # (N,C,T)

        elif input_rep == "raw_log":
            ch_raw = self.build_channels(t_ns, dtof_raw, cfg["channel_mode"])    # (N,C,T)
            ch_log = self.build_channels(t_ns, dtof_log, cfg["channel_mode"])    # (N,C,T)
            channels = np.concatenate([ch_raw, ch_log], axis=1)                  # (N,2C,T)

        else:
            raise ValueError(f"Unknown input_rep: {input_rep}")

        # ---------- to torch ----------
        self.signals = torch.tensor(channels, dtype=torch.float32)  # (N,C,T)
        self.labels = torch.tensor(y, dtype=torch.float32)          # (N,2)

        self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t_ns: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Channel masks in ns:
            early: 0–0.5 ns
            mid:   0.5–4 ns
            late:  4–crop_t_max ns
        """
        N, T = dtof.shape
        crop_t_max = float(self.cfg["crop_t_max"])

        if mode == "single":
            return dtof[:, None, :]  # (N,1,T)

        early = ((t_ns >= 0.0) & (t_ns < 0.5)).astype(np.float32)
        mid   = ((t_ns >= 0.5) & (t_ns < 4.0)).astype(np.float32)
        late  = ((t_ns >= 4.0) & (t_ns <= crop_t_max)).astype(np.float32)

        masks = np.stack([early, mid, late], axis=0)  # (3,T)

        if mode == "early_mid_late":
            return dtof[:, None, :] * masks[None, :, :]  # (N,3,T)

        if mode == "hybrid_4ch":
            full = dtof[:, None, :]                         # (N,1,T)
            gated = dtof[:, None, :] * masks[None, :, :]    # (N,3,T)
            return np.concatenate([full, gated], axis=1)    # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        return self.signals[idx], self.labels[idx]

In [None]:
class Net(nn.Module):
    """
    CNN for 1D DTOF signals with flexible input channels.

    Channel counts (C) depend on:
        channel_mode:
            - "single"         -> 1
            - "early_mid_late" -> 3
            - "hybrid_4ch"     -> 4
        input_rep:
            - "raw" / "log"    -> multiplier 1
            - "raw_log"        -> multiplier 2

    So:
        C = base_C * (2 if input_rep == "raw_log" else 1)

    Optional tunables in cfg:
        cfg["use_dilation"] = True / False
        cfg["kernels"]      = [3, 5, 5]      # keep final kernel smaller to reduce over-smoothing
        cfg["dilations"]    = [1, 2, 4]      # increasing dilation expands receptive field
        cfg["channels"]     = [32, 32, 16]   # out_channels per block
        cfg["pool_k"]       = 2              # MaxPool kernel
        cfg["pool_s"]       = 2              # MaxPool stride
    """

    def __init__(self, cfg: dict, input_length: int = 3000, output_dim: int = 2):
        super().__init__()
        self.cfg = cfg

        # -----------------------------
        # Infer input channels from cfg
        # -----------------------------
        base_C = {"single": 1, "early_mid_late": 3, "hybrid_4ch": 4}[cfg["channel_mode"]]
        mult = 2 if cfg.get("input_rep", "log") == "raw_log" else 1
        in_channels = base_C * mult

        # -----------------------------
        # Optional dilation settings
        # -----------------------------
        use_dilation = bool(cfg.get("use_dilation", False))

        # Kernels: increasing early->late but keep the final kernel modest
        kernels = cfg.get("kernels", [3, 5, 5])

        # Dilations: increase only when use_dilation=True
        dilations = cfg.get("dilations", [1, 2, 4]) if use_dilation else [1, 1, 1]

        # Out channels per conv block
        chs = cfg.get("channels", [32, 32, 16])

        # Pooling
        pool_k = int(cfg.get("pool_k", 2))
        pool_s = int(cfg.get("pool_s", 2))

        # Unpack 
        k1, k2, k3 = kernels
        d1, d2, d3 = dilations

        # -----------------------------
        # Helper: SAME padding for 1D conv
        # padding = dilation * (kernel - 1) / 2 (requires odd kernel)
        # -----------------------------
        def same_padding(kernel: int, dilation: int) -> int:
            if kernel % 2 == 0:
                raise ValueError(f"Kernel size must be odd for SAME padding. Got kernel={kernel}.")
            return (dilation * (kernel - 1)) // 2

        # -----------------------------
        # Convolution blocks
        # -----------------------------
        self.conv1 = nn.Conv1d(
            in_channels=in_channels,
            out_channels=chs[0],
            kernel_size=k1,
            dilation=d1,
            padding=same_padding(k1, d1),
        )
        self.bn1 = nn.BatchNorm1d(chs[0])
        self.pool1 = nn.MaxPool1d(kernel_size=pool_k, stride=pool_s)

        self.conv2 = nn.Conv1d(
            in_channels=chs[0],
            out_channels=chs[1],
            kernel_size=k2,
            dilation=d2,
            padding=same_padding(k2, d2),
        )
        self.bn2 = nn.BatchNorm1d(chs[1])
        self.pool2 = nn.MaxPool1d(kernel_size=pool_k, stride=pool_s)

        self.conv3 = nn.Conv1d(
            in_channels=chs[1],
            out_channels=chs[2],
            kernel_size=k3,
            dilation=d3,
            padding=same_padding(k3, d3),
        )
        self.bn3 = nn.BatchNorm1d(chs[2])
        self.pool3 = nn.MaxPool1d(kernel_size=pool_k, stride=pool_s)

        self.act = nn.ReLU()

        # -----------------------------
        # Dynamic flatten dimension
        # -----------------------------
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, input_length)
            feat = self._forward_features(dummy)
            self.flatten_dim = feat.shape[1]

        # -----------------------------
        # Fully connected head
        # -----------------------------
        self.fc1 = nn.Linear(self.flatten_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.pool1(self.act(self.bn1(self.conv1(x))))
        x = self.pool2(self.act(self.bn2(self.conv2(x))))
        x = self.pool3(self.act(self.bn3(self.conv3(x))))
        return x.view(x.size(0), -1)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._forward_features(x)
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        return self.fc3(x)

In [None]:
import torch
import matplotlib.pyplot as plt
import os

def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    num_epochs,
    device,
    save_path=None,
    eps=1e-12,
    print_every=1,
    exp_clip=20.0,    # exp(20) ~ 4.85e8, safety for overflow
    patience=20,
    min_delta=0.0,
    plot_path=None,   # optional: where to save loss curve png
):
    """
    Trains a model that predicts log(mua), log(mus').
    - Optimizes log-MSE loss
    - Reports RMSE and mean absolute % error in original (linear) units
    - Early stopping + optional best checkpoint saving
    - Tracks and plots train/val loss curves over epochs
    """

    model.to(device)
    best_val_loss = float("inf")
    epochs_no_improve = 0
    loss_fn = torch.nn.MSELoss()

    train_losses = []
    val_losses = []

    for epoch in range(num_epochs):
        # -------------------------
        # TRAIN
        # -------------------------
        model.train()

        tr_loss_sum = 0.0
        tr_loss_count = 0

        tr_sum_sq  = torch.zeros(2, device=device)
        tr_pct_sum = torch.zeros(2, device=device)
        tr_count = 0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels  = labels.to(device).float()                 # (B,2) linear
            B = labels.shape[0]

            log_labels = torch.log(labels.clamp_min(eps))       # (B,2) log

            optimizer.zero_grad()
            preds_log = model(signals).view_as(log_labels)      # (B,2)
            loss = loss_fn(preds_log, log_labels)
            loss.backward()
            optimizer.step()

            # per-sample accurate averaging
            tr_loss_sum += loss.item() * B
            tr_loss_count += B

            # Metrics in original units
            with torch.no_grad():
                preds_lin = torch.exp(preds_log.clamp(-exp_clip, exp_clip))
                err = preds_lin - labels

                abs_pct_err = (100.0 * err / labels.clamp_min(eps)).abs()  # (B,2)

                tr_sum_sq  += (err ** 2).sum(dim=0)
                tr_pct_sum += abs_pct_err.sum(dim=0)
                tr_count   += B

        train_loss = tr_loss_sum / max(1, tr_loss_count)
        train_rmse = torch.sqrt(tr_sum_sq / max(1, tr_count)).detach().cpu().numpy()
        train_pct_err = (tr_pct_sum / max(1, tr_count)).detach().cpu().numpy()

        # -------------------------
        # VALIDATE
        # -------------------------
        model.eval()

        va_loss_sum = 0.0
        va_loss_count = 0

        va_sum_sq  = torch.zeros(2, device=device)
        va_pct_sum = torch.zeros(2, device=device)
        va_count = 0

        with torch.no_grad():
            for signals, labels in val_loader:
                signals = signals.to(device)
                labels  = labels.to(device).float()
                B = labels.shape[0]

                log_labels = torch.log(labels.clamp_min(eps))
                preds_log  = model(signals).view_as(log_labels)

                loss = loss_fn(preds_log, log_labels)
                va_loss_sum += loss.item() * B
                va_loss_count += B

                preds_lin = torch.exp(preds_log.clamp(-exp_clip, exp_clip))
                err = preds_lin - labels

                abs_pct_err = (100.0 * err / labels.clamp_min(eps)).abs()

                va_sum_sq  += (err ** 2).sum(dim=0)
                va_pct_sum += abs_pct_err.sum(dim=0)
                va_count   += B

        val_loss = va_loss_sum / max(1, va_loss_count)
        val_rmse = torch.sqrt(va_sum_sq / max(1, va_count)).detach().cpu().numpy()
        val_pct_err = (va_pct_sum / max(1, va_count)).detach().cpu().numpy()

        # Track loss curves
        train_losses.append(train_loss)
        val_losses.append(val_loss)

        # Print
        if (epoch + 1) % print_every == 0:
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print(f"Train Loss (log-MSE): {train_loss:.6f} | Train RMSE [μa, μs′]: {train_rmse}")
            print(f"Train Mean Abs %Err [μa, μs′]: {train_pct_err}")
            print(f"Val   Loss (log-MSE): {val_loss:.6f} | Val   RMSE [μa, μs′]: {val_rmse}")
            print(f"Val   Mean Abs %Err [μa, μs′]: {val_pct_err}")

        # Early stopping + checkpoint
        if val_loss < (best_val_loss - min_delta):
            best_val_loss = val_loss
            epochs_no_improve = 0
            if save_path is not None:
                os.makedirs(os.path.dirname(save_path) or ".", exist_ok=True)
                torch.save(model.state_dict(), save_path)
                if (epoch + 1) % print_every == 0:
                    print(" -> Best validation so far, saved.")
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print(f"\nEarly stopping at epoch {epoch + 1}: "
                      f"no val improvement for {patience} epochs.")
                break

    # -------------------------
    # Plot loss curves
    # -------------------------
    if plot_path is None and save_path is not None:
        # If caller passed .pt or .pth, this handles both nicely
        root, ext = os.path.splitext(save_path)
        plot_path = root + "_loss_curves.png"

    if plot_path is not None:
        os.makedirs(os.path.dirname(plot_path) or ".", exist_ok=True)
        plt.figure()
        plt.plot(train_losses, label="Train log-MSE")
        plt.plot(val_losses,   label="Val log-MSE")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training and Validation Loss")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(plot_path, dpi=150)
        plt.close()

    return {
        "model": model,
        "train_losses": train_losses,
        "val_losses": val_losses,
        "best_val_loss": best_val_loss,
        "save_path": save_path,
        "loss_plot": plot_path,
    }




In [7]:
def make_train_val_loaders(dataset, batch_size=32, val_frac=0.2, seed=42, shuffle_train=True):
    n = len(dataset)
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    split = int(n * (1 - val_frac))
    train_idx = idx[:split]
    val_idx = idx[split:]

    train_ds = Subset(dataset, train_idx)
    val_ds = Subset(dataset, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle_train)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

In [8]:
# Trial run

matlab_path = "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat"

if __name__ == "__main__":
    cfg = {
        "crop_t_max": 6.0,
        "sg_window": frame_length,
        "sg_order": order,
        "eps": 1e-12,
        "channel_mode": "hybrid_4ch",  # "single" | "early_mid_late" | "hybrid_4ch"
        "input_rep": "raw_log",        # "raw" | "log" | "raw_log"

        # ---- OPTIONAL DILATION EXPERIMENT ----
        "use_dilation": True,          # <-- set False for baseline comparison
        "kernels": [3, 5, 5],          # increasing kernels but smaller final kernel
        "dilations": [1, 2, 4],        # increasing dilation
        "channels": [32, 32, 16],      # optional: keep your original channels
        "pool_k": 2,
        "pool_s": 2,
    }

    print("Dilation enabled:", cfg["use_dilation"])
    if cfg["use_dilation"]:
        print("kernels:", cfg["kernels"], "dilations:", cfg["dilations"])

    # Dataset sanity check 

    ds = DTOFDataset(matlab_path, cfg)
    x, y = ds[0]
    print("sample:", x.shape, y.shape)

    # Model sanity check

    model = Net(cfg, input_length=ds.T, output_dim=2).to(device)
    x = x.to(device)
    out = model(x.unsqueeze(0))
    print("model out:", out.shape)

    # Loader sanity check 
    
    train_loader, val_loader = make_train_val_loaders(ds, batch_size=32, val_frac=0.2, seed=42)

    signals, labels = next(iter(train_loader))
    print("signals:", signals.shape)  # (B, C, T)
    print("labels:", labels.shape)    # (B, 2)


Dilation enabled: True
kernels: [3, 5, 5] dilations: [1, 2, 4]
sample: torch.Size([8, 3000]) torch.Size([2])
model out: torch.Size([1, 2])
signals: torch.Size([32, 8, 3000])
labels: torch.Size([32, 2])


In [9]:
class ModelEvaluator: 
    """
    Evaluates a trained model that outputs log(mua), log(mus).
    Reports MAE/RMSE in original units by exp().
    """

    def __init__(self, model, device, eps=1e-12):
        self.model = model
        self.device = device
        self.eps = eps
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader):
        all_preds_lin = []
        all_labels_lin = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels = labels.to(self.device).float()              # (B,2) linear

                preds_log = self.model(signals)                      # (B,2) log
                preds_log = preds_log.view_as(labels)

                preds_lin = torch.exp(preds_log)                     # back to linear

                all_preds_lin.append(preds_lin.cpu())
                all_labels_lin.append(labels.cpu())

        preds = torch.cat(all_preds_lin, dim=0)
        labs  = torch.cat(all_labels_lin, dim=0)

        abs_err = torch.abs(preds - labs)
        sq_err  = (preds - labs) ** 2

        mae = abs_err.mean(dim=0)
        rmse = torch.sqrt(sq_err.mean(dim=0))

        return {
            "MAE": mae.numpy(),
            "RMSE": rmse.numpy(),
            "preds": preds.numpy(),
            "labels": labs.numpy(),
        }


In [None]:
# ============================
# RUN SCRIPT (baseline vs dilated) + auto epoch scheduling + early stopping
# ============================

# ----------------------------
# Helpers
# ----------------------------
def make_train_val_loaders(dataset, batch_size=32, val_frac=0.2, seed=42):
    n = len(dataset)
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    split = int(n * (1 - val_frac))
    train_idx = idx[:split]
    val_idx = idx[split:]

    train_ds = Subset(dataset, train_idx)
    val_ds = Subset(dataset, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def quick_eval_lin_rmse(model, loader, device, eps=1e-12, exp_clip=20.0):
    """Compute RMSE in original units (model outputs log-space)."""
    model.eval()
    sum_sq = torch.zeros(2, device=device)
    n = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).float()
            ylog = torch.log(y.clamp_min(eps))
            plog = model(x).view_as(ylog)
            plin = torch.exp(plog.clamp(-exp_clip, exp_clip))
            err = plin - y
            sum_sq += (err ** 2).sum(dim=0)
            n += y.shape[0]
    return torch.sqrt(sum_sq / max(1, n)).detach().cpu().numpy()


# ----------------------------
# 1) Paths + base config
# ----------------------------
mat_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat"

base_cfg = {
    "crop_t_max": 6.0,
    "sg_window": 31,
    "sg_order": 1,
    "eps": 1e-12,
    "channel_mode": "hybrid_4ch",
    "input_rep": "raw_log",

    # ---- DILATION EXPERIMENT DEFAULTS (OFF unless enabled below) ----
    "use_dilation": False,
    "kernels": [3, 5, 5],      # increasing kernels, smaller final kernel
    "dilations": [1, 2, 4],    # increasing dilation (only used if use_dilation=True)
    "channels": [32, 32, 16],
    "pool_k": 2,
    "pool_s": 2,
}

# ----------------------------
# 2) Device
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------------------------
# 3) Dataset + split (DO THIS ONCE so both models share identical split)
# ----------------------------
dataset = DTOFDataset(mat_path, base_cfg)

batch_size = 32
val_frac = 0.2
seed = 42

train_loader, val_loader = make_train_val_loaders(dataset, batch_size=batch_size, val_frac=val_frac, seed=seed)

N_total = len(dataset)
N_train = len(train_loader.dataset)
N_val = len(val_loader.dataset)

print("Total samples:", N_total)
print("Train samples:", N_train)
print("Val samples  :", N_val)
print("Signal shape :", dataset[0][0].shape, "Label shape:", dataset[0][1].shape)

# ----------------------------
# 4) Epoch scheduling (scaled by dataset size via target optimiser steps)
# ----------------------------
target_steps = 50_000  # adjust later
steps_per_epoch = math.ceil(N_train / batch_size)
num_epochs = math.ceil(target_steps / steps_per_epoch)

# Optional cap (if you want a hard max like 800)
num_epochs = min(num_epochs, 800)

print(f"steps_per_epoch: {steps_per_epoch}")
print(f"target_steps   : {target_steps}")
print(f"num_epochs     : {num_epochs}")

# ----------------------------
# 5) Run function (so we can compare baseline vs dilated cleanly)
# ----------------------------
def run_experiment(cfg, tag):
    print("\n" + "=" * 70)
    print(f"EXPERIMENT: {tag}")
    print(f"use_dilation={cfg.get('use_dilation', False)} | kernels={cfg.get('kernels')} | dilations={cfg.get('dilations')}")
    print("=" * 70)

    model = Net(cfg, input_length=dataset.T, output_dim=2).to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    # Pre-train sanity RMSE
    rmse0_train = quick_eval_lin_rmse(model, train_loader, device, eps=cfg["eps"])
    rmse0_val = quick_eval_lin_rmse(model, val_loader, device, eps=cfg["eps"])
    print("Pre-train RMSE train [mua, mus]:", rmse0_train)
    print("Pre-train RMSE val   [mua, mus]:", rmse0_val)

    save_path = f"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/CNN_initial_saved_pytorch_model_weights/best_{tag}.pt"

    plot_path = save_path.replace(".pt", "_loss_curves.png")

    train_out = train_model(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        num_epochs=num_epochs,
        device=device,
        save_path=save_path,
        plot_path=plot_path,     
        eps=cfg["eps"],
        print_every=1,
        patience=20,
        min_delta=1e-4,
    )

    print("Saved loss curves to:", train_out["loss_plot"])

    # Post-train RMSE
    rmse1_train = quick_eval_lin_rmse(model, train_loader, device, eps=cfg["eps"])
    rmse1_val = quick_eval_lin_rmse(model, val_loader, device, eps=cfg["eps"])
    print("Post-train RMSE train [mua, mus]:", rmse1_train)
    print("Post-train RMSE val   [mua, mus]:", rmse1_val)

    # One-batch forward shape check
    signals, labels = next(iter(train_loader))
    out = model(signals.to(device))
    print("One-batch shapes:")
    print("  signals:", signals.shape)
    print("  labels :", labels.shape)
    print("  output :", out.shape)

    return {
        "tag": tag,
        "rmse_train": rmse1_train,
        "rmse_val": rmse1_val,
        "save_path": save_path,
    }


# ----------------------------
# 6) Baseline vs Dilated configs
# ----------------------------
cfg_baseline = dict(base_cfg)
cfg_baseline["use_dilation"] = False

cfg_dilated = dict(base_cfg)
cfg_dilated["use_dilation"] = True
# kernels/dilations already present in base_cfg; tweak here if desired:
# cfg_dilated["kernels"] = [3, 5, 5]
# cfg_dilated["dilations"] = [1, 2, 4]

# ----------------------------
# 7) Run both experiments
# ----------------------------
res_base = run_experiment(cfg_baseline, tag="baseline_no_dilation")
res_dil  = run_experiment(cfg_dilated,  tag="dilated_1_2_4")

print("\n" + "=" * 70)
print("SUMMARY (Post-train RMSE)")
print("=" * 70)
print("Baseline RMSE val [mua, mus]:", res_base["rmse_val"])
print("Dilated  RMSE val [mua, mus]:", res_dil["rmse_val"])
print("=" * 70)

Using device: cpu
Total samples: 500
Train samples: 400
Val samples  : 100
Signal shape : torch.Size([8, 3000]) Label shape: torch.Size([2])
steps_per_epoch: 13
target_steps   : 50000
num_epochs     : 800

EXPERIMENT: baseline_no_dilation
use_dilation=False | kernels=[3, 5, 5] | dilations=[1, 2, 4]
Pre-train RMSE train [mua, mus]: [ 4.6593733 13.841332 ]
Pre-train RMSE val   [mua, mus]: [ 3.85406  13.790935]

Epoch 1/800
Train Loss (log-MSE): 2.751207 | Train RMSE [μa, μs′]: [ 0.3076432 10.961098 ]
Train Mean Abs %Err [μa, μs′]: [1236.8676   104.04612]
Val   Loss (log-MSE): 3.564036 | Val   RMSE [μa, μs′]: [ 0.17821103 12.538977  ]
Val   Mean Abs %Err [μa, μs′]: [953.43677  63.33407]
 -> Best validation so far, saved.

Epoch 2/800
Train Loss (log-MSE): 0.388672 | Train RMSE [μa, μs′]: [0.02745555 6.278013  ]
Train Mean Abs %Err [μa, μs′]: [64.05608  44.462616]
Val   Loss (log-MSE): 0.671356 | Val   RMSE [μa, μs′]: [0.03568224 8.792567  ]
Val   Mean Abs %Err [μa, μs′]: [158.07097   38.2