In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from tqdm import tqdm

from model import RefineFormer3D
from losses import RefineFormer3DLoss
from optimizer import get_optimizer, get_scheduler
from metrics import dice_coefficient, iou_score, precision_recall_f1, hausdorff_distance
from dataset import BraTSDataset
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS
from torch.cuda.amp import autocast, GradScaler

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device)
        targets = targets.to(device)

        optimizer.zero_grad()

        with autocast():  # Mixed precision
            outputs = model(inputs)
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    return epoch_loss

def validate_one_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0

    dice_all, iou_all, hausdorff_all = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device)
            targets = targets.to(device)

            with autocast():  # Optional mixed precision inference
                outputs = model(inputs)
                loss = criterion(outputs, targets)

            preds = outputs["main"]

            dice_scores, _ = dice_coefficient(preds, targets, num_classes)
            iou_scores, _ = iou_score(preds, targets, num_classes)
            hausdorff = hausdorff_distance(preds, targets)

            dice_all.append(np.mean(dice_scores))
            iou_all.append(np.mean(iou_scores))
            hausdorff_all.append(hausdorff)

            running_loss += loss.item()

    epoch_loss = running_loss / len(dataloader)
    mean_dice = np.nanmean(dice_all)
    mean_iou = np.nanmean(iou_all)
    mean_hausdorff = np.nanmean(hausdorff_all)

    return epoch_loss, mean_dice, mean_iou, mean_hausdorff

def save_checkpoint(state, save_dir, filename="best_model.pth"):
    os.makedirs(save_dir, exist_ok=True)
    torch.save(state, os.path.join(save_dir, filename))

def main():
    # ==================== CONFIG ====================
    device = torch.device(DEVICE)
    save_dir = "./checkpoints"
    best_dice = 0.0
    scaler = GradScaler()

    # ==================== MODEL ====================
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)
    model = model.to(device)

    # ==================== OPTIMIZER + LOSS ====================
    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer, mode="cosine", T_max=NUM_EPOCHS)
    criterion = RefineFormer3DLoss()

    # ==================== DATASET + DATALOADER ====================
    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])

    train_dataset = BraTSDataset(
    root_dirs=[
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG",   # 🔥 Replace this
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG"    # 🔥 Replace this
    ],
    transform=train_transform,
    )   

    val_dataset = BraTSDataset(
    root_dirs=[
        "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData"      # 🔥 Replace this
    ],
    transform=None,
    )


    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    # ==================== TRAINING LOOP ====================
    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")

        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice, val_iou, val_hd = validate_one_epoch(model, val_loader, criterion, device, NUM_CLASSES)

        scheduler.step()

        print(f"Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f} | Val Dice: {val_dice:.4f} | Val IoU: {val_iou:.4f} | Hausdorff: {val_hd:.4f}")

        # Save best model
        if val_dice > best_dice:
            best_dice = val_dice
            save_checkpoint({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'scheduler_state_dict': scheduler.state_dict(),
                'best_dice': best_dice,
            }, save_dir)
            print(f"✅ Saved Best Model (Dice: {best_dice:.4f})")




In [None]:
if __name__ == "__main__":
    main()

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.cuda.amp import GradScaler, autocast
from tqdm import tqdm
from model2 import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from metrics import dice_coefficient, iou_score, hausdorff_distance
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from losses import RefineFormer3DLoss
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS


def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

        optimizer.zero_grad()

        with autocast():
            outputs = model(inputs)['main']
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    epoch_loss = running_loss / len(dataloader.dataset)
    return epoch_loss


def validate_one_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    dice_all, iou_all, hd_all = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs, dim=1)
            dice_all.append(dice_coefficient(preds, targets, num_classes))
            iou_all.append(iou_score(preds, targets, num_classes))
            hd_all.append(hausdorff_distance(preds, targets))

    val_loss = running_loss / len(dataloader.dataset)
    avg_dice = torch.mean(torch.stack(dice_all)).item()
    avg_iou = torch.mean(torch.stack(iou_all)).item()
    avg_hd = torch.mean(torch.stack(hd_all)).item()

    return val_loss, avg_dice, avg_iou, avg_hd


def main():
    device = DEVICE
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)

    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler()

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])

    # Use new split structure
    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )

    val_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None,
    )

    train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=4, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=4, pin_memory=True)

    try:
        for epoch in range(1, NUM_EPOCHS + 1):
            print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
            train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
            val_loss, val_dice, val_iou, val_hd = validate_one_epoch(model, val_loader, criterion, device, NUM_CLASSES)
            scheduler.step()
            print(f"Train Loss: {train_loss:.4f}  |  Val Loss: {val_loss:.4f}")
            print(f"Val Dice: {val_dice:.4f} | Val IOU: {val_iou:.4f} | Val Hausdorff: {val_hd:.4f}")
            torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")
    except Exception as e:
        print(f"🚨 Training crashed due to: {e}")
        torch.save(model.state_dict(), "crashed_model2.pt")


if __name__ == "__main__":
    main()


  scaler = GradScaler()


✅ Total valid patient directories found: 411
✅ Total valid patient directories found: 73

--- Epoch [1/300] ---


  with autocast():
                                                 

🚨 Training crashed due to: only one dimension can be inferred


In [None]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast 
from tqdm import tqdm
from model4 import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from metrics import dice_coefficient, iou_score, hausdorff_distance
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from losses import RefineFormer3DLoss
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS

def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(inputs)['main']
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)

def validate_one_epoch(model, dataloader, criterion, device, num_classes):
    model.eval()
    running_loss = 0.0
    dice_all, iou_all, hd_all = [], [], []

    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation", leave=False):
            inputs = inputs.to(device, non_blocking=True)
            targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

            outputs = model(inputs)
            loss = criterion(outputs, targets)
            running_loss += loss.item() * inputs.size(0)

            preds = torch.argmax(outputs['main'], dim=1)
            dice_all.append(dice_coefficient(preds, targets, num_classes))
            iou_all.append(iou_score(preds, targets, num_classes))
            hd_all.append(hausdorff_distance(preds, targets))

    val_loss = running_loss / len(dataloader.dataset)
    return val_loss, torch.mean(torch.stack(dice_all)).item(), torch.mean(torch.stack(iou_all)).item(), torch.mean(torch.stack(hd_all)).item()

def main():
    device = DEVICE
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)

    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler(device='cuda')  # Updated for deprecation

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])

    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )
    val_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
        transform=None,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=4, pin_memory=True)
    val_loader   = DataLoader(val_dataset,   batch_size=1, shuffle=False, num_workers=4, pin_memory=True)

    for epoch in range(1, NUM_EPOCHS + 1):

        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        val_loss, val_dice, val_iou, val_hd = validate_one_epoch(model, val_loader, criterion, device, NUM_CLASSES)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}  |  Val Loss: {val_loss:.4f}")
        print(f"Val Dice: {val_dice:.4f} | Val IOU: {val_iou:.4f} | Val Hausdorff: {val_hd:.4f}")
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

if __name__ == "__main__":
    main()

In [1]:
import os
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch.amp import GradScaler, autocast
from tqdm import tqdm
from model4 import RefineFormer3D
from dataset import BraTSDataset
from optimizer import get_optimizer, get_scheduler
from augmentation import Compose3D, RandomFlip3D, RandomRotation3D, RandomNoise3D
from losses import RefineFormer3DLoss
from config import DEVICE, IN_CHANNELS, NUM_CLASSES, BASE_LR, WEIGHT_DECAY, NUM_EPOCHS


def train_one_epoch(model, dataloader, optimizer, criterion, device, scaler):
    model.train()
    running_loss = 0.0

    for inputs, targets in tqdm(dataloader, desc="Training", leave=False):
        inputs = inputs.to(device, non_blocking=True)
        targets = targets.to(device=device, dtype=torch.long, non_blocking=True)

        optimizer.zero_grad()

        with autocast(device_type='cuda'):
            outputs = model(inputs)['main']
            loss = criterion(outputs, targets)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * inputs.size(0)

    return running_loss / len(dataloader.dataset)


def main():
    device = DEVICE
    model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(device)

    optimizer = get_optimizer(model, base_lr=BASE_LR, weight_decay=WEIGHT_DECAY)
    scheduler = get_scheduler(optimizer)
    criterion = RefineFormer3DLoss()
    scaler = GradScaler(device='cuda')

    train_transform = Compose3D([
        RandomFlip3D(p=0.5),
        RandomRotation3D(p=0.5),
        RandomNoise3D(p=0.3),
    ])    

    train_dataset = BraTSDataset(
        root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/train"],
        transform=train_transform,
    )

    train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=12, pin_memory=True)

    for epoch in range(1, NUM_EPOCHS + 1):
        print(f"\n--- Epoch [{epoch}/{NUM_EPOCHS}] ---")
        train_loss = train_one_epoch(model, train_loader, optimizer, criterion, device, scaler)
        scheduler.step()

        print(f"Train Loss: {train_loss:.4f}")
        torch.save(model.state_dict(), f"checkpoint_epoch_{epoch}.pt")

    torch.save(model.state_dict(), "final_model.pt")
    print("✅ Final model saved as 'final_model.pt'")


if __name__ == "__main__":
    main()

✅ Total valid patient directories found: 411

--- Epoch [1/300] ---


                                                           

Train Loss: 1.7553

--- Epoch [2/300] ---


                                                            

Train Loss: 1.2305

--- Epoch [3/300] ---


                                                           

Train Loss: 0.8960

--- Epoch [4/300] ---


                                                           

Train Loss: 0.6988

--- Epoch [5/300] ---


                                                            

Train Loss: 0.5947

--- Epoch [6/300] ---


                                                            

Train Loss: 0.5332

--- Epoch [7/300] ---


                                                           

Train Loss: 0.4961

--- Epoch [8/300] ---


                                                           

Train Loss: 0.4596

--- Epoch [9/300] ---


                                                           

Train Loss: 0.4406

--- Epoch [10/300] ---


                                                            

Train Loss: 0.4233

--- Epoch [11/300] ---


                                                           

Train Loss: 0.4061

--- Epoch [12/300] ---


                                                           

Train Loss: 0.3961

--- Epoch [13/300] ---


                                                           

Train Loss: 0.3993

--- Epoch [14/300] ---


                                                           

Train Loss: 0.3823

--- Epoch [15/300] ---


                                                           

Train Loss: 0.3806

--- Epoch [16/300] ---


                                                           

Train Loss: 0.3733

--- Epoch [17/300] ---


                                                           

Train Loss: 0.3662

--- Epoch [18/300] ---


                                                           

Train Loss: 0.3590

--- Epoch [19/300] ---


                                                           

Train Loss: 0.3612

--- Epoch [20/300] ---


                                                           

Train Loss: 0.3492

--- Epoch [21/300] ---


                                                           

Train Loss: 0.3471

--- Epoch [22/300] ---


                                                           

Train Loss: 0.3457

--- Epoch [23/300] ---


                                                           

Train Loss: 0.3443

--- Epoch [24/300] ---


                                                           

Train Loss: 0.3454

--- Epoch [25/300] ---


                                                            

Train Loss: 0.3361

--- Epoch [26/300] ---


                                                           

Train Loss: 0.3270

--- Epoch [27/300] ---


                                                            

Train Loss: 0.3375

--- Epoch [28/300] ---


                                                           

Train Loss: 0.3232

--- Epoch [29/300] ---


                                                            

Train Loss: 0.3288

--- Epoch [30/300] ---


                                                           

: 

In [1]:
import torch
from torch.utils.data import DataLoader
from model4 import RefineFormer3D
from dataset import BraTSDataset
from metrics import dice_coefficient, iou_score, hausdorff_distance
from config import IN_CHANNELS, NUM_CLASSES

DEVICE = torch.device('cpu')

# ✅ Load test dataset
test_dataset = BraTSDataset(
    root_dirs=["/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BRATS_SPLIT/val"],
    transform=None,
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, num_workers=0, pin_memory=False)

# ✅ Initialize model and load weights
model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES)
model.load_state_dict(torch.load("checkpoint_epoch_29.pt", map_location=DEVICE))
model.to(DEVICE)
model.eval()

# ✅ Metric containers
dice_all, iou_all, hd_all = [], [], []

with torch.no_grad():
    for idx, (inputs, targets) in enumerate(test_loader):
        inputs = inputs.to(DEVICE)
        targets = targets.to(DEVICE, dtype=torch.long)

        outputs = model(inputs)
        raw_preds = outputs['main']
        preds = torch.argmax(raw_preds, dim=1)

        # Dice + IoU
        _, mean_dice = dice_coefficient(raw_preds, targets, NUM_CLASSES)
        _, mean_iou = iou_score(raw_preds, targets, NUM_CLASSES)

        # Prepare shapes for Hausdorff
        pred_input = preds.unsqueeze(0) if preds.dim() == 3 else preds
        if targets.dim() == 4 and targets.shape[0] == 1:
            target_input = targets.squeeze(0).unsqueeze(0)
        elif targets.dim() == 3:
            target_input = targets.unsqueeze(0)
        else:
            target_input = targets

        try:
            assert pred_input.shape == target_input.shape, f"Shape mismatch: pred {pred_input.shape}, target {target_input.shape}"
            hd = hausdorff_distance(pred_input, target_input)
        except Exception as e:
            print(f"⚠️ HD skipped for sample {idx} | pred shape: {pred_input.shape}, target shape: {target_input.shape} | Error: {str(e)}")
            hd = float('nan')

        dice_all.append(torch.tensor(mean_dice))
        iou_all.append(torch.tensor(mean_iou))
        hd_all.append(torch.tensor(hd))

# ✅ Final average results
skipped_hd = torch.isnan(torch.tensor(hd_all)).sum().item()
print("\n📊 Test Results:")
print(f"Avg Dice Score: {torch.mean(torch.stack(dice_all)).item():.4f}")
print(f"Avg IOU Score: {torch.mean(torch.stack(iou_all)).item():.4f}")
print(f"Avg Hausdorff Distance: {torch.nanmean(torch.stack(hd_all)).item():.4f}")
print(f"ℹ️ Skipped Hausdorff for {int(skipped_hd)} samples due to errors.")


✅ Total valid patient directories found: 73


  model.load_state_dict(torch.load("checkpoint_epoch_29.pt", map_location=DEVICE))


Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (1137, 2), target (19345, 3))
ℹ️ Skipped 1 samples due to issues. Writing details to 'skipped_samples.txt'


  mean_hd = np.nanmean(batch_hd)


Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (2812, 2), target (61746, 3))
ℹ️ Skipped 1 samples due to issues. Writing details to 'skipped_samples.txt'
Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (970, 2), target (18890, 3))
ℹ️ Skipped 1 samples due to issues. Writing details to 'skipped_samples.txt'
Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (5358, 2), target (224736, 3))
ℹ️ Skipped 1 samples due to issues. Writing details to 'skipped_samples.txt'
Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (2829, 2), target (88007, 3))
ℹ️ Skipped 1 samples due to issues. Writing details to 'skipped_samples.txt'
Sample 0 pred shape: (128, 128), target shape: (128, 128, 128)
⚠️ Skipping sample_0: shape mismatch (pred (1346, 2), target (37014, 3))
ℹ️ Skipped 1

In [3]:
import torch
from fvcore.nn import FlopCountAnalysis, parameter_count_table
from model4 import RefineFormer3D
from config import IN_CHANNELS, NUM_CLASSES

# Set device to CPU for analysis
DEVICE = torch.device("cpu")

# Initialize the model
model = RefineFormer3D(in_channels=IN_CHANNELS, num_classes=NUM_CLASSES).to(DEVICE)
model.eval()

# Dummy 3D input: [Batch, Channels, Depth, Height, Width]
dummy_input = torch.randn(1, IN_CHANNELS, 128, 128, 128).to(DEVICE)

# Compute FLOPs
with torch.no_grad():
    flops = FlopCountAnalysis(model, dummy_input)
    flops_result = flops.total() / 1e9  # Convert to GFLOPs

# Print results
print(f"✅ GFLOPs: {flops_result:.2f} G")
print("✅ Parameter Count:")
print(parameter_count_table(model))


Unsupported operator aten::mul encountered 24 time(s)
Unsupported operator aten::add encountered 24 time(s)
Unsupported operator aten::softmax encountered 8 time(s)
Unsupported operator aten::gelu encountered 15 time(s)
Unsupported operator aten::upsample_trilinear3d encountered 4 time(s)


✅ GFLOPs: 203.12 G
✅ Parameter Count:
| name                         | #elements or shape   |
|:-----------------------------|:---------------------|
| model                        | 15.7M                |
|  encoder                     |  15.3M               |
|   encoder.stages             |   9.5M               |
|    encoder.stages.0          |    0.1M              |
|    encoder.stages.1          |    0.4M              |
|    encoder.stages.2          |    2.5M              |
|    encoder.stages.3          |    6.4M              |
|   encoder.embeds             |   5.8M               |
|    encoder.embeds.0          |    7.1K              |
|    encoder.embeds.1          |    0.2M              |
|    encoder.embeds.2          |    1.1M              |
|    encoder.embeds.3          |    4.4M              |
|   encoder.norms              |   2.0K               |
|    encoder.norms.0           |    0.1K              |
|    encoder.norms.1           |    0.3K              |
|    encod

In [None]:
import torch
from model4 import RefineFormer3D
from thop import profile
from torchinfo import summary

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Initialize model with your config
model = RefineFormer3D(
    in_channels=4,
    num_classes=3,
    embed_dims=[64, 128, 320, 512],
    depths=[2, 2, 2, 2],
    num_heads=[1, 2, 4, 8],
    window_sizes=[(4, 4, 4), (2, 4, 4), (2, 2, 2), (1, 2, 2)],
    mlp_ratios=[4, 4, 4, 4],
    decoder_channels=[256, 128, 64, 32]
).to(device)

# Dummy input
dummy_input = torch.randn(1, 4, 128, 128, 128).to(device)

# Check output shape
model.eval()
with torch.no_grad():
    output = model(dummy_input)
    print("✅ Output shape:", output["main"].shape)

# Param count
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"📊 Total Parameters: {total_params / 1e6:.2f}M")

# THOP for GFLOPs
macs, params = profile(model, inputs=(dummy_input,), verbose=False)
print(f"⚙️ Estimated MACs: {macs / 1e9:.2f} GFLOPs")

# Optional: torchinfo summary
summary(model, input_size=(1, 4, 128, 128, 128), depth=4, device=device.type)


In [1]:
import os

def find_missing_seg_files(validation_dir):
    missing = []
    for case in os.listdir(validation_dir):
        case_path = os.path.join(validation_dir, case)
        if os.path.isdir(case_path):
            seg_file = os.path.join(case_path, f"{case}_seg.nii")
            if not os.path.isfile(seg_file):
                missing.append(seg_file)
    return missing

# Example usage
validation_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData"
missing_files = find_missing_seg_files(validation_dir)

if missing_files:
    print("⚠️ Missing segmentation files:")
    for f in missing_files:
        print(f)
else:
    print("✅ All segmentation files are present.")


⚠️ Missing segmentation files:
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_TCIA_195_1/Brats17_TCIA_195_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_AAM_1/Brats17_CBICA_AAM_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ABT_1/Brats17_CBICA_ABT_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ALA_1/Brats17_CBICA_ALA_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17ValidationData/Brats17_CBICA_ALT_1/Brats17_CBICA_ALT_1_seg.nii
/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for

In [None]:
import os

def check_missing_seg(root_dir):
    missing_patients = []
    for dirpath, dirnames, filenames in os.walk(root_dir):
        for dirname in dirnames:
            seg_path = os.path.join(dirpath, dirname, f"{dirname}_seg.nii")
            if not os.path.exists(seg_path):
                missing_patients.append(dirname)
    return missing_patients

if __name__ == "__main__":
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/HGG/"
    missing_hgg = check_missing_seg(data_dir)
    
    data_dir = "/mnt/m2ssd/research project/Lightweight 3D Vision Transformers for Medical Imaging/dataset/BraTs2017/BRATS2017/Brats17TrainingData/LGG/"
    missing_lgg = check_missing_seg(data_dir)

    print("Missing in HGG:", missing_hgg)
    print("Missing in LGG:", missing_lgg)


In [1]:
import torch

def check_environment():
    print("\n🔵 Checking environment:\n")

    # PyTorch Version
    print(f"Torch Version        : {torch.__version__}")

    # CUDA Version
    if torch.version.cuda:
        print(f"CUDA Version         : {torch.version.cuda}")
    else:
        print("CUDA Version         : None (CPU Only)")

    # cuDNN Version
    if torch.backends.cudnn.is_available():
        print(f"cuDNN Version        : {torch.backends.cudnn.version()}")
    else:
        print("cuDNN Version        : None")

    # GPU Availability
    gpu_available = torch.cuda.is_available()
    print(f"GPU Available        : {gpu_available}")

    if gpu_available:
        print(f"Device Count         : {torch.cuda.device_count()}")
        print(f"Current Device       : {torch.cuda.current_device()}")
        print(f"Device Name          : {torch.cuda.get_device_name(0)}")

    print("\n🟡 Checking autocast compatibility:")

    try:
        with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
            print("✅ autocast(device_type='cuda', dtype=torch.float16) works!")
    except Exception as e:
        print(f"❌ autocast(device_type='cuda') not supported: {e}")
        print("✅ You must use: autocast() only.")

    print("\n🟡 Checking GradScaler compatibility:")

    try:
        scaler = torch.cuda.amp.GradScaler()
        print("✅ GradScaler() works (no device_type needed).")
    except Exception as e:
        print(f"❌ GradScaler() failed: {e}")

    print("\n✅ Environment check complete.\n")

if __name__ == "__main__":
    check_environment()



🔵 Checking environment:

Torch Version        : 2.4.1+cu118
CUDA Version         : 11.8
cuDNN Version        : 90100
GPU Available        : True
Device Count         : 1
Current Device       : 0
Device Name          : NVIDIA GeForce RTX 3060 Laptop GPU

🟡 Checking autocast compatibility:
❌ autocast(device_type='cuda') not supported: __init__() got an unexpected keyword argument 'device_type'
✅ You must use: autocast() only.

🟡 Checking GradScaler compatibility:
✅ GradScaler() works (no device_type needed).

✅ Environment check complete.



  with torch.cuda.amp.autocast(device_type='cuda', dtype=torch.float16):
  scaler = torch.cuda.amp.GradScaler()
