# Brain Tumor Multiclass Classifier (2D MRI)

This notebook presents a deep learning pipeline for classifying 2D brain MRI scans into four tumor categories: **Glioma**, **Meningioma**, **Pituitary**, and **No Tumor**.

The model is built using **PyTorch** and leverages a pre-trained **DenseNet121** backbone with a custom classifier head. Training is performed in two phases: a warm-up stage where only the head is trained, followed by fine-tuning of selected high-level convolutional blocks.

The dataset used was originally published on [Kaggle](#) and is re-hosted on the **Hugging Face Hub** to simplify access and integration.

This project is part of a personal initiative to explore medical imaging with computer vision, with emphasis on **transfer learning**, **training strategy**, and **model generalization**.

> **Disclaimer:** This project is for **educational and research purposes only**. It is *not* intended for medical or clinical use.


## 1. Setup & Dependencies

This section installs and imports the required libraries for data handling, model development, training, and evaluation.


### 1.1 Install Dependencies (Colab Only)

If you're using Google Colab, install the required packages using the command below.


In [None]:
# Install Hugging Face datasets library (for Colab users)
!pip install -U datasets fsspec

Collecting fsspec
  Using cached fsspec-2025.5.1-py3-none-any.whl.metadata (11 kB)


### 1.2 Import Libraries

We import all necessary libraries including PyTorch, Albumentations, scikit-learn, and Hugging Face datasets.


In [None]:
# PyTorch core modules
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader, Dataset

# Torchvision for model architectures and data utilities
import torchvision.models as models
from torchvision import datasets

# Albumentations for data augmentation
import albumentations as A
from albumentations.pytorch import ToTensorV2

# Scikit-learn 
from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight

# Standard libraries
import numpy as np
from PIL import Image
import cv2
from tqdm import tqdm
import copy

# Hugging Face datasets
from datasets import load_dataset

## 2. Load and Prepare Raw Dataset

We load a multiclass brain MRI dataset from the Hugging Face Hub using the `load_dataset` function. This dataset contains four tumor categories and was originally sourced from Kaggle. It is automatically cached for reuse.

The loaded data will be split into training and validation sets using a **stratified sampling** strategy to ensure balanced class representation.


### 2.1 Load Dataset from Hugging Face

We load the dataset directly using the `datasets` library. The dataset contains labeled 2D brain MRI scans across four classes: **glioma**, **meningioma**, **pituitary**, and **no tumor**.


In [None]:
# Load brain tumor dataset from Hugging Face (auto-cached locally)
ds = load_dataset("Cayanaaa/BrainTumorDatasets", name="multiclass")
ds

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md: 0.00B [00:00, ?B/s]

multiclass/train-00000-of-00001.parquet:   0%|          | 0.00/130M [00:00<?, ?B/s]

multiclass/test-00000-of-00001.parquet:   0%|          | 0.00/25.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/5712 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1311 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['image', 'label'],
        num_rows: 5712
    })
    test: Dataset({
        features: ['image', 'label'],
        num_rows: 1311
    })
})

### 2.2 View Class Label Mapping

This command reveals the label names and their corresponding integer encodings used internally by the dataset.


In [None]:
# Display class labels and their corresponding integer indices
print(ds['train'].features['label'].names)

['glioma', 'meningioma', 'notumor', 'pituitary']


### 2.3 Extract Images and Labels from Dataset

We extract the raw image and label pairs from the dataset for further processing.


In [None]:
# Extract image-label pairs from the training split
train_data = ds['train']
images = train_data['image']
labels = train_data['label']

### 2.4 Stratified Train-Validation Split

To ensure balanced class distribution across the training and validation sets, we perform a stratified split. This minimizes the risk of class imbalance during model training.


In [None]:
# Split data into training and validation sets while preserving class distribution
train_imgs, val_imgs, train_labels, val_labels = train_test_split(images, labels,
                                                                  test_size=0.2,
                                                                  stratify=labels,
                                                                  random_state=42
                                                                  )

## 3. Dataset Preparation

In this section, we prepare the image dataset by applying preprocessing and augmentation techniques, defining a custom PyTorch `Dataset` class, and creating `DataLoaders` for both training and validation phases.


### 3.1 Define Transformation Pipelines

We define image preprocessing and augmentation pipelines using **Albumentations** to improve generalization and performance.

- The **training pipeline** includes resizing, flipping, distortion, noise, and normalization.
- The **validation pipeline** includes only resizing and normalization to ensure consistent evaluation.


In [None]:
# Define preprocessing & augmentation for training set
train_T = A.Compose([
    A.Resize(224, 224), # Resize to model input size
    A.HorizontalFlip(p=0.5), # Random horizontal flip
    A.VerticalFlip(p=0.5),  # Random vertical flip
    A.RandomBrightnessContrast(p=0.2),  # Slight brightness/contrast variation
    A.GridDistortion(num_steps=5, distort_limit=0.03, p=0.2),  # Grid-based distortion
    A.GaussNoise(p=0.1), # Add Gaussian noise
    A.Normalize(mean=[0.485, 0.456, 0.406], # Normalize using ImageNet stats
                std=[0.229, 0.224, 0.225]),
    ToTensorV2() # Convert to PyTorch tensor
])

# Define preprocessing for validation set (no augmentation)
val_T = A.Compose([
    A.Resize(224, 224),
    A.Normalize(mean=[0.485, 0.456, 0.406],
                std=[0.229, 0.224, 0.225]),
    ToTensorV2()
])

### 3.2 Define Custom Dataset Class

We define a custom PyTorch `Dataset` class to:

- Apply the appropriate transformations.
- Return each image and its label in tensor format.


In [None]:
# Custom Dataset class to load image-label pairs and apply transforms
class LoadDataset(Dataset):
  def __init__(self, images, labels, transform=None):
    self.images = images
    self.labels = labels
    self.transform = transform

  def __len__(self):
    return len(self.images)

  def __getitem__(self, idx):
    img = self.images[idx]
    img = img.convert('RGB') # Ensure image is in RGB format
    img = np.array(img)

    label = self.labels[idx]

    if self.transform:
      img = self.transform(image=img)['image']

    return img, torch.tensor(label, dtype=torch.long)

### 3.3 Create Dataset & DataLoader

We wrap the image-label pairs using our custom `Dataset` class, and prepare `DataLoaders` to efficiently feed data during training and evaluation.


In [None]:
# Wrap image and label arrays into Dataset objects
train_dataset = LoadDataset(train_imgs, train_labels, train_T)
val_dataset = LoadDataset(val_imgs, val_labels, val_T)

# Create DataLoaders for batching and shuffling
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True) # shuffle for training
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False) # no shuffle for validation

## 4. Model, Optimizer, and Training Setup

We adopt a **transfer learning** approach using a pre-trained **DenseNet121** model. To preserve the visual features learned from ImageNet, all convolutional layers are **frozen**, and we train **only the classifier head**. This initial setup focuses on **feature extraction**, before performing full fine-tuning in a later stage.


### 4.1 Load Pre-trained Model

We load **DenseNet121** with ImageNet weights to leverage powerful low-level feature extraction learned from large-scale natural images.


In [None]:
# Load DenseNet121 model pre-trained on ImageNet
model = models.densenet121(pretrained=True)

# Freeze all layers in the feature extractor to retain pre-trained representations
for param in model.parameters():
  param.requires_grad = False

# Replace the classifier head to match the number of output classes (4)
model.classifier = nn.Linear(model.classifier.in_features, 4)

Downloading: "https://download.pytorch.org/models/densenet121-a639ec97.pth" to /root/.cache/torch/hub/checkpoints/densenet121-a639ec97.pth
100%|██████████| 30.8M/30.8M [00:00<00:00, 111MB/s]


### 4.2 Define Optimizer, Scheduler, and Device

We use the Adam optimizer to update only the classifier head. A learning rate scheduler reduces the learning rate when validation performance plateaus. GPU is used if available.


In [None]:
# Configure optimizer to update only the classifier head
early_optimizer = torch.optim.Adam(model.classifier.parameters(), lr=1e-3)

# Set up learning rate scheduler to reduce LR if validation loss stops improving
scheduler_early = ReduceLROnPlateau(early_optimizer, mode='min', factor=0.1, patience=2, verbose=True)

# Automatically use GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Move model to the selected device
model.to(device)



### 4.3 Define Weighted Loss Function

To address class imbalance in the training data, we compute class weights and apply them to the cross-entropy loss function.


In [None]:
# Compute class weights to handle imbalance and reduce bias toward frequent classes
class_weight = compute_class_weight(
    class_weight='balanced',
    classes=np.unique(train_labels),
    y=train_labels
)
class_weight = torch.tensor(class_weight, dtype=torch.float32).to(device)

# Define weighted cross-entropy loss
crieterion = nn.CrossEntropyLoss(weight=class_weight)

### 4.4 Define Early Stopping

We implement a custom early stopping mechanism to terminate training when the validation loss no longer improves after a specified number of epochs.


In [None]:
# Custom early stopping class to monitor validation performance
# Stops training if no improvement is observed over 'patience' epochs
class EarlyStopping:
    def __init__(self, monitor='val_loss', mode='min', patience=3, delta=0.0, verbose=True):
         """
        Args:
            monitor (str): Metric to monitor ('val_loss' or 'val_acc')
            mode (str): 'min' → lower is better, 'max' → higher is better
            patience (int): # of epochs with no improvement before stopping
            delta (float): Minimum change to qualify as improvement
            verbose (bool): Print status each epoch if True
        """
        self.monitor = monitor
        self.mode = mode
        self.patience = patience
        self.delta = delta
        self.verbose = verbose

        self.best_score = None
        self.counter = 0
        self.early_stop = False

        # Set comparison function and initial best value
        if self.mode == 'min':
            self.monitor_op = lambda current, best: current < best - self.delta
            self.best_score = np.inf
        elif self.mode == 'max':
            self.monitor_op = lambda current, best: current > best + self.delta
            self.best_score = -np.inf
        else:
            raise ValueError("mode must be 'min' or 'max'")

    def __call__(self, current_score):
        # Initialize best score
        if self.best_score is None:
            self.best_score = current_score
            if self.verbose:
                print(f"[EarlyStopping] Initial best {self.monitor}: {self.best_score:.4f}")
        # Check for improvement
        elif self.monitor_op(current_score, self.best_score):
            self.best_score = current_score
            self.counter = 0
            if self.verbose:
                print(f"[EarlyStopping] Improved {self.monitor}: {self.best_score:.4f}")
        else:
            self.counter += 1
            if self.verbose:
                print(f"[EarlyStopping] No improvement in {self.monitor} for {self.counter}/{self.patience} epochs.")
            # Stop training if performance has not improved for 'patience' epochs
            if self.counter >= self.patience:
                if self.verbose:
                    print(f"[EarlyStopping] Stopping training. Best {self.monitor}: {self.best_score:.4f}")
                self.early_stop = True

In [None]:
# Create an EarlyStopping instance to monitor validation loss
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=3, verbose=True)

## 5. Train Classifier Head (Warm-up Phase)

In this phase, we only train the classifier head (fully connected layers) while keeping the backbone frozen. This **warm-up strategy** helps the model gradually adapt to the domain-specific brain MRI data without modifying the general features learned from ImageNet.

The goal is to allow the final layers to specialize on our dataset before unfreezing and fine-tuning the entire network.


In [None]:
# Save initial model weights and set best validation loss to infinity
best_model_wts = copy.deepcopy(model.state_dict())
best_val_loss = np.inf

num_epoch = 100

for epoch in range(num_epoch):
  print("-" * 50)
  print(f"Epoch {epoch+1}/{num_epoch}")
  print("-" * 50)

  # --- training Phase ---
  model.train()
  train_loss, correct, total = 0.0, 0, 0

  for images, labels in tqdm(train_loader, desc="Training"):
    images, labels = images.to(device), labels.to(device)

    early_optimizer.zero_grad()
    outputs = model(images)
    loss = crieterion(outputs, labels)
    loss.backward()
    early_optimizer.step()

    train_loss += loss.item() * images.size(0)
    _, predicted = torch.max(outputs, dim=1)
    correct += (predicted == labels).sum().item()
    total += labels.size(0)

  avg_train_loss = train_loss / total
  train_acc = correct / total

  # --- Validation Phase ---
  model.eval()
  val_loss, correct, total = 0.0, 0, 0

  with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Validation"):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = crieterion(outputs, labels)

        val_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

  avg_val_loss = val_loss / total
  val_acc = correct / total

  print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}")
  print(f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}")

  # Step the learning rate scheduler and update early stopping
  scheduler_early.step(avg_val_loss)
  early_stopping(avg_val_loss)

  # Save model weights if validation loss improves
  if avg_val_loss < best_val_loss:
    best_val_loss = avg_val_loss
    best_model_wts = copy.deepcopy(model.state_dict())
    torch.save(model.state_dict(), 'mct_best_model.pth')
    print(f"[INFO]: Best model updated")

  # Stop training if early stopping is triggered
  if early_stopping.early_stop:
    print(f"[INFO]: Training Stopped by EarlyStopping")
    break

# Load best model weights after training
model.load_state_dict(best_model_wts)
print("[INFO]: Best Model Loaded")



Epoch 1/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.47it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.9025 | Train Acc: 0.67
Val Loss: 0.5652 | Val Acc: 0.83
[EarlyStopping] Improved val_loss: 0.5652
[INFO]: Best model updated
Epoch 2/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.59it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.5926 | Train Acc: 0.80
Val Loss: 0.4685 | Val Acc: 0.84
[EarlyStopping] Improved val_loss: 0.4685
[INFO]: Best model updated
Epoch 3/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.5114 | Train Acc: 0.82
Val Loss: 0.4032 | Val Acc: 0.86
[EarlyStopping] Improved val_loss: 0.4032
[INFO]: Best model updated
Epoch 4/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.51it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.4567 | Train Acc: 0.84
Val Loss: 0.3715 | Val Acc: 0.87
[EarlyStopping] Improved val_loss: 0.3715
[INFO]: Best model updated
Epoch 5/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.56it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.94it/s]


Train Loss: 0.4312 | Train Acc: 0.86
Val Loss: 0.3478 | Val Acc: 0.88
[EarlyStopping] Improved val_loss: 0.3478
[INFO]: Best model updated
Epoch 6/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.83it/s]


Train Loss: 0.4306 | Train Acc: 0.85
Val Loss: 0.3342 | Val Acc: 0.88
[EarlyStopping] Improved val_loss: 0.3342
[INFO]: Best model updated
Epoch 7/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.54it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.86it/s]


Train Loss: 0.4107 | Train Acc: 0.85
Val Loss: 0.3257 | Val Acc: 0.88
[EarlyStopping] Improved val_loss: 0.3257
[INFO]: Best model updated
Epoch 8/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.56it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.06it/s]


Train Loss: 0.3889 | Train Acc: 0.86
Val Loss: 0.3175 | Val Acc: 0.89
[EarlyStopping] Improved val_loss: 0.3175
[INFO]: Best model updated
Epoch 9/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.59it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.10it/s]


Train Loss: 0.3896 | Train Acc: 0.86
Val Loss: 0.3050 | Val Acc: 0.89
[EarlyStopping] Improved val_loss: 0.3050
[INFO]: Best model updated
Epoch 10/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.63it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.3717 | Train Acc: 0.87
Val Loss: 0.3130 | Val Acc: 0.88
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 11/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.57it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.94it/s]


Train Loss: 0.3667 | Train Acc: 0.87
Val Loss: 0.2993 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2993
[INFO]: Best model updated
Epoch 12/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.12it/s]


Train Loss: 0.3517 | Train Acc: 0.88
Val Loss: 0.2937 | Val Acc: 0.89
[EarlyStopping] Improved val_loss: 0.2937
[INFO]: Best model updated
Epoch 13/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.10it/s]


Train Loss: 0.3416 | Train Acc: 0.88
Val Loss: 0.2886 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2886
[INFO]: Best model updated
Epoch 14/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:26<00:00,  2.68it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.99it/s]


Train Loss: 0.3559 | Train Acc: 0.87
Val Loss: 0.2860 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2860
[INFO]: Best model updated
Epoch 15/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.62it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.02it/s]


Train Loss: 0.3408 | Train Acc: 0.88
Val Loss: 0.2827 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2827
[INFO]: Best model updated
Epoch 16/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.12it/s]


Train Loss: 0.3500 | Train Acc: 0.88
Val Loss: 0.2801 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2801
[INFO]: Best model updated
Epoch 17/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.59it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.99it/s]


Train Loss: 0.3386 | Train Acc: 0.88
Val Loss: 0.2753 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2753
[INFO]: Best model updated
Epoch 18/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.62it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.91it/s]


Train Loss: 0.3360 | Train Acc: 0.87
Val Loss: 0.2752 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2752
[INFO]: Best model updated
Epoch 19/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.60it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.98it/s]


Train Loss: 0.3264 | Train Acc: 0.88
Val Loss: 0.2755 | Val Acc: 0.90
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 20/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.56it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.07it/s]


Train Loss: 0.3331 | Train Acc: 0.88
Val Loss: 0.2717 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2717
[INFO]: Best model updated
Epoch 21/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.93it/s]


Train Loss: 0.3282 | Train Acc: 0.88
Val Loss: 0.2743 | Val Acc: 0.90
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 22/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.59it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.90it/s]


Train Loss: 0.3295 | Train Acc: 0.88
Val Loss: 0.2684 | Val Acc: 0.91
[EarlyStopping] Improved val_loss: 0.2684
[INFO]: Best model updated
Epoch 23/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.58it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.95it/s]


Train Loss: 0.3173 | Train Acc: 0.88
Val Loss: 0.2611 | Val Acc: 0.90
[EarlyStopping] Improved val_loss: 0.2611
[INFO]: Best model updated
Epoch 24/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.59it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.10it/s]


Train Loss: 0.3173 | Train Acc: 0.88
Val Loss: 0.2862 | Val Acc: 0.90
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 25/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:28<00:00,  2.57it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.04it/s]


Train Loss: 0.3052 | Train Acc: 0.89
Val Loss: 0.2645 | Val Acc: 0.90
[EarlyStopping] No improvement in val_loss for 2/3 epochs.
Epoch 26/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:27<00:00,  2.61it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]

Train Loss: 0.3027 | Train Acc: 0.89
Val Loss: 0.2660 | Val Acc: 0.91
[EarlyStopping] No improvement in val_loss for 3/3 epochs.
[EarlyStopping] Stopping training. Best val_loss: 0.2611
[INFO]: Training Stopped by EarlyStopping
[INFO]: Best Model Loaded





## 6. Fine-Tuning Setup

In this phase, we fine-tune the deeper parts of the model to better adapt to the brain tumor classification task. Instead of unfreezing the entire backbone, we selectively unfreeze the final convolutional block and normalization layer to balance adaptability and generalization.

Fine-tuning allows the model to refine high-level features learned from ImageNet in a domain-specific context.



### 6.1 Unfreeze Selected Layers

Here, we unfreeze the `denseblock4` and `norm5` layers of the backbone while keeping all earlier layers frozen. This selective unfreezing helps avoid overfitting and reduces the risk of catastrophic forgetting.



In [None]:
# Only unfreeze the last DenseBlock and final batch norm layer (norm5)
for name, param in model.named_parameters():
  if 'denseblock4' in name or 'norm5' in name:
    param.requires_grad = True
  else:
    param.requires_grad = False

### 6.2 Fine-Tuning Optimizer & Callbacks

We define a new optimizer and learning rate scheduler for the fine-tuning phase. Only the parameters marked as trainable (i.e., from `denseblock4` and `norm5`) are updated during this phase.

An `EarlyStopping` callback is also set up to prevent overfitting by halting training when the validation loss no longer improves.

> **Note**: We print the active learning rate after optimizer setup to verify that the new learning rate is properly configured.


In [None]:
# Define optimizer for fine-tuning (only trainable parameters)
ft_optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)

# Print current learning rate (for verification)
current_lr = ft_optimizer.param_groups[0]['lr']
print(f"Active learning rate: {current_lr}")

# Define scheduler for fine-tuning
scheduler_ft = ReduceLROnPlateau(ft_optimizer, mode='min', factor=0.1, patience=2)

# Define early stopping callback
early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=3, verbose=True)


## 7. Fine-Tune Backbone (Training Loop)

In this section, we perform **fine-tuning** by training the previously unfrozen layers (`denseblock4` and `norm5`) along with the classifier head. Unlike the warm-up phase, this step allows the model to adjust higher-level convolutional features to the specific patterns present in brain MRI images.

The training loop here follows the same structure as the warm-up phase (Section 5), with updated optimizer and scheduler settings defined in Section 6.2. We continue to monitor validation loss and apply **early stopping** to prevent overfitting.


In [None]:
# Save initial model weights and set best validation loss to infinity
best_model_wts = copy.deepcopy(model.state_dict())
best_val_loss = np.inf

num_epoch = 100

for epoch in range(num_epoch):
  print(f"Epoch {epoch+1}/{num_epoch}")
  print("-" * 50)

  # --- training Phase ---
  model.train()
  train_loss, correct, total = 0.0, 0, 0

  for images, labels in tqdm(train_loader, desc="Training"):
    images, labels = images.to(device), labels.to(device)

    ft_optimizer.zero_grad()
    outputs = model(images)
    loss = crieterion(outputs, labels)
    loss.backward()
    ft_optimizer.step()

    train_loss += loss.item() * images.size(0)
    _, predicted = torch.max(outputs, dim=1)
    correct += (predicted == labels).sum().item()
    total += labels.size(0)

  avg_train_loss = train_loss / total
  train_acc = correct / total

  # --- Validation Phase ---
  model.eval()
  val_loss, correct, total = 0.0, 0, 0

  with torch.no_grad():
    for images, labels in tqdm(val_loader, desc="Validation"):
        images, labels = images.to(device), labels.to(device)

        outputs = model(images)
        loss = crieterion(outputs, labels)

        val_loss += loss.item() * images.size(0)
        _, predicted = torch.max(outputs, dim=1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

  avg_val_loss = val_loss / total
  val_acc = correct / total

  print(f"Train Loss: {avg_train_loss:.4f} | Train Acc: {train_acc:.2f}")
  print(f"Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.2f}")

  scheduler_ft.step(avg_val_loss)
  early_stopping(avg_val_loss)

  if avg_val_loss < best_val_loss:
    best_val_loss = avg_val_loss
    best_model_wts = copy.deepcopy(model.state_dict())
    torch.save(model.state_dict(), 'mct_best_model.pth')
    print(f"[INFO]: Best model updated")

  if early_stopping.early_stop:
    print(f"[INFO]: Training Stopped by EarlyStopping")
    break

model.load_state_dict(best_model_wts)
print("[INFO]: Best Model Loaded")



Epoch 1/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.00it/s]


Train Loss: 0.3080 | Train Acc: 0.89
Val Loss: 0.2512 | Val Acc: 0.91
[EarlyStopping] Improved val_loss: 0.2512
[INFO]: Best model updated
Epoch 2/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.97it/s]


Train Loss: 0.2912 | Train Acc: 0.90
Val Loss: 0.2425 | Val Acc: 0.91
[EarlyStopping] Improved val_loss: 0.2425
[INFO]: Best model updated
Epoch 3/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]


Train Loss: 0.2694 | Train Acc: 0.91
Val Loss: 0.2288 | Val Acc: 0.92
[EarlyStopping] Improved val_loss: 0.2288
[INFO]: Best model updated
Epoch 4/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


Train Loss: 0.2538 | Train Acc: 0.91
Val Loss: 0.2167 | Val Acc: 0.92
[EarlyStopping] Improved val_loss: 0.2167
[INFO]: Best model updated
Epoch 5/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.84it/s]


Train Loss: 0.2479 | Train Acc: 0.91
Val Loss: 0.2108 | Val Acc: 0.92
[EarlyStopping] Improved val_loss: 0.2108
[INFO]: Best model updated
Epoch 6/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.82it/s]


Train Loss: 0.2299 | Train Acc: 0.92
Val Loss: 0.2074 | Val Acc: 0.93
[EarlyStopping] Improved val_loss: 0.2074
[INFO]: Best model updated
Epoch 7/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]


Train Loss: 0.2273 | Train Acc: 0.92
Val Loss: 0.1999 | Val Acc: 0.93
[EarlyStopping] Improved val_loss: 0.1999
[INFO]: Best model updated
Epoch 8/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


Train Loss: 0.2193 | Train Acc: 0.92
Val Loss: 0.1960 | Val Acc: 0.93
[EarlyStopping] Improved val_loss: 0.1960
[INFO]: Best model updated
Epoch 9/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.83it/s]


Train Loss: 0.2100 | Train Acc: 0.93
Val Loss: 0.1892 | Val Acc: 0.93
[EarlyStopping] Improved val_loss: 0.1892
[INFO]: Best model updated
Epoch 10/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.95it/s]


Train Loss: 0.2085 | Train Acc: 0.93
Val Loss: 0.1833 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1833
[INFO]: Best model updated
Epoch 11/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.04it/s]


Train Loss: 0.1908 | Train Acc: 0.93
Val Loss: 0.1781 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1781
[INFO]: Best model updated
Epoch 12/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.03it/s]


Train Loss: 0.1935 | Train Acc: 0.94
Val Loss: 0.1765 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1765
[INFO]: Best model updated
Epoch 13/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.11it/s]


Train Loss: 0.1906 | Train Acc: 0.93
Val Loss: 0.1712 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1712
[INFO]: Best model updated
Epoch 14/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.07it/s]


Train Loss: 0.1868 | Train Acc: 0.93
Val Loss: 0.1691 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1691
[INFO]: Best model updated
Epoch 15/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.1836 | Train Acc: 0.94
Val Loss: 0.1662 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1662
[INFO]: Best model updated
Epoch 16/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.12it/s]


Train Loss: 0.1689 | Train Acc: 0.94
Val Loss: 0.1590 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1590
[INFO]: Best model updated
Epoch 17/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.36it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.10it/s]


Train Loss: 0.1630 | Train Acc: 0.94
Val Loss: 0.1593 | Val Acc: 0.94
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 18/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.1652 | Train Acc: 0.94
Val Loss: 0.1572 | Val Acc: 0.94
[EarlyStopping] Improved val_loss: 0.1572
[INFO]: Best model updated
Epoch 19/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.34it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.04it/s]


Train Loss: 0.1590 | Train Acc: 0.94
Val Loss: 0.1502 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1502
[INFO]: Best model updated
Epoch 20/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.05it/s]


Train Loss: 0.1641 | Train Acc: 0.94
Val Loss: 0.1482 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1482
[INFO]: Best model updated
Epoch 21/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.95it/s]


Train Loss: 0.1628 | Train Acc: 0.94
Val Loss: 0.1443 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1443
[INFO]: Best model updated
Epoch 22/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.89it/s]


Train Loss: 0.1408 | Train Acc: 0.96
Val Loss: 0.1443 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1443
[INFO]: Best model updated
Epoch 23/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.1425 | Train Acc: 0.95
Val Loss: 0.1417 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1417
[INFO]: Best model updated
Epoch 24/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.42it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.86it/s]


Train Loss: 0.1412 | Train Acc: 0.95
Val Loss: 0.1384 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1384
[INFO]: Best model updated
Epoch 25/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.46it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.97it/s]


Train Loss: 0.1451 | Train Acc: 0.95
Val Loss: 0.1373 | Val Acc: 0.95
[EarlyStopping] Improved val_loss: 0.1373
[INFO]: Best model updated
Epoch 26/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.89it/s]


Train Loss: 0.1373 | Train Acc: 0.95
Val Loss: 0.1353 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1353
[INFO]: Best model updated
Epoch 27/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


Train Loss: 0.1388 | Train Acc: 0.95
Val Loss: 0.1314 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1314
[INFO]: Best model updated
Epoch 28/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.00it/s]


Train Loss: 0.1361 | Train Acc: 0.96
Val Loss: 0.1303 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1303
[INFO]: Best model updated
Epoch 29/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.03it/s]


Train Loss: 0.1270 | Train Acc: 0.96
Val Loss: 0.1302 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1302
[INFO]: Best model updated
Epoch 30/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.1268 | Train Acc: 0.96
Val Loss: 0.1279 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1279
[INFO]: Best model updated
Epoch 31/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.1304 | Train Acc: 0.96
Val Loss: 0.1258 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1258
[INFO]: Best model updated
Epoch 32/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.1197 | Train Acc: 0.96
Val Loss: 0.1256 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1256
[INFO]: Best model updated
Epoch 33/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.05it/s]


Train Loss: 0.1287 | Train Acc: 0.95
Val Loss: 0.1210 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1210
[INFO]: Best model updated
Epoch 34/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.1206 | Train Acc: 0.96
Val Loss: 0.1226 | Val Acc: 0.96
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 35/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.01it/s]


Train Loss: 0.1188 | Train Acc: 0.96
Val Loss: 0.1183 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1183
[INFO]: Best model updated
Epoch 36/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.92it/s]


Train Loss: 0.1148 | Train Acc: 0.96
Val Loss: 0.1183 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 37/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.43it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


Train Loss: 0.1156 | Train Acc: 0.96
Val Loss: 0.1194 | Val Acc: 0.96
[EarlyStopping] No improvement in val_loss for 2/3 epochs.
Epoch 38/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.86it/s]


Train Loss: 0.1147 | Train Acc: 0.96
Val Loss: 0.1170 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1170
[INFO]: Best model updated
Epoch 39/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.41it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.1112 | Train Acc: 0.97
Val Loss: 0.1155 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1155
[INFO]: Best model updated
Epoch 40/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.85it/s]


Train Loss: 0.1169 | Train Acc: 0.96
Val Loss: 0.1141 | Val Acc: 0.96
[EarlyStopping] Improved val_loss: 0.1141
[INFO]: Best model updated
Epoch 41/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.88it/s]


Train Loss: 0.1055 | Train Acc: 0.97
Val Loss: 0.1145 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 42/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.90it/s]


Train Loss: 0.1028 | Train Acc: 0.97
Val Loss: 0.1131 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1131
[INFO]: Best model updated
Epoch 43/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.91it/s]


Train Loss: 0.1059 | Train Acc: 0.97
Val Loss: 0.1116 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1116
[INFO]: Best model updated
Epoch 44/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.39it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.98it/s]


Train Loss: 0.1034 | Train Acc: 0.97
Val Loss: 0.1103 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1103
[INFO]: Best model updated
Epoch 45/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.98it/s]


Train Loss: 0.1054 | Train Acc: 0.96
Val Loss: 0.1097 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1097
[INFO]: Best model updated
Epoch 46/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.06it/s]


Train Loss: 0.1022 | Train Acc: 0.97
Val Loss: 0.1086 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1086
[INFO]: Best model updated
Epoch 47/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.35it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.0969 | Train Acc: 0.97
Val Loss: 0.1085 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1085
[INFO]: Best model updated
Epoch 48/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.38it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.07it/s]


Train Loss: 0.0986 | Train Acc: 0.97
Val Loss: 0.1088 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 49/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.0968 | Train Acc: 0.97
Val Loss: 0.1102 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 2/3 epochs.
Epoch 50/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.36it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.09it/s]


Train Loss: 0.0989 | Train Acc: 0.97
Val Loss: 0.1068 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1068
[INFO]: Best model updated
Epoch 51/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.37it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.06it/s]


Train Loss: 0.0948 | Train Acc: 0.97
Val Loss: 0.1076 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 52/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.36it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.08it/s]


Train Loss: 0.0944 | Train Acc: 0.97
Val Loss: 0.1046 | Val Acc: 0.97
[EarlyStopping] Improved val_loss: 0.1046
[INFO]: Best model updated
Epoch 53/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:30<00:00,  2.36it/s]
Validation: 100%|██████████| 18/18 [00:05<00:00,  3.03it/s]


Train Loss: 0.1001 | Train Acc: 0.97
Val Loss: 0.1050 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 1/3 epochs.
Epoch 54/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.40it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.97it/s]


Train Loss: 0.0907 | Train Acc: 0.97
Val Loss: 0.1053 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 2/3 epochs.
Epoch 55/100
--------------------------------------------------


Training: 100%|██████████| 72/72 [00:29<00:00,  2.42it/s]
Validation: 100%|██████████| 18/18 [00:06<00:00,  2.87it/s]

Train Loss: 0.0903 | Train Acc: 0.97
Val Loss: 0.1047 | Val Acc: 0.97
[EarlyStopping] No improvement in val_loss for 3/3 epochs.
[EarlyStopping] Stopping training. Best val_loss: 0.1046
[INFO]: Training Stopped by EarlyStopping
[INFO]: Best Model Loaded





## 8. Save Final Model

After fine-tuning, the best-performing model (based on validation loss) is saved using `torch.save()`. This ensures that the most generalizable version of the model is preserved for deployment or further evaluation.

For privacy and reproducibility, the model is uploaded to Hugging Face Hub instead of being stored in a local path. The download link or model reference will be provided in the project source files.

**Model location**: Refer to the model card or config file in the `src/` directory.


In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
torch.save(model.state_dict(), '/path/to/your/drive')

## 📌 Conclusion

Expanding on the previous binary classification task, this notebook addresses a more nuanced challenge: **classifying brain MRI scans into four distinct tumor types**.

While the underlying architecture and training strategy — including **transfer learning**, **data augmentation**, and **fine-tuning** — remain consistent with the earlier approach, the multiclass setting introduces additional complexity. These include more intricate **decision boundaries**, **inter-class imbalance**, and the need for more **robust evaluation**.

Working through this project strengthened my understanding of **model scalability**, **multi-class loss handling**, and the practical limitations of generalization in medical image classification.

This notebook reflects an important step in moving from foundational experimentation toward more sophisticated deep learning pipelines that are structured, reproducible, and scalable.

> 💡 Together with the binary version, this project forms a continuous learning track that builds intuition, confidence, and capability in applying computer vision to real-world medical challenges.
