In [1]:
# 📌 Step 1: Install dependencies
!pip install torch torchvision scikit-learn tqdm




In [5]:
# ==============================================================================
# CELL 2: DOWNLOAD AND PREPARE DATASET
# ==============================================================================
# --- Download from Kaggle ---
import kagglehub
import shutil # Import shutil module
print("Downloading dataset from Kaggle Hub...")
path = kagglehub.dataset_download("mrnotalent/braint")
print(f"Path to dataset files: {path}")

# --- Copy to a writable directory ---
# This is necessary as the Kaggle input directory is read-only
src = path
dst = '/content/braint_original'
shutil.copytree(src, dst, dirs_exist_ok=True)
print(f"Dataset copied to {dst}")

Downloading dataset from Kaggle Hub...
Path to dataset files: /kaggle/input/braint
Dataset copied to /content/braint_original


In [6]:
# 📌 Step 2: Import libraries
import os, json, random
from pathlib import Path
from typing import Tuple

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import random_split, DataLoader, Subset
from torchvision import datasets, transforms, models
from sklearn.metrics import classification_report, confusion_matrix
from tqdm import tqdm


In [7]:
# 📌 Step 3: Reproducibility
def set_seed(seed: int = 42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(42)


In [8]:
# 📌 Step 4: Data transforms
def build_transforms(image_size: int = 224, grayscale_to_rgb: bool = True):
    tfms_train, tfms_val = [], []

    if grayscale_to_rgb:
        tfms_train.append(transforms.Grayscale(num_output_channels=3))
        tfms_val.append(transforms.Grayscale(num_output_channels=3))

    tfms_train += [
        transforms.Resize((image_size, image_size)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomVerticalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    tfms_val += [
        transforms.Resize((image_size, image_size)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406],
                             std=[0.229, 0.224, 0.225])
    ]
    return transforms.Compose(tfms_train), transforms.Compose(tfms_val)


In [9]:
# 📌 Step 5: Build datasets (automatic train/val split if no folders)
def build_datasets(root_dir: str, image_size=224, grayscale_to_rgb=True, val_split=0.15):
    root = Path(root_dir)
    train_dir, val_dir = root / "train", root / "val"
    has_splits = train_dir.exists() and val_dir.exists()
    tfm_train, tfm_val = build_transforms(image_size, grayscale_to_rgb)

    if has_splits:
        train_ds = datasets.ImageFolder(train_dir, transform=tfm_train)
        val_ds = datasets.ImageFolder(val_dir, transform=tfm_val)
    else:
        full_ds = datasets.ImageFolder(root, transform=tfm_train)
        n_val = int(val_split * len(full_ds))
        n_train = len(full_ds) - n_val
        train_idx, val_idx = random_split(range(len(full_ds)), [n_train, n_val], generator=torch.Generator().manual_seed(42))
        train_ds = Subset(full_ds, train_idx)
        val_base = datasets.ImageFolder(root, transform=tfm_val)
        val_ds = Subset(val_base, val_idx)
        train_ds.dataset.class_to_idx = full_ds.class_to_idx
        val_ds.dataset.class_to_idx = full_ds.class_to_idx
    return train_ds, val_ds


In [10]:
# 📌 Step 6: Build ResNet50 model
def build_model(num_classes: int, pretrained=True):
    weights = models.ResNet50_Weights.IMAGENET1K_V1 if pretrained else None
    model = models.resnet50(weights=weights)
    in_features = model.fc.in_features
    model.fc = nn.Linear(in_features, num_classes)
    return model


In [11]:
# 📌 Step 7: Training & validation loop
def step_epoch(model, loader, criterion, optimizer=None, device="cpu"):
    is_train = optimizer is not None
    model.train(is_train)

    running_loss, correct, total = 0.0, 0, 0
    for imgs, labels in tqdm(loader, disable=False):
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.set_grad_enabled(is_train):
            logits = model(imgs)
            loss = criterion(logits, labels)
        if is_train:
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * imgs.size(0)
        correct += (logits.argmax(1) == labels).sum().item()
        total += imgs.size(0)
    return running_loss / total, correct / total


In [13]:
# 📌 Step 8: Load dataset
data_dir = "/kaggle/input/braint"
train_ds, val_ds = build_datasets(data_dir, image_size=224, grayscale_to_rgb=True)

class_names = train_ds.dataset.classes if hasattr(train_ds, "dataset") else train_ds.classes
print(f"Detected {len(class_names)} classes.")

Detected 44 classes.


In [14]:
# 📌 Step 9: DataLoaders
train_loader = DataLoader(train_ds, batch_size=16, shuffle=True, num_workers=2)
val_loader = DataLoader(val_ds, batch_size=16, shuffle=False, num_workers=2)


In [15]:
# 📌 Step 10: Initialize model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = build_model(num_classes=len(class_names), pretrained=True).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode="max", factor=0.5, patience=2)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /root/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth


100%|██████████| 97.8M/97.8M [00:00<00:00, 154MB/s]


In [16]:
# 📌 Step 11: Train the model
best_acc, best_path = 0.0, "resnet50_44cls_best.pth"
epochs = 10

for epoch in range(1, epochs + 1):
    print(f"\nEpoch {epoch}/{epochs}")
    tr_loss, tr_acc = step_epoch(model, train_loader, criterion, optimizer, device)
    va_loss, va_acc = step_epoch(model, val_loader, criterion, None, device)
    print(f"Train: loss={tr_loss:.4f}, acc={tr_acc:.4f} | Val: loss={va_loss:.4f}, acc={va_acc:.4f}")
    scheduler.step(va_acc)
    if va_acc > best_acc:
        best_acc = va_acc
        torch.save({
            "arch": "resnet50",
            "num_classes": len(class_names),
            "classes": class_names,
            "state_dict": model.state_dict()
        }, best_path)
        print(f"✅ Saved best model to {best_path} (val_acc={best_acc:.4f})")



Epoch 1/10


100%|██████████| 863/863 [02:37<00:00,  5.46it/s]
100%|██████████| 153/153 [00:19<00:00,  7.74it/s]


Train: loss=3.3745, acc=0.0972 | Val: loss=2.9406, acc=0.1602
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.1602)

Epoch 2/10


100%|██████████| 863/863 [02:37<00:00,  5.48it/s]
100%|██████████| 153/153 [00:19<00:00,  7.75it/s]


Train: loss=2.6053, acc=0.2464 | Val: loss=2.4442, acc=0.3055
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.3055)

Epoch 3/10


100%|██████████| 863/863 [02:37<00:00,  5.47it/s]
100%|██████████| 153/153 [00:19<00:00,  7.91it/s]


Train: loss=2.0276, acc=0.3909 | Val: loss=1.8677, acc=0.4357
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.4357)

Epoch 4/10


100%|██████████| 863/863 [02:37<00:00,  5.48it/s]
100%|██████████| 153/153 [00:19<00:00,  7.98it/s]


Train: loss=1.5936, acc=0.5176 | Val: loss=1.4309, acc=0.5491
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.5491)

Epoch 5/10


100%|██████████| 863/863 [02:37<00:00,  5.48it/s]
100%|██████████| 153/153 [00:18<00:00,  8.07it/s]


Train: loss=1.3019, acc=0.5980 | Val: loss=1.2534, acc=0.6062
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.6062)

Epoch 6/10


100%|██████████| 863/863 [02:37<00:00,  5.49it/s]
100%|██████████| 153/153 [00:18<00:00,  8.06it/s]


Train: loss=1.0714, acc=0.6650 | Val: loss=0.9839, acc=0.6953
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.6953)

Epoch 7/10


100%|██████████| 863/863 [02:37<00:00,  5.48it/s]
100%|██████████| 153/153 [00:18<00:00,  8.17it/s]


Train: loss=0.8856, acc=0.7241 | Val: loss=0.7877, acc=0.7515
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.7515)

Epoch 8/10


100%|██████████| 863/863 [02:37<00:00,  5.49it/s]
100%|██████████| 153/153 [00:18<00:00,  8.07it/s]


Train: loss=0.7461, acc=0.7638 | Val: loss=0.7411, acc=0.7717
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.7717)

Epoch 9/10


100%|██████████| 863/863 [02:36<00:00,  5.50it/s]
100%|██████████| 153/153 [00:18<00:00,  8.23it/s]


Train: loss=0.6411, acc=0.7947 | Val: loss=0.6132, acc=0.8045
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.8045)

Epoch 10/10


100%|██████████| 863/863 [02:36<00:00,  5.51it/s]
100%|██████████| 153/153 [00:18<00:00,  8.13it/s]


Train: loss=0.5413, acc=0.8260 | Val: loss=0.6165, acc=0.8119
✅ Saved best model to resnet50_44cls_best.pth (val_acc=0.8119)


In [17]:
# 📌 Step 12: Save classes.txt for Streamlit app
with open("classes.txt", "w") as f:
    f.write("\n".join(class_names))
print("Saved classes.txt with class labels.")


Saved classes.txt with class labels.
