# Multiclass Brain Tumor Classifier — Runner Notebook

> **Purpose:** This notebook serves as the execution script for the **Multiclass brain tumor classification model**.  
> All core components (model architecture, data processing, training pipeline) are defined in separate Python modules to maintain a clean and modular project structure.


## Overview

This notebook is part of the **Brain Tumor AI** project, focusing on **multiclass classification** of medical images (notumor, pituitary, meningioma, glioma).  
It is designed to:
- Load and configure the modular components (model, data module, transforms, helpers, callbacks, loggers).
- Execute the training process using **PyTorch Lightning**.
- Save the trained model for inference.

By separating logic into `.py` files, the project ensures:
- **Reusability:** Components can be reused across multiple experiments.
- **Maintainability:** Easier debugging and updates.
- **Clarity:** The notebook focuses on workflow and results, not implementation details.


> **Note:** This project is for learning and portfolio purposes only — not for clinical use.


## 1. Install Dependencies & Import Libraries

### 1.1 Install Dependencies
Install the required packages to ensure the notebook runs without missing dependencies.

- **`datasets`** — Dataset handling and loading utilities.  
- **`fsspec`** — File system interface for remote/local storage.  
- **`pytorch-lightning`** — High-level PyTorch framework for training.  
- **`albumentations`** — Advanced image augmentation library.  
- **`torchmetrics`** — Standardized metrics for PyTorch.

> Skip this step if the environment already has these packages installed.


In [None]:
!pip install -q -U datasets fsspec pytorch-lightning albumentations torchmetrics

### 1.2 Import Required Libraries

Below are the required libraries and modules used in this notebook:

- **os, sys** — For file and system path handling.
- **torch** — PyTorch core library for deep learning operations.
- **pytorch_lightning** — High-level wrapper for PyTorch to simplify training loops.
- **scikit-learn (train_test_split, compute_class_weight)** — For dataset splitting and class weight computation.
- **google.colab.drive** — To mount Google Drive and access stored datasets/models.
- **datasets.load_dataset** — To load datasets in various formats from the Hugging Face Datasets library.


In [None]:
import os
import sys
import torch
import numpy as np
import pytorch_lightning as pl

from sklearn.model_selection import train_test_split
from sklearn.utils.class_weight import compute_class_weight


from google.colab import drive

from datasets import load_dataset

In [None]:
drive.mount('/content/drive')

In [None]:
CHECKPOINT_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass/checkpoint"

PROJECT_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass"

SAVE_PATH = "/content/drive/MyDrive/MyProject/brain-tumor-ai/Models/2D_Classifier_Multiclass/save_models"

In [None]:
if PROJECT_PATH not in sys.path:
  sys.path.append(PROJECT_PATH)

In [None]:
from module import DenseNetClassifierMulticlass
from datamodule import BrainTumorDataModule
from callbacks import get_callbacks
from logger import get_logger
from utils import set_seed, hf_dataset_to_tuple

In [None]:
set_seed(42)

In [None]:
ds = load_dataset("Cayanaaa/BrainTumorDatasets", name="multiclass")

In [None]:
print(ds['train'].features['label'].names)

In [None]:
train_data = ds['train']

In [None]:
images, labels = hf_dataset_to_tuple(train_data, image_key='image', label_key='label')

In [None]:
train_imgs, val_imgs, train_labels, val_labels = train_test_split(
    images, labels,
    test_size = 0.2,
    random_state = 42,
    stratify = labels
)

In [None]:
data_module = BrainTumorDataModule(
    train_data = (train_imgs, train_labels),
    val_data = (val_imgs, val_labels),
    batch_size = 64,
    num_worker = 4
)

In [None]:
model_warmup = DenseNetClassifierMulticlass(
    learning_rate = 1e-3,
    weight_decay = 1e-5,
    unfreeze_layers = None
)

In [None]:
callbacks_warmup = get_callbacks(
    dirpath = CHECKPOINT_PATH,
    monitor = 'val_loss',
    mode = 'min',
    patience = 3
)

In [None]:
logger_warmup = get_logger(
    log_dir = PROJECT_PATH/logs,
    name = "best_warmup_model_checkpoint"
)

In [None]:
trainer_warmup = pl.Trainer(
    max_epochs = 200,
    accelerator = 'gpu',
    precision = '16-mixed',
    callbacks = callbacks_warmup,
    logger = logger_warmup,
    log_every_n_step = 10,
    device = 1
)

In [None]:
trainer_warmup.fit(model_warmup, datamodule = data_module)

In [None]:
model_finetune = DenseNetClassifierMulticlass(
    learning_rate = 1e-5,
    weight_decay = 1e-6,
    unfreeze_layers = ["features.denseblock4", "features.norm5"]
)

In [None]:
callbacks_finetune = get_callbacks(
    dirpath = CHECKPOINT_PATH,
    monitor = 'val_loss',
    mode = 'min',
    patience = 3
)

In [None]:
logger_finetune = get_logger(
    log_dir = PROJECT_PATH/logs,
    name = "best_finetune_model_checkpoint"
)

In [None]:
trainer_finetune = pl.Trainer(
    max_epochs = 200,
    accelerator = 'gpu',
    precision = '16-mixed',
    callbacks = callbacks_finetune,
    logger = logger_finetune,
    log_every_n_step = 10,
    device = 1
)

In [None]:
trainer_finetune(model_finetune, datamodule=data_module)

In [None]:
print(callbacks_finetune[0].best_model_path)

In [None]:
best_checkpoint_model_path = callbacks_finetune[0].best_model_path

In [None]:
best_model = DenseNetClassifierMulticlass.load_from_checkpoint(best_checkpoint_model_path)

In [None]:
torch.save(best_model.state_dict(), f"{PROJECT_PATH}/best_ft_braTS_multiclass.pth")