In [None]:
import os 
import re
import random 
import torch 
from torch import nn 
import numpy as np 
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader,Subset
from scipy.signal import savgol_filter 
import h5py
random.seed(0)

In [None]:
# Define Savitzky Golay filter parameters 
order = 1
frame_length = 21
eps = 1e-8

In [34]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [24]:
class DTOFDataset(Dataset):
    """
        DTOF dataset for MATLAB v7.3 (.mat, HDF5) files.

        Required datasets inside .mat:
            X : (Nt, N) or (N, Nt)  reflectance DTOFs
            y : (2, N)  or (N, 2)   [mua, mus']
            t : (Nt,)              time vector in seconds (~1e-12 resolution)

        Preprocessing:
            - convert t from seconds -> ns
            - crop [0, crop_t_max] ns
            - Savitzky–Golay smoothing
            - clip small values to eps
            - input representation:
                * "raw"      -> use reflectance (smoothed+clipped)
                * "log"      -> use log(reflectance)
                * "raw_log"  -> concatenate raw and log along channel dimension
            - channel construction:
                * "single"         -> 1 channel (full)
                * "early_mid_late" -> 3 channels (early/mid/late masks)
                * "hybrid_4ch"     -> 4 channels (full + early/mid/late masks)

        Returns:
            signal: (C, T) float32 tensor
            label:  (2,) float32 tensor [mua, mus']  (raw labels for now)
        """

    def __init__(self, mat_path: str, cfg: dict):
        super().__init__()
        self.cfg = cfg

        # ---------- load HDF5 (.mat v7.3) ----------
        with h5py.File(mat_path, "r") as f:
            X = np.array(f["X"], dtype=np.float32)
            y = np.array(f["y"], dtype=np.float32)
            t = np.array(f["t"], dtype=np.float32).squeeze()

        # ---------- normalise shapes ----------
        # X -> (N, Nt)
        if X.shape[0] == t.shape[0]:
            X = X.T

        # y -> (N, 2)
        if y.shape[0] == 2:
            y = y.T

        if X.ndim != 2:
            raise ValueError(f"Expected X to be 2D, got {X.shape}")
        if y.ndim != 2 or y.shape[1] != 2:
            raise ValueError(f"Expected y to be (N,2), got {y.shape}")
        if t.ndim != 1:
            raise ValueError(f"Expected t to be (Nt,), got {t.shape}")
        if X.shape[1] != t.shape[0]:
            raise ValueError(f"X and t mismatch: X Nt={X.shape[1]} vs t Nt={t.shape[0]}")

        # ---------- time: seconds -> ns ----------
        t_ns = t * 1e9

        # ---------- crop ----------
        crop_t_max = float(cfg["crop_t_max"])  # ns
        t_mask = (t_ns >= 0.0) & (t_ns <= crop_t_max)
        if not np.any(t_mask):
            raise ValueError(
                f"Cropping removed all points. "
                f"t_ns range=[{t_ns.min():.3g}, {t_ns.max():.3g}] ns, crop_t_max={crop_t_max}"
            )

        t_ns = t_ns[t_mask]              # (T,)
        dtof = X[:, t_mask]              # (N,T)

        N, T = dtof.shape

        # ---------- Savitzky–Golay ----------
        sg_window = int(cfg["sg_window"])
        sg_order = int(cfg["sg_order"])

        # enforce validity (odd, > order, <= T)
        if sg_window % 2 == 0:
            sg_window += 1
        if sg_window <= sg_order:
            sg_window = sg_order + 2
            if sg_window % 2 == 0:
                sg_window += 1
        if sg_window > T:
            sg_window = T if (T % 2 == 1) else (T - 1)

        if sg_window >= 3:
            dtof = savgol_filter(dtof, sg_window, sg_order, axis=1)

        # ---------- clip ----------
        eps = float(cfg.get("eps", 1e-12))
        dtof[dtof < eps] = eps

        # ---------- choose representation ----------
        input_rep = cfg.get("input_rep", "log")  # "raw" | "log" | "raw_log"

        dtof_raw = dtof.astype(np.float32)
        dtof_log = np.log(dtof_raw).astype(np.float32)

        # ---------- build channels ----------
        if input_rep == "raw":
            channels = self.build_channels(t_ns, dtof_raw, cfg["channel_mode"])  # (N,C,T)

        elif input_rep == "log":
            channels = self.build_channels(t_ns, dtof_log, cfg["channel_mode"])  # (N,C,T)

        elif input_rep == "raw_log":
            ch_raw = self.build_channels(t_ns, dtof_raw, cfg["channel_mode"])    # (N,C,T)
            ch_log = self.build_channels(t_ns, dtof_log, cfg["channel_mode"])    # (N,C,T)
            channels = np.concatenate([ch_raw, ch_log], axis=1)                  # (N,2C,T)

        else:
            raise ValueError(f"Unknown input_rep: {input_rep}")

        # ---------- to torch ----------
        self.signals = torch.tensor(channels, dtype=torch.float32)  # (N,C,T)
        self.labels = torch.tensor(y, dtype=torch.float32)          # (N,2)

        self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t_ns: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Channel masks in ns:
            early: 0–0.5 ns
            mid:   0.5–4 ns
            late:  4–crop_t_max ns
        """
        N, T = dtof.shape
        crop_t_max = float(self.cfg["crop_t_max"])

        if mode == "single":
            return dtof[:, None, :]  # (N,1,T)

        early = ((t_ns >= 0.0) & (t_ns < 0.5)).astype(np.float32)
        mid   = ((t_ns >= 0.5) & (t_ns < 4.0)).astype(np.float32)
        late  = ((t_ns >= 4.0) & (t_ns <= crop_t_max)).astype(np.float32)

        masks = np.stack([early, mid, late], axis=0)  # (3,T)

        if mode == "early_mid_late":
            return dtof[:, None, :] * masks[None, :, :]  # (N,3,T)

        if mode == "hybrid_4ch":
            full = dtof[:, None, :]                         # (N,1,T)
            gated = dtof[:, None, :] * masks[None, :, :]    # (N,3,T)
            return np.concatenate([full, gated], axis=1)    # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        return self.signals[idx], self.labels[idx]

In [25]:
class Net(nn.Module):
    """
    CNN for 1D DTOF signals with flexible input channels.

    Channel counts (C) depend on:
        channel_mode:
            - "single"         -> 1
            - "early_mid_late" -> 3
            - "hybrid_4ch"     -> 4
        input_rep:
            - "raw" / "log"    -> multiplier 1
            - "raw_log"        -> multiplier 2

    So:
        C = base_C * (2 if input_rep == "raw_log" else 1)
    """

    def __init__(self, cfg: dict, input_length: int = 3000, output_dim: int = 2):
        super().__init__()

        base_C = {"single": 1, "early_mid_late": 3, "hybrid_4ch": 4}[cfg["channel_mode"]]
        mult = 2 if cfg.get("input_rep", "log") == "raw_log" else 1
        in_channels = base_C * mult

        # Convolution blocks
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=32, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(32)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(16)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.act = nn.ReLU()

        # Compute flattened feature size dynamically
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, input_length)
            feat = self._forward_features(dummy)
            self.flatten_dim = feat.shape[1]

        # Fully connected layers
        self.fc1 = nn.Linear(self.flatten_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def _forward_features(self, x):
        # Block 1
        x = self.pool1(self.act(self.bn1(self.conv1(x))))
        # Block 2
        x = self.pool2(self.act(self.bn2(self.conv2(x))))
        # Block 3
        x = self.pool3(self.act(self.bn3(self.conv3(x))))
        # Flatten
        return x.view(x.size(0), -1)

    def forward(self, x):
        x = self._forward_features(x)
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        return self.fc3(x)

In [41]:
def train_model(
    model,
    train_loader,
    val_loader,
    optimizer,
    num_epochs,
    device,
    save_path=None,
    eps=1e-12,
    print_every=1,
    exp_clip=20.0,   # exp(20) ~ 4.85e8, safety for overflow
):
    model.to(device)
    best_val_loss = float("inf")
    loss_fn = torch.nn.MSELoss()

    for epoch in range(num_epochs):
        # -------------------------
        # TRAIN
        # -------------------------
        model.train()
        running_loss = 0.0

        tr_sum_sq = torch.zeros(2, device=device)
        tr_count = 0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels = labels.to(device).float()                 # (B,2)

            log_labels = torch.log(labels.clamp_min(eps))      # (B,2)

            optimizer.zero_grad()

            preds_log = model(signals).view_as(log_labels)     # (B,2)
            loss = loss_fn(preds_log, log_labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            # RMSE in original units
            with torch.no_grad():
                preds_lin = torch.exp(preds_log.clamp(-exp_clip, exp_clip))
                err = preds_lin - labels
                tr_sum_sq += (err ** 2).sum(dim=0)
                tr_count += labels.shape[0]

        train_loss = running_loss / max(1, len(train_loader))
        train_rmse = torch.sqrt(tr_sum_sq / max(1, tr_count)).detach().cpu().numpy()

        # -------------------------
        # VALIDATE
        # -------------------------
        model.eval()
        val_running_loss = 0.0

        va_sum_sq = torch.zeros(2, device=device)
        va_count = 0

        with torch.no_grad():
            for signals, labels in val_loader:
                signals = signals.to(device)
                labels = labels.to(device).float()

                log_labels = torch.log(labels.clamp_min(eps))
                preds_log = model(signals).view_as(log_labels)

                loss = loss_fn(preds_log, log_labels)
                val_running_loss += loss.item()

                preds_lin = torch.exp(preds_log.clamp(-exp_clip, exp_clip))
                err = preds_lin - labels
                va_sum_sq += (err ** 2).sum(dim=0)
                va_count += labels.shape[0]

        val_loss = val_running_loss / max(1, len(val_loader))
        val_rmse = torch.sqrt(va_sum_sq / max(1, va_count)).detach().cpu().numpy()

        if (epoch + 1) % print_every == 0:
            print(f"\nEpoch {epoch + 1}/{num_epochs}")
            print(f"Train Loss (log-MSE): {train_loss:.6f} | Train RMSE [mua, mus]: {train_rmse}")
            print(f"Val   Loss (log-MSE): {val_loss:.6f} | Val   RMSE [mua, mus]: {val_rmse}")

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            if save_path is not None:
                torch.save(model.state_dict(), save_path)
                if (epoch + 1) % print_every == 0:
                    print(" -> Best validation so far, saved.")

    return model


In [42]:
def make_train_val_loaders(dataset, batch_size=32, val_frac=0.2, seed=42, shuffle_train=True):
    n = len(dataset)
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    split = int(n * (1 - val_frac))
    train_idx = idx[:split]
    val_idx = idx[split:]

    train_ds = Subset(dataset, train_idx)
    val_ds = Subset(dataset, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=shuffle_train)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)

    return train_loader, val_loader

In [43]:
# Trial run

matlab_path = "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat"

if __name__ == "__main__":
    cfg = {
        "crop_t_max": 6.0,
        "sg_window": frame_length,
        "sg_order": order,
        "eps": 1e-12,
        "channel_mode": "hybrid_4ch",  # "single" | "early_mid_late" | "hybrid_4ch"
        "input_rep": "raw_log",        # "raw" | "log" | "raw_log"
    }

    ds = DTOFDataset(matlab_path, cfg)
    x, y = ds[0]
    print("sample:", x.shape, y.shape)

    model = Net(cfg, input_length=ds.T, output_dim=2).to(device)
    out = model(x.unsqueeze(0))
    print("model out:", out.shape)

    train_loader, val_loader = make_train_val_loaders(ds, batch_size=32, val_frac=0.2, seed=42)

    signals, labels = next(iter(train_loader))
    print("signals:", signals.shape)  # (B, C, T)
    print("labels:", labels.shape)    # (B, 2)


sample: torch.Size([8, 3000]) torch.Size([2])
model out: torch.Size([1, 2])
signals: torch.Size([32, 8, 3000])
labels: torch.Size([32, 2])


In [44]:
class ModelEvaluator: 
    """
    Evaluates a trained model that outputs log(mua), log(mus).
    Reports MAE/RMSE in original units by exp().
    """

    def __init__(self, model, device, eps=1e-12):
        self.model = model
        self.device = device
        self.eps = eps
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader):
        all_preds_lin = []
        all_labels_lin = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels = labels.to(self.device).float()              # (B,2) linear

                preds_log = self.model(signals)                      # (B,2) log
                preds_log = preds_log.view_as(labels)

                preds_lin = torch.exp(preds_log)                     # back to linear

                all_preds_lin.append(preds_lin.cpu())
                all_labels_lin.append(labels.cpu())

        preds = torch.cat(all_preds_lin, dim=0)
        labs  = torch.cat(all_labels_lin, dim=0)

        abs_err = torch.abs(preds - labs)
        sq_err  = (preds - labs) ** 2

        mae = abs_err.mean(dim=0)
        rmse = torch.sqrt(sq_err.mean(dim=0))

        return {
            "MAE": mae.numpy(),
            "RMSE": rmse.numpy(),
            "preds": preds.numpy(),
            "labels": labs.numpy(),
        }


In [30]:
frame_length =21 
order = 1


In [None]:
# ============================
# TRIAL RUN
# dataset -> split -> 1-3 epochs -> metrics sanity checks
# ============================

# Assumes these are already defined in your script:
# - DTOFDataset (h5py loader, channel_mode + input_rep)
# - Net (in_channels inferred from cfg)
# - train_model (your log-label training loop)


def make_train_val_loaders(dataset, batch_size=32, val_frac=0.2, seed=42):
    n = len(dataset)
    idx = np.arange(n)
    rng = np.random.default_rng(seed)
    rng.shuffle(idx)

    split = int(n * (1 - val_frac))
    train_idx = idx[:split]
    val_idx = idx[split:]

    train_ds = Subset(dataset, train_idx)
    val_ds = Subset(dataset, val_idx)

    train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False)
    return train_loader, val_loader


def quick_eval_lin_rmse(model, loader, device, eps=1e-12, exp_clip=20.0):
    """Compute RMSE in original units on a loader (model outputs log-space)."""
    model.eval()
    sum_sq = torch.zeros(2, device=device)
    n = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device).float()  # linear labels
            ylog = torch.log(y.clamp_min(eps))
            plog = model(x).view_as(ylog)
            plin = torch.exp(plog.clamp(-exp_clip, exp_clip))
            err = plin - y
            sum_sq += (err ** 2).sum(dim=0)
            n += y.shape[0]
    return torch.sqrt(sum_sq / max(1, n)).detach().cpu().numpy()


# ----------------------------
# 1) Paths + config
# ----------------------------
mat_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat"

cfg = {
    "crop_t_max": 6.0,            # ns
    "sg_window": 21,
    "sg_order": 1,
    "eps": 1e-12,
    "channel_mode": "hybrid_4ch", # "single" | "early_mid_late" | "hybrid_4ch"
    "input_rep": "raw_log",       # "raw" | "log" | "raw_log"
}

# ----------------------------
# 2) Device
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# ----------------------------
# 3) Dataset + split
# ----------------------------
dataset = DTOFDataset(mat_path, cfg)
train_loader, val_loader = make_train_val_loaders(dataset, batch_size=32, val_frac=0.2, seed=42)

print("Total samples:", len(dataset))
print("Train samples:", len(train_loader.dataset))
print("Val samples  :", len(val_loader.dataset))
print("Signal shape :", dataset[0][0].shape, "Label shape:", dataset[0][1].shape)

# ----------------------------
# 4) Model + optimizer
# ----------------------------
model = Net(cfg, input_length=dataset.T, output_dim=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

# ----------------------------
# 5) Pre-train sanity RMSE (random model)
# ----------------------------
rmse0_train = quick_eval_lin_rmse(model, train_loader, device, eps=cfg["eps"])
rmse0_val = quick_eval_lin_rmse(model, val_loader, device, eps=cfg["eps"])
print("Pre-train RMSE train [mua, mus]:", rmse0_train)
print("Pre-train RMSE val   [mua, mus]:", rmse0_val)

# ----------------------------
# 6) Trial training (few epochs)
# ----------------------------
_ = train_model(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    num_epochs=3,          # TRIAL RUN
    device=device,
    save_path="/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/CNN_initial_saved_pytorch_model_weights/best_trial_TRIAL.pt",
    eps=cfg["eps"],
    print_every=1,
)

# ----------------------------
# 7) Post-train sanity RMSE
# ----------------------------
rmse1_train = quick_eval_lin_rmse(model, train_loader, device, eps=cfg["eps"])
rmse1_val = quick_eval_lin_rmse(model, val_loader, device, eps=cfg["eps"])
print("Post-train RMSE train [mua, mus]:", rmse1_train)
print("Post-train RMSE val   [mua, mus]:", rmse1_val)

# ----------------------------
# 8) One-batch forward shape check
# ----------------------------
signals, labels = next(iter(train_loader))
out = model(signals.to(device))
print("One-batch:")
print("  signals:", signals.shape)   # (B,C,T)
print("  labels :", labels.shape)    # (B,2)
print("  output :", out.shape)       # (B,2)


Using device: cpu
Total samples: 500
Train samples: 400
Val samples  : 100
Signal shape : torch.Size([8, 3000]) Label shape: torch.Size([2])
Pre-train RMSE train [mua, mus]: [ 0.8705639 13.795304 ]
Pre-train RMSE val   [mua, mus]: [ 0.88881505 13.737743  ]

Epoch 1/3
Train Loss (log-MSE): 2.147514 | Train RMSE [mua, mus]: [ 0.26125002 20.974863  ]
Val   Loss (log-MSE): 3.363873 | Val   RMSE [mua, mus]: [ 0.17471175 12.410777  ]
 -> Best validation so far, saved.

Epoch 2/3
Train Loss (log-MSE): 0.292014 | Train RMSE [mua, mus]: [0.01910768 7.4966664 ]
Val   Loss (log-MSE): 0.767043 | Val   RMSE [mua, mus]: [0.05236467 9.308151  ]
 -> Best validation so far, saved.

Epoch 3/3
Train Loss (log-MSE): 0.197030 | Train RMSE [mua, mus]: [0.01426826 4.5020103 ]
Val   Loss (log-MSE): 0.517796 | Val   RMSE [mua, mus]: [0.03221161 5.2897    ]
 -> Best validation so far, saved.
Post-train RMSE train [mua, mus]: [0.03507379 5.1658125 ]
Post-train RMSE val   [mua, mus]: [0.03221161 5.2897    ]
One-b