In [None]:
# Install necessary libraries
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu
!pip install timm
!pip install opencv-python

In [None]:
# --- STEP 0: Google Colab Setup ---
# This is the first cell you should run.
import os
import torch
from google.colab import drive
import timm

# Mount Google Drive
drive.mount('/content/drive')

# Verify the mount by listing the contents of your dataset folder
!ls "/content/drive/MyDrive/brain_tumor_dataset/train"
print("\n✅ Setup Complete. Proceed with the following cells.")

Mounted at /content/drive
glioma	meningioma  notumor  pituitary

✅ Setup Complete. Proceed with the following cells.


In [None]:
# --- PART 1: Imports and Dataset Setup ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
import numpy as np
import os
from PIL import Image

# Device setup
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"\n✅ Using device: {DEVICE}")

# Correct path to your dataset on Google Drive.
# This must point to the folder that contains your class folders.
root_dir = "/content/drive/MyDrive/brain_tumor_dataset/train"

IMG_SIZE = (224, 224)

# Data transforms
data_transforms = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# Load full dataset
full_dataset = datasets.ImageFolder(root=root_dir, transform=data_transforms)

# Auto-split into train/val (80/20)
train_size = int(0.8 * len(full_dataset))
val_size = len(full_dataset) - train_size
train_dataset, val_dataset = random_split(full_dataset, [train_size, val_size])

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

class_names = full_dataset.classes
print(f"📊 Classes found: {class_names}")
print(f"📂 Training samples: {len(train_dataset)} | Validation samples: {len(val_dataset)}")


✅ Using device: cuda
📊 Classes found: ['glioma', 'meningioma', 'notumor', 'pituitary']
📂 Training samples: 4569 | Validation samples: 1143


In [None]:
# --- PART 2: Hybrid Model with Grad-CAM Support ---
import torch.nn.functional as F
import timm

# ---------------------------
# 1. Attention Block
# ---------------------------
class AttentionBlock(nn.Module):
    def __init__(self, F_g, F_l, F_int):
        super(AttentionBlock, self).__init__()
        self.W_g = nn.Sequential(
            nn.Conv2d(F_g, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.W_x = nn.Sequential(
            nn.Conv2d(F_l, F_int, kernel_size=1),
            nn.BatchNorm2d(F_int)
        )
        self.psi = nn.Sequential(
            nn.Conv2d(F_int, 1, kernel_size=1),
            nn.BatchNorm2d(1),
            nn.Sigmoid()
        )
        self.relu = nn.ReLU(inplace=True)

    def forward(self, g, x):
        g1 = self.W_g(g)
        x1 = self.W_x(x)
        psi = self.relu(g1 + x1)
        psi = self.psi(psi)
        return x * psi

# ---------------------------
# 2. Attention U-Net
# ---------------------------
class AttentionUNet(nn.Module):
    def __init__(self, in_channels=3, out_channels=1, features=[64, 128, 256, 512]):
        super(AttentionUNet, self).__init__()
        self.encoder1 = self._block(in_channels, features[0])
        self.encoder2 = self._block(features[0], features[1])
        self.encoder3 = self._block(features[1], features[2])
        self.encoder4 = self._block(features[2], features[3])

        self.pool = nn.MaxPool2d(2)
        self.bottleneck = self._block(features[3], features[3] * 2)

        # Corrected input channels for upsampling blocks after concatenation
        self.up4 = self._up_block(features[3] * 2, features[3])
        self.att4 = AttentionBlock(F_g=features[3], F_l=features[3], F_int=features[3] // 2)
        self.up3 = self._up_block(features[3] + features[3], features[2]) # Corrected
        self.att3 = AttentionBlock(F_g=features[2], F_l=features[2], F_int=features[2] // 2)
        self.up2 = self._up_block(features[2] + features[2], features[1]) # Corrected
        self.att2 = AttentionBlock(F_g=features[1], F_l=features[1], F_int=features[1] // 2)
        self.up1 = self._up_block(features[1] + features[1], features[0]) # Corrected
        self.final_conv = nn.Conv2d(features[0], out_channels, kernel_size=1)

    def _block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.Conv2d(in_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_channels, out_channels, 3, padding=1),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def _up_block(self, in_channels, out_channels):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
            nn.BatchNorm2d(out_channels),
            nn.ReLU(inplace=True)
        )

    def forward(self, x):
        x1 = self.encoder1(x)
        x2 = self.encoder2(self.pool(x1))
        x3 = self.encoder3(self.pool(x2))
        x4 = self.encoder4(self.pool(x3))
        x5 = self.bottleneck(self.pool(x4))
        d4 = self.up4(x5)
        x4 = self.att4(d4, x4)
        d4 = torch.cat([x4, d4], dim=1)
        d3 = self.up3(d4)
        x3 = self.att3(d3, x3)
        d3 = torch.cat([x3, d3], dim=1)
        d2 = self.up2(d3)
        x2 = self.att2(d2, x2)
        d2 = torch.cat([x2, d2], dim=1)
        d1 = self.up1(d2)
        mask = torch.sigmoid(self.final_conv(d1))
        return mask, x5

# ---------------------------
# 3. Hybrid Model with Grad-CAM Support (Corrected)
# ---------------------------
class HybridAttentionSwin(nn.Module):
    def __init__(self, num_classes):
        super(HybridAttentionSwin, self).__init__()
        self.segmentation_net = AttentionUNet(in_channels=3)
        self.swin = timm.create_model("swin_tiny_patch4_window7_224", pretrained=True, num_classes=0)
        self.classifier = nn.Linear(self.swin.num_features, num_classes)

        self.feature_maps = None
        self.gradients = None

        def save_features_hook(module, input, output):
            self.feature_maps = output

        def save_gradients_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0]

        # Register hooks on the last layer of the Swin Transformer
        self.swin.norm.register_forward_hook(save_features_hook)
        self.swin.norm.register_backward_hook(save_gradients_hook)

    def forward(self, x):
        # Pass the ORIGINAL 3-channel input 'x' to the segmentation network AND the Swin Transformer
        seg_mask, deep_features = self.segmentation_net(x)
        swin_features = self.swin(x)

        classification = self.classifier(swin_features)
        return classification, seg_mask

    def get_features_and_gradients(self):
        return self.feature_maps, self.gradients


# ---------------------------
# Initialize Model
# ---------------------------
model = HybridAttentionSwin(num_classes=len(class_names)).to(DEVICE)
print(f"✅ Hybrid Segmentation + Classification model created with {len(class_names)} classes.")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/114M [00:00<?, ?B/s]

✅ Hybrid Segmentation + Classification model created with 4 classes.


In [None]:
# --- PART 3: Training (Classification + Segmentation Multi-task) ---
import json
from pathlib import Path
from tqdm.notebook import tqdm

# ---------------------------
# CONFIG
# ---------------------------
MASK_ROOT = None  # Change to your mask path if you have one.
SAVE_DIR = "/content/drive/MyDrive/checkpoints"
os.makedirs(SAVE_DIR, exist_ok=True)

NUM_EPOCHS = 12
BATCH_SIZE = 8
LR = 1e-4
WEIGHT_DECAY = 1e-5

# Transform (must match Part 1)
IMG_SIZE = (224, 224)
train_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])
val_transform = transforms.Compose([
    transforms.Resize(IMG_SIZE),
    transforms.ToTensor(),
    transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
])

# ---------------------------
# Custom Dataset supporting optional masks
# ---------------------------
class MultiTaskImageDataset(torch.utils.data.Dataset):
    def __init__(self, image_folder_dataset, mask_root=None, transform=None, mask_transform=None):
        super().__init__()
        self.base = image_folder_dataset
        self.mask_root = Path(mask_root) if mask_root else None
        self.transform = transform
        self.mask_transform = mask_transform if mask_transform else transforms.Compose([
            transforms.Resize(IMG_SIZE),
            transforms.ToTensor()
        ])

        if hasattr(self.base, "dataset") and hasattr(self.base, "indices"):
            self.samples = [self.base.dataset.samples[i] for i in self.base.indices]
        elif hasattr(self.base, "samples"):
            self.samples = self.base.samples
        else:
            raise RuntimeError("Unsupported dataset type for MultiTaskImageDataset")

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, class_idx = self.samples[idx]
        image = Image.open(img_path).convert("RGB")
        image_t = self.transform(image) if self.transform else transforms.ToTensor()(image)

        if self.mask_root:
            img_name = Path(img_path).name
            possible_paths = [
                self.mask_root / img_name,
                self.mask_root / Path(img_path).parent.name / img_name
            ]
            found = next((p for p in possible_paths if p.exists()), None)
            if found:
                mask = Image.open(found).convert("L")
                mask_t = self.mask_transform(mask)
                if mask_t.max() > 1:
                    mask_t = mask_t / 255.0
            else:
                mask_t = torch.zeros((1, IMG_SIZE[0], IMG_SIZE[1]), dtype=torch.float32)
        else:
            mask_t = torch.zeros((1, IMG_SIZE[0], IMG_SIZE[1]), dtype=torch.float32)

        return image_t, class_idx, mask_t, img_path

# ---------------------------
# Build train/val MultiTask datasets & loaders
# ---------------------------
mt_train_ds = MultiTaskImageDataset(train_dataset, mask_root=MASK_ROOT, transform=train_transform)
mt_val_ds = MultiTaskImageDataset(val_dataset, mask_root=MASK_ROOT, transform=val_transform)

train_loader = DataLoader(mt_train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
val_loader = DataLoader(mt_val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=2, pin_memory=True)

print(f"Train samples: {len(mt_train_ds)} | Val samples: {len(mt_val_ds)}")
print(f"Masks enabled: {MASK_ROOT is not None}")

# ---------------------------
# Losses, Optimizer, Scheduler
# ---------------------------
criterion_cls = nn.CrossEntropyLoss()
criterion_seg = nn.BCELoss()

optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', factor=0.5, patience=2)

# ---------------------------
# Training Loop
# ---------------------------
best_val_acc = 0.0
best_checkpoint_path = os.path.join(SAVE_DIR, "best_checkpoint.pth")
save_model_path = os.path.join(SAVE_DIR, "best_model_state.pth")

for epoch in range(NUM_EPOCHS):
    model.train()
    running_loss = running_cls_loss = running_seg_loss = 0.0
    total = correct = 0

    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Train")
    for imgs, labels, masks, _ in pbar:
        imgs, labels, masks = imgs.to(DEVICE), labels.to(DEVICE), masks.to(DEVICE)
        optimizer.zero_grad()

        outputs, pred_masks = model(imgs)
        loss_cls = criterion_cls(outputs, labels)

        if MASK_ROOT:
            if pred_masks.shape != masks.shape:
                pred_masks_resized = nn.functional.interpolate(pred_masks, size=masks.shape[-2:], mode='bilinear', align_corners=False)
            else:
                pred_masks_resized = pred_masks
            loss_seg = criterion_seg(pred_masks_resized, masks.float())
        else:
            loss_seg = torch.tensor(0.0, device=DEVICE)

        loss = loss_cls + 1.0 * loss_seg
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        running_cls_loss += loss_cls.item()
        running_seg_loss += loss_seg.item()
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

        pbar.set_postfix({
            "loss": f"{running_loss/(total/BATCH_SIZE):.4f}",
            "acc": f"{100*correct/total:.2f}%"
        })

    train_acc = 100 * correct / total
    print(f"Epoch {epoch+1} train_loss: {running_loss/len(train_loader):.4f} train_acc: {train_acc:.2f}%")

    # Validation
    model.eval()
    val_total = val_correct = 0
    val_cls_loss = val_seg_loss = 0.0

    with torch.no_grad():
        for imgs, labels, masks, _ in tqdm(val_loader, desc=f"Epoch {epoch+1}/{NUM_EPOCHS} Val"):
            imgs, labels, masks = imgs.to(DEVICE), labels.to(DEVICE), masks.to(DEVICE)
            outputs, pred_masks = model(imgs)
            loss_cls = criterion_cls(outputs, labels)

            if MASK_ROOT:
                if pred_masks.shape != masks.shape:
                    pred_masks_resized = nn.functional.interpolate(pred_masks, size=masks.shape[-2:], mode='bilinear', align_corners=False)
                else:
                    pred_masks_resized = pred_masks
                loss_seg = criterion_seg(pred_masks_resized, masks.float())
            else:
                loss_seg = torch.tensor(0.0, device=DEVICE)

            val_cls_loss += loss_cls.item()
            val_seg_loss += loss_seg.item()
            _, preds = torch.max(outputs, 1)
            val_correct += (preds == labels).sum().item()
            val_total += labels.size(0)

    val_acc = 100 * val_correct / val_total
    print(f"Epoch {epoch+1} VAL cls_loss: {val_cls_loss/len(val_loader):.4f} seg_loss: {val_seg_loss/len(val_loader):.4f} val_acc: {val_acc:.2f}%")

    scheduler.step(val_acc)
    current_lr = optimizer.param_groups[0]['lr']
    print(f"📉 Current LR: {current_lr:.6f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        checkpoint = {
            "epoch": epoch+1,
            "model_state": model.state_dict(),
            "optimizer_state": optimizer.state_dict(),
            "val_acc": val_acc,
            "class_names": class_names
        }
        torch.save(checkpoint, best_checkpoint_path)
        torch.save(model.state_dict(), save_model_path)
        print(f"💾 New best model saved with val_acc: {val_acc:.2f}%")

with open(os.path.join(SAVE_DIR, "class_names.json"), "w") as f:
    json.dump(class_names, f)

print(f"✅ Training finished. Best val acc: {best_val_acc:.2f}%")
print(f"Checkpoints saved to: {SAVE_DIR}")

Train samples: 4569 | Val samples: 1143
Masks enabled: False


Epoch 1/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

  self._maybe_warn_non_full_backward_hook(args, result, grad_fn)


Epoch 1 train_loss: 0.3074 train_acc: 89.08%


Epoch 1/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 1 VAL cls_loss: 0.1576 seg_loss: 0.0000 val_acc: 95.36%
📉 Current LR: 0.000100
💾 New best model saved with val_acc: 95.36%


Epoch 2/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 2 train_loss: 0.1033 train_acc: 96.50%


Epoch 2/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 2 VAL cls_loss: 0.1087 seg_loss: 0.0000 val_acc: 96.59%
📉 Current LR: 0.000100
💾 New best model saved with val_acc: 96.59%


Epoch 3/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 3 train_loss: 0.0754 train_acc: 97.57%


Epoch 3/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 3 VAL cls_loss: 0.1630 seg_loss: 0.0000 val_acc: 95.10%
📉 Current LR: 0.000100


Epoch 4/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 4 train_loss: 0.0674 train_acc: 97.96%


Epoch 4/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 4 VAL cls_loss: 0.1216 seg_loss: 0.0000 val_acc: 96.59%
📉 Current LR: 0.000100


Epoch 5/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 5 train_loss: 0.0434 train_acc: 98.77%


Epoch 5/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 5 VAL cls_loss: 0.1004 seg_loss: 0.0000 val_acc: 97.64%
📉 Current LR: 0.000100
💾 New best model saved with val_acc: 97.64%


Epoch 6/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 6 train_loss: 0.0478 train_acc: 98.45%


Epoch 6/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 6 VAL cls_loss: 0.0484 seg_loss: 0.0000 val_acc: 98.78%
📉 Current LR: 0.000100
💾 New best model saved with val_acc: 98.78%


Epoch 7/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 7 train_loss: 0.0545 train_acc: 98.42%


Epoch 7/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 7 VAL cls_loss: 0.1608 seg_loss: 0.0000 val_acc: 95.54%
📉 Current LR: 0.000100


Epoch 8/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 8 train_loss: 0.0548 train_acc: 98.42%


Epoch 8/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 8 VAL cls_loss: 0.1217 seg_loss: 0.0000 val_acc: 95.28%
📉 Current LR: 0.000100


Epoch 9/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 9 train_loss: 0.0361 train_acc: 99.02%


Epoch 9/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 9 VAL cls_loss: 0.1008 seg_loss: 0.0000 val_acc: 97.55%
📉 Current LR: 0.000050


Epoch 10/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 10 train_loss: 0.0131 train_acc: 99.69%


Epoch 10/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 10 VAL cls_loss: 0.0602 seg_loss: 0.0000 val_acc: 98.69%
📉 Current LR: 0.000050


Epoch 11/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 11 train_loss: 0.0082 train_acc: 99.67%


Epoch 11/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 11 VAL cls_loss: 0.0691 seg_loss: 0.0000 val_acc: 98.43%
📉 Current LR: 0.000050


Epoch 12/12 Train:   0%|          | 0/572 [00:00<?, ?it/s]

Epoch 12 train_loss: 0.0121 train_acc: 99.65%


Epoch 12/12 Val:   0%|          | 0/143 [00:00<?, ?it/s]

Epoch 12 VAL cls_loss: 0.0836 seg_loss: 0.0000 val_acc: 98.43%
📉 Current LR: 0.000025
✅ Training finished. Best val acc: 98.78%
Checkpoints saved to: /content/drive/MyDrive/checkpoints
