<a href="https://colab.research.google.com/github/AshraqtTamer/BrainTumorUsingVIT/blob/main/BrainTumorUsingVIT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("denizkavi1/brain-tumor")

print("Path to dataset files:", path)

Downloading from https://www.kaggle.com/api/v1/datasets/download/denizkavi1/brain-tumor?dataset_version_number=2...


100%|██████████| 700M/700M [00:31<00:00, 23.0MB/s]

Extracting files...





Path to dataset files: /root/.cache/kagglehub/datasets/denizkavi1/brain-tumor/versions/2


In [None]:
!pip install lightning timm


Collecting lightning
  Downloading lightning-2.5.5-py3-none-any.whl.metadata (39 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.15.2-py3-none-any.whl.metadata (5.7 kB)
Collecting torchmetrics<3.0,>0.7.0 (from lightning)
  Downloading torchmetrics-1.8.2-py3-none-any.whl.metadata (22 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.5.5-py3-none-any.whl.metadata (20 kB)
Downloading lightning-2.5.5-py3-none-any.whl (828 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m828.5/828.5 kB[0m [31m53.1 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading lightning_utilities-0.15.2-py3-none-any.whl (29 kB)
Downloading torchmetrics-1.8.2-py3-none-any.whl (983 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m983.2/983.2 kB[0m [31m65.2 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading pytorch_lightning-2.5.5-py3-none-any.whl (832 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
import os
import torch
import timm
import pytorch_lightning as pl
import torch.nn as nn
import torch.optim as optim
import torchmetrics

from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from pytorch_lightning import LightningModule, Trainer, seed_everything


In [None]:
seed_everything(42)  # for reproducibility

IMAGE_SIZE = 224
BATCH_SIZE = 16
N_CLASSES = 3
EPOCHS = 20

train_path = "/root/.cache/kagglehub/datasets/denizkavi1/brain-tumor/versions/2"


INFO:lightning_fabric.utilities.seed:Seed set to 42


In [None]:
# Matches samplewise_center, samplewise_std_normalization, and some augmentation
# In PyTorch, normalizing by per-sample statistics is less common;
# Instead, we typically normalize by the dataset stats.
# However, you can implement "samplewise" behavior if needed.
# For demonstration, we do a "global" normalization approach.

train_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],  # or dataset mean if you have it
        std=[0.5, 0.5, 0.5]    # or dataset std if you have it
    ),
])

val_transform = transforms.Compose([
    transforms.Resize((IMAGE_SIZE, IMAGE_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(
        mean=[0.5, 0.5, 0.5],
        std=[0.5, 0.5, 0.5]
    ),
])


In [None]:
from torch.utils.data import random_split

dataset_full = datasets.ImageFolder(train_path, transform=train_transform)
n_val = int(0.2 * len(dataset_full))
n_train = len(dataset_full) - n_val

train_dataset, val_dataset = random_split(dataset_full, [n_train, n_val])

# Overwrite transform for val_dataset to avoid data augmentation
val_dataset.dataset.transform = val_transform

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=2)


In [None]:
class ViTLightningModule(pl.LightningModule):
    def __init__(self, lr=1e-4, n_classes=3):
        super().__init__()
        self.save_hyperparameters()

        # Create a Vision Transformer model from timm
        # vit_base_patch16_224 is typically the name for ViT-B/16
        self.model = timm.create_model(
            "vit_base_patch16_224",
            pretrained=True,
            num_classes=n_classes
        )

        self.criterion = nn.CrossEntropyLoss()

        # Metrics
        self.train_acc = torchmetrics.classification.MulticlassAccuracy(num_classes=n_classes)
        self.val_acc = torchmetrics.classification.MulticlassAccuracy(num_classes=n_classes)

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)

        # compute accuracy
        acc = self.train_acc(outputs, labels)

        # log metrics
        self.log('train_loss', loss, prog_bar=True)
        self.log('train_acc', acc, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx):
        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)

        acc = self.val_acc(outputs, labels)

        self.log('val_loss', loss, prog_bar=True)
        self.log('val_acc', acc, prog_bar=True)

        return loss

    def test_step(self, batch, batch_idx):

        images, labels = batch
        outputs = self(images)
        loss = self.criterion(outputs, labels)

        acc = self.test_acc(outputs, labels)

        self.log('test_loss', loss, prog_bar=True)
        self.log('test_acc', acc, prog_bar=True)

        return loss

    def configure_optimizers(self):
        # Adam or any suitable optimizer
        optimizer = optim.Adam(self.parameters(), lr=self.hparams.lr)
        return optimizer


In [None]:
model = ViTLightningModule(lr=1e-4, n_classes=N_CLASSES)

trainer = Trainer(
    max_epochs=EPOCHS,
    # devices="auto",
    # precision=16 if torch.cuda.is_available() else 32,  # optional mixed precision
)



The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


model.safetensors:   0%|          | 0.00/346M [00:00<?, ?B/s]

INFO:pytorch_lightning.utilities.rank_zero:💡 Tip: For seamless cloud uploads and versioning, try installing [litmodels](https://pypi.org/project/litmodels/) to enable LitModelCheckpoint, which syncs automatically with the Lightning model registry.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:HPU available: False, using: 0 HPUs


In [None]:
trainer.fit(model, train_loader, val_loader)


INFO:pytorch_lightning.accelerators.cuda:LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
INFO:pytorch_lightning.callbacks.model_summary:
  | Name      | Type               | Params | Mode 
---------------------------------------------------------
0 | model     | VisionTransformer  | 85.8 M | train
1 | criterion | CrossEntropyLoss   | 0      | train
2 | train_acc | MulticlassAccuracy | 0      | train
3 | val_acc   | MulticlassAccuracy | 0      | train
---------------------------------------------------------
85.8 M    Trainable params
0         Non-trainable params
85.8 M    Total params
343.204   Total estimated model params size (MB)
279       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=20` reached.
