In [23]:
# ================== SETUP ==================
import os, random, numpy as np, torch, timm
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Subset
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import torch.nn as nn
from tqdm import tqdm
import matplotlib.pyplot as plt
import pandas as pd

In [12]:
# Reproducibility
seed = 42
random.seed(seed); np.random.seed(seed); torch.manual_seed(seed)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cpu


In [13]:
# ================== PATHS & CONSTANTS ==================
data_dir = "Brain_cancer"   # your folder path
IMG_SIZE = 224
BATCH_SIZE = 16
EPOCHS = 10
LR = 3e-4

In [14]:
# ================== TRANSFORMS ==================
train_transform = transforms.Compose([
    transforms.RandomResizedCrop(IMG_SIZE),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])
val_test_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE,IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485,0.456,0.406],
                         std=[0.229,0.224,0.225])
])


In [15]:
# ================== DATASET & SPLITS ==================
full_dataset = datasets.ImageFolder(root=data_dir)
print("Classes:", full_dataset.classes)

targets = np.array(full_dataset.targets)
indices = np.arange(len(full_dataset))

# 70% train, 15% val, 15% test
train_idx, temp_idx, y_train, y_temp = train_test_split(
    indices, targets, test_size=0.3, stratify=targets, random_state=seed
)
val_idx, test_idx, y_val, y_test = train_test_split(
    temp_idx, y_temp, test_size=0.5, stratify=y_temp, random_state=seed
)

full_dataset.transform = train_transform
train_ds = Subset(full_dataset, train_idx)
full_dataset.transform = val_test_transform
val_ds = Subset(full_dataset, val_idx)
test_ds = Subset(full_dataset, test_idx)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")

Classes: ['brain_glioma', 'brain_menin', 'brain_tumor']
Train: 4239 | Val: 908 | Test: 909


In [16]:
# ================== MODEL (ViT) ==================
num_classes = len(full_dataset.classes)
model_name = "vit_base_patch16_224"  # can change to vit_tiny_patch16_224

model = timm.create_model(model_name, pretrained=True, num_classes=num_classes)

# Freeze all layers
for p in model.parameters():
    p.requires_grad = False

# Replace classifier head
if hasattr(model, "reset_classifier"):
    model.reset_classifier(num_classes=num_classes)
else:
    model.head = nn.Linear(model.head.in_features, num_classes)

model = model.to(device)

# ================== TRAINING SETUP ==================
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='max', patience=3, factor=0.5)

In [17]:
# ================== TRAINING LOOP ==================
best_val_acc = 0.0
save_path = "best_vit_brain_tumor.pth"

for epoch in range(EPOCHS):
    model.train()
    total_loss, preds_all, labels_all = 0.0, [], []
    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Train]"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * imgs.size(0)
        preds_all.extend(outputs.argmax(1).cpu().numpy())
        labels_all.extend(labels.cpu().numpy())
    train_acc = accuracy_score(labels_all, preds_all)
    train_loss = total_loss / len(train_ds)

    # Validation
    model.eval()
    val_preds, val_labels, val_loss = [], [], 0.0
    with torch.no_grad():
        for imgs, labels in tqdm(val_loader, desc=f"Epoch {epoch+1}/{EPOCHS} [Val]"):
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            loss = criterion(outputs, labels)
            val_loss += loss.item() * imgs.size(0)
            val_preds.extend(outputs.argmax(1).cpu().numpy())
            val_labels.extend(labels.cpu().numpy())
    val_acc = accuracy_score(val_labels, val_preds)
    val_loss /= len(val_ds)
    scheduler.step(val_acc)

    print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} Acc={train_acc:.4f} | Val Loss={val_loss:.4f} Acc={val_acc:.4f}")

    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), save_path)
        print("✅ Saved best model")

Epoch 1/10 [Train]: 100%|█████████████████████| 265/265 [03:04<00:00,  1.44it/s]
Epoch 1/10 [Val]: 100%|█████████████████████████| 57/57 [00:57<00:00,  1.01s/it]


Epoch 1: Train Loss=0.5197 Acc=0.7941 | Val Loss=0.3580 Acc=0.8711
✅ Saved best model


Epoch 2/10 [Train]: 100%|█████████████████████| 265/265 [03:03<00:00,  1.45it/s]
Epoch 2/10 [Val]: 100%|█████████████████████████| 57/57 [00:57<00:00,  1.01s/it]


Epoch 2: Train Loss=0.2841 Acc=0.9023 | Val Loss=0.2646 Acc=0.9130
✅ Saved best model


Epoch 3/10 [Train]: 100%|█████████████████████| 265/265 [03:09<00:00,  1.40it/s]
Epoch 3/10 [Val]: 100%|█████████████████████████| 57/57 [00:57<00:00,  1.02s/it]


Epoch 3: Train Loss=0.2292 Acc=0.9210 | Val Loss=0.2259 Acc=0.9196
✅ Saved best model


Epoch 4/10 [Train]: 100%|█████████████████████| 265/265 [03:02<00:00,  1.45it/s]
Epoch 4/10 [Val]: 100%|█████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


Epoch 4: Train Loss=0.1987 Acc=0.9325 | Val Loss=0.2067 Acc=0.9240
✅ Saved best model


Epoch 5/10 [Train]: 100%|█████████████████████| 265/265 [03:05<00:00,  1.42it/s]
Epoch 5/10 [Val]: 100%|█████████████████████████| 57/57 [00:58<00:00,  1.03s/it]


Epoch 5: Train Loss=0.1733 Acc=0.9446 | Val Loss=0.1831 Acc=0.9350
✅ Saved best model


Epoch 6/10 [Train]: 100%|█████████████████████| 265/265 [03:03<00:00,  1.44it/s]
Epoch 6/10 [Val]: 100%|█████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


Epoch 6: Train Loss=0.1591 Acc=0.9469 | Val Loss=0.1713 Acc=0.9328


Epoch 7/10 [Train]: 100%|█████████████████████| 265/265 [03:05<00:00,  1.43it/s]
Epoch 7/10 [Val]: 100%|█████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


Epoch 7: Train Loss=0.1441 Acc=0.9540 | Val Loss=0.1656 Acc=0.9493
✅ Saved best model


Epoch 8/10 [Train]: 100%|█████████████████████| 265/265 [03:04<00:00,  1.44it/s]
Epoch 8/10 [Val]: 100%|█████████████████████████| 57/57 [00:57<00:00,  1.01s/it]


Epoch 8: Train Loss=0.1320 Acc=0.9597 | Val Loss=0.1624 Acc=0.9493


Epoch 9/10 [Train]: 100%|█████████████████████| 265/265 [03:08<00:00,  1.41it/s]
Epoch 9/10 [Val]: 100%|█████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


Epoch 9: Train Loss=0.1250 Acc=0.9601 | Val Loss=0.1456 Acc=0.9438


Epoch 10/10 [Train]: 100%|████████████████████| 265/265 [03:06<00:00,  1.42it/s]
Epoch 10/10 [Val]: 100%|████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


Epoch 10: Train Loss=0.1152 Acc=0.9641 | Val Loss=0.1393 Acc=0.9548
✅ Saved best model


In [18]:
# ================== EVALUATION ON TEST SET ==================
model.load_state_dict(torch.load(save_path, map_location=device))
model.eval()

y_true, y_pred = [], []
with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Testing"):
        imgs = imgs.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(1).cpu().numpy()
        y_pred.extend(preds)
        y_true.extend(labels.numpy())

acc = accuracy_score(y_true, y_pred)
prec, rec, f1, _ = precision_recall_fscore_support(y_true, y_pred, average='macro')
cm = confusion_matrix(y_true, y_pred)

print("\n=== FINAL TEST RESULTS ===")
print(f"Accuracy:  {acc:.4f}")
print(f"Precision: {prec:.4f}")
print(f"Recall:    {rec:.4f}")
print(f"F1-score:  {f1:.4f}")
print("Confusion Matrix:\n", cm)

Testing: 100%|██████████████████████████████████| 57/57 [00:58<00:00,  1.02s/it]


=== FINAL TEST RESULTS ===
Accuracy:  0.9472
Precision: 0.9470
Recall:    0.9470
F1-score:  0.9469
Confusion Matrix:
 [[291  10   0]
 [ 11 272  17]
 [  1   9 298]]



