In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from model5_upgrade import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from losses import RefineFormer3DLoss
import torch.nn.functional as F
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS

def pad_input_for_windows(x, window_size=(2, 2, 2)):
    # x: [B, C, D, H, W]
    _, _, D, H, W = x.shape
    pad_d = (window_size[0] - D % window_size[0]) % window_size[0]
    pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
    pad_w = (window_size[2] - W % window_size[2]) % window_size[2]

    return F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))  # (W_left, W_right, H_top, H_bottom, D_front, D_back)


def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        inputs = pad_input_for_windows(inputs, window_size=(2, 2, 2))

        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)


        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(inputs)
            if torch.isnan(outputs["main"]).any():
                print(" Model output contains NaNs ")

            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)


def main():
    device = DEVICE
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)

    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler(device='cuda')

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])    

    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=14, pin_memory=True)

    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}")
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

    torch.save(model.state_dict(), "final_model.pt")
    print(" Final model saved as 'final_model.pt'")


if __name__ == "__main__":
    main()



                                                          

KeyboardInterrupt: 

In [None]:
epoch = 90  # or whatever epoch you just finished

torch.save({
    'epoch': epoch,
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scaler_state_dict': scaler.state_dict(),
}, f"checkpoint_epoch_{epoch}.pt")

print(f"✅ Saved checkpoint for epoch {epoch}")


NameError: name 'epoch' is not defined

In [1]:
import torch
from torch.utils.data import DataLoader
from model3 import RefineFormer3D
from dataset import BraTSDataset
from config import IN_CHANNELS, NUM_CLASSES
import numpy as np

DEVICE = torch.device('cpu')

# Dataset & loader
test_dataset = BraTSDataset(
    root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
    transform=None,
)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False, num_workers=14)

# Model
model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)
model.load_state_dict(torch.load("checkpoint_epoch_124.pt", map_location=DEVICE))
model.to(DEVICE)
model.eval()

# Metrics
dice_list, iou_list = [], []
wt_dice_list, tc_dice_list, et_dice_list = [], [], []
total_correct, total_voxels = 0, 0

def dice_score(pred, target):
    intersection = (pred & target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + 1e-5) / (union + 1e-5)

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE, dtype=torch.long)

        outputs = model(inputs)['main']
        preds = torch.argmax(outputs, dim=1)

        # Overall Accuracy
        total_correct += (preds == targets).sum().item()
        total_voxels += torch.numel(preds)

        for i in range(preds.shape[0]):
            pred = preds[i].cpu().numpy()
            gt = targets[i].cpu().numpy()

            # One-hot masks for dice per class
            for c in range(NUM_CLASSES):
                dice = dice_score((pred == c), (gt == c))
                dice_list.append(dice)

            # WT = labels 1, 2, 3
            wt_dice = dice_score((pred > 0), (gt > 0))
            wt_dice_list.append(wt_dice)

            # TC = labels 1 and 3
            tc_dice = dice_score(((pred == 1) | (pred == 3)), ((gt == 1) | (gt == 3)))
            tc_dice_list.append(tc_dice)

            # ET = label 3
            et_dice = dice_score((pred == 3), (gt == 3))
            et_dice_list.append(et_dice)

# Print final results
print("\n📊 Evaluation Results:")
print(f"Avg Dice Score (mean of all classes): {np.mean(dice_list):.4f}")
print(f"WT Dice (Whole Tumor): {np.mean(wt_dice_list)*100:.2f}%")
print(f"TC Dice (Tumor Core): {np.mean(tc_dice_list)*100:.2f}%")
print(f"ET Dice (Enhancing Tumor): {np.mean(et_dice_list)*100:.2f}%")
print(f"Overall Accuracy: {(total_correct / total_voxels)*100:.2f}%")


✅ Total valid patient directories found: 73


  model.load_state_dict(torch.load("checkpoint_epoch_124.pt", map_location=DEVICE))


RuntimeError: Error(s) in loading state_dict for RefineFormer3D:
	Missing key(s) in state_dict: "patch_embed.proj.primary_conv.weight", "patch_embed.proj.cheap_operation.weight", "patch_embed.proj.se.fc.0.weight", "patch_embed.proj.se.fc.0.bias", "patch_embed.proj.se.fc.2.weight", "patch_embed.proj.se.fc.2.bias", "patch_embed.dw.weight", "patch_embed.dw.bias", "patch_embed.norm.weight", "patch_embed.norm.bias", "blocks.0.norm.weight", "blocks.0.norm.bias", "blocks.0.attn.depth.in_proj_weight", "blocks.0.attn.depth.in_proj_bias", "blocks.0.attn.depth.out_proj.weight", "blocks.0.attn.depth.out_proj.bias", "blocks.0.attn.height.in_proj_weight", "blocks.0.attn.height.in_proj_bias", "blocks.0.attn.height.out_proj.weight", "blocks.0.attn.height.out_proj.bias", "blocks.0.attn.width.in_proj_weight", "blocks.0.attn.width.in_proj_bias", "blocks.0.attn.width.out_proj.weight", "blocks.0.attn.width.out_proj.bias", "blocks.1.norm.weight", "blocks.1.norm.bias", "blocks.1.attn.depth.in_proj_weight", "blocks.1.attn.depth.in_proj_bias", "blocks.1.attn.depth.out_proj.weight", "blocks.1.attn.depth.out_proj.bias", "blocks.1.attn.height.in_proj_weight", "blocks.1.attn.height.in_proj_bias", "blocks.1.attn.height.out_proj.weight", "blocks.1.attn.height.out_proj.bias", "blocks.1.attn.width.in_proj_weight", "blocks.1.attn.width.in_proj_bias", "blocks.1.attn.width.out_proj.weight", "blocks.1.attn.width.out_proj.bias", "blocks.2.norm.weight", "blocks.2.norm.bias", "blocks.2.attn.depth.in_proj_weight", "blocks.2.attn.depth.in_proj_bias", "blocks.2.attn.depth.out_proj.weight", "blocks.2.attn.depth.out_proj.bias", "blocks.2.attn.height.in_proj_weight", "blocks.2.attn.height.in_proj_bias", "blocks.2.attn.height.out_proj.weight", "blocks.2.attn.height.out_proj.bias", "blocks.2.attn.width.in_proj_weight", "blocks.2.attn.width.in_proj_bias", "blocks.2.attn.width.out_proj.weight", "blocks.2.attn.width.out_proj.bias", "blocks.3.norm.weight", "blocks.3.norm.bias", "blocks.3.attn.depth.in_proj_weight", "blocks.3.attn.depth.in_proj_bias", "blocks.3.attn.depth.out_proj.weight", "blocks.3.attn.depth.out_proj.bias", "blocks.3.attn.height.in_proj_weight", "blocks.3.attn.height.in_proj_bias", "blocks.3.attn.height.out_proj.weight", "blocks.3.attn.height.out_proj.bias", "blocks.3.attn.width.in_proj_weight", "blocks.3.attn.width.in_proj_bias", "blocks.3.attn.width.out_proj.weight", "blocks.3.attn.width.out_proj.bias", "blocks.4.norm.weight", "blocks.4.norm.bias", "blocks.4.attn.depth.in_proj_weight", "blocks.4.attn.depth.in_proj_bias", "blocks.4.attn.depth.out_proj.weight", "blocks.4.attn.depth.out_proj.bias", "blocks.4.attn.height.in_proj_weight", "blocks.4.attn.height.in_proj_bias", "blocks.4.attn.height.out_proj.weight", "blocks.4.attn.height.out_proj.bias", "blocks.4.attn.width.in_proj_weight", "blocks.4.attn.width.in_proj_bias", "blocks.4.attn.width.out_proj.weight", "blocks.4.attn.width.out_proj.bias", "blocks.5.norm.weight", "blocks.5.norm.bias", "blocks.5.attn.depth.in_proj_weight", "blocks.5.attn.depth.in_proj_bias", "blocks.5.attn.depth.out_proj.weight", "blocks.5.attn.depth.out_proj.bias", "blocks.5.attn.height.in_proj_weight", "blocks.5.attn.height.in_proj_bias", "blocks.5.attn.height.out_proj.weight", "blocks.5.attn.height.out_proj.bias", "blocks.5.attn.width.in_proj_weight", "blocks.5.attn.width.in_proj_bias", "blocks.5.attn.width.out_proj.weight", "blocks.5.attn.width.out_proj.bias", "decoder.conv.primary_conv.weight", "decoder.conv.cheap_operation.weight", "decoder.conv.se.fc.0.weight", "decoder.conv.se.fc.0.bias", "decoder.conv.se.fc.2.weight", "decoder.conv.se.fc.2.bias", "decoder.norm.weight", "decoder.norm.bias", "decoder.head.weight", "decoder.head.bias", "aux_head.weight", "aux_head.bias", "boundary_head.conv.weight", "boundary_head.conv.bias". 
	Unexpected key(s) in state_dict: "encoder.embeds.0.proj.primary_conv.weight", "encoder.embeds.0.proj.cheap_operation.weight", "encoder.embeds.0.norm.weight", "encoder.embeds.0.norm.bias", "encoder.embeds.1.proj.primary_conv.weight", "encoder.embeds.1.proj.cheap_operation.weight", "encoder.embeds.1.norm.weight", "encoder.embeds.1.norm.bias", "encoder.embeds.2.proj.primary_conv.weight", "encoder.embeds.2.proj.cheap_operation.weight", "encoder.embeds.2.norm.weight", "encoder.embeds.2.norm.bias", "encoder.embeds.3.proj.primary_conv.weight", "encoder.embeds.3.proj.cheap_operation.weight", "encoder.embeds.3.norm.weight", "encoder.embeds.3.norm.bias", "encoder.stages.0.0.norm1.weight", "encoder.stages.0.0.norm1.bias", "encoder.stages.0.0.attn.relative_bias", "encoder.stages.0.0.attn.rel_idx", "encoder.stages.0.0.attn.qkv.weight", "encoder.stages.0.0.attn.qkv.bias", "encoder.stages.0.0.attn.proj.weight", "encoder.stages.0.0.attn.proj.bias", "encoder.stages.0.0.norm2.weight", "encoder.stages.0.0.norm2.bias", "encoder.stages.0.0.mlp.fc1.weight", "encoder.stages.0.0.mlp.fc1.bias", "encoder.stages.0.0.mlp.dwconv.weight", "encoder.stages.0.0.mlp.dwconv.bias", "encoder.stages.0.0.mlp.fc2.weight", "encoder.stages.0.0.mlp.fc2.bias", "encoder.stages.0.1.norm1.weight", "encoder.stages.0.1.norm1.bias", "encoder.stages.0.1.attn.relative_bias", "encoder.stages.0.1.attn.rel_idx", "encoder.stages.0.1.attn.qkv.weight", "encoder.stages.0.1.attn.qkv.bias", "encoder.stages.0.1.attn.proj.weight", "encoder.stages.0.1.attn.proj.bias", "encoder.stages.0.1.norm2.weight", "encoder.stages.0.1.norm2.bias", "encoder.stages.0.1.mlp.fc1.weight", "encoder.stages.0.1.mlp.fc1.bias", "encoder.stages.0.1.mlp.dwconv.weight", "encoder.stages.0.1.mlp.dwconv.bias", "encoder.stages.0.1.mlp.fc2.weight", "encoder.stages.0.1.mlp.fc2.bias", "encoder.stages.1.0.norm1.weight", "encoder.stages.1.0.norm1.bias", "encoder.stages.1.0.attn.relative_bias", "encoder.stages.1.0.attn.rel_idx", "encoder.stages.1.0.attn.qkv.weight", "encoder.stages.1.0.attn.qkv.bias", "encoder.stages.1.0.attn.proj.weight", "encoder.stages.1.0.attn.proj.bias", "encoder.stages.1.0.norm2.weight", "encoder.stages.1.0.norm2.bias", "encoder.stages.1.0.mlp.fc1.weight", "encoder.stages.1.0.mlp.fc1.bias", "encoder.stages.1.0.mlp.dwconv.weight", "encoder.stages.1.0.mlp.dwconv.bias", "encoder.stages.1.0.mlp.fc2.weight", "encoder.stages.1.0.mlp.fc2.bias", "encoder.stages.1.1.norm1.weight", "encoder.stages.1.1.norm1.bias", "encoder.stages.1.1.attn.relative_bias", "encoder.stages.1.1.attn.rel_idx", "encoder.stages.1.1.attn.qkv.weight", "encoder.stages.1.1.attn.qkv.bias", "encoder.stages.1.1.attn.proj.weight", "encoder.stages.1.1.attn.proj.bias", "encoder.stages.1.1.norm2.weight", "encoder.stages.1.1.norm2.bias", "encoder.stages.1.1.mlp.fc1.weight", "encoder.stages.1.1.mlp.fc1.bias", "encoder.stages.1.1.mlp.dwconv.weight", "encoder.stages.1.1.mlp.dwconv.bias", "encoder.stages.1.1.mlp.fc2.weight", "encoder.stages.1.1.mlp.fc2.bias", "encoder.stages.2.0.norm1.weight", "encoder.stages.2.0.norm1.bias", "encoder.stages.2.0.attn.relative_bias", "encoder.stages.2.0.attn.rel_idx", "encoder.stages.2.0.attn.qkv.weight", "encoder.stages.2.0.attn.qkv.bias", "encoder.stages.2.0.attn.proj.weight", "encoder.stages.2.0.attn.proj.bias", "encoder.stages.2.0.norm2.weight", "encoder.stages.2.0.norm2.bias", "encoder.stages.2.0.mlp.fc1.weight", "encoder.stages.2.0.mlp.fc1.bias", "encoder.stages.2.0.mlp.dwconv.weight", "encoder.stages.2.0.mlp.dwconv.bias", "encoder.stages.2.0.mlp.fc2.weight", "encoder.stages.2.0.mlp.fc2.bias", "encoder.stages.2.1.norm1.weight", "encoder.stages.2.1.norm1.bias", "encoder.stages.2.1.attn.relative_bias", "encoder.stages.2.1.attn.rel_idx", "encoder.stages.2.1.attn.qkv.weight", "encoder.stages.2.1.attn.qkv.bias", "encoder.stages.2.1.attn.proj.weight", "encoder.stages.2.1.attn.proj.bias", "encoder.stages.2.1.norm2.weight", "encoder.stages.2.1.norm2.bias", "encoder.stages.2.1.mlp.fc1.weight", "encoder.stages.2.1.mlp.fc1.bias", "encoder.stages.2.1.mlp.dwconv.weight", "encoder.stages.2.1.mlp.dwconv.bias", "encoder.stages.2.1.mlp.fc2.weight", "encoder.stages.2.1.mlp.fc2.bias", "encoder.stages.3.0.norm1.weight", "encoder.stages.3.0.norm1.bias", "encoder.stages.3.0.attn.relative_bias", "encoder.stages.3.0.attn.rel_idx", "encoder.stages.3.0.attn.qkv.weight", "encoder.stages.3.0.attn.qkv.bias", "encoder.stages.3.0.attn.proj.weight", "encoder.stages.3.0.attn.proj.bias", "encoder.stages.3.0.norm2.weight", "encoder.stages.3.0.norm2.bias", "encoder.stages.3.0.mlp.fc1.weight", "encoder.stages.3.0.mlp.fc1.bias", "encoder.stages.3.0.mlp.dwconv.weight", "encoder.stages.3.0.mlp.dwconv.bias", "encoder.stages.3.0.mlp.fc2.weight", "encoder.stages.3.0.mlp.fc2.bias", "encoder.stages.3.1.norm1.weight", "encoder.stages.3.1.norm1.bias", "encoder.stages.3.1.attn.relative_bias", "encoder.stages.3.1.attn.rel_idx", "encoder.stages.3.1.attn.qkv.weight", "encoder.stages.3.1.attn.qkv.bias", "encoder.stages.3.1.attn.proj.weight", "encoder.stages.3.1.attn.proj.bias", "encoder.stages.3.1.norm2.weight", "encoder.stages.3.1.norm2.bias", "encoder.stages.3.1.mlp.fc1.weight", "encoder.stages.3.1.mlp.fc1.bias", "encoder.stages.3.1.mlp.dwconv.weight", "encoder.stages.3.1.mlp.dwconv.bias", "encoder.stages.3.1.mlp.fc2.weight", "encoder.stages.3.1.mlp.fc2.bias", "encoder.norms.0.weight", "encoder.norms.0.bias", "encoder.norms.1.weight", "encoder.norms.1.bias", "encoder.norms.2.weight", "encoder.norms.2.bias", "encoder.norms.3.weight", "encoder.norms.3.bias", "decoder.decode3.attn_fuse.0.primary_conv.weight", "decoder.decode3.attn_fuse.0.cheap_operation.weight", "decoder.decode3.attn_fuse.1.weight", "decoder.decode3.attn_fuse.1.bias", "decoder.decode3.routing_mlp.2.weight", "decoder.decode3.routing_mlp.2.bias", "decoder.decode3.expert_convs.0.0.primary_conv.weight", "decoder.decode3.expert_convs.0.0.cheap_operation.weight", "decoder.decode3.expert_convs.0.1.weight", "decoder.decode3.expert_convs.0.1.bias", "decoder.decode3.expert_convs.1.0.primary_conv.weight", "decoder.decode3.expert_convs.1.0.cheap_operation.weight", "decoder.decode3.expert_convs.1.1.weight", "decoder.decode3.expert_convs.1.1.bias", "decoder.decode3.expert_convs.2.0.primary_conv.weight", "decoder.decode3.expert_convs.2.0.cheap_operation.weight", "decoder.decode3.expert_convs.2.1.weight", "decoder.decode3.expert_convs.2.1.bias", "decoder.decode3.expert_convs.3.0.primary_conv.weight", "decoder.decode3.expert_convs.3.0.cheap_operation.weight", "decoder.decode3.expert_convs.3.1.weight", "decoder.decode3.expert_convs.3.1.bias", "decoder.decode2.attn_fuse.0.primary_conv.weight", "decoder.decode2.attn_fuse.0.cheap_operation.weight", "decoder.decode2.attn_fuse.1.weight", "decoder.decode2.attn_fuse.1.bias", "decoder.decode2.routing_mlp.2.weight", "decoder.decode2.routing_mlp.2.bias", "decoder.decode2.expert_convs.0.0.primary_conv.weight", "decoder.decode2.expert_convs.0.0.cheap_operation.weight", "decoder.decode2.expert_convs.0.1.weight", "decoder.decode2.expert_convs.0.1.bias", "decoder.decode2.expert_convs.1.0.primary_conv.weight", "decoder.decode2.expert_convs.1.0.cheap_operation.weight", "decoder.decode2.expert_convs.1.1.weight", "decoder.decode2.expert_convs.1.1.bias", "decoder.decode2.expert_convs.2.0.primary_conv.weight", "decoder.decode2.expert_convs.2.0.cheap_operation.weight", "decoder.decode2.expert_convs.2.1.weight", "decoder.decode2.expert_convs.2.1.bias", "decoder.decode2.expert_convs.3.0.primary_conv.weight", "decoder.decode2.expert_convs.3.0.cheap_operation.weight", "decoder.decode2.expert_convs.3.1.weight", "decoder.decode2.expert_convs.3.1.bias", "decoder.decode1.attn_fuse.0.primary_conv.weight", "decoder.decode1.attn_fuse.0.cheap_operation.weight", "decoder.decode1.attn_fuse.1.weight", "decoder.decode1.attn_fuse.1.bias", "decoder.decode1.routing_mlp.2.weight", "decoder.decode1.routing_mlp.2.bias", "decoder.decode1.expert_convs.0.0.primary_conv.weight", "decoder.decode1.expert_convs.0.0.cheap_operation.weight", "decoder.decode1.expert_convs.0.1.weight", "decoder.decode1.expert_convs.0.1.bias", "decoder.decode1.expert_convs.1.0.primary_conv.weight", "decoder.decode1.expert_convs.1.0.cheap_operation.weight", "decoder.decode1.expert_convs.1.1.weight", "decoder.decode1.expert_convs.1.1.bias", "decoder.decode1.expert_convs.2.0.primary_conv.weight", "decoder.decode1.expert_convs.2.0.cheap_operation.weight", "decoder.decode1.expert_convs.2.1.weight", "decoder.decode1.expert_convs.2.1.bias", "decoder.decode1.expert_convs.3.0.primary_conv.weight", "decoder.decode1.expert_convs.3.0.cheap_operation.weight", "decoder.decode1.expert_convs.3.1.weight", "decoder.decode1.expert_convs.3.1.bias", "decoder.final_up.1.primary_conv.weight", "decoder.final_up.1.cheap_operation.weight", "decoder.final_up.2.weight", "decoder.final_up.2.bias", "decoder.seg_head.weight", "decoder.seg_head.bias". 

In [2]:
import torch
from torch.utils.data import DataLoader
from model3 import RefineFormer3D
from dataset import BraTSDataset
from config import IN_CHANNELS, NUM_CLASSES
import numpy as np

# Device
DEVICE = torch.device('cpu')

# Dataset & loader
test_dataset = BraTSDataset(
    root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
    transform=None,
)
test_loader = DataLoader(
    test_dataset,
    batch_size=2,
    shuffle=False,
    num_workers=14,
    pin_memory=True,
)

# Model
model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)

# Robustly load checkpoint: only load matching keys
ckpt = torch.load("checkpoint_epoch_124.pt", map_location=DEVICE)
model_dict = model.state_dict()
# filter out unmatched keys
pretrained_dict = {k: v for k, v in ckpt.items() if k in model_dict}
print(f"Loading {len(pretrained_dict)}/{len(model_dict)} layers from checkpoint")
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

model.to(DEVICE)
model.eval()

# Dice / accuracy helpers
def dice_score(pred, target):
    intersection = (pred & target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + 1e-5) / (union + 1e-5)

# Accumulators
dice_list = []
wt_dice_list = []
tc_dice_list = []
et_dice_list = []
total_correct = 0
total_voxels = 0

with torch.no_grad():
    for inputs, targets in test_loader:
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE, dtype=torch.long)

        outputs = model(inputs)["main"]
        preds = torch.argmax(outputs, dim=1)

        # Overall Accuracy
        total_correct += (preds == targets).sum().item()
        total_voxels += preds.numel()

        # Per‐sample Dice
        for i in range(preds.shape[0]):
            p = preds[i].cpu().numpy()
            g = targets[i].cpu().numpy()
            # per‐class dice
            for c in range(NUM_CLASSES):
                dice_list.append(dice_score(p == c, g == c))
            # WT (classes>0)
            wt_dice_list.append(dice_score(p > 0, g > 0))
            # TC (classes 1 or 3)
            tc_dice_list.append(dice_score((p == 1) | (p == 3), (g == 1) | (g == 3)))
            # ET (class 3)
            et_dice_list.append(dice_score(p == 3, g == 3))

# Print final results
print("\n📊 Evaluation Results:")
print(f"Avg Dice (mean over all classes): {np.mean(dice_list):.4f}")
print(f"WT Dice (Whole Tumor):           {np.mean(wt_dice_list)*100:.2f}%")
print(f"TC Dice (Tumor Core):            {np.mean(tc_dice_list)*100:.2f}%")
print(f"ET Dice (Enhancing Tumor):       {np.mean(et_dice_list)*100:.2f}%")
print(f"Overall Accuracy:                {(total_correct/total_voxels)*100:.2f}%")


✅ Total valid patient directories found: 73
Loading 0/108 layers from checkpoint


  ckpt = torch.load("checkpoint_epoch_124.pt", map_location=DEVICE)
  return fn(*args, **kwargs)



📊 Evaluation Results:
Avg Dice (mean over all classes): 0.2089
WT Dice (Whole Tumor):           8.38%
TC Dice (Tumor Core):            5.30%
ET Dice (Enhancing Tumor):       100.00%
Overall Accuracy:                37.71%


MAIN SCRIPT ------- TRAIN + VAL

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint

from model5_ghost_upg import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D, TumorCoreContrast
from losses import RefineFormer3DLoss
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS, ATTN_DROP_RATE

def pad_input_for_windows(x, window_size=(2, 2, 2)):
    _, _, D, H, W = x.shape
    pad_d = (window_size[0] - D % window_size[0]) % window_size[0]
    pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
    pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
    return F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))

def dice_score(pred, target):
    intersection = (pred & target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + 1e-5) / (union + 1e-5)

def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    dice_list, wt_dice_list, tc_dice_list, et_dice_list = [], [], [], []
    total_correct, total_voxels = 0, 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device)
            inputs = pad_input_for_windows(inputs, window_size=(2, 2, 2))
            targets = targets.to(device=device, dtype=torch.long)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs["main"], dim=1)

            total_correct += (preds == targets).sum().item()
            total_voxels += preds.numel()

            for i in range(preds.shape[0]):
                p = preds[i].cpu().numpy()
                g = targets[i].cpu().numpy()
                for c in range(NUM_CLASSES):
                    dice_list.append(dice_score(p == c, g == c))
                wt_dice_list.append(dice_score(p > 0, g > 0))
                tc_dice_list.append(dice_score((p == 1) | (p == 3), (g == 1) | (g == 3)))
                et_dice_list.append(dice_score(p == 3, g == 3))

    avg_loss = val_loss / len(dataloader.dataset)
    avg_dice = np.mean(dice_list)
    acc = total_correct / total_voxels
    print("--- Evaluation Results:")
    print(f"Avg Dice (mean over all classes): {avg_dice:.4f}")
    print(f"WT Dice (Whole Tumor):           {np.mean(wt_dice_list)*100:.2f}%")
    print(f"TC Dice (Tumor Core):            {np.mean(tc_dice_list)*100:.2f}%")
    print(f"ET Dice (Enhancing Tumor):       {np.mean(et_dice_list)*100:.2f}%")
    print(f"Per-Class Dice: {[f'{d*100:.2f}%' for d in dice_list[:NUM_CLASSES]]}")
    print(f"Overall Accuracy:                {acc*100:.2f}%")
    return avg_loss, avg_dice

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        inputs = pad_input_for_windows(inputs, window_size=(2, 2, 2))
        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

        optimizer.zero_grad()
        with autocast(device_type='cuda'):

        # --- Test-Time Augmentation (TTA) ---
            flipped_inputs = torch.flip(inputs, dims=[2])  # flip along depth
            logits_orig = model(inputs)["main"]
            logits_flip = model(flipped_inputs)["main"]
            logits_flip = torch.flip(logits_flip, dims=[2])  # flip back

            # Average predictions
            logits = (logits_orig + logits_flip) / 2.0
            outputs = {"main": logits}


            if torch.isnan(outputs["main"]).any():
                print(" Model output contains NaNs ")
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)

def main():
    device = DEVICE
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES, attn_drop_rate=ATTN_DROP_RATE).to(device)
    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler(device='cuda')

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
        TumorCoreContrast(p=0.5, scale=1.5),
    ])

    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )
    val_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=14, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=12, pin_memory=True)

    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice = validate(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f} | Avg Dice: {val_dice:.4f}")
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

    torch.save(model.state_dict(), "final_model.pt")
    print(" Final model saved as 'final_model.pt'")

if __name__ == "__main__":
    main()

✅ Total valid patient directories found: 411
✅ Total valid patient directories found: 73

--- Epoch [1/300] ---


  return fn(*args, **kwargs)
                                                 

OutOfMemoryError: CUDA out of memory. Tried to allocate 2.00 GiB. GPU 0 has a total capacity of 5.78 GiB of which 1.75 GiB is free. Including non-PyTorch memory, this process has 3.27 GiB memory in use. Of the allocated memory 3.10 GiB is allocated by PyTorch, and 31.99 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
import numpy as np

from model5_ghost_upg import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D, TumorCoreContrast
from losses import RefineFormer3DLoss
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS, ATTN_DROP_RATE

def pad_input_for_windows(x, window_size=(2, 2, 2)):
    _, _, D, H, W = x.shape
    pad_d = (window_size[0] - D % window_size[0]) % window_size[0]
    pad_h = (window_size[1] - H % window_size[1]) % window_size[1]
    pad_w = (window_size[2] - W % window_size[2]) % window_size[2]
    return F.pad(x, (0, pad_w, 0, pad_h, 0, pad_d))

def dice_score(pred, target):
    intersection = (pred & target).sum()
    union = pred.sum() + target.sum()
    return (2 * intersection + 1e-5) / (union + 1e-5)

def validate(model, dataloader, criterion, device):
    model.eval()
    val_loss = 0.0
    dice_list, wt_dice_list, tc_dice_list, et_dice_list = [], [], [], []
    total_correct, total_voxels = 0, 0

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device)
            inputs = pad_input_for_windows(inputs, window_size=(2, 2, 2))
            targets = targets.to(device=device, dtype=torch.long)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            val_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs["main"], dim=1)
            total_correct += (preds == targets).sum().item()
            total_voxels += preds.numel()

            for i in range(preds.shape[0]):
                p = preds[i].cpu().numpy()
                g = targets[i].cpu().numpy()
                for c in range(NUM_CLASSES):
                    dice_list.append(dice_score(p == c, g == c))
                wt_dice_list.append(dice_score(p > 0, g > 0))
                tc_dice_list.append(dice_score((p == 1) | (p == 3), (g == 1) | (g == 3)))
                et_dice_list.append(dice_score(p == 3, g == 3))

    avg_loss = val_loss / len(dataloader.dataset)
    avg_dice = np.mean(dice_list)
    acc = total_correct / total_voxels

    print("--- Evaluation Results:")
    print(f"Avg Dice (mean over all classes): {avg_dice:.4f}")
    print(f"WT Dice (Whole Tumor):           {np.mean(wt_dice_list)*100:.2f}%")
    print(f"TC Dice (Tumor Core):            {np.mean(tc_dice_list)*100:.2f}%")
    print(f"ET Dice (Enhancing Tumor):       {np.mean(et_dice_list)*100:.2f}%")
    print(f"Per-Class Dice: {[f'{d*100:.2f}%' for d in dice_list[:NUM_CLASSES]]}")
    print(f"Overall Accuracy:                {acc*100:.2f}%")

    return avg_loss, avg_dice

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        inputs = pad_input_for_windows(inputs, window_size=(2, 2, 2))
        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

        optimizer.zero_grad()
        with autocast(device_type='cuda'):
            outputs = model(inputs)
            if torch.isnan(outputs["main"]).any():
                print("⚠️ Model output contains NaNs")
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)

def main():
    device = DEVICE
    model = RefineFormer3D(
        in_channels=IN_CHANNELS,
        num_classes=NUM_CLASSES,
        attn_drop_rate=ATTN_DROP_RATE
    ).to(device)

    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler(device='cuda')

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
        TumorCoreContrast(p=0.5, scale=1.5),
    ])

    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )
    val_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=14, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=2, shuffle=False, num_workers=12, pin_memory=True)

    best_dice = 0.0
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice = validate(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Validation Loss: {val_loss:.4f} | Avg Dice: {val_dice:.4f}")
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            torch.save(model.state_dict(), "best_model.pt")
            print(f"✅ New best model saved (best Dice: {best_dice:.4f})")

    torch.save(model.state_dict(), "final_model.pt")
    print("Final model saved as 'final_model.pt'")

if __name__ == "__main__":
    main()


✅ Total valid patient directories found: 411
✅ Total valid patient directories found: 73

--- Epoch [1/300] ---


  return fn(*args, **kwargs)
  with torch.enable_grad(), device_autocast_ctx, torch.cpu.amp.autocast(**ctx.cpu_autocast_kwargs):  # type: ignore[attr-defined]
                                                           

--- Evaluation Results:
Avg Dice (mean over all classes): 0.6544
WT Dice (Whole Tumor):           73.49%
TC Dice (Tumor Core):            55.58%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.40%', '49.97%', '0.00%']
Overall Accuracy:                96.66%
Train Loss: 2.6926 | Validation Loss: 2.3231 | Avg Dice: 0.6544
✅ New best model saved (best Dice: 0.6544)

--- Epoch [2/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7360
WT Dice (Whole Tumor):           81.02%
TC Dice (Tumor Core):            65.38%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.77%', '64.36%', '32.34%']
Overall Accuracy:                98.03%
Train Loss: 2.1635 | Validation Loss: 2.0300 | Avg Dice: 0.7360
✅ New best model saved (best Dice: 0.7360)

--- Epoch [3/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7697
WT Dice (Whole Tumor):           83.99%
TC Dice (Tumor Core):            66.70%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.82%', '78.22%', '63.94%']
Overall Accuracy:                98.16%
Train Loss: 1.9029 | Validation Loss: 1.7812 | Avg Dice: 0.7697
✅ New best model saved (best Dice: 0.7697)

--- Epoch [4/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7739
WT Dice (Whole Tumor):           83.51%
TC Dice (Tumor Core):            69.75%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.68%', '69.84%', '59.25%']
Overall Accuracy:                98.20%
Train Loss: 1.7057 | Validation Loss: 1.5782 | Avg Dice: 0.7739
✅ New best model saved (best Dice: 0.7739)

--- Epoch [5/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7902
WT Dice (Whole Tumor):           84.49%
TC Dice (Tumor Core):            70.27%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.78%', '74.79%', '65.39%']
Overall Accuracy:                98.33%
Train Loss: 1.5447 | Validation Loss: 1.4514 | Avg Dice: 0.7902
✅ New best model saved (best Dice: 0.7902)

--- Epoch [6/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7772
WT Dice (Whole Tumor):           81.49%
TC Dice (Tumor Core):            68.39%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.58%', '62.77%', '59.51%']
Overall Accuracy:                98.11%
Train Loss: 1.4015 | Validation Loss: 1.3072 | Avg Dice: 0.7772

--- Epoch [7/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8097
WT Dice (Whole Tumor):           86.40%
TC Dice (Tumor Core):            72.39%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.79%', '76.77%', '73.15%']
Overall Accuracy:                98.53%
Train Loss: 1.2764 | Validation Loss: 1.1592 | Avg Dice: 0.8097
✅ New best model saved (best Dice: 0.8097)

--- Epoch [8/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.7995
WT Dice (Whole Tumor):           84.04%
TC Dice (Tumor Core):            70.42%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.44%', '53.89%', '68.59%']
Overall Accuracy:                98.38%
Train Loss: 1.1509 | Validation Loss: 1.0617 | Avg Dice: 0.7995

--- Epoch [9/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8153
WT Dice (Whole Tumor):           86.84%
TC Dice (Tumor Core):            73.55%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.86%', '79.76%', '75.97%']
Overall Accuracy:                98.57%
Train Loss: 1.0729 | Validation Loss: 0.9759 | Avg Dice: 0.8153
✅ New best model saved (best Dice: 0.8153)

--- Epoch [10/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8144
WT Dice (Whole Tumor):           86.32%
TC Dice (Tumor Core):            72.39%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.84%', '79.98%', '76.63%']
Overall Accuracy:                98.52%
Train Loss: 0.9734 | Validation Loss: 0.9032 | Avg Dice: 0.8144

--- Epoch [11/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8308
WT Dice (Whole Tumor):           87.44%
TC Dice (Tumor Core):            74.79%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.78%', '76.00%', '74.54%']
Overall Accuracy:                98.68%
Train Loss: 0.9024 | Validation Loss: 0.8012 | Avg Dice: 0.8308
✅ New best model saved (best Dice: 0.8308)

--- Epoch [12/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8261
WT Dice (Whole Tumor):           87.27%
TC Dice (Tumor Core):            74.68%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.66%', '67.29%', '70.72%']
Overall Accuracy:                98.71%
Train Loss: 0.8252 | Validation Loss: 0.7567 | Avg Dice: 0.8261

--- Epoch [13/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8289
WT Dice (Whole Tumor):           88.15%
TC Dice (Tumor Core):            74.50%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.85%', '80.56%', '77.02%']
Overall Accuracy:                98.72%
Train Loss: 0.7814 | Validation Loss: 0.7144 | Avg Dice: 0.8289

--- Epoch [14/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8321
WT Dice (Whole Tumor):           88.28%
TC Dice (Tumor Core):            76.24%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.81%', '75.47%', '73.93%']
Overall Accuracy:                98.77%
Train Loss: 0.7553 | Validation Loss: 0.6759 | Avg Dice: 0.8321
✅ New best model saved (best Dice: 0.8321)

--- Epoch [15/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8363
WT Dice (Whole Tumor):           88.21%
TC Dice (Tumor Core):            75.75%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.79%', '74.65%', '77.93%']
Overall Accuracy:                98.78%
Train Loss: 0.6908 | Validation Loss: 0.6486 | Avg Dice: 0.8363
✅ New best model saved (best Dice: 0.8363)

--- Epoch [16/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8392
WT Dice (Whole Tumor):           88.35%
TC Dice (Tumor Core):            75.89%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.87%', '81.37%', '81.07%']
Overall Accuracy:                98.76%
Train Loss: 0.6803 | Validation Loss: 0.6327 | Avg Dice: 0.8392
✅ New best model saved (best Dice: 0.8392)

--- Epoch [17/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8297
WT Dice (Whole Tumor):           87.05%
TC Dice (Tumor Core):            74.28%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.89%', '81.83%', '79.76%']
Overall Accuracy:                98.68%
Train Loss: 0.6592 | Validation Loss: 0.6417 | Avg Dice: 0.8297

--- Epoch [18/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8424
WT Dice (Whole Tumor):           88.29%
TC Dice (Tumor Core):            76.06%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.87%', '82.09%', '82.26%']
Overall Accuracy:                98.78%
Train Loss: 0.6219 | Validation Loss: 0.6039 | Avg Dice: 0.8424
✅ New best model saved (best Dice: 0.8424)

--- Epoch [19/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8365
WT Dice (Whole Tumor):           88.59%
TC Dice (Tumor Core):            74.83%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.89%', '83.13%', '82.28%']
Overall Accuracy:                98.78%
Train Loss: 0.6147 | Validation Loss: 0.6052 | Avg Dice: 0.8365

--- Epoch [20/300] ---


                                                           

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8298
WT Dice (Whole Tumor):           88.98%
TC Dice (Tumor Core):            74.15%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.84%', '79.21%', '73.50%']
Overall Accuracy:                98.75%
Train Loss: 0.5995 | Validation Loss: 0.6245 | Avg Dice: 0.8298

--- Epoch [21/300] ---


                                                           

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8363
WT Dice (Whole Tumor):           88.44%
TC Dice (Tumor Core):            74.77%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.86%', '79.53%', '69.96%']
Overall Accuracy:                98.78%
Train Loss: 0.6073 | Validation Loss: 0.5939 | Avg Dice: 0.8363

--- Epoch [22/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8413
WT Dice (Whole Tumor):           89.39%
TC Dice (Tumor Core):            77.61%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.86%', '80.31%', '75.51%']
Overall Accuracy:                98.86%
Train Loss: 0.5871 | Validation Loss: 0.5669 | Avg Dice: 0.8413

--- Epoch [23/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8413
WT Dice (Whole Tumor):           87.64%
TC Dice (Tumor Core):            75.52%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.89%', '82.25%', '78.05%']
Overall Accuracy:                98.77%
Train Loss: 0.5745 | Validation Loss: 0.5704 | Avg Dice: 0.8413

--- Epoch [24/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8444
WT Dice (Whole Tumor):           89.23%
TC Dice (Tumor Core):            77.86%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.80%', '75.51%', '72.73%']
Overall Accuracy:                98.85%
Train Loss: 0.5650 | Validation Loss: 0.5718 | Avg Dice: 0.8444
✅ New best model saved (best Dice: 0.8444)

--- Epoch [25/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8519
WT Dice (Whole Tumor):           89.17%
TC Dice (Tumor Core):            77.68%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.85%', '79.87%', '80.90%']
Overall Accuracy:                98.90%
Train Loss: 0.5587 | Validation Loss: 0.5345 | Avg Dice: 0.8519
✅ New best model saved (best Dice: 0.8519)

--- Epoch [26/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8529
WT Dice (Whole Tumor):           88.85%
TC Dice (Tumor Core):            77.51%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.81%', '76.83%', '82.37%']
Overall Accuracy:                98.85%
Train Loss: 0.5567 | Validation Loss: 0.5431 | Avg Dice: 0.8529
✅ New best model saved (best Dice: 0.8529)

--- Epoch [27/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8569
WT Dice (Whole Tumor):           89.63%
TC Dice (Tumor Core):            77.94%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.80%', '76.40%', '78.15%']
Overall Accuracy:                98.91%
Train Loss: 0.5259 | Validation Loss: 0.5342 | Avg Dice: 0.8569
✅ New best model saved (best Dice: 0.8569)

--- Epoch [28/300] ---


                                                            

--- Evaluation Results:
Avg Dice (mean over all classes): 0.8516
WT Dice (Whole Tumor):           89.40%
TC Dice (Tumor Core):            77.67%
ET Dice (Enhancing Tumor):       100.00%
Per-Class Dice: ['99.83%', '78.79%', '80.76%']
Overall Accuracy:                98.88%
Train Loss: 0.5384 | Validation Loss: 0.5422 | Avg Dice: 0.8516

--- Epoch [29/300] ---


Training:  34%|███▍      | 140/411 [07:03<03:53,  1.16it/s] 

In [2]:
model.eval()
with torch.no_grad():
    dummy = torch.randn(1, 4, 128, 128, 128)  # BraTS-like shape
    output = model(dummy)
    print(f"Input shape: {dummy.shape}")
    print(f"Output shape: {output['main'].shape}")


  return fn(*args, **kwargs)


AssertionError: Nw=2, expected=8. Mismatch in window token count.

In [1]:
import torch
from fvcore.nn import FlopCountAnalysis, parameter_count_table
from model5_upgrade import RefineFormer3D
from config import IN_CHANNELS, NUM_CLASSES

# Set device to CPU for analysis
DEVICE = torch.device("cpu")

# Initialize the model
model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)
model.eval()

# Dummy 3D input: [Batch, Channels, Depth, Height, Width]
dummy_input = torch.randn(1, IN_CHANNELS, 128, 128, 128).to(DEVICE)

# Compute FLOPs
with torch.no_grad():
    flops = FlopCountAnalysis(model, dummy_input)
    flops_result = flops.total() / 1e9  # Convert to GFLOPs

# Print results
print(f"✅ GFLOPs: {flops_result:.2f} G")
print("✅ Parameter Count:")
print(parameter_count_table(model))


  return fn(*args, **kwargs)


RuntimeError: _Map_base::at

In [None]:
import torch
from model4 import RefineFormer3D
from thop import profile
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model with your config
model = RefineFormer3D(
    in_channels=4,
    num_classes=3,
    embed_dims=[64, 128, 320, 512],
    depths=[2, 2, 2, 2],
    num_heads=[1, 2, 4, 8],
    window_sizes=[(4, 4, 4), (2, 4, 4), (2, 2, 2), (1, 2, 2)],
    mlp_ratios=[4, 4, 4, 4],
    decoder_channels=[256, 128, 64, 32]
).to(device)

# Dummy input
dummy_input = torch.randn(1, 4, 128, 128, 128).to(device)

# Check output shape
model.eval()
with torch.no_grad():
    output = model(dummy_input)
    print("✅ Output shape:", output["main"].shape)

# Param count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"📊 Total Parameters: {total_params / 1e6:.2f}M")

# THOP for GFLOPs
macs, params = profile(model, inputs=(dummy_input,), verbose=False)
print(f"⚙️ Estimated MACs: {macs / 1e9:.2f} GFLOPs")

# Optional: torchinfo summary
summary(model, input_size=(1, 4, 128, 128, 128), depth=4, device=device.type)


In [1]:
import os

def find_missing_seg_files(validation_dir):
    missing = []
    for case in os.listdir(validation_dir):
        case_path = os.path.join(validation_dir, case)
        if os.path.isdir(case_path):
            seg_file = os.path.join(case_path, f"{case}_seg.nii")
            if not os.path.isfile(seg_file):
                missing.append(seg_file)
    return missing

# Example usage
validation_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData"
missing_files = find_missing_seg_files(validation_dir)

if missing_files:
    print("⚠️ Missing segmentation files:")
    for f in missing_files:
        print(f)
else:
    print("✅ All segmentation files are present.")


⚠️ Missing segmentation files:
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_TCIA_195_1/Brats17_TCIA_195_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_AAM_1/Brats17_CBICA_AAM_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ABT_1/Brats17_CBICA_ABT_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ALA_1/Brats17_CBICA_ALA_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ALT_1/Brats17_CBICA_ALT_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for

In [None]:
import os

def check_missing_seg(root_dir):
    missing_patients = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for dirname in dirnames:
            seg_path = os.path.join(dirpath, dirname, f"{dirname}_seg.nii")
            if not os.path.exists(seg_path):
                missing_patients.append(dirname)
    return missing_patients

if __name__ == "__main__":
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG/"
    missing_hgg = check_missing_seg(data_dir)
    
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG/"
    missing_lgg = check_missing_seg(data_dir)

    print("Missing in HGG:", missing_hgg)
    print("Missing in LGG:", missing_lgg)


In [1]:
import torch

def check_environment():
    print("\n🔵 Checking environment:\n")

    # PyTorch Version
    print(f"Torch Version        : {torch.__version__}")

    # CUDA Version
    if torch.version.cuda:
        print(f"CUDA Version         : {torch.version.cuda}")
    else:
        print("CUDA Version         : None (CPU Only)")

    # cuDNN Version
    if torch.backends.cudnn.is_available():
        print(f"cuDNN Version        : {torch.backends.cudnn.version()}")
    else:
        print("cuDNN Version        : None")

    # GPU Availability
    gpu_available = torch.cuda.is_available()
    print(f"GPU Available        : {gpu_available}")

    if gpu_available:
        print(f"Device Count         : {torch.cuda.device_count()}")
        print(f"Current Device       : {torch.cuda.current_device()}")
        print(f"Device Name          : {torch.cuda.get_device_name(0)}")

    print("\n🟡 Checking autocast compatibility:")

    try:
        with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
            print("✅ autocast(device_type='cuda', dtype=torch.float16) works!")
    except Exception as e:
        print(f"❌ autocast(device_type='cuda') not supported: {e}")
        print("✅ You must use: autocast() only.")

    print("\n🟡 Checking GradScaler compatibility:")

    try:
        scaler = torch.cuda.amp.GradScaler()
        print("✅ GradScaler() works (no device_type needed).")
    except Exception as e:
        print(f"❌ GradScaler() failed: {e}")

    print("\n✅ Environment check complete.\n")

if __name__ == "__main__":
    check_environment()



🔵 Checking environment:

Torch Version        : 2.4.1+cu118
CUDA Version         : 11.8
cuDNN Version        : 90100
GPU Available        : True
Device Count         : 1
Current Device       : 0
Device Name          : NVIDIA GeForce RTX 3060 Laptop GPU

🟡 Checking autocast compatibility:
❌ autocast(device_type='cuda') not supported: __init__() got an unexpected keyword argument 'device_type'
✅ You must use: autocast() only.

🟡 Checking GradScaler compatibility:
✅ GradScaler() works (no device_type needed).

✅ Environment check complete.



  with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
  scaler = torch.cuda.amp.GradScaler()


In [None]:
✅ Phase 1: Prune Redundant Model Components (Baseline Compression)

Objective: Identify parts of RefineFormer3D that are heavy but not critical for performance.

Steps:

    Use torchprofile or fvcore to inspect each block’s contribution to total FLOPs.

    Replace heavier MLP/FFN blocks with:

        Depthwise separable convolutions.

        Low-rank approximated linear layers (e.g., nn.LinearLowRank).

    Remove any redundant skip connections or high-resolution fusion blocks unless justified.

Evaluation:

    Re-train with minimal epochs (10–15) on a smaller subset.

    Check Dice/IOU drop < 2%. If yes, continue pruning.

✅ Phase 2: Switch to Efficient Attention

Objective: Replace traditional attention modules with efficient alternatives.

What to try:

    Linear Attention (e.g., Linformer, Performer).

    Axial Attention (for better locality).

    Grouped Self-Attention (token grouping across z-axis).

Integration:

    Replace global self-attention in transformer encoder blocks with these.

    Test combinations like axial in encoder, linear in bottleneck.

Fallback:

    If accuracy drops >3%, roll back to original attention in critical stages only.

✅ Phase 3: Token Downsampling / Early Pooling

Objective: Reduce spatial dimensions early in the encoder.

How:

    Add an aggressive Conv3d(stride=2) or pooling block right after the input.

    Maintain hierarchical features through skip-connections.

Tip: Pair with a lightweight decoder like SegNeXt3D or DeepLab3D.
✅ Phase 4: Knowledge Distillation (KD)

Objective: Train a compact student model to mimic your best performing RefineFormer3D.

Steps:

    Train the student using Dice + KL divergence from teacher’s soft logits.

    Freeze teacher; use teacher’s intermediate attention maps (optional).

Bonus: Try structured KD from only ET, TC, WT prediction heads.
✅ Phase 5: Quantization & Mixed Precision Tricks

Objective: Compress without architectural changes.

Ideas:

    Apply Post-Training Quantization (PTQ).

    Try Quantization-Aware Training (QAT) for better retention of metrics.

    Use torch.amp consistently.

Fallback:

    If quantization drops ET Dice, exclude final layer from quantization.

✅ Phase 6: Neural Architecture Search (Optional but Novel)

Objective: Auto-discover low-FLOP configurations.

Tools:

    Use nn.MetaConv3d with NASLib, Nni, or AutoFormer.

    Constraints: max GFLOPs < SegFormer3D but ≥ 87% Dice.

🚨 Contingency Plan

    If performance improves: Save logs, metrics, configs, and prepare ablation table.

    If performance drops severely:

        Revert one phase back and try a lighter version of the current idea.

        Use early stopping and hyperparameter search for the new config.

🧪 Tracking

    Maintain a spreadsheet:

        Columns: Version, GFLOPs, Params, Dice, IOU, Accuracy, Training Time, Notes.

    Use wandb or TensorBoard to compare training behavior.