In [1]:
# Importing Necessary libraries
import os
import glob
import numpy as np
import nibabel as nib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as transforms

#Swin Transformer and Model
from timm.models.swin_transformer import swin_base_patch4_window7_224 as SwinTransformer
from tqdm import tqdm
import random
import warnings


warnings.filterwarnings("ignore")

In [3]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("awsaf49/brats20-dataset-training-validation")

print("Path to dataset files:", path)

Path to dataset files: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1


In [8]:
ls /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/

[0m[01;34mBraTS20_Training_001[0m/  [01;34mBraTS20_Training_094[0m/  [01;34mBraTS20_Training_187[0m/  [01;34mBraTS20_Training_280[0m/
[01;34mBraTS20_Training_002[0m/  [01;34mBraTS20_Training_095[0m/  [01;34mBraTS20_Training_188[0m/  [01;34mBraTS20_Training_281[0m/
[01;34mBraTS20_Training_003[0m/  [01;34mBraTS20_Training_096[0m/  [01;34mBraTS20_Training_189[0m/  [01;34mBraTS20_Training_282[0m/
[01;34mBraTS20_Training_004[0m/  [01;34mBraTS20_Training_097[0m/  [01;34mBraTS20_Training_190[0m/  [01;34mBraTS20_Training_283[0m/
[01;34mBraTS20_Training_005[0m/  [01;34mBraTS20_Training_098[0m/  [01;34mBraTS20_Training_191[0m/  [01;34mBraTS20_Training_284[0m/
[01;34mBraTS20_Training_006[0m/  [01;34mBraTS20_Training_099[0m/  [01;34mBraTS20_Training_192[0m/  [01;34mBraTS20_Training_285[0m/
[01;34mBraTS20_Training_007[0m/  [01;34mBraTS20_Training_100[0m/  [01;34mBraTS20_Training_193[0m/  [01;34mBraTS20_Training_286[0m/
[01;34mBraTS20_Trainin

In [9]:
import os

DATASET_PATH = r"/root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"


print("🔍 Checking Dataset Path...")
print(os.listdir(DATASET_PATH)[:5])

🔍 Checking Dataset Path...
['BraTS20_Training_350', 'BraTS20_Training_047', 'BraTS20_Training_100', 'BraTS20_Training_366', 'BraTS20_Training_068']


In [10]:
for patient in os.listdir(DATASET_PATH)[:5]:  # Check first 5 patients
    patient_path = os.path.join(DATASET_PATH, patient)

    print(f"\n🔍 Checking {patient}...")

    for mod in ["flair", "t1", "t1ce", "t2", "seg"]:
        files = [f for f in os.listdir(patient_path) if mod in f]
        if files:
            print(f"✅ {mod}: {files}")
        else:
            print(f"❌ Missing {mod}!")


🔍 Checking BraTS20_Training_350...
✅ flair: ['BraTS20_Training_350_flair.nii']
✅ t1: ['BraTS20_Training_350_t1ce.nii', 'BraTS20_Training_350_t1.nii']
✅ t1ce: ['BraTS20_Training_350_t1ce.nii']
✅ t2: ['BraTS20_Training_350_t2.nii']
✅ seg: ['BraTS20_Training_350_seg.nii']

🔍 Checking BraTS20_Training_047...
✅ flair: ['BraTS20_Training_047_flair.nii']
✅ t1: ['BraTS20_Training_047_t1ce.nii', 'BraTS20_Training_047_t1.nii']
✅ t1ce: ['BraTS20_Training_047_t1ce.nii']
✅ t2: ['BraTS20_Training_047_t2.nii']
✅ seg: ['BraTS20_Training_047_seg.nii']

🔍 Checking BraTS20_Training_100...
✅ flair: ['BraTS20_Training_100_flair.nii']
✅ t1: ['BraTS20_Training_100_t1.nii', 'BraTS20_Training_100_t1ce.nii']
✅ t1ce: ['BraTS20_Training_100_t1ce.nii']
✅ t2: ['BraTS20_Training_100_t2.nii']
✅ seg: ['BraTS20_Training_100_seg.nii']

🔍 Checking BraTS20_Training_366...
✅ flair: ['BraTS20_Training_366_flair.nii']
✅ t1: ['BraTS20_Training_366_t1ce.nii', 'BraTS20_Training_366_t1.nii']
✅ t1ce: ['BraTS20_Training_366_t1ce.

In [12]:
import nibabel as nib
import numpy as np
import os
import torch


DATASET_PATH = r"/root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/"

PATIENT = "BraTS20_Training_001"
PATIENT_PATH = os.path.join(DATASET_PATH, PATIENT)

# Function to Load NIfTI
def load_nii(file_path):
    print(f"✅ Loading: {file_path}")
    return nib.load(file_path).get_fdata()

# Load Image Modalities
flair = load_nii(os.path.join(PATIENT_PATH, f"{PATIENT}_flair.nii"))
t1 = load_nii(os.path.join(PATIENT_PATH, f"{PATIENT}_t1.nii"))
t1ce = load_nii(os.path.join(PATIENT_PATH, f"{PATIENT}_t1ce.nii"))
t2 = load_nii(os.path.join(PATIENT_PATH, f"{PATIENT}_t2.nii"))
seg = load_nii(os.path.join(PATIENT_PATH, f"{PATIENT}_seg.nii"))

print(f"\n🧠 **Image Shapes:**")
print(f"Flair: {flair.shape}, T1: {t1.shape}, T1ce: {t1ce.shape}, T2: {t2.shape}")
print(f"Segmentation Mask: {seg.shape}")


image = np.stack([flair, t1, t1ce, t2], axis=0)
image_tensor = torch.tensor(image, dtype=torch.float32)
mask_tensor = torch.tensor(seg, dtype=torch.long)

print(f"\n📌 **Converted to Tensor:**")
print(f"Image Tensor Shape: {image_tensor.shape}")
print(f"Mask Tensor Shape: {mask_tensor.shape}")


✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_flair.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1ce.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t2.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_se

In [18]:
from torch.utils.data import Dataset, DataLoader
class BraTSDataset(Dataset):
    def __init__(self, dataset_path):
        self.dataset_path = dataset_path
        self.patients = sorted(os.listdir(dataset_path))
        self.modalities = ["flair", "t1", "t1ce", "t2"]
        self.transform = transforms.Compose([
            transforms.ToTensor(),
        ])

    def __len__(self):
        return len(self.patients)

    def load_nii(self, patient_folder, modality):
        possible_files = glob.glob(os.path.join(patient_folder, f"*{modality}*.nii*"))

        if not possible_files and modality == "t1":
            possible_files = glob.glob(os.path.join(patient_folder, f"*ROI*{modality}*.nii*"))

        if not possible_files:
            raise FileNotFoundError(f"❌ No file found for modality '{modality}' in {patient_folder}")

        possible_files.sort(key=lambda x: x.endswith(".gz"), reverse=True)
        file_path = possible_files[0]

        print(f"✅ Loading: {file_path}")
        return nib.load(file_path).get_fdata()

    def __getitem__(self, idx):
        patient = self.patients[idx]
        patient_path = os.path.join(self.dataset_path, patient)

        try:
            images = [self.load_nii(patient_path, mod) for mod in self.modalities]
            images = np.stack(images, axis=0)

            mask = self.load_nii(patient_path, "seg")

            slice_idx = mask.shape[-1] // 2
            image = images[:, :, :, slice_idx]
            mask = mask[:, :, slice_idx]

            image_tensor = torch.tensor(image, dtype=torch.float32)
            mask_tensor = torch.tensor(mask, dtype=torch.long)


            image_tensor = F.interpolate(image_tensor.unsqueeze(0), size=(224, 224), mode="bilinear", align_corners=False).squeeze(0)
            mask_tensor = mask_tensor.unsqueeze(0).unsqueeze(0).float()
            mask_tensor = F.interpolate(mask_tensor, size=(224, 224), mode="nearest")
            mask_tensor = mask_tensor.squeeze(0).squeeze(0).long()

            print(f"✅ Loaded {patient}: Image {image_tensor.shape}, Mask {mask_tensor.shape}")
            return image_tensor, mask_tensor

        except FileNotFoundError as e:
            print(f"⚠️ Skipping {patient} due to missing files: {e}")
            return None

In [19]:

import glob
import torch.nn.functional as F
#DATASET_PATH = r"C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG"

# Load Dataset
train_dataset = BraTSDataset(DATASET_PATH)


valid_samples = [d for d in train_dataset if d is not None]


if len(valid_samples) == 0:
    raise ValueError("❌ No valid samples found. Check dataset paths and file integrity.")


train_loader = DataLoader(valid_samples, batch_size=2, shuffle=True, num_workers=0)

print(f"✅ Loaded {len(train_loader.dataset)} valid patients from BraTS 2020.")

✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_flair.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1ce.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t1ce.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_t2.nii
✅ Loading: /root/.cache/kagglehub/datasets/awsaf49/brats20-dataset-training-validation/versions/1/BraTS2020_TrainingData/MICCAI_BraTS2020_TrainingData/BraTS20_Training_001/BraTS20_Training_001_

In [None]:
# import torch
# import torch.nn as nn
# import torch.nn.functional as F
# from timm.models.swin_transformer import swin_tiny_patch4_window7_224 as SwinTransformer

# class SwinUNet(nn.Module):
#     def __init__(self, img_size=224, in_channels=4, out_channels=4):
#         super(SwinUNet, self).__init__()


#         self.swin = SwinTransformer(pretrained=True)
#         self.swin.head = nn.Identity()


#         self.swin.patch_embed.proj = nn.Conv2d(in_channels, 96, kernel_size=4, stride=4, padding=0)


#         self.up1 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)
#         self.up2 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)
#         self.up3 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)
#         self.up4 = nn.ConvTranspose2d(96, out_channels, kernel_size=2, stride=2) # output

#     def forward(self, x):
#         B, C, H, W = x.shape  # Batch, Channels, Height, Width

#         # Swin Transformer Forward Pass
#         x = self.swin.patch_embed(x)
#         x = self.swin.patch_embed.norm(x)
#         x = self.swin.layers[0](x)
#         x1 = x
#         x = self.swin.layers[1](x)
#         x2 = x
#         x = self.swin.layers[2](x)
#         x3 = x
#         x = self.swin.layers[3](x)
#         x4 = x


#         print(f"Feature Maps from Swin Transformer:")
#         print(f"Level 0: {x1.shape}")
#         print(f"Level 1: {x2.shape}")
#         print(f"Level 2: {x3.shape}")
#         print(f"Level 3: {x4.shape}")


#         x4 = x4.permute(0, 3, 1, 2)


#         x = self.up1(x4)
#         x = self.up2(x)
#         x = self.up3(x)
#         x = self.up4(x)

#         return x

# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = SwinUNet().to(device)

# # Create a dummy input to check dimensions
# dummy_input = torch.randn(2, 4, 224, 224).to(device)
# output = model(dummy_input)

# print("✅ Swin-UNet Forward Pass Successful! Output shape:", output.shape)

In [15]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from timm.models.swin_transformer import swin_tiny_patch4_window7_224 as SwinTransformer


class MultiScaleAttentionFusion(nn.Module):
    def __init__(self, in_channels):
        super(MultiScaleAttentionFusion, self).__init__()

        # Channel Attention (Global)
        self.channel_attn = nn.Sequential(
            nn.AdaptiveAvgPool2d(1),  # Global Pooling
            nn.Conv2d(in_channels, in_channels // 4, 1, bias=False),
            nn.ReLU(inplace=True),
            nn.Conv2d(in_channels // 4, in_channels, 1, bias=False),
            nn.Sigmoid()
        )

        # Spatial Attention
        self.spatial_attn = nn.Sequential(
            nn.Conv2d(in_channels, 1, kernel_size=7, padding=3, bias=False),
            nn.Sigmoid()
        )

    def forward(self, x):
        # Apply channel attention
        channel_weights = self.channel_attn(x)
        x = x * channel_weights

        # Apply spatial attention
        spatial_weights = self.spatial_attn(x)
        x = x * spatial_weights

        return x


class SwinUNet(nn.Module):
    def __init__(self, img_size=224, in_channels=4, out_channels=4):
        super(SwinUNet, self).__init__()

        self.swin = SwinTransformer(pretrained=True)
        self.swin.head = nn.Identity()

        self.swin.patch_embed.proj = nn.Conv2d(in_channels, 96, kernel_size=4, stride=4, padding=0)

        # Decoder with Attention Fusion
        self.attn1 = MultiScaleAttentionFusion(768)
        self.attn2 = MultiScaleAttentionFusion(384)
        self.attn3 = MultiScaleAttentionFusion(192)
        self.attn4 = MultiScaleAttentionFusion(96)

        self.up1 = nn.ConvTranspose2d(768, 384, kernel_size=2, stride=2)
        self.up2 = nn.ConvTranspose2d(384, 192, kernel_size=2, stride=2)
        self.up3 = nn.ConvTranspose2d(192, 96, kernel_size=2, stride=2)
        self.up4 = nn.ConvTranspose2d(96, out_channels, kernel_size=2, stride=2)  # Output

    def forward(self, x):
        B, C, H, W = x.shape

        # Swin Transformer Forward Pass
        x = self.swin.patch_embed(x)
        x = self.swin.patch_embed.norm(x)
        x = self.swin.layers[0](x)
        x1 = x
        x = self.swin.layers[1](x)
        x2 = x
        x = self.swin.layers[2](x)
        x3 = x
        x = self.swin.layers[3](x)
        x4 = x

        # Multi-Scale Attention Fusion
        x4 = self.attn1(x4.permute(0, 3, 1, 2))
        x = self.up1(x4)

        x3 = self.attn2(x3.permute(0, 3, 1, 2))
        x = self.up2(x + x3)

        x2 = self.attn3(x2.permute(0, 3, 1, 2))
        x = self.up3(x + x2)

        x1 = self.attn4(x1.permute(0, 3, 1, 2))
        x = self.up4(x + x1)

        return x


device = "cuda" if torch.cuda.is_available() else "cpu"
model = SwinUNet().to(device)

dummy_input = torch.randn(2, 4, 224, 224).to(device)
output = model(dummy_input)

print("✅ Swin-UNet with Multi-Scale Attention Fusion Forward Pass Successful! Output shape:", output.shape)

model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

✅ Swin-UNet with Multi-Scale Attention Fusion Forward Pass Successful! Output shape: torch.Size([2, 4, 112, 112])


In [16]:
from torch.utils.data import random_split, DataLoader

# Split Dataset into 80% Train, 20% Validation
dataset_size = len(train_dataset)
train_size = int(0.8 * dataset_size)
val_size = dataset_size - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])


train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=2, shuffle=False, num_workers=0)

print(f"✅ Train Set: {len(train_dataset)} samples | Val Set: {len(val_dataset)} samples")


✅ Train Set: 296 samples | Val Set: 75 samples


In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
torch.cuda.is_available = lambda : False

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm import tqdm

# Define Dice Loss (Fixed)
class DiceLoss(nn.Module):
    def __init__(self, smooth=1e-6):
        super(DiceLoss, self).__init__()
        self.smooth = smooth

    def forward(self, pred, target):
        pred = torch.sigmoid(pred)


        target_one_hot = F.one_hot(target, num_classes=4).permute(0, 3, 1, 2).float()

        intersection = (pred * target_one_hot).sum()
        return 1 - (2. * intersection + self.smooth) / (pred.sum() + target_one_hot.sum() + self.smooth)

device = torch.device("cpu")

# Initialize Model, Losses & Optimizer
model = SwinUNet().to(device)
criterion_ce = nn.CrossEntropyLoss()
criterion_dice = DiceLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=1e-4)

# Training Loop
def train_model(model, train_loader, epochs=5):
    model.to("cpu")
    for epoch in range(epochs):
        model.train()
        train_loss = 0
        loop = tqdm(train_loader, desc=f"Epoch [{epoch+1}/{epochs}]", leave=False)

        for images, masks in loop:
            images, masks = images.cpu(), masks.cpu()

            masks = torch.clamp(masks, 0, 3)
            masks_resized = F.interpolate(masks.unsqueeze(1).float(), size=(112, 112), mode="nearest").squeeze(1).long()

            optimizer.zero_grad()
            outputs = model(images)

            loss_ce = criterion_ce(outputs, masks_resized)
            loss_dice = criterion_dice(outputs, masks_resized)
            loss = loss_ce + loss_dice

            loss.backward()
            optimizer.step()

            train_loss += loss.item()
            loop.set_postfix({"Loss": loss.item()})

        print(f"\n✅ Epoch {epoch+1}: Train Loss: {train_loss/len(train_loader):.4f}")


train_model(model, train_loader, epochs=5)




✅ Epoch 1: Train Loss: 0.2708





✅ Epoch 2: Train Loss: 0.1317


Epoch [3/5]:  64%|██████▍   | 118/184 [03:27<01:57,  1.78s/it, Loss=0.0918]

In [None]:
import torch.nn.functional as F

def evaluate_model(model, val_loader):
    """Compute accuracy of the model after training."""
    model.eval()
    correct_pixels = 0
    total_pixels = 0

    with torch.no_grad():
        for images, masks in val_loader:
            images, masks = images.to(device), masks.to(device)
            outputs = model(images)


            preds = torch.argmax(outputs, dim=1)


            masks_resized = F.interpolate(masks.unsqueeze(1).float(), size=preds.shape[1:], mode="nearest").squeeze(1).long()


            correct_pixels += (preds == masks_resized).sum().item()
            total_pixels += masks_resized.numel()

    final_accuracy = correct_pixels / total_pixels
    print(f"\n🎯 **Final Model Accuracy on Validation Set: {final_accuracy:.4f}**")

    return final_accuracy

In [None]:
evaluate_model(model, val_loader)

✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG\Brats18_CBICA_AMH_1\Brats18_CBICA_AMH_1_flair.nii
✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG\Brats18_CBICA_AMH_1\Brats18_CBICA_AMH_1_t1.nii
✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG\Brats18_CBICA_AMH_1\Brats18_CBICA_AMH_1_t1ce.nii
✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG\Brats18_CBICA_AMH_1\Brats18_CBICA_AMH_1_t2.nii
✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\brats2018\versions\10\MICCAI_BraTS_2018_Data_Training\HGG\Brats18_CBICA_AMH_1\Brats18_CBICA_AMH_1_seg.nii
✅ Loaded Brats18_CBICA_AMH_1: Image torch.Size([4, 224, 224]), Mask torch.Size([224, 224])
✅ Loading: C:\Users\Dell\.cache\kagglehub\datasets\sanglequang\

0.9778626852526725

In [None]:
torch.save(model,'swin2018.pt')