In [1]:
import os
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import RefineFormer3D
from dataset import BraTSDataset
from losses import RefineFormer3DLoss
from metrics import dice_coefficient, iou_score, hausdorff_distance
from config import DEVICE, IN_CHANNELS, NUM_CLASSES
from augmentation import Compose3D  # Optional, not used here

def evaluate(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    dice_all, iou_all, hd_all = [], [], []
    total_correct, total_voxels = 0, 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating", leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device, dtype=torch.long)

            outputs = model(inputs)
            if isinstance(outputs, dict):  # If model returns dict
                outputs = outputs['main']

            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs, dim=1)

            # Accuracy
            total_correct += (preds == targets).sum().item()
            total_voxels += torch.numel(targets)

            # Per-sample metrics
            for b in range(preds.shape[0]):
                dice_all.append(dice_coefficient(preds[b], targets[b], num_classes))
                iou_all.append(iou_score(preds[b], targets[b], num_classes))
                hd_all.append(hausdorff_distance(preds[b], targets[b]))

    avg_loss = running_loss / len(dataloader.dataset)
    avg_dice = torch.mean(torch.stack(dice_all)).item()
    avg_iou = torch.mean(torch.stack(iou_all)).item()
    avg_hd = torch.mean(torch.stack(hd_all)).item()
    accuracy = total_correct / total_voxels

    return avg_loss, avg_dice, avg_iou, avg_hd, accuracy


def main():
    print("📁 Loading validation dataset...")
    val_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None,
    )
    print(f"✅ Total validation samples: {len(val_dataset)}")

    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)

    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)

    checkpoint_path = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/my implementations/segformer3d_upgraded/crashed_model.pt"
    model.load_state_dict(torch.load(checkpoint_path))
    print(f"✅ Loaded model from: {checkpoint_path}")

    criterion = RefineFormer3DLoss()

    print("\n🚀 Running evaluation...")
    loss, dice, iou, hd, acc = evaluate(model, val_loader, criterion, DEVICE, NUM_CLASSES)

    print("\n📊 Evaluation Results:")
    print(f"Validation Loss: {loss:.4f}")
    print(f"Dice Coefficient: {dice:.4f}")
    print(f"IoU Score: {iou:.4f}")
    print(f"Hausdorff Distance: {hd:.4f}")
    print(f"Voxel Accuracy: {acc:.4f}")


if __name__ == "__main__":
    main()


📁 Loading validation dataset...
✅ Total valid patient directories found: 73
✅ Total validation samples: 73


  model.load_state_dict(torch.load(checkpoint_path))


✅ Loaded model from: /mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/my implementations/segformer3d_upgraded/crashed_model.pt

🚀 Running evaluation...


                                                  

ValueError: operands could not be broadcast together with shapes (3,2) (3,) 

In [3]:
import os
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import RefineFormer3D
from dataset import BraTSDataset
from losses import RefineFormer3DLoss
from metrics import dice_coefficient, iou_score, hausdorff_distance
from config import DEVICE, IN_CHANNELS, NUM_CLASSES

def dice_per_region(preds, targets):
    eps = 1e-5

    def binary_dice(pred_mask, gt_mask):
        intersection = torch.sum(pred_mask * gt_mask)
        union = torch.sum(pred_mask) + torch.sum(gt_mask)
        return (2 * intersection + eps) / (union + eps)

    results = {
        'Dice_WT': [],
        'Dice_TC': [],
        'Dice_ET': []
    }

    for b in range(preds.size(0)):
        pred = preds[b]
        gt = targets[b]

        pred_wt = (pred > 0).float()
        gt_wt = (gt > 0).float()
        results['Dice_WT'].append(binary_dice(pred_wt, gt_wt))

        pred_tc = ((pred == 1) | (pred == 4)).float()
        gt_tc = ((gt == 1) | (gt == 4)).float()
        results['Dice_TC'].append(binary_dice(pred_tc, gt_tc))

        pred_et = (pred == 4).float()
        gt_et = (gt == 4).float()
        results['Dice_ET'].append(binary_dice(pred_et, gt_et))

    return {k: torch.mean(torch.stack(v)).item() for k, v in results.items()}

def evaluate(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    dice_scores, iou_scores, hd_scores = [], [], []
    wt_scores, tc_scores, et_scores = [], [], []
    total_correct = 0
    total_voxels = 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Evaluating", leave=False):
            inputs = inputs.to(device)
            targets = targets.to(device, dtype=torch.long)

            out_dict = model(inputs)
            logits = out_dict["main"] if isinstance(out_dict, dict) else out_dict

            gt_shape = targets.shape[1:]
            logits = F.interpolate(logits, size=gt_shape, mode="trilinear", align_corners=False)

            loss = criterion(logits, targets)
            running_loss += loss.item() * inputs.size(0)

            _, mean_dice = dice_coefficient(logits, targets, num_classes)
            _, mean_iou = iou_score(logits, targets, num_classes)
            mean_hd = hausdorff_distance(logits, targets)

            dice_scores.append(mean_dice)
            iou_scores.append(mean_iou)
            hd_scores.append(mean_hd)

            preds = torch.argmax(logits, dim=1)
            total_correct += (preds == targets).sum().item()
            total_voxels += preds.numel()

            region_dices = dice_per_region(preds, targets)
            wt_scores.append(region_dices['Dice_WT'])
            tc_scores.append(region_dices['Dice_TC'])
            et_scores.append(region_dices['Dice_ET'])

    N = len(dataloader.dataset)
    avg_loss = running_loss / N
    avg_dice = sum(dice_scores) / N
    avg_iou = sum(iou_scores) / N
    avg_hd = sum(hd_scores) / N
    accuracy = total_correct / total_voxels

    avg_wt = sum(wt_scores) / N
    avg_tc = sum(tc_scores) / N
    avg_et = sum(et_scores) / N

    return avg_loss, avg_dice, avg_iou, avg_hd, accuracy, avg_wt, avg_tc, avg_et

def main():
    val_ds = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None
    )
    print(f"✅ Total validation samples: {len(val_ds)}")
    val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)

    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)
    ckpt = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/my implementations/segformer3d_upgraded/crashed_model.pt"
    model.load_state_dict(torch.load(ckpt, map_location=DEVICE))
    print(f"✅ Loaded model from: {ckpt}")

    criterion = RefineFormer3DLoss()

    print("\n🚀 Running evaluation…")
    loss, dice, iou, hd, acc, wt, tc, et = evaluate(model, val_loader, criterion, DEVICE, NUM_CLASSES)

    print("\n📊 Evaluation Results:")
    print(f"Validation Loss      : {loss:.4f}")
    print(f"Mean Dice Coefficient: {dice:.4f}")
    print(f"Mean IoU Score       : {iou:.4f}")
    print(f"Mean Hausdorff Dist. : {hd:.4f}")
    print(f"Voxel Accuracy       : {acc:.4f}")
    print(f"Dice - Whole Tumor   : {wt:.4f}")
    print(f"Dice - Tumor Core    : {tc:.4f}")
    print(f"Dice - Enhancing Tumor: {et:.4f}")

if __name__ == "__main__":
    main()


✅ Total valid patient directories found: 73
✅ Total validation samples: 73


  model.load_state_dict(torch.load(ckpt, map_location=DEVICE))


✅ Loaded model from: /mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/my implementations/segformer3d_upgraded/crashed_model.pt

🚀 Running evaluation…


                                                           


📊 Evaluation Results:
Validation Loss      : 1.4467
Mean Dice Coefficient: 0.3202
Mean IoU Score       : 0.2684
Mean Hausdorff Dist. : 28.5248
Voxel Accuracy       : 0.9558
Dice - Whole Tumor   : 0.3134
Dice - Tumor Core    : 0.2075
Dice - Enhancing Tumor: 0.5068


