In [1]:
!pip3 install lightning einops torchmetrics

Collecting lightning
  Downloading lightning-2.4.0-py3-none-any.whl.metadata (38 kB)
Collecting einops
  Downloading einops-0.8.0-py3-none-any.whl.metadata (12 kB)
Collecting torchmetrics
  Downloading torchmetrics-1.4.1-py3-none-any.whl.metadata (20 kB)
Collecting lightning-utilities<2.0,>=0.10.0 (from lightning)
  Downloading lightning_utilities-0.11.6-py3-none-any.whl.metadata (5.2 kB)
Collecting pytorch-lightning (from lightning)
  Downloading pytorch_lightning-2.4.0-py3-none-any.whl.metadata (21 kB)
Downloading lightning-2.4.0-py3-none-any.whl (810 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m811.0/811.0 kB[0m [31m11.0 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25hDownloading einops-0.8.0-py3-none-any.whl (43 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.2/43.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25hDownloading torchmetrics-1.4.1-py3-none-any.whl (866 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [3

In [2]:
import pytorch_lightning as pl                    # PyTorch Lightning framework
import torch                                      # Core PyTorch library
import torch.nn as nn                             # PyTorch neural network modules
import torch.optim                                # PyTorch optimization
from torch.utils.data import DataLoader, Subset   # PyTorch data utilities
from torchvision import datasets, transforms      # Datasets and transforms from torchvision
from sklearn.model_selection import train_test_split  # Utility for train-validation splitting
import matplotlib.pyplot as plt                   # Matplotlib for plotting
import seaborn as sns                             # Seaborn for visualization aesthetics
from einops import repeat                         # Einops for tensor operations
from einops.layers.torch import Rearrange         # Einops layers for tensor rearranging
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping  # Callback utilities
from torchmetrics.classification import (         # Metrics for classification tasks
    MulticlassAccuracy, 
    MulticlassF1Score, 
    MulticlassPrecision, 
    MulticlassRecall
)


In [3]:
# Check if GPU is available
import torch
gpu_available = torch.cuda.is_available()
print("GPU Available:", gpu_available)

import multiprocessing

num_cores = multiprocessing.cpu_count()
print(f"Number of available CPU cores: {num_cores}")

GPU Available: True
Number of available CPU cores: 8


In [4]:
class FashionMNISTDataModule(pl.LightningDataModule):
    def __init__(self, batch_size=64):
        super().__init__()
        self.batch_size = batch_size
        self.transform = transforms.Compose([
            transforms.ToTensor(),  # This will normalize the values to [0, 1]
        ])

    def prepare_data(self):
        # Download datasets
        datasets.FashionMNIST(root='', train=True, download=True)
        datasets.FashionMNIST(root='', train=False, download=True)

    def setup(self, stage=None):
        # Load data
        full_train_dataset = datasets.FashionMNIST(root='', train=True, transform=self.transform)
        train_idx, val_idx = train_test_split(
            range(len(full_train_dataset)),
            test_size=0.1,
            shuffle=True,
            stratify=full_train_dataset.targets
        )
        
        self.train_dataset = Subset(full_train_dataset, train_idx)
        self.val_dataset   = Subset(full_train_dataset, val_idx)
        self.test_dataset  = datasets.FashionMNIST(root='', train=False, transform=self.transform)

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=8, pin_memory=True
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, num_workers=8, pin_memory=True
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset, batch_size=self.batch_size, num_workers=8, pin_memory=True
        )

In [5]:
from pytorch_lightning import Callback
import time

class MetricsCallback(Callback):
    def __init__(self):
        super().__init__()
        self.reset_metrics()

    def reset_metrics(self):
        self.train_losses = []
        self.val_losses = []
        self.train_acc = []
        self.val_acc = []
        self.train_f1 = []
        self.val_f1 = []
        self.train_precision = []
        self.val_precision = []
        self.train_recall = []
        self.val_recall = []

    def on_train_epoch_end(self, trainer, pl_module):
        self.train_losses.append(trainer.callback_metrics['train_loss'].item())
        self.train_acc.append(trainer.callback_metrics['train_acc'].item())
        self.train_f1.append(trainer.callback_metrics['train_f1'].item())
        self.train_precision.append(trainer.callback_metrics['train_precision'].item())
        self.train_recall.append(trainer.callback_metrics['train_recall'].item())

    def on_validation_epoch_end(self, trainer, pl_module):
        self.val_losses.append(trainer.callback_metrics['val_loss'].item())
        self.val_acc.append(trainer.callback_metrics['val_acc'].item())
        self.val_f1.append(trainer.callback_metrics['val_f1'].item())
        self.val_precision.append(trainer.callback_metrics['val_precision'].item())
        self.val_recall.append(trainer.callback_metrics['val_recall'].item())

    def get_metrics(self):
        return {
            "train_losses": self.train_losses,
            "val_losses": self.val_losses,
            "train_acc": self.train_acc,
            "val_acc": self.val_acc,
            "train_f1": self.train_f1,
            "val_f1": self.val_f1,
            "train_precision": self.train_precision,
            "val_precision": self.val_precision,
            "train_recall": self.train_recall,
            "val_recall": self.val_recall,
        }

In [6]:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set_style("darkgrid")


def model_monitor(train_losses, val_losses, train_acc, val_acc, train_f1, val_f1, train_precision, val_precision, train_recall, val_recall):
    red    = '#FF6347'
    orange = "orange"
    blue   = '#4682B4'
    green  = "#55A868"
    purple = "#800080"
    cyan   = "#00FFFF"

    fig, ax = plt.subplots(2, 2, figsize=(20, 12))
    
    # Loss
    sns.lineplot(x=range(1, len(train_losses) + 1), y=train_losses, label='Train Loss', color=blue, ax=ax[0, 0])
    sns.lineplot(x=range(1, len(val_losses) + 1), y=val_losses, label='Validation Loss', color=green, ax=ax[0, 0])
    ax[0, 0].set_title('Training vs. Validation Loss', fontsize=16, fontweight='bold')
    ax[0, 0].set_xlabel('Epochs', fontsize=12, fontweight='bold')
    ax[0, 0].set_ylabel('Loss', fontsize=12, fontweight='bold')
    ax[0, 0].legend(facecolor='white', fontsize=12, title_fontsize='11', edgecolor='black')
    ax[0, 0].grid(True)

    # Accuracy
    sns.lineplot(x=range(1, len(train_acc) + 1), y=train_acc, label='Train Accuracy', color=orange, ax=ax[0, 1])
    sns.lineplot(x=range(1, len(val_acc) + 1), y=val_acc, label='Validation Accuracy', color=purple, ax=ax[0, 1])
    ax[0, 1].set_title('Training vs. Validation Accuracy', fontsize=16, fontweight='bold')
    ax[0, 1].set_xlabel('Epochs', fontsize=12, fontweight='bold')
    ax[0, 1].set_ylabel('Accuracy', fontsize=12, fontweight='bold')
    ax[0, 1].legend(facecolor='white', fontsize=12, title_fontsize='11', edgecolor='black')
    ax[0, 1].grid(True)

    # F1 Score
    sns.lineplot(x=range(1, len(train_f1) + 1), y=train_f1, label='Train F1 Score', color=cyan, ax=ax[1, 0])
    sns.lineplot(x=range(1, len(val_f1) + 1), y=val_f1, label='Validation F1 Score', color=red, ax=ax[1, 0])
    ax[1, 0].set_title('Training vs. Validation F1 Score', fontsize=16, fontweight='bold')
    ax[1, 0].set_xlabel('Epochs', fontsize=12, fontweight='bold')
    ax[1, 0].set_ylabel('F1 Score', fontsize=12, fontweight='bold')
    ax[1, 0].legend(facecolor='white', fontsize=12, title_fontsize='11', edgecolor='black')
    ax[1, 0].grid(True)

    # Precision and Recall
    sns.lineplot(x=range(1, len(train_precision) + 1), y=train_precision, label='Train Precision', color=blue, ax=ax[1, 1])
    sns.lineplot(x=range(1, len(val_precision) + 1), y=val_precision, label='Validation Precision', color=green, ax=ax[1, 1])
    sns.lineplot(x=range(1, len(train_recall) + 1), y=train_recall, label='Train Recall', color=orange, ax=ax[1, 1])
    sns.lineplot(x=range(1, len(val_recall) + 1), y=val_recall, label='Validation Recall', color=purple, ax=ax[1, 1])
    ax[1, 1].set_title('Training vs. Validation Precision and Recall', fontsize=16, fontweight='bold')
    ax[1, 1].set_xlabel('Epochs', fontsize=12, fontweight='bold')
    ax[1, 1].set_ylabel('Score', fontsize=12, fontweight='bold')
    ax[1, 1].legend(facecolor='white', fontsize=12, title_fontsize='11', edgecolor='black')
    ax[1, 1].grid(True)

    plt.tight_layout()
    plt.show();

In [9]:
import torch
import torch.nn as nn
import pytorch_lightning as pl
from einops import repeat
from einops.layers.torch import Rearrange

class PatchEmbedding(nn.Module):
    def __init__(self, in_channels=1, patch_size=4, emb_size=128, img_size=28, embedding_type='linear'):
        super().__init__()
        self.patch_size = patch_size
        self.embedding_type = embedding_type
        
        if embedding_type == 'linear':
            self.projection = nn.Sequential(
                Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size),
                nn.Linear(patch_size * patch_size * in_channels, emb_size)
            )
        elif embedding_type == 'conv':
            self.projection = nn.Sequential(
                nn.Conv2d(in_channels, emb_size, kernel_size=patch_size, stride=patch_size),
                Rearrange('b e (h) (w) -> b (h w) e')
            )
        
    def forward(self, x):
        x = self.projection(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, n_heads, dropout):
        super().__init__()
        self.n_heads = n_heads
        self.att = nn.MultiheadAttention(embed_dim=dim, num_heads=n_heads, dropout=dropout)
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        
    def forward(self, x):
        q, k, v = self.q(x), self.k(x), self.v(x)
        attn_output, _ = self.att(q, k, v)
        return attn_output

class FeedForward(nn.Sequential):
    def __init__(self, dim, hidden_dim, dropout = 0.1):
        super().__init__(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )

class TransformerBlock(nn.Module):
    def __init__(self, dim, n_heads, mlp_ratio=1.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = Attention(dim, n_heads=n_heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.ffn = FeedForward(dim, int(dim * mlp_ratio), dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = x + self.dropout(self.attn(self.norm1(x)))
        x = x + self.ffn(self.norm2(x))
        return x

class ViT(pl.LightningModule):
    def __init__(self, ch=1, img_size=28, patch_size=4, emb_dim=128,
                 n_layers=6, num_classes=10, dropout=0.1, heads=8, 
                 learning_rate=1e-3, embedding_type='linear', mlp_ratio=1.0):
        super().__init__()
        self.save_hyperparameters()
        
        # Patch Embedding
        self.patch_embedding = PatchEmbedding(in_channels=ch, patch_size=patch_size, 
                                              emb_size=emb_dim, img_size=img_size,
                                              embedding_type=embedding_type)
        
        # Learnable parameters
        num_patches = (img_size // patch_size) ** 2
        self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, emb_dim))
        self.cls_token = nn.Parameter(torch.randn(1, 1, emb_dim))
        
        # Transformer Encoder
        self.layers = nn.ModuleList(
            [TransformerBlock(emb_dim, heads, mlp_ratio, dropout=dropout) for _ in range(n_layers)]
        )
        
        # Classification head
        self.head = nn.Sequential(
            nn.LayerNorm(emb_dim),
            nn.Linear(emb_dim, num_classes)
        )
        
        self.dropout = nn.Dropout(dropout)
        self.criterion = nn.CrossEntropyLoss()
        
        # Metrics
        self.train_accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.test_accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.val_accuracy = MulticlassAccuracy(num_classes=num_classes)
        self.train_f1 = MulticlassF1Score(num_classes=num_classes)
        self.test_f1 = MulticlassF1Score(num_classes=num_classes)
        self.val_f1 = MulticlassF1Score(num_classes=num_classes)
        self.train_precision = MulticlassPrecision(num_classes=num_classes)
        self.test_precision = MulticlassPrecision(num_classes=num_classes)
        self.val_precision = MulticlassPrecision(num_classes=num_classes)
        self.train_recall = MulticlassRecall(num_classes=num_classes)
        self.test_recall = MulticlassRecall(num_classes=num_classes)
        self.val_recall = MulticlassRecall(num_classes=num_classes)
    
    def forward(self, img):
        x = self.patch_embedding(img)
        b, n, _ = x.shape
        
        # Add classification token
        cls_tokens = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat([cls_tokens, x], dim=1)
        
        # Add positional embedding
        x += self.pos_embedding[:, :(n + 1)]
        
        x = self.dropout(x)
        
        for layer in self.layers:
            x = layer(x)
        
        return self.head(x[:, 0])
    
    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        
        # Calculate metrics
        acc = self.train_accuracy(y_hat, y)
        f1 = self.train_f1(y_hat, y)
        precision = self.train_precision(y_hat, y)
        recall = self.train_recall(y_hat, y)
        
        # Log metrics
        self.log('train_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('train_f1', f1, on_step=False, on_epoch=True)
        self.log('train_precision', precision, on_step=False, on_epoch=True)
        self.log('train_recall', recall, on_step=False, on_epoch=True)
        
        return loss
    
    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        
        # Calculate metrics
        acc = self.val_accuracy(y_hat, y)
        f1 = self.val_f1(y_hat, y)
        precision = self.val_precision(y_hat, y)
        recall = self.val_recall(y_hat, y)
        
        # Log metrics
        self.log('val_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('val_f1', f1, on_step=False, on_epoch=True)
        self.log('val_precision', precision, on_step=False, on_epoch=True)
        self.log('val_recall', recall, on_step=False, on_epoch=True)
        
        return loss
    
    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = self.criterion(y_hat, y)
        
        # Calculate Metrics
        acc = self.test_accuracy(y_hat, y)
        f1 = self.test_f1(y_hat, y)
        precision = self.test_precision(y_hat, y)
        recall = self.test_recall(y_hat, y)
        
        # Log metrics
        self.log('test_loss', loss, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_acc', acc, on_step=False, on_epoch=True, prog_bar=True)
        self.log('test_f1', f1, on_step=False, on_epoch=True)
        self.log('test_precision', precision, on_step=False, on_epoch=True)
        self.log('test_recall', recall, on_step=False, on_epoch=True)
        return loss
        
    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.hparams.learning_rate)
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5)
        ##return {
        #    "optimizer": optimizer,
        #    "lr_scheduler": {
        #        "scheduler": scheduler,
        #        "monitor": "val_loss",
        #    },
        #}
        #scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)
        return [optimizer]#, [scheduler]

In [10]:
# Example usage:
data_module = FashionMNISTDataModule(batch_size=64)
model = ViT(ch=1, img_size=28, patch_size=4, emb_dim=64, n_layers=6, num_classes=10, dropout=0.1, heads=8, learning_rate=1e-3, embedding_type='linear')

# Callbacks
metrics_callback = MetricsCallback()
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',
    dirpath='checkpoints',
    filename='vit-{epoch:02d}-{val_loss:.2f}',
    save_top_k=3,
    mode='min',
)
early_stop_callback = EarlyStopping(
    monitor='val_loss',
    min_delta=0.00,
    patience=10,
    verbose=False,
    mode='min'
)

trainer = pl.Trainer(
    max_epochs=50,
    precision="16-mixed",
    devices=1 if torch.cuda.is_available() else 1,
    accelerator='gpu' if torch.cuda.is_available() else 'cpu',
    callbacks=[
        metrics_callback, 
        #checkpoint_callback, 
        #early_stop_callback
    ]
)

trainer.fit(model, data_module)

Using 16bit Automatic Mixed Precision (AMP)
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

   | Name            | Type                | Params | Mode 
-----------------------------------------------------------------
0  | patch_embedding | PatchEmbedding      | 1.1 K  | train
1  | layers          | ModuleList          | 226 K  | train
2  | head            | Sequential          | 778    | train
3  | dropout         | Dropout             | 0      | train
4  | criterion       | CrossEntropyLoss    | 0      | train
5  | train_accuracy  | MulticlassAccuracy  | 0      | train
6  | test_accuracy   | MulticlassAccuracy  | 0      | train
7  | val_accuracy    | MulticlassAccuracy  | 0      | train
8  | train_f1        | MulticlassF1Score   | 0      | train
9  | test_f1         | MulticlassF1Score   | 0      | train
10 | val_f1          | MulticlassF1Score   | 0      | train
11 | train_

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

Process Process-317:
Process Process-307:

Detected KeyboardInterrupt, attempting graceful shutdown ...
Process Process-308:
Traceback (most recent call last):
  File "/usr/lib/python3.11/multiprocessing/process.py", line 317, in _bootstrap
    util._exit_function()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 360, in _exit_function
    _run_finalizers()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 300, in _run_finalizers
    finalizer()
  File "/usr/lib/python3.11/multiprocessing/util.py", line 224, in __call__
    res = self._callback(*self._args, **self._kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/usr/lib/python3.11/multiprocessing/util.py", line 133, in _remove_temp_dir
    rmtree(tempdir)
  File "/usr/lib/python3.11/shutil.py", line 715, in rmtree
    if isinstance(path, bytes):
       ^^^^^^^^^^^^^^^^^^^^^^^
KeyboardInterrupt


NameError: name 'exit' is not defined

In [None]:
trainer.test(model, data_module)

In [None]:
# After training, you can plot the metrics
model_monitor(**metrics_callback.get_metrics())