#### Alternate Backbone Check VNet

In [2]:
import torch
print("Torch:", torch.__version__)
print("CUDA:", torch.version.cuda)
print("GPU available:", torch.cuda.is_available())


Torch: 1.12.1+cu113
CUDA: 11.3
GPU available: True


#### VNet Model Class

In [17]:
import torch
import torch.nn as nn

class VNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        self.conv1 = nn.Conv3d(in_channels, out_channels, 3, padding=1)
        self.in1 = nn.InstanceNorm3d(out_channels)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv3d(out_channels, out_channels, 3, padding=1)
        self.in2 = nn.InstanceNorm3d(out_channels)

    def forward(self, x):
        x = self.relu(self.in1(self.conv1(x)))
        x = self.relu(self.in2(self.conv2(x)))
        return x

# ...existing code...
# UPDATE: Remove sigmoid from VNet forward (return raw logits)
class VNet(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, base_channels=16):
        super().__init__()
        c1, c2, c3, c4 = base_channels, base_channels*2, base_channels*4, base_channels*8
        self.enc1 = VNetBlock(in_channels, c1)
        self.pool1 = nn.MaxPool3d(2)
        self.enc2 = VNetBlock(c1, c2)
        self.pool2 = nn.MaxPool3d(2)
        self.enc3 = VNetBlock(c2, c3)
        self.pool3 = nn.MaxPool3d(2)
        self.bottleneck = VNetBlock(c3, c4)
        self.up3 = nn.ConvTranspose3d(c4, c3, 2, stride=2)
        self.dec3 = VNetBlock(c3 + c3, c3)
        self.up2 = nn.ConvTranspose3d(c3, c2, 2, stride=2)
        self.dec2 = VNetBlock(c2 + c2, c2)
        self.up1 = nn.ConvTranspose3d(c2, c1, 2, stride=2)
        self.dec1 = VNetBlock(c1 + c1, c1)
        self.final = nn.Conv3d(c1, out_channels, 1)

    def forward(self, x):
        e1 = self.enc1(x)
        e2 = self.enc2(self.pool1(e1))
        e3 = self.enc3(self.pool2(e2))
        b  = self.bottleneck(self.pool3(e3))
        d3 = self.up3(b)
        d3 = self.dec3(torch.cat([d3, e3], dim=1))
        d2 = self.up2(d3)
        d2 = self.dec2(torch.cat([d2, e2], dim=1))
        d1 = self.up1(d2)
        d1 = self.dec1(torch.cat([d1, e1], dim=1))
        return self.final(d1)  
    

%pip install torchinfo
import torch
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
sampleModel = VNet(in_channels=1, out_channels=1).to(device)

# N, C, D, H, W = 1, 1, 64, 64, 96
summary(sampleModel, input_size=(1, 1, 64, 64, 96), device=str(device))





[notice] A new release of pip is available: 25.1.1 -> 25.2
[notice] To update, run: python.exe -m pip install --upgrade pip


Layer (type:depth-idx)                   Output Shape              Param #
VNet                                     [1, 1, 64, 64, 96]        --
├─VNetBlock: 1-1                         [1, 16, 64, 64, 96]       --
│    └─Conv3d: 2-1                       [1, 16, 64, 64, 96]       448
│    └─InstanceNorm3d: 2-2               [1, 16, 64, 64, 96]       --
│    └─ReLU: 2-3                         [1, 16, 64, 64, 96]       --
│    └─Conv3d: 2-4                       [1, 16, 64, 64, 96]       6,928
│    └─InstanceNorm3d: 2-5               [1, 16, 64, 64, 96]       --
│    └─ReLU: 2-6                         [1, 16, 64, 64, 96]       --
├─MaxPool3d: 1-2                         [1, 16, 32, 32, 48]       --
├─VNetBlock: 1-3                         [1, 32, 32, 32, 48]       --
│    └─Conv3d: 2-7                       [1, 32, 32, 32, 48]       13,856
│    └─InstanceNorm3d: 2-8               [1, 32, 32, 32, 48]       --
│    └─ReLU: 2-9                         [1, 32, 32, 32, 48]       --
│    └─

##### Resuable functions 

In [5]:
import torch
import numpy as np
from scipy.spatial.distance import directed_hausdorff
from scipy.ndimage import morphology

def under_segmentation_rate(y_true, y_pred, smooth=1):
    y_pred_labels = (y_pred >= 0.5).float()
    y_true_labels = (y_true >= 0.5).float()
    false_neg = torch.sum(y_true_labels * (1 - y_pred_labels))
    ground_truth_area = torch.sum(y_true_labels)
    prediction_area = torch.sum(y_pred_labels)
    return (false_neg + smooth) / (ground_truth_area + prediction_area + smooth)

def over_segmentation_ratio(y_true, y_pred, smooth=1e-5):
    y_pred_labels = (y_pred >= 0.5).float()
    y_true_labels = (y_true >= 0.5).float()
    false_pos = torch.sum((y_pred_labels == 1) & (y_true_labels == 0))
    ground_truth_area = torch.sum(y_true_labels)
    prediction_area = torch.sum(y_pred_labels)
    return (false_pos + smooth) / (ground_truth_area + prediction_area + smooth)

def average_surface_distance(y_true, y_pred, spacing=None):
    y_true_np = y_true.detach().cpu().numpy().astype(np.bool_)
    y_pred_np = y_pred.detach().cpu().numpy().astype(np.bool_)
    conn = morphology.generate_binary_structure(y_true_np.ndim, 1)
    true_surface = np.logical_xor(y_true_np, morphology.binary_erosion(y_true_np, conn))
    pred_surface = np.logical_xor(y_pred_np, morphology.binary_erosion(y_pred_np, conn))
    if np.sum(true_surface) == 0 or np.sum(pred_surface) == 0:
        return float('inf')
    true_distances = morphology.distance_transform_edt(~true_surface, sampling=spacing)
    pred_distances = morphology.distance_transform_edt(~pred_surface, sampling=spacing)
    dist_pred_to_true = np.mean(true_distances[pred_surface != 0])
    dist_true_to_pred = np.mean(pred_distances[true_surface != 0])
    return (dist_pred_to_true + dist_true_to_pred) / 2.0

def hausdorff_distance(y_true, y_pred):
    y_true_labels = (y_true >= 0.5).cpu().numpy().astype(np.bool_)
    y_pred_labels = (y_pred >= 0.5).cpu().numpy().astype(np.bool_)
    distances = []
    for slice_idx in range(y_true_labels.shape[0]):
        ts = y_true_labels[slice_idx, :, :]
        ps = y_pred_labels[slice_idx, :, :]
        if np.sum(ts) == 0 or np.sum(ps) == 0:
            continue
        true_pts = np.argwhere(ts)
        pred_pts = np.argwhere(ps)
        hd_f = directed_hausdorff(true_pts, pred_pts)[0]
        hd_b = directed_hausdorff(pred_pts, true_pts)[0]
        distances.append(max(hd_f, hd_b))
    return np.mean(distances) if distances else float('inf')

def dice_coefficient(y_true, y_pred, smooth=1):
    y_pred_labels = (y_pred >= 0.5).float()
    inter = torch.sum(y_true * y_pred_labels)
    return (2 * inter + smooth) / (torch.sum(y_true) + torch.sum(y_pred_labels) + smooth)

def jaccard_index(y_true, y_pred, smooth=1):
    y_pred_labels = (y_pred >= 0.5).float()
    inter = torch.sum(y_true * y_pred_labels)
    union = torch.sum(y_true) + torch.sum(y_pred_labels) - inter
    return (inter + smooth) / (union + smooth)

def mean_iou(y_true, y_pred, smooth=1):
    return jaccard_index(y_true, y_pred, smooth)

def f1_score(y_true, y_pred, smooth=1):
    y_pred_labels = (y_pred >= 0.5).float()
    tp = torch.sum(y_true * y_pred_labels)
    precision = tp / (torch.sum(y_pred_labels) + smooth)
    recall = tp / (torch.sum(y_true) + smooth)
    return (2 * precision * recall + smooth) / (precision + recall + smooth)

def recall_score(y_true, y_pred, smooth=1):
    y_pred_labels = (y_pred >= 0.5).float()
    tp = torch.sum(y_true * y_pred_labels)
    return (tp + smooth) / (torch.sum(y_true) + smooth)

def calculate_accuracy(y_true, y_pred):
    y_pred_labels = (y_pred >= 0.5).float()
    return (y_pred_labels == y_true).float().sum() / y_true.numel()

# Order chosen to ensure desired printed metrics appear first
METRIC_FUNCS = {
    "dice": dice_coefficient,
    "f1": f1_score,
    "recall": recall_score,
    "jaccard": jaccard_index,
    "hausdorff": hausdorff_distance,
    "usr": under_segmentation_rate,
    "osr": over_segmentation_ratio,
    "asd": average_surface_distance,
    "accuracy": calculate_accuracy,
    "mean_iou": mean_iou
}

class MetricAggregator:
    def __init__(self, metric_funcs=None):
        self.metric_funcs = metric_funcs or METRIC_FUNCS
        self.reset()
    def reset(self):
        self.storage = {k: [] for k in self.metric_funcs.keys()}
    def update(self, y_true, y_pred):
        if y_true.dim() == 5:
            for b in range(y_true.size(0)):
                self._update_single(y_true[b, 0], y_pred[b, 0])
        elif y_true.dim() == 4:
            for b in range(y_true.size(0)):
                self._update_single(y_true[b], y_pred[b])
        else:
            self._update_single(y_true, y_pred)
    def _update_single(self, yt, yp):
        for name, fn in self.metric_funcs.items():
            try:
                val = fn(yt, yp)
                if isinstance(val, torch.Tensor):
                    val = val.item()
            except Exception:
                val = float('nan')
            self.storage[name].append(val)
    def compute(self):
        out = {}
        for k, arr in self.storage.items():
            arr = [a for a in arr if a == a]
            out[k] = float(np.mean(arr)) if len(arr) else 0.0
        return out


##### data preparation, splitting, saving , checkpoints

In [8]:
import os
import torch
from torch.utils.data import Dataset, DataLoader, random_split
import nibabel as nib
import numpy as np
from typing import Tuple, Dict

DATA_CONFIG = {
    "HARP_NP": {
        "img_dir": r"D:\Faizaan\AlzhimersData\ADNI\HippoCampus_labels_data\Task520_HarP\imageTr",
        "lbl_dir": r"D:\Faizaan\AlzhimersData\ADNI\HippoCampus_labels_data\Task520_HarP\labelTr"
    },
    "HARP_FP": {
        "img_dir": r"D:\Faizaan\AlzhimersData\ADNI\HippoCampus_labels_data\Task520_Harp_Preprocessed",
        "lbl_dir": r"D:\Faizaan\AlzhimersData\ADNI\HippoCampus_labels_data\Task520_HarP\labelTr"
    }
}

class HarPDataset(Dataset):
    def __init__(self, img_dir, label_dir, normalize=True):
        self.img_paths = sorted([
            os.path.join(img_dir, f) for f in os.listdir(img_dir)
            if f.lower().endswith((".nii", ".nii.gz"))
        ])
        self.lbl_paths = sorted([
            os.path.join(label_dir, f) for f in os.listdir(label_dir)
            if f.lower().endswith((".nii", ".nii.gz"))
        ])
        assert len(self.img_paths) == len(self.lbl_paths), "Image/label count mismatch"
        self.normalize = normalize

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        img = nib.load(self.img_paths[idx]).get_fdata().astype(np.float32)
        lbl = nib.load(self.lbl_paths[idx]).get_fdata().astype(np.float32)

        if self.normalize:
            m, M = img.min(), img.max()
            if M > m:
                img = (img - m) / (M - m)

        # Binarize label (in case of multi-values)
        lbl = (lbl > 0).astype(np.float32)

        img = np.expand_dims(img, 0)
        lbl = np.expand_dims(lbl, 0)

        return torch.from_numpy(img), torch.from_numpy(lbl)

def get_dataloaders(dataset_key: str,
                    batch_size: int = 2,
                    val_ratio: float = 0.1,
                    test_ratio: float = 0.1,
                    seed: int = 42) -> Tuple[DataLoader, DataLoader, DataLoader]:
    cfg = DATA_CONFIG[dataset_key]
    ds = HarPDataset(cfg["img_dir"], cfg["lbl_dir"])
    n = len(ds)
    val_size = int(n * val_ratio)
    test_size = int(n * test_ratio)
    train_size = n - val_size - test_size
    torch.manual_seed(seed)
    train_ds, val_ds, test_ds = random_split(ds, [train_size, val_size, test_size])
    return (
        DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=0),
        DataLoader(val_ds, batch_size=1, shuffle=False, num_workers=0),
        DataLoader(test_ds, batch_size=1, shuffle=False, num_workers=0),
    )

In [9]:
import os, json, time
from pathlib import Path
import torch
import torch.nn as nn
from torch.cuda.amp import autocast, GradScaler



def set_seed(seed=42):
    import random, numpy as np
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
set_seed(42)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

DATASET_KEY = "HARP_NP"  # change to "HARP_FP" later for preprocessed
MODEL_TAG = "NP" if DATASET_KEY.endswith("NP") else "FP"
MODEL_NAME_BASE = f"VNet_{MODEL_TAG}"
OUTPUT_ROOT = Path("VNETModels") / DATASET_KEY
OUTPUT_ROOT.mkdir(parents=True, exist_ok=True)
(OUTPUT_ROOT / "checkpoints").mkdir(exist_ok=True)
CONFIG = {
    "dataset": DATASET_KEY,
    "model_tag": MODEL_TAG,
    "model_name": MODEL_NAME_BASE,
    "batch_size": 2,
    "epochs": 250,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "loss": "BCE",
    "threshold": 0.5,
    "early_stopping_patience":7,
    "mixed_precision": True
}
with open(OUTPUT_ROOT / "config.json", "w") as f:
    json.dump(CONFIG, f, indent=2)

Device: cuda


In [10]:
def get_loss(name="BCE"):
    if name.upper() == "BCE":
        return nn.BCEWithLogitsLoss()
    elif name.upper() == "DICEBCE":
        class DiceBCELoss(nn.Module):
            def __init__(self):
                super().__init__()
                self.bce = nn.BCEWithLogitsLoss()
            def forward(self, pred_logits, target, smooth=1.):
                bce = self.bce(pred_logits, target)
                probs = torch.sigmoid(pred_logits)
                p = probs.view(-1)
                t = target.view(-1)
                inter = (p * t).sum()
                dice = (2*inter + smooth)/(p.sum() + t.sum() + smooth)
                return bce + (1 - dice)
        return DiceBCELoss()
    else:
        raise ValueError("Unsupported loss")

        
def save_checkpoint(state, is_best=False):
    if not is_best:
        return  # only act when a new best is found
    tag = CONFIG["model_tag"]
    best_path = OUTPUT_ROOT / f"best_model_{tag}.pth"
    torch.save(state, best_path)
    (OUTPUT_ROOT / f"trained_{tag}.flag").write_text(str(state["epoch"]))

def save_history(history_rows, header, path):
    import csv
    write_header = not path.exists()
    with open(path, "a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow(header)
        w.writerow(history_rows)

def binarize(pred, thr):
    return (pred >= thr).float()

##### training, testing (dataset key was HARP_NP)

In [11]:
train_loader, val_loader, test_loader = get_dataloaders(
    CONFIG["dataset"],
    batch_size=CONFIG["batch_size"]
)

model = VNet(in_channels=1, out_channels=1).to(device)
criterion = get_loss(CONFIG["loss"])
optimizer = torch.optim.Adam(model.parameters(), lr=CONFIG["lr"], weight_decay=CONFIG["weight_decay"])
scaler = GradScaler(enabled=CONFIG["mixed_precision"])

print("Train batches:", len(train_loader), "Val:", len(val_loader), "Test:", len(test_loader))

Train batches: 55 Val: 13 Test: 13


In [12]:
#NP training
history_csv = OUTPUT_ROOT / "history.csv"
best_model_path = OUTPUT_ROOT / f"best_model_{CONFIG['model_tag']}.pth"

if best_model_path.exists():
    print(f"Detected existing model {best_model_path.name}. Skip training (delete it to retrain).")
else:
    best_dice = -1
    epochs_no_improve = 0
    metric_names = list(METRIC_FUNCS.keys())
    display_metrics = ["dice","f1","recall","jaccard","hausdorff","usr","osr"]  # keeping original order

    for epoch in range(1, CONFIG["epochs"] + 1):
        model.train()
        train_loss = 0.0
        train_acc_sum = 0.0  # added
        start_time = time.time()
        for imgs, lbls in train_loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer.zero_grad(set_to_none=True)
            with autocast(enabled=CONFIG["mixed_precision"]):
                logits = model(imgs)
                loss = criterion(logits, lbls)
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
            # accumulate loss
            train_loss += loss.item() * imgs.size(0)
            # accumulate training accuracy (thresholded at 0.5 inside calculate_accuracy)
            with torch.no_grad():
                probs_tr = torch.sigmoid(logits)
                batch_acc = calculate_accuracy(lbls, probs_tr)
            train_acc_sum += batch_acc.item() * imgs.size(0)
        train_loss /= len(train_loader.dataset)
        train_acc = train_acc_sum / len(train_loader.dataset)  # added

        model.eval()
        val_loss = 0.0
        agg = MetricAggregator()
        with torch.no_grad():
            for imgs, lbls in val_loader:
                imgs, lbls = imgs.to(device), lbls.to(device)
                logits = model(imgs)
                loss = criterion(logits, lbls)
                val_loss += loss.item() * imgs.size(0)
                probs = torch.sigmoid(logits)
                agg.update(lbls, probs)
        val_loss /= len(val_loader.dataset)
        val_metrics = agg.compute()
        epoch_dice = val_metrics["dice"]
        val_acc = val_metrics["accuracy"]  # added

        elapsed = time.time() - start_time
        is_best = epoch_dice > best_dice
        if is_best:
            best_dice = epoch_dice
            epochs_no_improve = 0
        else:
            epochs_no_improve += 1

        save_checkpoint({
            "epoch": epoch,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_dice": epoch_dice,
            "config": CONFIG
        }, is_best=is_best)

        row = [epoch, train_loss, val_loss, *[val_metrics[m] for m in metric_names], best_dice, elapsed, train_acc, val_acc]
        header = ["epoch","train_loss","val_loss", *metric_names,"best_dice","seconds","train_accuracy","val_accuracy"]
        save_history(row, header, history_csv)

        disp = " | ".join([f"{m} {val_metrics[m]:.4f}" for m in display_metrics])
        print(f"Epoch {epoch:03d} | TrainLoss {train_loss:.4f} | ValLoss {val_loss:.4f} | "
              f"TrainAcc {train_acc:.4f} | ValAcc {val_acc:.4f} | {disp} | "
              f"BestDice {best_dice:.4f} | Pat {epochs_no_improve} | {elapsed:.1f}s")

        if epochs_no_improve >= CONFIG["early_stopping_patience"]:
            print("Early stopping triggered.")
            break

  conn = morphology.generate_binary_structure(y_true_np.ndim, 1)
  true_surface = np.logical_xor(y_true_np, morphology.binary_erosion(y_true_np, conn))
  pred_surface = np.logical_xor(y_pred_np, morphology.binary_erosion(y_pred_np, conn))
  true_distances = morphology.distance_transform_edt(~true_surface, sampling=spacing)
  pred_distances = morphology.distance_transform_edt(~pred_surface, sampling=spacing)


Epoch 001 | TrainLoss 0.5740 | ValLoss 0.4799 | TrainAcc 0.8290 | ValAcc 0.9731 | dice 0.4590 | f1 0.7064 | recall 0.8758 | jaccard 0.2989 | hausdorff 40.3212 | usr 0.0350 | osr 0.5062 | BestDice 0.4590 | Pat 0 | 21.6s
Epoch 002 | TrainLoss 0.4247 | ValLoss 0.3782 | TrainAcc 0.9853 | ValAcc 0.9923 | dice 0.7408 | f1 0.8451 | recall 0.8620 | jaccard 0.5916 | hausdorff 14.0804 | usr 0.0595 | osr 0.1998 | BestDice 0.7408 | Pat 0 | 18.2s
Epoch 003 | TrainLoss 0.3368 | ValLoss 0.2942 | TrainAcc 0.9926 | ValAcc 0.9908 | dice 0.7245 | f1 0.8339 | recall 0.9383 | jaccard 0.5707 | hausdorff 7.3961 | usr 0.0241 | osr 0.2515 | BestDice 0.7408 | Pat 1 | 18.5s
Epoch 004 | TrainLoss 0.2565 | ValLoss 0.2214 | TrainAcc 0.9933 | ValAcc 0.9953 | dice 0.8196 | f1 0.8880 | recall 0.8128 | jaccard 0.6952 | hausdorff 7.0300 | usr 0.0949 | osr 0.0856 | BestDice 0.8196 | Pat 0 | 19.0s
Epoch 005 | TrainLoss 0.1982 | ValLoss 0.1751 | TrainAcc 0.9944 | ValAcc 0.9953 | dice 0.8292 | f1 0.8934 | recall 0.8898 | ja

In [13]:
# --- Evaluate best saved model on validation and test sets (prints metrics table) ---
from collections import OrderedDict

def evaluate_loader(model, loader, criterion, metric_funcs):
    model.eval()
    agg = MetricAggregator(metric_funcs)
    total_loss = 0.0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            logits = model(imgs)
            loss = criterion(logits, lbls)
            total_loss += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits)
            agg.update(lbls, probs)
    avg_loss = total_loss / len(loader.dataset)
    metrics = agg.compute()
    metrics = OrderedDict(metrics)  # preserve order
    metrics["loss"] = avg_loss
    return metrics

best_model_path = OUTPUT_ROOT / f"best_model_{CONFIG['model_tag']}.pth"
assert best_model_path.exists(), f"Best model not found: {best_model_path}"

ckpt = torch.load(best_model_path, map_location=device)
model.load_state_dict(ckpt["model_state"])
print(f"Loaded best model ({best_model_path.name}) from epoch {ckpt.get('epoch')} (val_dice={ckpt.get('val_dice'):.4f})")

metric_names = list(METRIC_FUNCS.keys())  # ordered as defined
val_metrics  = evaluate_loader(model, val_loader,  criterion, METRIC_FUNCS)
test_metrics = evaluate_loader(model, test_loader, criterion, METRIC_FUNCS)

# Build table rows (ensure same column order)
columns = metric_names + ["loss"]
def row_str(split, metrics):
    return [split] + [f"{metrics[m]:.4f}" for m in metric_names] + [f"{metrics['loss']:.4f}"]

header = ["split"] + columns
rows = [
    row_str("validation", val_metrics),
    row_str("test",       test_metrics),
]

# Pretty print table
col_widths = [max(len(header[i]), *(len(r[i]) for r in rows)) for i in range(len(header))]
def fmt_line(cells):
    return " | ".join(c.ljust(col_widths[i]) for i,c in enumerate(cells))
sep = "-+-".join("-"*w for w in col_widths)

print("\nEvaluation Metrics (best model)")
print(fmt_line(header))
print(sep)
for r in rows:
    print(fmt_line(r))

# Optionally save to CSV
summary_csv = OUTPUT_ROOT / "eval_summary.csv"
import csv
write_header = not summary_csv.exists()
with open(summary_csv, "a", newline="") as f:
    w = csv.writer(f)
    if write_header:
        w.writerow(header)
    for r in rows:
        w.writerow(r)

print(f"\nSaved evaluation summary to {summary_csv}")

Loaded best model (best_model_NP.pth) from epoch 23 (val_dice=0.8848)


  conn = morphology.generate_binary_structure(y_true_np.ndim, 1)
  true_surface = np.logical_xor(y_true_np, morphology.binary_erosion(y_true_np, conn))
  pred_surface = np.logical_xor(y_pred_np, morphology.binary_erosion(y_pred_np, conn))
  true_distances = morphology.distance_transform_edt(~true_surface, sampling=spacing)
  pred_distances = morphology.distance_transform_edt(~pred_surface, sampling=spacing)



Evaluation Metrics (best model)
split      | dice   | f1     | recall | jaccard | hausdorff | usr    | osr    | asd     | accuracy | mean_iou | loss  
-----------+--------+--------+--------+---------+-----------+--------+--------+---------+----------+----------+-------
validation | 0.8848 | 0.9264 | 0.8832 | 0.7939  | 4.1683    | 0.0589 | 0.0564 | 22.1161 | 0.9970   | 0.7939   | 0.0257
test       | 0.8833 | 0.9254 | 0.8948 | 0.7914  | 4.3425    | 0.0523 | 0.0645 | 22.4111 | 0.9972   | 0.7914   | 0.0254

Saved evaluation summary to VNETModels\HARP_NP\eval_summary.csv


##### training, testing (dataset key = HARP_FP)

In [14]:
# --- FP dataset training & evaluation 
from pathlib import Path
# Configure FP run
DATASET_KEY_FP = "HARP_FP"
MODEL_TAG_FP = "FP"
CONFIG_FP = {
    "dataset": DATASET_KEY_FP,
    "model_tag": MODEL_TAG_FP,
    "model_name": f"VNet_{MODEL_TAG_FP}",
    "batch_size": 2,
    "epochs": 250,
    "lr": 1e-3,
    "weight_decay": 1e-4,
    "loss": "BCE",
    "threshold": 0.5,
    "early_stopping_patience": 7,
    "mixed_precision": True
}

OUTPUT_ROOT_FP = Path("VNETModels") / DATASET_KEY_FP
OUTPUT_ROOT_FP.mkdir(parents=True, exist_ok=True)
(OUTPUT_ROOT_FP / "checkpoints").mkdir(exist_ok=True)

with open(OUTPUT_ROOT_FP / "config.json", "w") as f:
    import json; json.dump(CONFIG_FP, f, indent=2)

# Dataloaders
train_loader_fp, val_loader_fp, test_loader_fp = get_dataloaders(
    CONFIG_FP["dataset"],
    batch_size=CONFIG_FP["batch_size"]
)

# Model / optim / loss
model_fp = VNet(in_channels=1, out_channels=1).to(device)
criterion_fp = get_loss(CONFIG_FP["loss"])
optimizer_fp = torch.optim.Adam(model_fp.parameters(),
                                lr=CONFIG_FP["lr"],
                                weight_decay=CONFIG_FP["weight_decay"])
scaler_fp = GradScaler(enabled=CONFIG_FP["mixed_precision"])

print("FP Train batches:", len(train_loader_fp),
      "Val:", len(val_loader_fp),
      "Test:", len(test_loader_fp))

# Checkpoint helper (only best)
def save_checkpoint_fp(state, is_best=False):
    if not is_best:
        return
    best_path = OUTPUT_ROOT_FP / f"best_model_{CONFIG_FP['model_tag']}.pth"
    torch.save(state, best_path)
    (OUTPUT_ROOT_FP / f"trained_{CONFIG_FP['model_tag']}.flag").write_text(str(state["epoch"]))

# History helper
def save_history_fp(history_rows, header, path):
    import csv
    write_header = not path.exists()
    with open(path, "a", newline="") as f:
        w = csv.writer(f)
        if write_header:
            w.writerow(header)
        w.writerow(history_rows)

history_csv_fp = OUTPUT_ROOT_FP / "history.csv"
best_model_path_fp = OUTPUT_ROOT_FP / f"best_model_{CONFIG_FP['model_tag']}.pth"

# Training
if best_model_path_fp.exists():
    print(f"Detected existing FP model {best_model_path_fp.name}. Skip training (delete it to retrain).")
else:
    best_dice_fp = -1
    epochs_no_improve_fp = 0
    metric_names_fp = list(METRIC_FUNCS.keys())
    display_metrics_fp = ["dice","f1","recall","jaccard","hausdorff","usr","osr"]

    for epoch in range(1, CONFIG_FP["epochs"] + 1):
        model_fp.train()
        train_loss_fp = 0.0
        train_acc_sum_fp = 0.0
        start_time_fp = time.time()

        for imgs, lbls in train_loader_fp:
            imgs, lbls = imgs.to(device), lbls.to(device)
            optimizer_fp.zero_grad(set_to_none=True)
            with autocast(enabled=CONFIG_FP["mixed_precision"]):
                logits_fp = model_fp(imgs)
                loss_fp = criterion_fp(logits_fp, lbls)
            scaler_fp.scale(loss_fp).backward()
            scaler_fp.step(optimizer_fp)
            scaler_fp.update()

            train_loss_fp += loss_fp.item() * imgs.size(0)
            with torch.no_grad():
                probs_tr_fp = torch.sigmoid(logits_fp)
                batch_acc_fp = calculate_accuracy(lbls, probs_tr_fp)
            train_acc_sum_fp += batch_acc_fp.item() * imgs.size(0)

        train_loss_fp /= len(train_loader_fp.dataset)
        train_acc_fp = train_acc_sum_fp / len(train_loader_fp.dataset)

        # Validation
        model_fp.eval()
        val_loss_fp = 0.0
        agg_fp = MetricAggregator()
        with torch.no_grad():
            for imgs, lbls in val_loader_fp:
                imgs, lbls = imgs.to(device), lbls.to(device)
                logits_fp = model_fp(imgs)
                loss_fp = criterion_fp(logits_fp, lbls)
                val_loss_fp += loss_fp.item() * imgs.size(0)
                probs_fp = torch.sigmoid(logits_fp)
                agg_fp.update(lbls, probs_fp)

        val_loss_fp /= len(val_loader_fp.dataset)
        val_metrics_fp = agg_fp.compute()
        epoch_dice_fp = val_metrics_fp["dice"]
        val_acc_fp = val_metrics_fp["accuracy"]

        elapsed_fp = time.time() - start_time_fp
        is_best_fp = epoch_dice_fp > best_dice_fp
        if is_best_fp:
            best_dice_fp = epoch_dice_fp
            epochs_no_improve_fp = 0
        else:
            epochs_no_improve_fp += 1

        save_checkpoint_fp({
            "epoch": epoch,
            "model_state": model_fp.state_dict(),
            "optimizer_state": optimizer_fp.state_dict(),
            "val_dice": epoch_dice_fp,
            "config": CONFIG_FP
        }, is_best=is_best_fp)

        row_fp = [epoch, train_loss_fp, val_loss_fp,
                  *[val_metrics_fp[m] for m in metric_names_fp],
                  best_dice_fp, elapsed_fp, train_acc_fp, val_acc_fp]
        header_fp = ["epoch","train_loss","val_loss", *metric_names_fp,
                     "best_dice","seconds","train_accuracy","val_accuracy"]
        save_history_fp(row_fp, header_fp, history_csv_fp)

        disp_fp = " | ".join([f"{m} {val_metrics_fp[m]:.4f}" for m in display_metrics_fp])
        print(f"[FP] Epoch {epoch:03d} | TrainLoss {train_loss_fp:.4f} | ValLoss {val_loss_fp:.4f} | "
              f"TrainAcc {train_acc_fp:.4f} | ValAcc {val_acc_fp:.4f} | {disp_fp} | "
              f"BestDice {best_dice_fp:.4f} | Pat {epochs_no_improve_fp} | {elapsed_fp:.1f}s")

        if epochs_no_improve_fp >= CONFIG_FP["early_stopping_patience"]:
            print("[FP] Early stopping triggered.")
            break


FP Train batches: 55 Val: 13 Test: 13


  conn = morphology.generate_binary_structure(y_true_np.ndim, 1)
  true_surface = np.logical_xor(y_true_np, morphology.binary_erosion(y_true_np, conn))
  pred_surface = np.logical_xor(y_pred_np, morphology.binary_erosion(y_pred_np, conn))
  true_distances = morphology.distance_transform_edt(~true_surface, sampling=spacing)
  pred_distances = morphology.distance_transform_edt(~pred_surface, sampling=spacing)


[FP] Epoch 001 | TrainLoss 0.5775 | ValLoss 0.4790 | TrainAcc 0.8092 | ValAcc 0.9727 | dice 0.4688 | f1 0.7065 | recall 0.9223 | jaccard 0.3077 | hausdorff 41.0587 | usr 0.0205 | osr 0.5108 | BestDice 0.4688 | Pat 0 | 25.9s
[FP] Epoch 002 | TrainLoss 0.4263 | ValLoss 0.3798 | TrainAcc 0.9856 | ValAcc 0.9923 | dice 0.7294 | f1 0.8398 | recall 0.7988 | jaccard 0.5775 | hausdorff 9.2206 | usr 0.0920 | osr 0.1786 | BestDice 0.7294 | Pat 0 | 18.6s
[FP] Epoch 003 | TrainLoss 0.3430 | ValLoss 0.3066 | TrainAcc 0.9910 | ValAcc 0.9920 | dice 0.7336 | f1 0.8405 | recall 0.8489 | jaccard 0.5812 | hausdorff 8.3706 | usr 0.0667 | osr 0.1997 | BestDice 0.7336 | Pat 0 | 18.6s
[FP] Epoch 004 | TrainLoss 0.2728 | ValLoss 0.2394 | TrainAcc 0.9922 | ValAcc 0.9942 | dice 0.7775 | f1 0.8648 | recall 0.7780 | jaccard 0.6376 | hausdorff 7.5145 | usr 0.1119 | osr 0.1107 | BestDice 0.7775 | Pat 0 | 21.5s
[FP] Epoch 005 | TrainLoss 0.2141 | ValLoss 0.1890 | TrainAcc 0.9933 | ValAcc 0.9947 | dice 0.8033 | f1 0.8

In [15]:

# --- Evaluation (FP) ---
from collections import OrderedDict

def evaluate_loader_fp(model_eval, loader, criterion, metric_funcs):
    model_eval.eval()
    agg_eval = MetricAggregator(metric_funcs)
    total_loss_eval = 0.0
    with torch.no_grad():
        for imgs, lbls in loader:
            imgs, lbls = imgs.to(device), lbls.to(device)
            logits = model_eval(imgs)
            loss = criterion(logits, lbls)
            total_loss_eval += loss.item() * imgs.size(0)
            probs = torch.sigmoid(logits)
            agg_eval.update(lbls, probs)
    avg_loss_eval = total_loss_eval / len(loader.dataset)
    metrics_eval = agg_eval.compute()
    metrics_eval = OrderedDict(metrics_eval)
    metrics_eval["loss"] = avg_loss_eval
    return metrics_eval

assert best_model_path_fp.exists(), "Best FP model not found (train stage may have been skipped unexpectedly)."
ckpt_fp = torch.load(best_model_path_fp, map_location=device)
model_fp.load_state_dict(ckpt_fp["model_state"])
print(f"\n[FP] Loaded best model epoch {ckpt_fp.get('epoch')} (val_dice={ckpt_fp.get('val_dice'):.4f})")

metric_names_fp = list(METRIC_FUNCS.keys())
val_metrics_fp  = evaluate_loader_fp(model_fp, val_loader_fp,  criterion_fp, METRIC_FUNCS)
test_metrics_fp = evaluate_loader_fp(model_fp, test_loader_fp, criterion_fp, METRIC_FUNCS)

columns_fp = metric_names_fp + ["loss"]
def row_fp_str(split, mets):
    return [split] + [f"{mets[m]:.4f}" for m in metric_names_fp] + [f"{mets['loss']:.4f}"]

header_fp_tbl = ["split"] + columns_fp
rows_fp = [
    row_fp_str("validation_FP", val_metrics_fp),
    row_fp_str("test_FP",       test_metrics_fp),
]

widths_fp = [max(len(header_fp_tbl[i]), *(len(r[i]) for r in rows_fp)) for i in range(len(header_fp_tbl))]
def fmt_fp(cells): return " | ".join(c.ljust(widths_fp[i]) for i,c in enumerate(cells))
sep_fp = "-+-".join("-"*w for w in widths_fp)

print("\n[FP] Evaluation Metrics (best model)")
print(fmt_fp(header_fp_tbl))
print(sep_fp)
for r in rows_fp:
    print(fmt_fp(r))

# Save summary
summary_fp_csv = OUTPUT_ROOT_FP / "eval_summary.csv"
import csv
write_header_fp = not summary_fp_csv.exists()
with open(summary_fp_csv, "a", newline="") as f:
    w = csv.writer(f)
    if write_header_fp:
        w.writerow(header_fp_tbl)
    for r in rows_fp:
        w.writerow(r)
print(f"\n[FP] Saved evaluation summary to {summary_fp_csv}")


[FP] Loaded best model epoch 20 (val_dice=0.8696)


  conn = morphology.generate_binary_structure(y_true_np.ndim, 1)
  true_surface = np.logical_xor(y_true_np, morphology.binary_erosion(y_true_np, conn))
  pred_surface = np.logical_xor(y_pred_np, morphology.binary_erosion(y_pred_np, conn))
  true_distances = morphology.distance_transform_edt(~true_surface, sampling=spacing)
  pred_distances = morphology.distance_transform_edt(~pred_surface, sampling=spacing)



[FP] Evaluation Metrics (best model)
split         | dice   | f1     | recall | jaccard | hausdorff | usr    | osr    | asd     | accuracy | mean_iou | loss  
--------------+--------+--------+--------+---------+-----------+--------+--------+---------+----------+----------+-------
validation_FP | 0.8696 | 0.9172 | 0.8875 | 0.7702  | 4.3783    | 0.0555 | 0.0750 | 22.1161 | 0.9966   | 0.7702   | 0.0295
test_FP       | 0.8683 | 0.9163 | 0.8944 | 0.7676  | 4.6690    | 0.0515 | 0.0803 | 22.4111 | 0.9968   | 0.7676   | 0.0291

[FP] Saved evaluation summary to VNETModels\HARP_FP\eval_summary.csv
