DTOF Deep Learning Pipeline

This module implements a complete, modular DL pipeline for predicting
the optical properties (μa, μs′) of homogeneous tissue models from time resolved 
DTOF signals. The system is designed for reproducible experimentation and supports highly flexible preprocessing and neural architecture configurations 

Applicability 

This framework is intended for studies involving: 
    * Monte Carlo simulated DTOFs (homogeneous or multilayer)
    * inversion of DTOFs to estimate μa, μs′
    * benchmarking the effects of preprocessing choices such as SG filters 
    * deep learning architecture optimisation 
    * systematic comparison of multi-channel DTOF representaitons 

Supported Model Variants 

This pipeline allows dynamic selection of the number of input channels. 3 commonly used configurations are supported: 

    (1) Single channel DTOF:
        * DTOF cropped to 0-5 ns 
        * Standardised after smoothing 
        * Input shape: (1, T)

    (2) 3 channel temporal bin model: 
        * Early, mid, and late temporal masks (0 - 0.5 ns, 0.5 - 4 ns, 4 -5 ns)
        * Each bin multiplied with the DTOF to form 3 channels 
        * Input shape: (3, T)
    
    (3) 4 channel hybrid model:
        * Channel 1: Full DTOF (0 - 0.5 ns)
        * Channel 2-4: Early, mid, late temporal bins
        * Input shape: (4, T)

Channel configuration is selected from the CONFIG block and handled automatically by the DTOFDataset 

Pipeline Components 

This module consists of: 

1. CONFIG dictionary: 
    Centralises all experiment-level parameters including: 
        - preprocessing settings (SG frame / order, clipping, cropping)
        - number of input channels 
        - temporal mask definitions 
        - CNN architecture (layer widths, output dimensions)
        - learning rate, batch size, epochs 

2. DTOFDataset: 
    Implements the preprocessing chain: 
        - CSV loading 
        - Savitzky-Golay smoothing 
        - negative-value clipping 
        - mean / std standardisation 
        - crop to 0 - 5 ns
        - dynamic construction of 1, 3, or 4 input channels 
        - returns (signal, [mua, mus])

3. Net (CNN architecture): 
    * Input channels match CONFIG["in_channels"]
    * 3 Conv blocks with BatchNorm, ReLU, and pooling 
    * Automatic flatten dimension calculatino 
    * Fully connected regression head configurable via CONFIG 
    * Output dimension fixed at 2 (μa, μs′)

4. Training loop: 
    * GPU/CPU detection 
    * Train/validation loss computation 
    * best-model checkpoint saving 
    * optional printing of sample predictions 
    * automatic plotting of train/val loss curves (PNG)

5. Evaluation and logging: 
    Provides: 
        - MAE and RMSE for μa, μs′
        - prediction vs.ground truth arrays 
        - JSON logging of each run's configuration and metrics 
        - loss curve visualisation per run 
Dependencies 

Required: 
    * torch 
    * numpy
    * pandas 
    * scipy (Savitzky-Golay filtering)
    * matplotlib (loss curve plotting)

Recommended: 
    * seaborn (optional visualisation enhancements)

Inputs & Outputs 

Inputs: 
    * DTOFs_Homo_raw.csv - raw DTOF data with time in column 0
    * DTOFs_Homo_labels.csv - extracted (μa, μs′) labels 

Outputs: 
    * training and validation loss curves .png
    * MAE / RMSE metrics 
    * prediction arrays for downstream plotting 
    * JSON log for all runs

In [40]:
import os
import json
import re
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.signal import savgol_filter

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt

In [41]:
# 1. Global configuration 

CONFIG = {
    "csv_path": "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_raw.csv", 
    "label_csv_path": "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_labels.csv", 

    # Preprocessing 
    "sg_window": 21, 
    "sg_order": 1, 
    "eps": 1e-8, 
    "crop_t_max": 5.0,      # crop DTOF to 0-5ns

    # Channels: "single", "early_mid_late", "hybrid_4ch"
    "channel_mode": "single",

    # Data split + loader
    "train_frac": 0.8,
    "batch_size": 32,

    # Model
    "in_channels": 1,            # will be overwritten based on dataset.C
    "output_dim": 2,             # mua + mus
    "hidden_fc": [128, 64],

    # Training
    "lr": 1e-3,
    "epochs": 20,

    # Outputs / logging
    "run_name": "exp_single_w21_o1",
    "save_dir": "runs",
    "log_path": "runs/dtof_runs_log.json",
}

In [42]:
# 2. Label extraction helper

def extract_labels_from_dtof_csv(csv_path: str, label_csv_path: str):
    """
    Extracts (mua, mus) labels from DTOF column headers in the raw CSV.

    Expects column names: "mua: 0.005  mus: 2.0"
    Writes a CSV of labels with columns ["mua", "mus"].
    """

    df = pd.read_csv(csv_path)
    dtof_columns = df.columns[1:]  # skip time column
    labels = []

    for col in dtof_columns:
            col_clean = str(col).strip()
            match = re.search(r"mua:\s*([0-9.]+)\s+mus:\s*([0-9.]+)", col_clean)
            if not match:
                raise ValueError(f"Could not parse mua/mus from column '{col}'")
            mua_val = float(match.group(1))
            mus_val = float(match.group(2))
            labels.append((mua_val, mus_val))   
    
    label_df = pd.DataFrame(labels, columns=["mua", "mus"])
    label_df.to_csv(label_csv_path, index=False)
    print(f"[INFO] Saved labels to {label_csv_path} (N={len(label_df)})")

In [43]:
# 3. DTOFDataset: preprocessing + channel construction 

class DTOFDataset(Dataset):
    """
    DTOF dataset with preprocessing and flexible channel configurations.

    Preprocessing:
        - load CSV
        - crop time axis to [0, crop_t_max]
        - Savitzky-Golay smoothing
        - negative-value clipping
        - per-trace standardisation
        - channel construction:
            * "single"        -> 1 channel, full DTOF
            * "early_mid_late"-> 3 channels (early/mid/late masks)
            * "hybrid_4ch"    -> 4 channels (full + 3 masks)

    Returns:
        signal: (C, T) tensor (C = n. of channels)
        label:  (2,) tensor [mua, mus]
    """

    def __init__(self, csv_path: str, labels: np.ndarray, cfg: dict):
        super().__init__()
        self.cfg = cfg

        df = pd.read_csv(csv_path)

        # Time and DTOFs
        time_full = df.iloc[:, 0].values               # (T_full,)
        dtof_full = df.iloc[:, 1:].values.T            # (N, T_full)
        N, T_full = dtof_full.shape

        # Crop time axis
        t_mask = (time_full >= 0.0) & (time_full <= cfg["crop_t_max"])
        time = time_full[t_mask]                       # (T,)
        dtof = dtof_full[:, t_mask]                    # (N, T)

        # Savitzky–Golay smoothing
        dtof_smooth = savgol_filter(
            dtof,
            cfg["sg_window"],
            cfg["sg_order"],
            axis=1
        )

        # Clip negatives and standardise
        eps = cfg["eps"]
        dtof_smooth[dtof_smooth < 0] = eps

        mean = dtof_smooth.mean(axis=1, keepdims=True)
        std = dtof_smooth.std(axis=1, keepdims=True)
        dtof_std = (dtof_smooth - mean) / (std + eps)  # (N, T)

        # Build channels
        channels = self.build_channels(time, dtof_std, cfg["channel_mode"])
        # channels: (N, C, T)

        self.signals = torch.tensor(channels, dtype=torch.float32)  # (N,C,T)
        self.labels = torch.tensor(labels, dtype=torch.float32)     # (N,2)

        self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Construct channels based on the chosen mode:
            "single"         -> 1 channel, full DTOF
            "early_mid_late" -> 3 masked channels
            "hybrid_4ch"     -> 1 full + 3 masked = 4 channels
        """
        N, T = dtof.shape

        if mode == "single":
            # (N, 1, T)
            return dtof[:, None, :]

        # Define early/mid/late masks within cropped time
        early = ((t >= 0.0) & (t < 0.5)).astype(float)
        mid = ((t >= 0.5) & (t < 4.0)).astype(float)
        late = ((t >= 4.0) & (t <= self.cfg["crop_t_max"])).astype(float)
        masks = np.stack([early, mid, late], axis=0)  # (3, T)

        if mode == "early_mid_late":
            # Multiply each DTOF by masks to get 3 channels
            out = dtof[:, None, :] * masks[None, :, :]  # (N,3,T)
            return out

        if mode == "hybrid_4ch":
            # Channel 1: full DTOF
            ch_full = dtof[:, None, :]                   # (N,1,T)
            ch_bins = dtof[:, None, :] * masks[None, :, :]  # (N,3,T)
            return np.concatenate([ch_full, ch_bins], axis=1)  # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self) -> int:
        return self.N

    def __getitem__(self, idx: int):
        return self.signals[idx], self.labels[idx]   

In [44]:
# 4. CNN Model (Net) with flexible in_channels and FC head 

class Net(nn.Module):
    """
    1D CNN for DTOF-based regression to (μa, μs′).

    - Supports variable input channels (C) from CONFIG["in_channels"]
    - Uses three conv blocks with BatchNorm, ReLU, MaxPool
    - Automatically computes flatten dimension
    - FC layers controlled by CONFIG["hidden_fc"]
    """

    def __init__(self, cfg: dict, input_length: int):
        super().__init__()

        C = cfg["in_channels"]

        self.conv1 = nn.Sequential(
            nn.Conv1d(C, 32, kernel_size=7, padding=3),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv2 = nn.Sequential(
            nn.Conv1d(32, 32, kernel_size=5, padding=2),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        self.conv3 = nn.Sequential(
            nn.Conv1d(32, 16, kernel_size=3, padding=1),
            nn.BatchNorm1d(16),
            nn.ReLU(),
            nn.MaxPool1d(2)
        )

        # Determine flatten dim automatically
        with torch.no_grad():
            dummy = torch.zeros(1, C, input_length)  # (1,C,T)
            feat = self._forward_features(dummy)
            flatten_dim = feat.shape[1]  # (1, flatten_dim)

        # Build FC head from cfg["hidden_fc"]
        fc_layers = []
        last = flatten_dim
        for h in cfg["hidden_fc"]:
            fc_layers += [nn.Linear(last, h), nn.ReLU()]
            last = h

        fc_layers.append(nn.Linear(last, cfg["output_dim"]))
        self.fc = nn.Sequential(*fc_layers)

    def _forward_features(self, x: torch.Tensor) -> torch.Tensor:
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.conv3(x)
        x = x.flatten(1)  # (batch, -1)
        return x

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self._forward_features(x)
        x = self.fc(x)
        return x

In [45]:

# 5. Training loop with loss tracking, prediction inspection & plotting 

def train_model(
    model: nn.Module,
    train_loader: DataLoader,
    val_loader: DataLoader,
    loss_fn: nn.Module,
    optimizer: torch.optim.Optimizer,
    cfg: dict,
    device: torch.device
): 
    """
    Train the model with tracking of train/val loss curves, saving best model,
    and basic prediction inspection on the first validation batch.
    """
    num_epochs = cfg["epochs"]
    save_dir = cfg["save_dir"]
    run_name = cfg["run_name"]

    os.makedirs(save_dir, exist_ok=True)

    train_losses = []
    val_losses = []
    best_val = float("inf")
    best_path = os.path.join(save_dir, f"{run_name}_best.pth")

    for epoch in range(num_epochs):
        # Training phase 
        model.train()
        running_train = 0.0

        for signals, labels in train_loader:
            signals = signals.to(device)
            labels = labels.to(device).float()

            optimizer.zero_grad()
            preds = model(signals)            # (B, 2)
            loss = loss_fn(preds, labels)
            loss.backward()
            optimizer.step()

            running_train += loss.item()

        epoch_train = running_train / len(train_loader)
        train_losses.append(epoch_train)

        # Validation phase 
        model.eval()
        running_val = 0.0

        with torch.no_grad():
            for batch_idx, (signals, labels) in enumerate(val_loader):
                signals = signals.to(device)
                labels = labels.to(device).float()

                preds = model(signals)
                loss = loss_fn(preds, labels)
                running_val += loss.item()

                # Inspect prediction from first val batch of first epoch 
                print("[VAL SAMPLE] First batch predictions vs labels:")
                print("Preds (μa, μs'):", preds[:5].cpu())
                print("Labels (μa, μs'):", labels[:5].cpu())
                print("Abs error:", (preds[:5] - labels[:5]).abs().cpu())
                print("---------------------------------------------------")

        epoch_val = running_val / len(val_loader)
        val_losses.append(epoch_val)

        print(f"Epoch {epoch+1}/{num_epochs} | "
              f"Train: {epoch_train:.4f} | Val: {epoch_val:.4f}")

        # Save best model
        if epoch_val < best_val:
            best_val = epoch_val
            torch.save(model.state_dict(), best_path)

        # Plot loss curves 
        fig_path = os.path.join(save_dir, f"{run_name}_loss_curves.png")
        plt.figure()
        plt.plot(train_losses, label="Train loss")
        plt.plot(val_losses, label="Val loss")
        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title(f"Loss curves: {run_name}")
        plt.legend()
        plt.grid(True)
        plt.tight_layout()
        plt.savefig(fig_path, dpi=150)
        plt.close()

    results = {
        "train_losses": train_losses,
        "val_losses": val_losses,
        "best_val": best_val,
        "best_path": best_path,
        "loss_plot": fig_path
    }
    return results


In [46]:
# 6. Evaluation helper (MAE / RMSE) on a given DataLoader 

class ModelEvaluator:
    """
    Evaluate a trained model on a DataLoader and compute MAE/RMSE for (μa, μs').
    """

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader: DataLoader):
        all_preds = []
        all_labels = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels = labels.to(self.device).float()

                preds = self.model(signals)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

        all_preds = torch.cat(all_preds, dim=0)
        all_labels = torch.cat(all_labels, dim=0)

        abs_err = torch.abs(all_preds - all_labels)
        sq_err = (all_preds - all_labels) ** 2

        mae = abs_err.mean(dim=0)                 # (2,)
        rmse = torch.sqrt(sq_err.mean(dim=0))     # (2,)

        metrics = {
            "MAE": mae.numpy(),        # [MAE_mua, MAE_mus]
            "RMSE": rmse.numpy(),      # [RMSE_mua, RMSE_mus]
            "preds": all_preds.numpy(),
            "labels": all_labels.numpy()
        }
        return metrics


In [47]:
# 7. logging of runs (config + metrics) to JSON 

def log_run(cfg: dict, results: dict, log_path: str):
    """
    Append a run entry to a JSON log file, including config and metrics.
    """
    os.makedirs(os.path.dirname(log_path), exist_ok=True)

    entry = {
        "timestamp": datetime.now().isoformat(timespec="seconds"),
        "run_name": cfg["run_name"],
        "sg_window": cfg["sg_window"],
        "sg_order": cfg["sg_order"],
        "channel_mode": cfg["channel_mode"],
        "crop_t_max": cfg["crop_t_max"],
        "hidden_fc": cfg["hidden_fc"],
        "lr": cfg["lr"],
        "epochs": cfg["epochs"],
        "batch_size": cfg["batch_size"],
        "best_val": results["best_val"],
        "loss_plot": results["loss_plot"],
        "model_path": results["best_path"]
    }

    if os.path.exists(log_path):
        with open(log_path, "r") as f:
            data = json.load(f)
    else:
        data = []

    data.append(entry)
    with open(log_path, "w") as f:
        json.dump(data, f, indent=2)

    print(f"[LOG] Appended run entry to {log_path}")



In [48]:
# 8. Single experiment runner

def run_experiment(cfg: dict):
    """
    Run a single experiment:
        - ensure labels CSV exists (optional)
        - build dataset, split, loaders
        - build model & train
        - log results
    """

    # Load labels
    label_df = pd.read_csv(cfg["label_csv_path"])
    labels_arr = label_df.values.astype(np.float32)  # (N,2)

    # Build dataset
    dataset = DTOFDataset(cfg["csv_path"], labels_arr, cfg)

    # Ensure in_channels matches dataset
    cfg["in_channels"] = dataset.C

    # Train/val split
    n_total = len(dataset)
    n_train = int(cfg["train_frac"] * n_total)
    n_val = n_total - n_train

    generator = torch.Generator().manual_seed(42)
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=generator)

    train_loader = DataLoader(train_ds, batch_size=cfg["batch_size"], shuffle=True)
    val_loader = DataLoader(val_ds, batch_size=cfg["batch_size"], shuffle=False)

    # Model, loss, optimizer, device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"[INFO] Using device: {device}")

    model = Net(cfg, input_length=dataset.T).to(device)
    loss_fn = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=cfg["lr"])

    # Train
    results = train_model(model, train_loader, val_loader, loss_fn, optimizer, cfg, device)

    # Evaluate (optional, on validation set)
    evaluator = ModelEvaluator(model, device)
    metrics = evaluator.evaluate(val_loader)
    print(f"[EVAL] MAE (μa, μs'): {metrics['MAE']}")
    print(f"[EVAL] RMSE (μa, μs'): {metrics['RMSE']}")

    # Log
    log_run(cfg, results, cfg["log_path"])

    return results, metrics



In [49]:
# 9. Simple grid search over a few configurations 

def grid_search():
    """
    Example grid search over SG window length, channel mode, and learning rate.
    """
    base_cfg = CONFIG

    windows = [11, 21, 31, 41]
    orders = [1,2,3,4]
    modes = ["single", "early_mid_late", "hybrid_4ch"]
    lrs = [1e-3, 3e-4]

    for w in windows:
        for o in orders:
            for mode in modes:
                for lr in lrs:
                    cfg = dict(base_cfg)  # shallow copy
                    cfg["sg_window"] = w
                    cfg["sg_order"] = o
                    cfg["channel_mode"] = mode
                    cfg["lr"] = lr
                    cfg["run_name"] = f"{mode}_w{w}_o{o}_lr{lr}"

                    print(f"\n=== Running {cfg['run_name']} ===")
                    run_experiment(cfg)

In [50]:
# Entry point, to make the training run only when the files are executed from this pipeline template and not imported
if __name__ == "__main__":
    results, metrics = run_experiment(CONFIG)

[INFO] Using device: cpu
[VAL SAMPLE] First batch predictions vs labels:
Preds (μa, μs'): tensor([[0.0103, 1.5766],
        [0.0104, 1.6607],
        [0.0290, 1.7042],
        [0.0201, 1.4921],
        [0.0633, 1.3061]])
Labels (μa, μs'): tensor([[7.1922e-03, 2.8356e+00],
        [1.6799e-02, 8.8183e+00],
        [1.1679e-02, 1.0500e+01],
        [1.8964e-02, 3.3764e+00],
        [3.9238e-02, 2.0000e+00]])
Abs error: tensor([[3.0732e-03, 1.2589e+00],
        [6.4133e-03, 7.1576e+00],
        [1.7363e-02, 8.7958e+00],
        [1.1176e-03, 1.8842e+00],
        [2.4096e-02, 6.9387e-01]])
---------------------------------------------------
[VAL SAMPLE] First batch predictions vs labels:
Preds (μa, μs'): tensor([[0.0023, 1.6060],
        [0.0193, 1.4890],
        [0.0388, 1.4024],
        [0.0082, 1.6415],
        [0.0490, 1.3753]])
Labels (μa, μs'): tensor([[0.0117, 4.3869],
        [0.0242, 4.0203],
        [0.0392, 3.6843],
        [0.0190, 8.8183],
        [0.0273, 2.3814]])
Abs error: 