In [6]:
import os
import json
import re
from datetime import datetime

import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
from scipy.optimize import curve_fit

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader, random_split

import matplotlib.pyplot as plt

from pathlib import Path
from typing import List, Dict, Any, Iterable



In [None]:
from training_core import DTOFDataset, Net, ModelEvaluator, get_in_channels

In [9]:
# JSON log path 
JSON_PATH = Path(
    "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/"
    "Year 3/Research Project in Biomedical Engineering/Code/JSON logs/dtof_runs_log.json"
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", DEVICE)


Using device: cpu


In [3]:
# Model Evaluation and visualisation of performance (MAE / RMSE, percentage error)
class ModelEvaluator:
    """
    Evaluate a trained model on a DataLoader and compute MAE/RMSE for (μa, μs'),
    plus percentage error plots vs actual values.
    """

    def __init__(self, model: nn.Module, device: torch.device):
        self.model = model
        self.device = device
        self.model.to(device)
        self.model.eval()

    def evaluate(self, data_loader: DataLoader, cfg: dict):
        all_preds  = []
        all_labels = []

        with torch.no_grad():
            for signals, labels in data_loader:
                signals = signals.to(self.device)
                labels  = labels.to(self.device).float()

                preds = self.model(signals)

                all_preds.append(preds.cpu())
                all_labels.append(labels.cpu())

        all_preds  = torch.cat(all_preds,  dim=0)  # (N,2)
        all_labels = torch.cat(all_labels, dim=0)  # (N,2)

        # Basic metrics: MAE and RMSE
        abs_err = torch.abs(all_preds - all_labels)
        sq_err  = (all_preds - all_labels) ** 2

        mae  = abs_err.mean(dim=0)             # (2,)
        rmse = torch.sqrt(sq_err.mean(dim=0))  # (2,)

        # ---------- Percentage error vs Actual ----------
        preds_np  = all_preds.numpy()
        labels_np = all_labels.numpy()

        # Avoid division by very small numbers
        eps = 1e-8
        denom = np.maximum(np.abs(labels_np), eps)  # (N,2)

        pct_error = 100.0 * (preds_np - labels_np) / denom  # signed %
        abs_pct_error = np.abs(pct_error)                   # absolute %

        # Scatter plots: Actual vs % error for μa and μs′
        save_dir = cfg["save_dir"]
        run_name = cfg["run_name"]
        os.makedirs(save_dir, exist_ok=True)

        # error vs true μa plot
        fig_mua = os.path.join(save_dir, f"{run_name}_pct_error_mua.png")
        x = labels_np[:, 0]            # true μa
        y = abs_pct_error[:, 0]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 1.0, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·exp(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μa exponential fit failed:", e)

        plt.xlabel("True μa")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μa: {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mua, dpi=150)
        plt.close()

        # error vs true μs′ plot
        fig_mus = os.path.join(save_dir, f"{run_name}_pct_error_mus.png")
        x = labels_np[:, 1]            # true μs'
        y = abs_pct_error[:, 1]        # absolute percentage error

        plt.figure()
        plt.scatter(x, y, s=10, alpha=0.6, label="Absolute % error")

        # Exponential model
        def _exp_model(x, a, b, c):
            return a * np.exp(-b * x) + c

        # Fit curve
        try:
            popt, _ = curve_fit(_exp_model, x, y, p0=[100, 0.5, 0.0], maxfev=5000)
            x_fit = np.linspace(min(x), max(x), 300)
            y_fit = _exp_model(x_fit, *popt)
            plt.plot(
                x_fit, y_fit, "r-", linewidth=2,
                label=f"Fit: a·e^(-b·x) + c\n"
                    f"a={popt[0]:.2f}, b={popt[1]:.2f}, c={popt[2]:.2f}"
            )
        except Exception as e:
            print("[WARN] μs' exponential fit failed:", e)

        plt.xlabel("True μs'")
        plt.ylabel("Absolute % error")
        plt.title(f"Percentage error vs Actual μs': {run_name}")
        plt.grid(True)
        plt.legend()
        plt.tight_layout()
        plt.savefig(fig_mus, dpi=150)
        plt.close()

        metrics = {
            "MAE": mae.numpy(),          # [MAE_mua, MAE_mus]
            "RMSE": rmse.numpy(),        # [RMSE_mua, RMSE_mus]
            "preds": preds_np,
            "labels": labels_np,
            "pct_error": pct_error,      # signed %
            "abs_pct_error": abs_pct_error,
            "pct_error_plots": {
                "mua": fig_mua,
                "mus": fig_mus,
            }
        }
        return metrics


In [14]:
JSON_PATH = Path(
    "/Users/lydialichen/Library/CloudStorage/OneDrive-UniversityCollegeLondon/"
    "Year 3/Research Project in Biomedical Engineering/Code/JSON logs/dtof_runs_log.json"
)


def load_run_logs(json_path: Path) -> List[Dict[str, Any]]:
    """
    Load DTOF run logs from a JSON file and normalise them into a list of dicts.
    """
    with json_path.open("r") as f:
        data = json.load(f)

    if isinstance(data, list):
        return data

    if isinstance(data, dict) and "runs" in data and isinstance(data["runs"], list):
        return data["runs"]

    if isinstance(data, dict):
        runs = [v for v in data.values() if isinstance(v, dict)]
        if runs:
            return runs

    raise ValueError("Unrecognised JSON structure for run logs.")


def print_global_best_metrics(runs: List[Dict[str, Any]]) -> None:
    """
    Find and print:
      - Best config by best_val (MSE)
      - Best config by MAE_μa, MAE_μs'
      - Best config by RMSE_μa, RMSE_μs'
    """

    best_mse_run = None
    best_mse_val = float("inf")

    best_mae_mua_run = None
    best_mae_mua_val = float("inf")

    best_mae_mus_run = None
    best_mae_mus_val = float("inf")

    best_rmse_mua_run = None
    best_rmse_mua_val = float("inf")

    best_rmse_mus_run = None
    best_rmse_mus_val = float("inf")

    for r in runs:
        # ----- MSE -----
        mse = r.get("best_val", None)
        if mse is not None and mse < best_mse_val:
            best_mse_val = mse
            best_mse_run = r

        mae = r.get("MAE", None)
        rmse = r.get("RMSE", None)

        # ----- MAE -----
        if mae is not None and len(mae) >= 2:
            mae_mua, mae_mus = float(mae[0]), float(mae[1])

            if mae_mua < best_mae_mua_val:
                best_mae_mua_val = mae_mua
                best_mae_mua_run = r

            if mae_mus < best_mae_mus_val:
                best_mae_mus_val = mae_mus
                best_mae_mus_run = r

        # ----- RMSE -----
        if rmse is not None and len(rmse) >= 2:
            rmse_mua, rmse_mus = float(rmse[0]), float(rmse[1])

            if rmse_mua < best_rmse_mua_val:
                best_rmse_mua_val = rmse_mua
                best_rmse_mua_run = r

            if rmse_mus < best_rmse_mus_val:
                best_rmse_mus_val = rmse_mus
                best_rmse_mus_run = r

    print("=" * 80)
    print("GLOBAL BEST CONFIGS BY METRIC")
    print("=" * 80)

    # Best MSE
    if best_mse_run is not None:
        print("\nBest by MSE (best_val):")
        print(f"  run_name: {best_mse_run['run_name']}")
        print(f"  channel_mode: {best_mse_run.get('channel_mode')}")
        print(f"  sg_window: {best_mse_run.get('sg_window')}, "
              f"sg_order: {best_mse_run.get('sg_order')}")
        print(f"  lr: {best_mse_run.get('lr')}")
        print(f"  best_val (MSE): {best_mse_val:.6f}")
    else:
        print("\nBest by MSE: (no runs found)")

    # Best MAE μa
    if best_mae_mua_run is not None:
        print("\nBest by MAE(μa):")
        print(f"  run_name: {best_mae_mua_run['run_name']}")
        print(f"  channel_mode: {best_mae_mua_run.get('channel_mode')}")
        print(f"  MAE(μa): {best_mae_mua_val:.6f}")
    else:
        print("\nBest by MAE(μa): (no runs with MAE)")

    # Best MAE μs'
    if best_mae_mus_run is not None:
        print("\nBest by MAE(μs'):")
        print(f"  run_name: {best_mae_mus_run['run_name']}")
        print(f"  channel_mode: {best_mae_mus_run.get('channel_mode')}")
        print(f"  MAE(μs'): {best_mae_mus_val:.6f}")
    else:
        print("\nBest by MAE(μs'): (no runs with MAE)")

    # Best RMSE μa
    if best_rmse_mua_run is not None:
        print("\nBest by RMSE(μa):")
        print(f"  run_name: {best_rmse_mua_run['run_name']}")
        print(f"  channel_mode: {best_rmse_mua_run.get('channel_mode')}")
        print(f"  RMSE(μa): {best_rmse_mua_val:.6f}")
    else:
        print("\nBest by RMSE(μa): (no runs with RMSE)")

    # Best RMSE μs'
    if best_rmse_mus_run is not None:
        print("\nBest by RMSE(μs'):")
        print(f"  run_name: {best_rmse_mus_run['run_name']}")
        print(f"  channel_mode: {best_rmse_mus_run.get('channel_mode')}")
        print(f"  RMSE(μs'): {best_rmse_mus_val:.6f}")
    else:
        print("\nBest by RMSE(μs'): (no runs with RMSE)")


def main():
    runs = load_run_logs(JSON_PATH)
    print(f"Loaded {len(runs)} runs from log.")
    print_global_best_metrics(runs)


if __name__ == "__main__":
    main()


Loaded 192 runs from log.
GLOBAL BEST CONFIGS BY METRIC

Best by MSE (best_val):
  run_name: early_mid_late_w11_o4_lr0.001
  channel_mode: early_mid_late
  sg_window: 11, sg_order: 4
  lr: 0.001
  best_val (MSE): 0.020709

Best by MAE(μa):
  run_name: early_mid_late_w31_o2_lr0.0003
  channel_mode: early_mid_late
  MAE(μa): 0.011018

Best by MAE(μs'):
  run_name: hybrid_4ch_w11_o1_lr0.001
  channel_mode: hybrid_4ch
  MAE(μs'): 0.150076

Best by RMSE(μa):
  run_name: early_mid_late_w31_o2_lr0.0003
  channel_mode: early_mid_late
  RMSE(μa): 0.013267

Best by RMSE(μs'):
  run_name: hybrid_4ch_w31_o3_lr0.0003
  channel_mode: hybrid_4ch
  RMSE(μs'): 0.228337
