In [3]:
import os 
import re
import random 
import torch 
from torch import nn 
import numpy as np 
import matplotlib.pyplot as plt 
from torch.utils.data import Dataset, DataLoader , random_split
import pandas as pd 
import scipy 
from scipy.signal import savgol_filter 
import h5py
random.seed(0)

In [6]:
# Define Savitzky Golay filter parameters 
order = 1
frame_length = 21
eps = 1e-8

In [9]:
class DTOFDataset(Dataset):
    """
    DTOF dataset loaded from a MATLAB .mat file.

    Expected variables in .mat:
        X : (Nt, N) or (N, Nt)  reflectance DTOFs
        y : (2, N)  or (N, 2)   [mua, mus]
        t : (Nt,)              time vector (seconds, ~1e-12 resolution)

    Preprocessing:
        - convert t from seconds -> ns
        - crop time axis to [0, crop_t_max] ns
        - Savitzky–Golay smoothing
        - clip small / negative values
        - log-transform reflectance (Option B)
        - optional per-trace standardisation
        - channel construction:
            * "single"
            * "early_mid_late"
            * "hybrid_4ch"

    Returns:
        signal: (C, T) tensor
        label:  (2,) tensor  [mua, mus]  (raw, for now)
    """

    def __init__(
            self,
            mat_path: str,
            cfg: dict
    ):
            super().__init__()
            self.cfg = cfg

            # Load HDF5 .mat
            with h5py.File(mat_path, "r") as f:
                X = np.array(f["X"], dtype=np.float32)
                y = np.array(f["y"], dtype=np.float32)
                t = np.array(f["t"], dtype=np.float32).squeeze()

            # Ensure shapes 
            # X -> (N, Nt)
            if X.shape[0] == t.shape[0]: 
                 X = X.T
            X = X.astype(np.float32)

            # y -> (N,2)
            if y.shape[0] ==2: 
                 y = y.T
            y = y.astype(np.float32)

            if X.shape[1] != t.shape[0]:
                 raise ValueError("Time axis length does not match DTOF length")
            
            # time:seconds -> ns
            t_ns = t * 1e9

            # crop 
            crop_t_max = float(cfg["crop_t_max"])  # ns
            t_mask = (t_ns >= 0.0) & (t_ns <= crop_t_max)

            if not np.any(t_mask):
                raise ValueError(
                    f"Cropping removed all samples. "
                    f"t_ns range=[{t_ns.min():.3g}, {t_ns.max():.3g}] ns"
                )

            t_ns = t_ns[t_mask]
            dtof = X[:, t_mask]          # (N, T)

            N, T = dtof.shape

            # Savitzky-Golay 
            sg_window = int(cfg["sg_window"])
            sg_order = int(cfg["sg_order"])
            
            # enforce validity
            if sg_window % 2 == 0:
                sg_window += 1
            if sg_window <= sg_order:
                sg_window = sg_order + 2
            if sg_window % 2 == 0:
                sg_window += 1
            if sg_window > T:
                sg_window = T if T % 2 == 1 else T - 1

            if sg_window >= 3:
                dtof = savgol_filter(dtof, sg_window, sg_order, axis=1)
            
            # clip + log-transform
            eps = float(cfg.get("eps", 1e-12))
            dtof[dtof < eps] = eps
            dtof = np.log(dtof)


            # channel construction
            channels = self.build_channels(t_ns, dtof, cfg["channel_mode"])

            self.signals = torch.tensor(channels, dtype=torch.float32)
            self.labels = torch.tensor(y, dtype=torch.float32)

            self.N, self.C, self.T = self.signals.shape

    def build_channels(self, t_ns: np.ndarray, dtof: np.ndarray, mode: str) -> np.ndarray:
        """
        Channel construction using ns time gates:
            early: 0–0.5 ns
            mid:   0.5–4 ns
            late:  4–crop_t_max ns
        """
        N, T = dtof.shape
        crop_t_max = float(self.cfg["crop_t_max"])

        if mode == "single":
            return dtof[:, None, :]  # (N,1,T)

        early = ((t_ns >= 0.0) & (t_ns < 0.5)).astype(np.float32)
        mid   = ((t_ns >= 0.5) & (t_ns < 4.0)).astype(np.float32)
        late  = ((t_ns >= 4.0) & (t_ns <= crop_t_max)).astype(np.float32)

        masks = np.stack([early, mid, late], axis=0)  # (3,T)

        if mode == "early_mid_late":
            return dtof[:, None, :] * masks[None, :, :]  # (N,3,T)

        if mode == "hybrid_4ch":
            full = dtof[:, None, :]
            gated = dtof[:, None, :] * masks[None, :, :]
            return np.concatenate([full, gated], axis=1)  # (N,4,T)

        raise ValueError(f"Unknown channel_mode: {mode}")

    def __len__(self):
        return self.N

    def __getitem__(self, idx):
        return self.signals[idx], self.labels[idx]

In [10]:
matlab_path = "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/dataset_homo_small.mat"

cfg = {
    "crop_t_max": 6.0,        # ns
    "sg_window": frame_length,
    "sg_order": order,
    "eps": 1e-12,
    "per_trace_standardise": True,
    "channel_mode": "hybrid_4ch",
}

ds = DTOFDataset(matlab_path, cfg)
x, y = ds[0]
print(x.shape, y.shape)  # (C, T), (2,)

torch.Size([4, 3000]) torch.Size([2])


In [5]:
class Net(nn.Module):
    def __init__(self, in_channels = 3, input_length = 3000, output_dim = 2):
        """
        CNN for 1D DTOF signals with 3 input channels (early / mid / late masks)
        Blocks: [Conv1d -> BN -> ReLU -> MaxPool1d] x 3 -> Flatten -> FCs.
        """

        super().__init__()

        # Convolution blocks
        self.conv1 = nn.Conv1d(in_channels=in_channels, out_channels=32, kernel_size=7, padding=3)
        self.bn1 = nn.BatchNorm1d(32)
        self.pool1 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv2 = nn.Conv1d(in_channels=32, out_channels=32, kernel_size=5, padding=2)
        self.bn2 = nn.BatchNorm1d(32)
        self.pool2 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.conv3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm1d(16)
        self.pool3 = nn.MaxPool1d(kernel_size=2, stride=2)

        self.act = nn.ReLU()

        # Compute the flattened feature size dynamically
        with torch.no_grad():
            dummy = torch.zeros(1, in_channels, input_length)
            feat = self._forward_features(dummy)
            self.flatten_dim = feat.shape[1]

        # Fully connected layers
        self.fc1 = nn.Linear(self.flatten_dim, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, output_dim)

    def _forward_features(self, x):
        """Convolutional feature extractor followed by flatten."""
        # Block 1
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.act(x)
        x = self.pool1(x)

        # Block 2
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.act(x)
        x = self.pool2(x)

        # Block 3
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.act(x)
        x = self.pool3(x)

        # Flatten to (batch, features)
        x = x.view(x.size(0), -1)
        return x

    def forward(self, x):
        """
        x: (batch_size, in_channels = 3, time_points = 3000)
        """
        x = self._forward_features(x) # (batch, flatten_dim)
        x = self.act(self.fc1(x))
        x = self.act(self.fc2(x))
        x = self.fc3(x)
        return x


In [6]:
def train_model(model, train_loader, val_loader, loss_fn, optimizer, num_epochs, device, save_path = None): 
    """ 
    Train the CNN with a training + validation loop 

    Inputs: 
        model: instance of Net 
        train_loader: DataLoader for training set (yields signals, labels)
        val_loader: DataLoader for validation set 
        loss_fun : loss function, e.g. nn.MSELoss()
        optimiser: optimiser, e.g. torch.optim.Adam(...)
        num_epochs: number of epochs to train 
        device: u
        save_path: optional path to save best model, to use later when we develop more models (str or None)
    """
    # Move model to device 
    model.to(device)
    best_val_loss = float("inf")

    for epoch in range(num_epochs): 
        print(f"\nEpoch {epoch + 1}/ {num_epochs}")

        # TRAINING PHASE 
        model.train()
        running_loss = 0.0
        for signals, labels in train_loader: 
            # Move the batch to device 
            signals = signals.to(device) # (batch, 3, T = 3000)
            labels = labels.to(device).float() # (batch,) or (batch, 1)

            # Zero gradients 
            optimizer.zero_grad()

            # Forward pass 
            preds = model(signals) # (batch, 1) or (batch, )
            preds = preds.view_as(labels) # reshape the predictions to have the same shape as labels 

            # Loss
            loss = loss_fn(preds, labels)
            
            # Backward + update 
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

    train_loss = running_loss / len(train_loader)
    print(f"Train Loss: {train_loss:.4f}")

    # VALIDATION PHASE
    model.eval()
    val_loss = 0.0

    with torch.no_grad():
        for signals, labels in val_loader: 
            signals = signals.to(device)
            labels = labels.to(device).float()
            
            preds = model(signals)
            preds = preds.view_as(labels)

            loss = loss_fn(preds, labels)
            val_loss += loss.item()
            
    val_loss /= len(val_loader)
    print(f"Validation Loss: {val_loss:.4f}")

    if val_loss < best_val_loss: 
        print(" -> Best validation loss so far, saving the model.")
        best_val_loss = val_loss
        if save_path is not None: 
            torch.save(model.state_dict(), save_path)




In [7]:
def train_test_split(data, target, test_size = 0.2, shuffle = True, random_state = None):
    """
    Splits the data and target lists into training and validation subsets. 
    """

    if len(data) != len(target): 
        raise ValueError("Data and target must have the same length.")
    
    if shuffle: 
        if random_state is not None: 
            random.seed(random_state)
        pairs = list(zip(data, target))
        random.shuffle(pairs)
        data, target = zip(*pairs)
    split_idx = int(len(data) * (1 - test_size))

    return (
        data[:split_idx], # X_train
        data[split_idx:], # X_val 
        target[:split_idx], # y_train
        target[split_idx:] # y_val
    )

In [8]:
def extract_labels_from_dtof_csv(csv_path, label_csv_path):
    """
    Extract (mua, mus) labels from DTOF CSV column headers and save to a new CSV file.

    Input: 
        csv_path : str
            Path to the large DTOF CSV file (first column = time_ns, others = DTOFs).
        label_csv_path : str
            Where to save the generated labels CSV.
    """

    df = pd.read_csv(csv_path)

    # DTOF columns start from index 1
    dtof_columns = df.columns[1:]

    labels = []

    for col in dtof_columns:
        col_clean = col.strip()

        # Regex matches text like "mua: 0.015 mus: 0.75"
        match = re.search(r"mua:\s*([0-9.]+)\s+mus:\s*([0-9.]+)", col_clean)
        if not match:
            raise ValueError(f"Could not parse mua/mus from column '{col}'")

        mua_val = float(match.group(1))
        mus_val = float(match.group(2))
        labels.append((mua_val, mus_val))

    labels = np.asarray(labels, dtype=np.float32)

    # Save to CSV
    label_df = pd.DataFrame(labels, columns=["mua", "mus"])
    label_df.to_csv(label_csv_path, index=False)

    print(f"Labels extracted and saved to: {label_csv_path}")
    print(f"Total signals: {len(labels)}")
    print(f"Example:\n{label_df.head()}")


In [9]:
class ModelEvaluator: 
    """
    Evaluates a trained model on a dataset and computes inversion accuracy metrics.  
    Assumes that targets are 2D: (mua, mus)
    """

    def __init__(self, model, device):
        self.model = model 
        self.device = device
        self.model.to(device)
        self.model.eval() # evaluation mode
    
    def evaluate(self, data_loader): 
        all_preds = []
        all_labels = []

        with torch.no_grad(): 
            for signals, labels in data_loader: 
                signals = signals.to(self.device)
                labels = labels.to(self.device).float()
                
                preds = self.model(signals)
                preds = preds.view_as(labels)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

        all_preds = torch.cat(all_preds, dim = 0)
        all_labels = torch.cat(all_labels, dim = 0)

        # Compute errors 
        abs_err = torch.abs(all_preds - all_labels) # (N, 2)
        sq_err = (all_preds - all_labels) ** 2

        mae = abs_err.mean(dim = 0)
        rmse = torch.sqrt(sq_err.mean(dim = 0))

        metrics = {
            "MAE": mae.numpy(), 
            "RMSE": rmse.numpy(),  
            "preds": all_preds.numpy(), 
            "lables": all_labels.numpy()
        }

        return metrics

INITIAL DATA: Converting .mat file containing DTOFs into 2 .csv files: labels + DTOF data

Data description: Labels csv saved as DTOFs_Homo_labels.csv and DTOFs_Homo_raw.csv

In [10]:
# Extract the labels from csv_path to label_csv_path

csv_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_raw.csv"
label_csv_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_labels.csv"
extract_labels_from_dtof_csv(csv_path, label_csv_path)

frame_length =21 
order = 1

# Load labels as a numpy array 
label_df = pd.read_csv(label_csv_path) # columns: ["mua", "mus"]
labels_arr = label_df.values.astype(np.float32)

# Create the DTOF dataset with labels

dataset = DTOFDataset(
    csv_path= csv_path,
    labels = labels_arr, 
    window_length= frame_length, 
    polyorder= order, 
    eps = 1e-8, 
)

loader = DataLoader(dataset, batch_size = 32, shuffle = True)
signals, labels = next(iter(loader))
print("signals:", signals.shape)
print("labels:", labels.shape)

Labels extracted and saved to: /Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_labels.csv
Total signals: 400
Example:
        mua  mus
0  0.005000  2.0
1  0.005644  2.0
2  0.006371  2.0
3  0.007192  2.0
4  0.008119  2.0
signals: torch.Size([32, 3, 3000])
labels: torch.Size([32, 2])


In [None]:
# Extract the labels from csv_path to label_csv_path

csv_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_raw.csv"
label_csv_path = r"/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/Year 3/Research Project in Biomedical Engineering/Code/Pre-obtained data/DTOFs_Homo_labels.csv"
extract_labels_from_dtof_csv(csv_path, label_csv_path)

frame_length =21 
order = 1

# Load labels as a numpy array 
label_df = pd.read_csv(label_csv_path) # columns: ["mua", "mus"]
labels_arr = label_df.values.astype(np.float32)

# Create the DTOF dataset with labels

dataset = DTOFDataset(
    csv_path= csv_path,
    labels = labels_arr, 
    window_length= frame_length, 
    polyorder= order, 
    eps = 1e-8, 
)

loader = DataLoader(dataset, batch_size = 32, shuffle = True)
signals, labels = next(iter(loader))
print("signals:", signals.shape)
print("labels:", labels.shape)

FINAL DATA: Converting .mat file containing DTOFs into 2 .csv files: labels + DTOF data

In [11]:
# Build dataset splits 
train_frac = 0.8 
n_total = len(dataset)
n_train = int(train_frac * n_total)
n_val = n_total - n_train

generator = torch.Generator().manual_seed(42) # for reproducibility 

train_dataset, val_dataset = random_split(
    dataset, 
    [n_train, n_val], 
    generator = generator
)

print("Total samples:", n_total)
print("Train samples:", len(train_dataset))
print("Val samples:  ", len(val_dataset))

# Data Loaders 
batch_size = 32

train_loader = DataLoader(
    train_dataset, 
    batch_size = batch_size, 
    shuffle = True
)

val_loader = DataLoader(
    val_dataset, 
    batch_size = batch_size, 
    shuffle = False
)

Total samples: 400
Train samples: 320
Val samples:   80


In [None]:
# 1. Device 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 2. Model: instantiating the CNN and moving the model to the device before training
model = Net(
    in_channels=3, 
    input_length = dataset.T,
    output_dim = 2 # predicting both (mua, mus)
).to(device)

# 3. Loss + optimizer 
loss_fn = torch.nn.MSELoss() # MSE error 
optimizer = torch.optim.Adam(model.parameters(), lr = 1e-3) # learning rate of 0.001, decrease to 1e-4 if the training is unstable and increase to 3e-3 if training is too slow 

# 4. Train 
num_epochs = 20 

train_model(
    model = model, 
    train_loader = train_loader, 
    val_loader= val_loader, 
    loss_fn= loss_fn, 
    optimizer= optimizer, 
    num_epochs= num_epochs, 
    device=device, 
    save_path= "best_dtof_cnn.pth"

)

# Forward pass with real DTOF batch, to verify the model pipeline from input -> output runs without shape errors
outputs = model(signals)

# Instantiate evaluator 
evaluator = ModelEvaluator(model, device)

# Run evaluation on validation loader 
metrics = evaluator.evaluate(val_loader)

print("MAE:", metrics["MAE"])
print("RMSE:", metrics["RMSE"])