# Multiclass Brain Tumor Classifier — Runner Notebook

> **Purpose:** This notebook serves as the execution script for the **multiclass brain tumor classification model**.  
> All core components (model architecture, data processing, training pipeline) are defined in separate Python modules to maintain a clean and modular project structure.

## Overview

This notebook is part of the **Brain Tumor AI** project, focusing on **multiclass classification** of medical images (**glioma**, **meningioma**, **no tumor**, **pituitary**).  
It is designed to:

- **Load and configure** the modular components: model, data module, transforms, helpers, callbacks, and loggers.
- **Execute** the training process using **PyTorch Lightning**.
- **Save** the trained model for later inference.

By separating the logic into `.py` files, the project ensures:

- **Reusability** — Components can be reused across different experiments without rewriting code.
- **Maintainability** — Easier debugging, testing, and incremental updates.
- **Clarity** — The notebook focuses on workflow execution and results, not on low-level implementation.

> **Disclaimer:** This project is developed for **learning and portfolio purposes only** — it is **not intended for clinical use**.


## 1. Install Dependencies & Import Libraries

### 1.1 Install Dependencies
Install the required packages to ensure the notebook runs without missing dependencies.

- **`datasets`** — Dataset handling and loading utilities.  
- **`fsspec`** — File system interface for remote/local storage.  
- **`pytorch-lightning`** — High-level PyTorch framework for training.  

> Skip this step if the environment already has these packages installed.


In [None]:
!pip install -q -U datasets fsspec pytorch-lightning

### 1.2 Import Required Libraries

Below are the required libraries and modules used in this notebook:

- **os, sys** — For file and system path handling.
- **torch** — PyTorch core library for deep learning operations.
- **pytorch_lightning** — High-level wrapper for PyTorch to simplify training loops.
- **scikit-learn (train_test_split)** — For dataset splitting.
- **google.colab.drive** — To mount Google Drive and access stored datasets/models.
- **datasets.load_dataset** — To load datasets in various formats from the Hugging Face Datasets library.


In [2]:
import os
import sys
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torchvision import models

import pytorch_lightning as pl

from sklearn.utils.class_weight import compute_class_weight
from sklearn.model_selection import train_test_split

from google.colab import drive

from datasets import load_dataset

- **Mount Google Drive**  
  Using `drive.mount('/content/drive')` to connect the Colab environment with your Google Drive, enabling you to save and access files persistently.



In [3]:
drive.mount('/content/drive')

Mounted at /content/drive


## 2. Define Project paths and Import Custom Module

### 2.1 Configure Directory Paths  
Here, we define the key directory paths used throughout the project:  

- **`CHECKPOINT_PATH`** — Location where model checkpoints will be saved and loaded from.  
- **`PROJECT_PATH`** — Root path of the project, used as a base reference for file operations.  
- **`SAVE_PATH`** — Directory for storing final outputs, such as trained models.  

The `PROJECT_PATH` is appended to `sys.path` to make sure Python can locate and import the project modules without issues.


In [4]:
CHECKPOINT_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass/checkpoint"

PROJECT_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass"

SAVE_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass/save_models"

In [5]:
if PROJECT_PATH not in sys.path:
    sys.path.append(PROJECT_PATH)

### 2.2 Import Custom Modules

This section imports the custom Python modules that define the model architecture, data pipeline, training callbacks, and helper functions.  
By keeping these components in separate files, the project maintains a clean and modular structure.

- **NN_PLmodule** → Custom PyTorch Lightning model for binary brain tumor classification.  
- **DataModule** → Handles data loading, preprocessing, and batching using PyTorch Lightning's DataModule structure.  
- **get_callbacks** → Retrieves predefined training callbacks such as model checkpointing and early stopping.  
- **set_seed** → Utility function to ensure reproducibility across runs.


In [6]:
from module import NN_PLmodule
from datamodule import DataModule
from callbacks import get_callbacks
from logger import get_logger
from utils import set_seed, hf_dataset_to_tuple

## 3. Define Seed, Load and Prepare Raw Dataset


### 3.1 Set Random Seed

To ensure reproducibility of results, a fixed random seed is set at the beginning of the data preparation process.  
By setting the seed, all operations involving randomness (such as data shuffling, train-test splitting, and weight initialization) will produce the same outcome each time the notebook is executed. This step is crucial for debugging and for achieving consistent experimental results.


In [7]:
set_seed(42)

### 3.2 Load Dataset from Hugging Face

We load the dataset directly using the `datasets` library. The dataset contains labeled 2D brain MRI scans across four classes: **glioma**, **meningioma**, **no tumor**, **pituitary**.

In [None]:
ds = load_dataset("Cayanaaa/BrainTumorDatasets", name="multiclass")

### 3.3 View Class Label Mapping

This command reveals the label names and their corresponding integer encodings used internally by the dataset.


In [None]:
print(ds['train'].features['label'].names)

### 3.4 Extract Train and Test from Hugging Faces Dataset

We extract the raw image and label pairs from the dataset for further processing.


In [10]:
train_data = ds['train']

### 3.5 Convert dataset.arrow from Huggingfaces to Tuple

In [11]:
images, labels = hf_dataset_to_tuple(train_data, image_key='image', label_key='label')

### 3.6 Stratified Train-Validation Split

To ensure balanced class distribution across the training and validation sets, we perform a stratified split. This minimizes the risk of class imbalance during model training.


In [12]:
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
    images, labels,
    test_size = 0.2,
    random_state = 42,
    stratify = labels
)

### 3.6 Initialize Data Module

We instantiate the **`DataModule`** with the prepared training and validation datasets.  
This module handles data loading, preprocessing, and batching automatically during training and validation.

**Parameters:**
- **`train_data`** & **`val_data`** — Tuples containing image tensors and corresponding labels.
- **`batch_size`** — Number of samples per batch during training/validation.
- **`img_size`** — Target spatial size for resizing images before feeding them into the model.
- **`num_workers`** — Number of subprocesses to use for data loading to speed up I/O operations.

By using a **`LightningDataModule`**, we ensure a clean separation between the **data pipeline** and the **model logic**, improving code maintainability and reusability.


In [13]:
data_module = DataModule(
    train_data = (train_imgs, train_labels),
    val_data = (val_imgs, val_labels),
    batch_size = 64,
    num_workers = 2,
    use_sampler = False,
)

## 4. Device Setup & Core Model Configuration (Pre-Training Preparation)


### 4.1 Set Computational Device

This step determines the **hardware device** to be used for training and inference.  
The code automatically selects **GPU** (`cuda`) if it is available; otherwise, it defaults to **CPU**.

By explicitly setting the device, we ensure:
- **Compatibility** across different hardware environments.
- **Optimal performance** when GPU acceleration is available.
- **Clarity** in debugging by confirming which device is being used.

Printing the selected device helps verify that PyTorch has correctly identified and assigned the computational resource before proceeding with model initialization and training.


In [None]:
manual_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"use device: {manual_device}")

### 4.2 Define Model Backbone

In this step, we configure the **DenseNet-121** architecture as the model backbone.  
The network is initialized with **ImageNet-pretrained weights** to leverage **transfer learning**, allowing the model to benefit from features learned on a large-scale dataset.

Key steps:
- **Load Pretrained Model:** `models.densenet121(pretrained=True)` loads a DenseNet-121 with pretrained weights.
- **Retrieve Feature Size:** `backbone.classifier.in_features` extracts the number of features from the final layer.
- **Modify Output Layer:** The classifier is replaced with a new `nn.Linear` layer to match the **two output classes** required for binary classification.

This approach significantly speeds up convergence and improves performance, especially when working with limited training data.


In [None]:
backbone = models.densenet121(pretrained=True)
num_features = backbone.classifier.in_features
backbone.classifier = nn.Linear(num_features, 4)

### 4.3 Define Loss Function with Class Weights

We use **CrossEntropyLoss** as the criterion for training the model.  
Since the dataset is **class-imbalanced**, we compute **class weights** using `compute_class_weight` from scikit-learn to ensure that minority classes are not underrepresented during training.

Steps:
- **Compute Class Weights:** The `balanced` mode automatically assigns higher weights to minority classes based on their frequency in the training labels.
- **Convert to Tensor:** The computed weights are converted to a PyTorch tensor and moved to the appropriate computation device (`manual_device`).
- **Initialize Loss Function:** The weights are passed to `nn.CrossEntropyLoss` to apply class-specific penalties during training.

This weighting strategy improves the model's ability to correctly classify underrepresented classes, leading to more balanced performance across all categories.


In [16]:
class_weights = compute_class_weight(
    class_weight = "balanced",
    classes = np.unique(train_labels),
    y = train_labels
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(manual_device)

criterion = nn.CrossEntropyLoss(weight=class_weights)

## 5. Warm-up Training Phases

### 5.1 Freeze Backbone & Enable Classifier Training

In the warmup phase, we freeze all parameters of the **pre-trained backbone** to prevent them from being updated during backpropagation.  
This ensures that the model retains the learned feature representations from the original dataset (e.g., ImageNet) while allowing only the **classifier layer** to adapt to the new task.

**Implementation Details:**
- `p.requires_grad = False` — Disables gradient computation for all backbone parameters, effectively freezing them.
- Conditional check for `"classifier"` in parameter names — Re-enables training for the classifier layer by setting `p.requires_grad = True`.

By doing this, the training process focuses solely on optimizing the final classification head during the warmup stage, reducing the risk of overfitting and speeding up convergence.


In [17]:
for name, p in backbone.named_parameters():
    p.requires_grad = False

    if "classifier" in name:
        p.requires_grad = True

### 5.2 Initializing a Warm-Up Model with a Lightning Module

In this step, we instantiate **`NN_PLmodule`**, a custom PyTorch Lightning module designed for this project.
This module encapsulates the **model architecture**, **training/validation/testing steps**, and **metric tracking** in a single, organized class, while allowing flexible injection of key components.

**Main Steps:**
- **`backbone`** — A pre-trained feature extractor (e.g., DenseNet-121) with its classifier tuned for binary classification.
- **`manual_device`** — Explicit device assignment (`CPU` or `GPU`) for full control over where the computation takes place. 
- **`num_classes`** — Specifies the number of classes (used in metrics)
- **`set_criterion(criterion)`** — Injects the chosen loss function (weighted CrossEntropy) into the module, allowing for proper handling of class imbalance.

By leveraging PyTorch Lightning, the training pipeline remains **clean, modular, and scalable**, while the injection approach ensures that critical components (e.g., backbone, loss function) can be swapped or reconfigured without rewriting the core training logic.

In [19]:
warmup_model = NN_PLmodule(backbone=backbone, manual_device=manual_device, num_classes=4)
warmup_model.set_criterion(criterion)

### 5.3 Configure Optimizer and Scheduler for Warmup

In this step, we define and attach the optimizer and learning rate scheduler for the **warmup phase** of training.

**Key Components:**
- **`AdamW` Optimizer** — Chosen for its ability to decouple weight decay from gradient updates, which often leads to better generalization compared to standard Adam.  
  - **Parameters**:  
    - **`filter(lambda p: p.requires_grad, backbone.parameters())`** ensures that only the trainable layers (here, the classifier) are updated.  
    - **`lr=1e-3`** — Relatively higher learning rate for warmup to quickly adapt the newly initialized classifier layer.  
    - **`weight_decay=1e-5`** — Helps regularize the model and reduce overfitting.

- **`ReduceLROnPlateau` Scheduler** — Monitors the validation loss and reduces the learning rate when progress plateaus.  
  - **`factor=0.1`** — Reduces the learning rate by 90% when triggered.  
  - **`patience=2`** — Waits two epochs without improvement before adjusting the LR.

- **Current LR Logging** — Prints the active learning rate to provide immediate feedback during the warmup setup.

Finally, the optimizer and scheduler are injected into the **`Hybrid_PLmodule`** using **`set_optimizer`**, keeping the Lightning training loop clean while allowing full control over optimization behavior.


In [None]:
warmup_optimizer = optim.AdamW(filter(lambda p: p.requires_grad, backbone.parameters()),
                               lr=1e-3,
                               weight_decay=1e-5
                            )
warmup_scheduler = ReduceLROnPlateau(warmup_optimizer,
                                     mode='min',
                                     factor=0.1,
                                     patience=2
                                    )
current_lr = warmup_optimizer.param_groups[0]['lr']
print(f"Active learning rate: {current_lr}")

warmup_model.set_optimizer(warmup_optimizer, warmup_scheduler)

### 5.4 Configuring Callbacks

In this step, we set up the **callbacks** that will be used during model training.  
Callbacks in PyTorch Lightning provide a mechanism to inject custom behavior at various stages of the training loop — such as saving checkpoints, early stopping, or scheduling learning rates.

Here, we use the custom function `get_callbacks()` to create and configure the following:

- **Model Checkpointing**  
  Automatically saves the model's weights whenever the monitored metric (`val_loss`) improves.  
  - **`dirpath`**: Path to store checkpoint files.  
  - **`monitor`**: Metric used to decide if a new checkpoint should be saved (`val_loss` in this case).  
  - **`mode`**: Set to `"min"` so that lower values of `val_loss` are considered better.  

- **Early Stopping**  
  Stops training early if the monitored metric does not improve after a defined patience period (`patience=3` here), preventing overfitting and saving time.

> *By modularizing callbacks into a separate function (`get_callbacks()`), we maintain cleaner code and make it easier to reuse and adjust the configuration across multiple experiments.*


In [21]:
warmup_callbacks = get_callbacks(
    dirpath = CHECKPOINT_PATH,
    monitor = 'val_loss',
    mode = 'min',
    patience = 3,
    filename="best-warmup-{epoch:02d}-{val_loss:.2f}",
)

### 5.5 Setup TensorBoard Logger for Warm-Up Phase

In this step, we initialize the **TensorBoard logger** to track and visualize training metrics during the warm-up phase.

- **`save_dir`** specifies the root directory where logs will be stored.  
- Logs are saved inside a subfolder named `"warmup"` to keep warm-up training logs organized separately from other phases.  
- This setup enables detailed monitoring of key metrics such as loss, accuracy, and learning rate using TensorBoard’s interactive web interface.

> **Note:** The magic command `%load_ext tensorboard` is executed once to enable TensorBoard integration in this notebook session.  
> After that, the `%tensorboard` command can be run multiple times to launch the TensorBoard UI pointing to the appropriate log directory without needing to reload the extension.


In [22]:
warmup_logger = get_logger(
    log_dir = f"{PROJECT_PATH}/logs",
    name = 'multiclass_classifier_warmup',
)

### 5.6 Configure Trainer for Warm-Up Phase

This cell sets up the **PyTorch Lightning Trainer** which orchestrates the training loop.

**Key parameters:**
- **`max_epochs`**: The maximum number of training epochs.
- **`accelerator`**: Automatically selects the best available device (GPU/CPU).
- **`callbacks`**: Includes checkpointing and early stopping to optimize training.
- **`logger`**: Enables logging of metrics to TensorBoard.
- **`log_every_n_steps`**: Logs training metrics every 10 batches for timely monitoring.


In [None]:
warmup_trainer = pl.Trainer(
    max_epochs = 200,
    accelerator = 'gpu',
    precision=32,
    callbacks = warmup_callbacks,
    logger = warmup_logger,
    log_every_n_steps = 10,
    devices= 1
)

### 5.7 Execute Warmup Training

We initiate the **warmup training phase** using the configured PyTorch Lightning `Trainer` instance.  
During this stage, only the **classifier layer** of the model is trainable (all backbone parameters remain frozen), allowing the newly added classification head to learn without disrupting the pretrained feature extractor.

**Process Overview:**
- **Model** — The `warmup_model`, already initialized with the backbone, loss function, optimizer, and scheduler.
- **Data Module** — The `data_module` that handles batch loading, preprocessing, and shuffling for both training and validation sets.
- **Callbacks & Loggers** — Integrated within the trainer to automatically save checkpoints, monitor validation metrics, and handle early stopping.

This phase serves as a foundation for **subsequent fine-tuning**, ensuring that the classifier head is adequately trained before unfreezing and updating the deeper layers of the backbone.


In [None]:
warmup_trainer.fit(warmup_model, data_module)

## 6. Finetune Training Phases

### 6.1 Selective Layer Unfreezing for Fine-Tuning

In the fine-tuning phase, we selectively **unfreeze** the deeper layers of the backbone while keeping earlier layers frozen.  
Specifically, parameters from **`features.denseblock4`** and **`features.norm5`** are set to `requires_grad=True`, enabling them to be updated during training.

**Rationale:**
- **Early layers** of CNN backbones typically capture low-level features (edges, textures) that are general and transferable across domains. These are kept frozen to preserve their learned representations.
- **Deeper layers** capture high-level, domain-specific patterns. By unfreezing them, the model can adapt these representations to the specific nuances of the brain tumor dataset.

This selective unfreezing strategy strikes a balance between:
- **Preserving general feature extraction** from pretraining.
- **Adapting high-level features** to the target task without overfitting.


In [25]:
for name, p in backbone.named_parameters():
    if "features.denseblock4" in name or "features.norm5" in name:
        p.requires_grad = True
    else:
        p.requires_grad = False

### 6.2 Load Warm-Up Checkpoint for Fine-Tuning

Before starting fine-tuning, the best-performing model from the **warm-up phase** is loaded as the starting point.  

**Key Step:**
- **Retrieve Best Checkpoint Path** — Extracted from the first callback in `warmup_callbacks` (usually the `ModelCheckpoint` instance).
- **Load Model Weights** — Using `NN_PLmodule.load_from_checkpoint()`, the saved weights are restored into a new `finetune_model` instance.
- **Preserve Backbone Configuration** — The same backbone and `manual_device` settings are passed during loading.
- **Flexible Loading (`strict=False`)** — Allows loading even if there are minor mismatches in layer names or shapes, which can happen when modifying the architecture between training phases.

This approach ensures that fine-tuning starts from a **well-initialized model** rather than from scratch, leveraging the representations learned during warm-up to achieve faster convergence and potentially better generalization.


In [None]:
best_warmup_model = warmup_callbacks[0].best_model_path

finetune_model = NN_PLmodule.load_from_checkpoint(
    best_warmup_model,
    backbone=backbone,
    manual_device=manual_device,
    num_classes=4,
    strict=False

)

### 6.3 Configure Fine-Tuning Criterion, Optimizer, and Scheduler

In the fine-tuning phase, the model’s training configuration is adjusted to enable **more precise weight updates**:

- **Criterion Assignment**  
   The same loss function (`criterion`) used during the warm-up phase is assigned to the fine-tuning model via `set_criterion()`.  
   This maintains consistency in optimization objectives between training stages.

- **Optimizer — AdamW**  
   - **Learning Rate:** Set to `1e-5` (significantly lower than in warm-up) to avoid large weight updates that could disrupt previously learned representations.  
   - **Weight Decay:** Reduced to `1e-6` to provide mild regularization while minimizing the risk of underfitting.  
   - **Parameter Filtering:** Only parameters with `requires_grad=True` (unfrozen layers) are passed to the optimizer.

- **Scheduler — ReduceLROnPlateau**  
   - Monitors a validation metric (e.g., `val_loss`).  
   - Reduces the learning rate by a factor of `0.1` if no improvement is observed for `2` consecutive epochs.  
   - This adaptive adjustment helps the model converge more effectively.

- **Logging Active Learning Rate**  
   The initial learning rate is printed for confirmation before training begins.

By lowering the learning rate and targeting only the newly unfrozen layers, fine-tuning focuses on **refining high-level features** without disrupting the stable lower-level representations learned during the warm-up phase.


In [None]:
finetune_model.set_criterion(criterion)

finetune_optimizer = optim.AdamW(filter(lambda p: p.requires_grad, backbone.parameters()),
                                 lr=1e-5,
                                 weight_decay=1e-6
                                )
finetune_scheduler = ReduceLROnPlateau(finetune_optimizer,
                                       mode='min',
                                       factor=0.1,
                                       patience=2
                                    )
current_lr = finetune_optimizer.param_groups[0]['lr']
print(f"Active learning rate: {current_lr}")

finetune_model.set_optimizer(finetune_optimizer, finetune_scheduler)

### 6.4 Configuring Callbacks for Fine-Tuning

In this step, we set up the **callbacks** that will be used during the **fine-tuning** phase.  
Callbacks in PyTorch Lightning provide a mechanism to inject custom behavior at various stages of the training loop — such as saving checkpoints, early stopping, or scheduling learning rates.

Here, we use the custom function `get_callbacks()` to create and configure the following:

- **Model Checkpointing**  
  Automatically saves the model's weights whenever the monitored metric (`val_loss`) improves.  
  - **`dirpath`**: Path to store checkpoint files (defined by `CHECKPOINT_PATH`).  
  - **`monitor`**: Metric used to decide if a new checkpoint should be saved (`val_loss` in this case).  
  - **`mode`**: Set to `"min"` so that lower values of `val_loss` are considered better.

- **Early Stopping**  
  Stops training early if the monitored metric does not improve after a defined patience period (`patience=3` here), preventing overfitting and saving time.

> *By modularizing callbacks into a separate function (`get_callbacks()`), we maintain cleaner code and make it easier to reuse and adjust the configuration across multiple experiments.*


In [29]:
finetune_callbacks = get_callbacks(
    dirpath = CHECKPOINT_PATH,
    monitor = 'val_loss',
    mode = 'min',
    patience = 3,
    filename="best-finetune-{epoch:02d}-{val_loss:.2f}"
)

### 6.5 Setup TensorBoard Logger for Fine-Tuning Phase

This step initializes the **TensorBoard logger** to track and visualize training metrics during the fine-tuning phase.  

- **`save_dir`** specifies the root directory where logs will be stored.  
- Logs are saved inside a subfolder named `"finetune"` to keep fine-tuning training logs organized separately from other phases.

Using TensorBoard enables easy monitoring of key metrics such as loss, accuracy, and learning rate through an interactive web interface.


In [30]:
finetune_logger = get_logger(
    log_dir = f"{PROJECT_PATH}/logs",
    name = 'multiclass_classifier_finetune',
)

### 6.6 Configure Trainer for Warm-Up Phase

This cell sets up the **PyTorch Lightning Trainer** which orchestrates the training loop.

**Key parameters:**
- **`max_epochs`**: The maximum number of training epochs.
- **`accelerator`**: Automatically selects the best available device (GPU/CPU).
- **`callbacks`**: Includes checkpointing and early stopping to optimize training.
- **`logger`**: Enables logging of metrics to TensorBoard.
- **`log_every_n_steps`**: Logs training metrics every 10 batches for timely monitoring.


In [None]:
finetune_trainer = pl.Trainer(
    max_epochs = 200,
    accelerator = 'gpu',
    precision=32,
    callbacks = finetune_callbacks,
    logger = finetune_logger,
    log_every_n_steps = 10,
    devices = 1
)

### 6.7 Execute Finetune Training

With the fine-tuning configuration in place, the model is now trained using the **PyTorch Lightning** `Trainer` instance (`finetune_trainer`).  

This step builds upon the **pre-trained weights** refined during the warm-up phase, allowing the model to adapt more specifically to the brain tumor classification task while minimizing the risk of overfitting.


In [None]:
finetune_trainer.fit(finetune_model, data_module)

## 7. Save and Export Fine-Tuned Model


### 7.1 Load Best Checkpoint After Fine-Tuning

After completing the fine-tuning process, we retrieve the path of the **best checkpoint** automatically saved by the `ModelCheckpoint` callback.

- `callbacks_finetune[0].best_model_path` accesses the best checkpoint based on the monitored metric (`val_loss` in this case).
- Using PyTorch Lightning's `load_from_checkpoint` method, the model is reloaded with weights from this best checkpoint.
- This loaded model can then be saved as a `.pth` file for easy storage and future use.
- The saved `.pth` model file can be used later for **evaluation** and **inference** without needing to retrain or reload the entire checkpoint.

This workflow ensures a clean separation between training, model saving, and later deployment or analysis.


In [33]:
best_checkpoint_path = finetune_callbacks[0].best_model_path

In [None]:
best_model = NN_PLmodule.load_from_checkpoint(best_checkpoint_path, backbone=backbone, num_classes=4, strict=False)

### 7.2 Save Fine-Tuned Model

In this step, we perform two important actions:

- **Save Model Weights**  
  The fine-tuned model's parameters (weights) are saved as a `.pth` file using `torch.save()`.  
  - `best_model.state_dict()` extracts the model's state dictionary containing all learnable parameters.  
  - The file is saved at the specified `PROJECT_PATH` with the name `best_ft_braTS_binary.pth`.  
  - Saving the model weights separately allows lightweight storage and easy loading for future inference or evaluation without the full training checkpoint overhead.


In [39]:
torch.save(best_model.state_dict(), f"{SAVE_PATH}/best_ft_braTS_multiclass.pth")

# Conclusion

> **This notebook marks an important milestone** in my personal learning journey as I transition from *vanilla PyTorch* to **PyTorch Lightning**, now adapted for **multiclass classification**.


Through this experience, I have gained valuable insights into:

- **Modular Project Structure** — Leveraging PyTorch Lightning to organize code for improved clarity, maintainability, and scalability.  
- **Reusable Training Runner** — A single Lightning module can handle both binary and multiclass tasks with only minor modifications, thanks to Lightning’s flexible API.  
- **Metric Adaptation** — Integrating multiclass metrics such as `Accuracy` and `AUROC` with minimal changes, while keeping logging and aggregation clean via `on_epoch_end` hooks.  
- **Custom Injection Approach** — Directly injecting backbone, optimizer, scheduler, and device configurations. This allowed me to take an existing binary classification runner and adapt it to multiclass by only editing targeted components.  
- **Loss Function Flexibility** — While this project uses **Cross-Entropy Loss (CE)** for multiclass, the design still supports **Binary Cross-Entropy with Logits Loss (BCEWithLogitsLoss)** for other use cases.


> This project is also a **proof of how modular PyTorch Lightning can be** — I only needed to copy an existing binary runner, make targeted adjustments for multiclass, and the rest of the training pipeline worked without major refactoring.

Next, I plan to **extend this modular Lightning approach to segmentation models**, exploring how similar design principles can be applied to more complex architectures and tasks.

## Technical Highlights

- **Stage-Based `shared_step` Implementation**  
  A single `shared_step` method handles **train**, **validation**, and **test** logic, reducing code duplication while still allowing stage-specific metrics and logging.

- **Metrics with Epoch Aggregation**  
  Accuracy and AUROC are updated per batch using `.update()` and aggregated at the end of each epoch via `.compute()` inside `on_*_epoch_end` hooks.  
  This ensures stable metric reporting and avoids misleading per-batch variations.

- **Probability Conversion**  
  Raw logits from the model are converted to probabilities using `torch.softmax(outputs, dim=1)` before passing to metrics like Accuracy and AUROC.

- **Loss Function**  
  Uses `torch.nn.CrossEntropyLoss` for multiclass classification, which combines `LogSoftmax` and `NLLLoss` internally, making it both numerically stable and efficient.

- **Device Handling**  
  Custom `batch_to_device` ensures that both inputs and labels are correctly moved to the GPU or CPU, avoiding mismatched device errors.

- **Modularity & Reusability**  
  The runner can be adapted to new datasets or architectures simply by replacing:
  1. The `backbone` (model architecture)
  2. The `criterion` (loss function)
  3. The metrics initialization  
  This makes it ideal for scaling to new tasks like segmentation.
