In [None]:
#Importing the all Libraries
import os
import numpy as np
import torch
import torch.nn as nn
from tqdm import tqdm
import torch.optim as optim
from torch.utils.data import DataLoader, Subset, ConcatDataset
from torchvision import datasets, transforms
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             confusion_matrix, classification_report, matthews_corrcoef)
import matplotlib.pyplot as plt

In [None]:
# Speed up the process, CUDA  = GPU is AVAILABLE

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

Using device: cuda


In [None]:
# Training Parameters being set ( transfer learning=frozen backbone)
lr = 0.000196277
batch_size = 64
dropout = 0.11857
unfreeze_layers = 4
k_folds = 5
patience = 5

In [None]:
# Dataset preparation ( transformation)

data_root = "brats-path-2025"
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor()
])

# Training and validation data imported
train_dataset = datasets.ImageFolder(os.path.join(data_root, 'train'), transform=transform)
val_dataset = datasets.ImageFolder(os.path.join(data_root, 'val'), transform=transform)
test_dataset = datasets.ImageFolder(os.path.join(data_root, 'test'), transform=transform)

# Combining it for K-Fold validation
full_dataset = ConcatDataset([train_dataset, val_dataset, test_dataset])
targets = np.concatenate([
    np.array([sample[1] for sample in tqdm(train_dataset)]),
    np.array([sample[1] for sample in tqdm(val_dataset)]),
    np.array([sample[1] for sample in tqdm(test_dataset)])
])
class_names = train_dataset.classes

100%|██████████| 262238/262238 [04:02<00:00, 1083.36it/s]
100%|██████████| 56189/56189 [00:52<00:00, 1076.93it/s]
100%|██████████| 56204/56204 [00:51<00:00, 1090.77it/s]


In [None]:
# Early stopping

class EarlyStopping:
    def __init__(self, patience=5):
        self.patience = patience
        self.counter = 0
        self.best_loss = float('inf')
        self.early_stop = False

    def __call__(self, val_loss):
        if val_loss < self.best_loss:
            self.best_loss = val_loss
            self.counter = 0
        else:
            self.counter += 1
            if self.counter >= self.patience:
                self.early_stop = True

In [None]:
# MobileNetV2 Set Up

def build_model(num_classes, dropout, unfreeze_layers):
    model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
    for param in model.parameters():
        param.requires_grad = False
    child_counter = 0
    for child in reversed(list(model.features.children())):
        if child_counter >= unfreeze_layers:
            break
        for param in child.parameters():
            param.requires_grad = True
        child_counter += 1
    model.classifier = nn.Sequential(
        nn.Dropout(p=dropout),
        nn.Linear(model.last_channel, num_classes)
    )
    return model.to(device)


In [None]:
# Training & Validation graphs

def plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies, fold):
    plt.figure(figsize=(12, 5))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label="Train Loss")
    plt.plot(val_losses, label="Val Loss")
    plt.title(f"Fold {fold+1} - Loss")
    plt.legend()

    plt.subplot(1, 2, 2)
    plt.plot(train_accuracies, label="Train Acc")
    plt.plot(val_accuracies, label="Val Acc")
    plt.title(f"Fold {fold+1} - Accuracy")
    plt.legend()

    plt.tight_layout()
    plt.savefig(f"fold_{fold+1}_metrics.png")
    plt.close()

In [None]:
# K-Fold Cross-Validation

skf = StratifiedKFold(n_splits=k_folds, shuffle=True, random_state=42)
metrics_all = []

for fold, (train_idx, val_idx) in enumerate(skf.split(np.zeros(len(targets)), targets)):
    print(f"\n--- Fold {fold+1}/{k_folds} ---")

    train_subset = Subset(full_dataset, train_idx)
    val_subset = Subset(full_dataset, val_idx)

    train_loader = DataLoader(train_subset, batch_size=batch_size, shuffle=True)
    val_loader = DataLoader(val_subset, batch_size=batch_size)

    model = build_model(len(class_names), dropout, unfreeze_layers)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=lr)
    scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
    early_stopper = EarlyStopping(patience=patience)

    train_losses, val_losses = [], []
    train_accuracies, val_accuracies = [], []

    for epoch in range(4):
        model.train()
        total, correct, loss_total = 0, 0, 0
        for x, y in tqdm(train_loader):
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            outputs = model(x)
            loss = criterion(outputs, y)
            loss.backward()
            optimizer.step()
            loss_total += loss.item()
            _, preds = outputs.max(1)
            correct += preds.eq(y).sum().item()
            total += y.size(0)
        train_losses.append(loss_total / len(train_loader))
        train_accuracies.append(correct / total)

        # Validation
        model.eval()
        total, correct, val_loss_total = 0, 0, 0
        y_true, y_pred = [], []
        with torch.no_grad():
            for x, y in tqdm(val_loader):
                x, y = x.to(device), y.to(device)
                outputs = model(x)
                loss = criterion(outputs, y)
                val_loss_total += loss.item()
                _, preds = outputs.max(1)
                y_true.extend(y.cpu().numpy())
                y_pred.extend(preds.cpu().numpy())
                correct += preds.eq(y).sum().item()
                total += y.size(0)
        val_losses.append(val_loss_total / len(val_loader))
        val_accuracies.append(correct / total)

        scheduler.step()
        early_stopper(val_losses[-1])
        if early_stopper.early_stop:
            print("Early stopping triggered.")
            break

    # Evaluation Metrics
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average='macro', zero_division=0)
    rec = recall_score(y_true, y_pred, average='macro', zero_division=0)
    f1 = f1_score(y_true, y_pred, average='macro', zero_division=0)
    mcc = matthews_corrcoef(y_true, y_pred)
    cm = confusion_matrix(y_true, y_pred)
    specificity = []
    for i in range(len(class_names)):
        tn = cm.sum() - (cm[i, :].sum() + cm[:, i].sum() - cm[i, i])
        fp = cm[:, i].sum() - cm[i, i]
        specificity.append(tn / (tn + fp + 1e-6))
    spec = np.mean(specificity)

    print(f"Accuracy: {acc:.4f}, Precision: {prec:.4f}, Recall: {rec:.4f}, F1: {f1:.4f}, MCC: {mcc:.4f}, Specificity: {spec:.4f}")
    print("Classification Report:\n", classification_report(y_true, y_pred, target_names=class_names))
    print("Confusion Matrix:\n", cm)

    metrics_all.append([acc, prec, rec, f1, mcc, spec])
    plot_metrics(train_losses, val_losses, train_accuracies, val_accuracies, fold)

Downloading: "https://download.pytorch.org/models/mobilenet_v2-7ebf99e0.pth" to /root/.cache/torch/hub/checkpoints/mobilenet_v2-7ebf99e0.pth



--- Fold 1/5 ---


100%|██████████| 13.6M/13.6M [00:00<00:00, 156MB/s]
100%|██████████| 9366/9366 [10:20<00:00, 15.10it/s]
100%|██████████| 2342/2342 [02:07<00:00, 18.41it/s]
100%|██████████| 9366/9366 [10:13<00:00, 15.27it/s]
100%|██████████| 2342/2342 [02:11<00:00, 17.81it/s]
100%|██████████| 9366/9366 [09:59<00:00, 15.64it/s]
100%|██████████| 2342/2342 [02:13<00:00, 17.52it/s]
100%|██████████| 9366/9366 [09:49<00:00, 15.90it/s]
100%|██████████| 2342/2342 [02:03<00:00, 18.98it/s]


Accuracy: 0.9849, Precision: 0.9602, Recall: 0.9619, F1: 0.9606, MCC: 0.9795, Specificity: 0.9980
Classification Report:
               precision    recall  f1-score   support

          CT       0.99      0.98      0.98     25037
          DM       0.92      0.97      0.94       436
          IC       0.99      1.00      0.99     12677
          LI       1.00      0.90      0.95       642
          MP       0.94      0.91      0.92      2135
          NC       1.00      0.99      1.00     25933
          PL       0.92      0.94      0.93       200
          PN       0.90      0.98      0.94      3882
          WM       0.99      1.00      0.99      3985

    accuracy                           0.98     74927
   macro avg       0.96      0.96      0.96     74927
weighted avg       0.99      0.98      0.98     74927

Confusion Matrix:
 [[24575     2    59     0   121     4     3   239    34]
 [    2   421     0     0     0     6     2     5     0]
 [   37     0 12618     0     1     2   

100%|██████████| 9366/9366 [09:30<00:00, 16.42it/s]
100%|██████████| 2342/2342 [02:13<00:00, 17.60it/s]
100%|██████████| 9366/9366 [09:34<00:00, 16.31it/s]
100%|██████████| 2342/2342 [02:07<00:00, 18.43it/s]
100%|██████████| 9366/9366 [09:42<00:00, 16.07it/s]
100%|██████████| 2342/2342 [02:09<00:00, 18.04it/s]
100%|██████████| 9366/9366 [09:30<00:00, 16.43it/s]
100%|██████████| 2342/2342 [02:01<00:00, 19.30it/s]


Accuracy: 0.9886, Precision: 0.9624, Recall: 0.9673, F1: 0.9644, MCC: 0.9845, Specificity: 0.9984
Classification Report:
               precision    recall  f1-score   support

          CT       0.99      0.99      0.99     25037
          DM       0.87      0.98      0.92       436
          IC       0.99      0.99      0.99     12677
          LI       0.99      0.99      0.99       642
          MP       0.96      0.92      0.94      2135
          NC       1.00      1.00      1.00     25933
          PL       0.93      0.88      0.91       199
          PN       0.95      0.97      0.96      3882
          WM       0.98      1.00      0.99      3985

    accuracy                           0.99     74926
   macro avg       0.96      0.97      0.96     74926
weighted avg       0.99      0.99      0.99     74926

Confusion Matrix:
 [[24696    13    60     5    63    15     4   137    44]
 [    1   426     0     0     0     7     0     2     0]
 [   31     0 12605     3     7     2   

100%|██████████| 9366/9366 [09:34<00:00, 16.30it/s]
100%|██████████| 2342/2342 [02:01<00:00, 19.33it/s]
100%|██████████| 9366/9366 [09:25<00:00, 16.56it/s]
100%|██████████| 2342/2342 [02:03<00:00, 18.92it/s]
100%|██████████| 9366/9366 [09:42<00:00, 16.07it/s]
100%|██████████| 2342/2342 [02:03<00:00, 18.97it/s]
100%|██████████| 9366/9366 [09:27<00:00, 16.52it/s]
100%|██████████| 2342/2342 [02:03<00:00, 18.95it/s]


Accuracy: 0.9861, Precision: 0.9604, Recall: 0.9729, F1: 0.9663, MCC: 0.9811, Specificity: 0.9981
Classification Report:
               precision    recall  f1-score   support

          CT       0.99      0.98      0.98     25037
          DM       0.97      0.96      0.96       436
          IC       0.99      0.99      0.99     12677
          LI       0.91      1.00      0.95       642
          MP       0.95      0.90      0.92      2135
          NC       1.00      1.00      1.00     25932
          PL       0.94      0.95      0.95       200
          PN       0.95      0.97      0.96      3882
          WM       0.96      1.00      0.98      3985

    accuracy                           0.99     74926
   macro avg       0.96      0.97      0.97     74926
weighted avg       0.99      0.99      0.99     74926

Confusion Matrix:
 [[24561     1    79    30    98    11     1   142   114]
 [    1   418     1     0     0     5     5     6     0]
 [   46     0 12545    32    12     1   

100%|██████████| 9366/9366 [09:25<00:00, 16.55it/s]
100%|██████████| 2342/2342 [02:04<00:00, 18.88it/s]
100%|██████████| 9366/9366 [09:32<00:00, 16.37it/s]
100%|██████████| 2342/2342 [02:07<00:00, 18.44it/s]
100%|██████████| 9366/9366 [09:36<00:00, 16.24it/s]
100%|██████████| 2342/2342 [02:09<00:00, 18.13it/s]
100%|██████████| 9366/9366 [09:37<00:00, 16.20it/s]
100%|██████████| 2342/2342 [02:07<00:00, 18.36it/s]


Accuracy: 0.9878, Precision: 0.9618, Recall: 0.9669, F1: 0.9643, MCC: 0.9833, Specificity: 0.9983
Classification Report:
               precision    recall  f1-score   support

          CT       0.99      0.99      0.99     25037
          DM       0.94      0.97      0.96       436
          IC       0.99      1.00      0.99     12677
          LI       0.97      0.98      0.98       642
          MP       0.90      0.92      0.91      2134
          NC       1.00      1.00      1.00     25933
          PL       0.90      0.90      0.90       200
          PN       0.98      0.95      0.97      3882
          WM       0.99      1.00      0.99      3985

    accuracy                           0.99     74926
   macro avg       0.96      0.97      0.96     74926
weighted avg       0.99      0.99      0.99     74926

Confusion Matrix:
 [[24668     0   112     7   184     7     1    26    32]
 [    1   425     1     0     0     9     0     0     0]
 [   28     0 12631     2     1     0   

100%|██████████| 9366/9366 [09:20<00:00, 16.71it/s]
100%|██████████| 2342/2342 [01:59<00:00, 19.59it/s]
100%|██████████| 9366/9366 [09:14<00:00, 16.88it/s]
100%|██████████| 2342/2342 [02:02<00:00, 19.18it/s]
100%|██████████| 9366/9366 [09:22<00:00, 16.64it/s]
100%|██████████| 2342/2342 [02:02<00:00, 19.15it/s]
100%|██████████| 9366/9366 [09:25<00:00, 16.56it/s]
100%|██████████| 2342/2342 [02:01<00:00, 19.24it/s]


Accuracy: 0.9885, Precision: 0.9690, Recall: 0.9701, F1: 0.9695, MCC: 0.9843, Specificity: 0.9984
Classification Report:
               precision    recall  f1-score   support

          CT       0.99      0.98      0.99     25037
          DM       0.96      0.98      0.97       435
          IC       0.99      0.99      0.99     12678
          LI       0.99      0.99      0.99       641
          MP       0.94      0.90      0.92      2135
          NC       1.00      1.00      1.00     25933
          PL       0.91      0.91      0.91       200
          PN       0.96      0.98      0.97      3882
          WM       0.98      1.00      0.99      3985

    accuracy                           0.99     74926
   macro avg       0.97      0.97      0.97     74926
weighted avg       0.99      0.99      0.99     74926

Confusion Matrix:
 [[24659     3    90     1   103     8     0   126    47]
 [    2   425     0     0     0     5     0     3     0]
 [   31     0 12613     0     0     1   

In [None]:
# Average metrics

metrics_all = np.array(metrics_all)
print("\n===== AVERAGE METRICS OVER ALL FOLDS =====")
print(f"Accuracy:     {metrics_all[:,0].mean():.4f}")
print(f"Precision:    {metrics_all[:,1].mean():.4f}")
print(f"Recall:       {metrics_all[:,2].mean():.4f}")
print(f"F1 Score:     {metrics_all[:,3].mean():.4f}")
print(f"MCC:          {metrics_all[:,4].mean():.4f}")
print(f"Specificity:  {metrics_all[:,5].mean():.4f}")



===== AVERAGE METRICS OVER ALL FOLDS =====
Accuracy:     0.9872
Precision:    0.9628
Recall:       0.9678
F1 Score:     0.9650
MCC:          0.9825
Specificity:  0.9982
