In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from torchvision.transforms import InterpolationMode
from timm.models.vision_transformer import vit_base_patch16_224

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)



Using device: cuda


In [None]:
train_data_path = "/kaggle/input/brain-tumor-mri-dataset/Training"
val_data_path   = "/kaggle/input/brain-tumor-mri-dataset/Testing"
# -----------------------------------

# Data transforms
train_transform = transforms.Compose([
    transforms.Resize((224,224), interpolation=InterpolationMode.BILINEAR),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
])

val_transform = transforms.Compose([
    transforms.Resize((224,224), interpolation=InterpolationMode.BILINEAR),
    transforms.ToTensor(),
])

train_dataset = datasets.ImageFolder(train_data_path, transform=train_transform)
val_dataset   = datasets.ImageFolder(val_data_path, transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
val_loader   = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=2)

num_classes = len(train_dataset.classes)
print("No of classes:", num_classes)
print("No of train images:", len(train_dataset))
print("No of test images:", len(val_dataset))


No of classes: 4
No of train images: 5712
No of test images: 1311


In [None]:
class HighPerfNestedModel(nn.Module):
    def __init__(self, num_classes):
        super(HighPerfNestedModel, self).__init__()
        # --- Slow module: CNN backbone ---
        self.cnn = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(128, 256, 3, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(2)
        )
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1,1))
        self.cnn_fc = nn.Linear(256, 256)  # flattened by adaptive pooling

        # --- Fast module: ViT backbone ---
        self.vit = vit_base_patch16_224(pretrained=True)
        self.vit.head = nn.Linear(self.vit.head.in_features, 256)

        # --- Gated Fusion ---
        self.gate = nn.Sequential(
            nn.Linear(256*2, 256),
            nn.Sigmoid()
        )

        # --- Final classifier ---
        self.classifier = nn.Linear(256, num_classes)

    def forward(self, x):
        # CNN features
        cnn_feat = self.cnn(x)
        cnn_feat = self.adaptive_pool(cnn_feat)
        cnn_feat = cnn_feat.view(cnn_feat.size(0), -1)
        cnn_feat = self.cnn_fc(cnn_feat)

        # ViT features
        vit_feat = self.vit(x)

        # Gated fusion
        fusion_input = torch.cat([cnn_feat, vit_feat], dim=1)
        gate = self.gate(fusion_input)
        fused = gate * cnn_feat + (1 - gate) * vit_feat

        # Classifier
        out = self.classifier(fused)
        return out

model = HighPerfNestedModel(num_classes=num_classes).to(device)

model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

In [None]:
learning_rate = 1e-4
slow_update_freq = 5

# --- Nested Learning Optimizers ---
optimizer_fast = optim.Adam(list(model.vit.parameters()) + list(model.classifier.parameters()), lr=learning_rate)
optimizer_slow = optim.Adam(list(model.cnn.parameters()) + list(model.cnn_fc.parameters()) + list(model.gate.parameters()), lr=learning_rate*0.1)

# Cosine LR schedulers
scheduler_fast = optim.lr_scheduler.CosineAnnealingLR(optimizer_fast, T_max=30)
scheduler_slow = optim.lr_scheduler.CosineAnnealingLR(optimizer_slow, T_max=30)

criterion = nn.CrossEntropyLoss(label_smoothing=0.1)

In [None]:
num_epochs = 30
best_val_acc = 0.0

for epoch in range(num_epochs):
    model.train()
    running_loss = 0.0
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        optimizer_fast.zero_grad()
        if batch_idx % slow_update_freq == 0:
            optimizer_slow.zero_grad()

        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()

        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)

        # Nested learning step
        optimizer_fast.step()
        if batch_idx % slow_update_freq == 0:
            optimizer_slow.step()

        running_loss += loss.item()

    scheduler_fast.step()
    scheduler_slow.step()

    # Validation
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in val_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    val_acc = 100 * correct / total

    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(train_loader):.4f}, "
          f"Validation Accuracy: {val_acc:.2f}%")

    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), "best_highperf_nested_model.pth")

print(f"Training Finished. Best Validation Accuracy: {best_val_acc:.2f}%")


Epoch [1/30], Loss: 0.6553, Validation Accuracy: 88.63%
Epoch [2/30], Loss: 0.5072, Validation Accuracy: 93.14%
Epoch [3/30], Loss: 0.4664, Validation Accuracy: 94.05%
Epoch [4/30], Loss: 0.4569, Validation Accuracy: 95.27%
Epoch [5/30], Loss: 0.4361, Validation Accuracy: 93.97%
Epoch [6/30], Loss: 0.4406, Validation Accuracy: 93.29%
Epoch [7/30], Loss: 0.4245, Validation Accuracy: 94.97%
Epoch [8/30], Loss: 0.4141, Validation Accuracy: 93.14%
Epoch [9/30], Loss: 0.4145, Validation Accuracy: 97.48%
Epoch [10/30], Loss: 0.3965, Validation Accuracy: 97.25%
Epoch [11/30], Loss: 0.3959, Validation Accuracy: 96.11%
Epoch [12/30], Loss: 0.3934, Validation Accuracy: 96.03%
Epoch [13/30], Loss: 0.3833, Validation Accuracy: 97.10%
Epoch [14/30], Loss: 0.3841, Validation Accuracy: 96.49%
Epoch [15/30], Loss: 0.3723, Validation Accuracy: 98.70%
Epoch [16/30], Loss: 0.3707, Validation Accuracy: 99.01%
Epoch [17/30], Loss: 0.3683, Validation Accuracy: 99.01%
Epoch [18/30], Loss: 0.3630, Validation 