In [1]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import RefineFormer3D
from losses import RefineFormer3DLoss
from optimizer import get_optimizer, get_scheduler
from metrics import dice_coefficient, iou_score, precision_recall_f1, hausdorff_distance
from dataset import BraTSDataset
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS
from torch.cuda.amp import autocast, GradScaler

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        with autocast():  # Mixed precision
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

def validate_one_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0

    dice_all, iou_all, hausdorff_all = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast():  # Optional mixed precision inference
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            preds = outputs["main"]

            dice_scores, _ = dice_coefficient(preds, targets, num_classes)
            iou_scores, _ = iou_score(preds, targets, num_classes)
            hausdorff = hausdorff_distance(preds, targets)

            dice_all.append(np.mean(dice_scores))
            iou_all.append(np.mean(iou_scores))
            hausdorff_all.append(hausdorff)

            running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    mean_dice = np.nanmean(dice_all)
    mean_iou = np.nanmean(iou_all)
    mean_hausdorff = np.nanmean(hausdorff_all)

    return epoch_loss, mean_dice, mean_iou, mean_hausdorff

def save_checkpoint(state, save_dir, filename="best_model.pth"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(state, os.path.join(save_dir, filename))

def main():
    # ==================== CONFIG ====================
    device = torch.device(DEVICE)
    save_dir = "./checkpoints"
    best_dice = 0.0
    scaler = GradScaler()

    # ==================== MODEL ====================
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)
    model = model.to(device)

    # ==================== OPTIMIZER + LOSS ====================
    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer, mode="cosine", T_max=NUM_EPOCHS)
    criterion = RefineFormer3DLoss()

    # ==================== DATASET + DATALOADER ====================
    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])

    train_dataset = BraTSDataset(
    root_dirs=[
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG",   # 🔥 Replace this
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG"    # 🔥 Replace this
    ],
    transform=train_transform,
    )   

    val_dataset = BraTSDataset(
    root_dirs=[
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData"      # 🔥 Replace this
    ],
    transform=None,
    )


    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    # ==================== TRAINING LOOP ====================
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")

        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice, val_iou, val_hd = validate_one_epoch(model, val_loader, criterion, device, NUM_CLASSES)

        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f} | Val IoU: {val_iou:.4f} | Hausdorff: {val_hd:.4f}")

        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_dice': best_dice,
            }, save_dir)
            print(f"✅ Saved Best Model (Dice: {best_dice:.4f})")




In [2]:
if __name__ == "__main__":
    main()

  scaler = GradScaler()



--- Epoch [1/300] ---


  with autocast():  # Mixed precision
Training:   0%|          | 0/285 [00:03<?, ?it/s]


TypeError: forward() takes 2 positional arguments but 5 were given

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm

from model import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from metrics import dice_coefficient, iou_score, hausdorff_distance
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from losses import RefineFormer3DLoss   # ✅ Correct loss import

from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS


def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device, non_blocking=True)

        optimizer.zero_grad()
        
        with autocast():   # ✅ Specify dtype explicitly
            outputs = model(inputs)['main']
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


def validate_one_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    dice_all, iou_all, hd_all = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device, non_blocking=True)

            outputs = model(inputs)['main']

            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs, dim=1)
            dice_all.append(dice_coefficient(preds, targets, num_classes))
            iou_all.append(iou_score(preds, targets, num_classes))
            hd_all.append(hausdorff_distance(preds, targets))

    val_loss = running_loss / len(dataloader.dataset)
    avg_dice = torch.mean(torch.stack(dice_all)).item()
    avg_iou = torch.mean(torch.stack(iou_all)).item()
    avg_hd = torch.mean(torch.stack(hd_all)).item()

    return val_loss, avg_dice, avg_iou, avg_hd


def main():
    device = DEVICE

    # Model
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)
    model = model.to(device)

    # Optimizer, Scheduler, Loss
    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()   # ✅ Correct
    scaler = GradScaler()  # ✅ Fixed warning

    # Data
    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])

    train_dataset = BraTSDataset(
        root_dirs=[
            "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG",
            "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG",
        ],
        transform=train_transform,
    )

    val_dataset = BraTSDataset(
        root_dirs=[
            "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData",
        ],
        transform=None,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    # Training loop
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")

        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice, val_iou, val_hd = validate_one_epoch(model, val_loader, criterion, device, NUM_CLASSES)

        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}  |  Val Loss: {val_loss:.4f}")
        print(f"Val Dice: {val_dice:.4f} | Val IOU: {val_iou:.4f} | Val Hausdorff: {val_hd:.4f}")


if __name__ == "__main__":
    main()





--- Epoch [1/300] ---


  with autocast():   # ✅ Specify dtype explicitly
                                                 

RuntimeError: shape '[1, 32, 52, 52, 52]' is invalid for input of size 4492800

In [6]:
import os

def check_missing_seg(root_dir):
    missing_patients = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for dirname in dirnames:
            seg_path = os.path.join(dirpath, dirname, f"{dirname}_seg.nii")
            if not os.path.exists(seg_path):
                missing_patients.append(dirname)
    return missing_patients

if __name__ == "__main__":
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG/"
    missing_hgg = check_missing_seg(data_dir)
    
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG/"
    missing_lgg = check_missing_seg(data_dir)

    print("Missing in HGG:", missing_hgg)
    print("Missing in LGG:", missing_lgg)


Missing in HGG: ['Brats17_2013_2_1_flair.nii']
Missing in LGG: []


In [2]:
import torch

def check_environment():
    print("\n🔵 Checking environment:\n")

    # PyTorch Version
    print(f"Torch Version        : {torch.__version__}")

    # CUDA Version
    if torch.version.cuda:
        print(f"CUDA Version         : {torch.version.cuda}")
    else:
        print("CUDA Version         : None (CPU Only)")

    # cuDNN Version
    if torch.backends.cudnn.is_available():
        print(f"cuDNN Version        : {torch.backends.cudnn.version()}")
    else:
        print("cuDNN Version        : None")

    # GPU Availability
    gpu_available = torch.cuda.is_available()
    print(f"GPU Available        : {gpu_available}")

    if gpu_available:
        print(f"Device Count         : {torch.cuda.device_count()}")
        print(f"Current Device       : {torch.cuda.current_device()}")
        print(f"Device Name          : {torch.cuda.get_device_name(0)}")

    print("\n🟡 Checking autocast compatibility:")

    try:
        with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
            print("✅ autocast(device_type='cuda', dtype=torch.float16) works!")
    except Exception as e:
        print(f"❌ autocast(device_type='cuda') not supported: {e}")
        print("✅ You must use: autocast() only.")

    print("\n🟡 Checking GradScaler compatibility:")

    try:
        scaler = torch.cuda.amp.GradScaler()
        print("✅ GradScaler() works (no device_type needed).")
    except Exception as e:
        print(f"❌ GradScaler() failed: {e}")

    print("\n✅ Environment check complete.\n")

if __name__ == "__main__":
    check_environment()



🔵 Checking environment:

Torch Version        : 2.4.1+cu118
CUDA Version         : 11.8
cuDNN Version        : 90100
GPU Available        : True
Device Count         : 1
Current Device       : 0
Device Name          : NVIDIA GeForce RTX 3060 Laptop GPU

🟡 Checking autocast compatibility:
❌ autocast(device_type='cuda') not supported: __init__() got an unexpected keyword argument 'device_type'
✅ You must use: autocast() only.

🟡 Checking GradScaler compatibility:
✅ GradScaler() works (no device_type needed).

✅ Environment check complete.



  with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
  scaler = torch.cuda.amp.GradScaler()
