1. Develop an image classification model based on transformer architecture without relying on pre-implemented transformer or self-attention modules such as torch.nn.Transformer or torch.nn.MultiheadAttention.

In [1]:
from modules.config import ViTConfig, TrainingConfig, DataConfig
from modules.ViT import VisionTransformer
import lightning as L
import torch
import torch.nn as nn
from torchvision import datasets
import torchvision.transforms as transforms
from torchvision.models import resnet152  # For comparison
from torch.utils.data import DataLoader, Subset
from torchmetrics import AUROC
from torchmetrics.classification import MulticlassAccuracy, MulticlassROC, MulticlassF1Score
from lightning.pytorch.callbacks import Timer  # GPUStatsMonitor

import numpy as np

from dataclasses import asdict
import time

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
class ClassificationLightningModel(L.LightningModule):
    def __init__(
        self,
        model,
        num_classes,
        lr,
        weight_decay=0.0,
        lr_scheduler=torch.optim.lr_scheduler.CosineAnnealingLR,
    ):
        super().__init__()
        self.model = model

        self.lr_scheduler = lr_scheduler
        self.lr = lr
        self.weight_decay = weight_decay

        # Basic metrics
        self.train_acc = MulticlassAccuracy(num_classes=num_classes, average="micro")
        self.val_acc = MulticlassAccuracy(num_classes=num_classes, average="micro")
        self.test_acc = MulticlassAccuracy(num_classes=num_classes, average="micro")

        # Additional classification metrics
        self.test_roc = MulticlassROC(num_classes=num_classes)
        self.test_f1 = MulticlassF1Score(num_classes=num_classes)
        self.test_auroc = AUROC(task="multiclass", num_classes=num_classes)

        # Time tracking
        self.inference_times = []

        # Save hyperparameters for logging
        self.save_hyperparameters(ignore=["model"])

    def forward(self, x):
        return self.model(x)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.train_acc(logits, y)

        self.log("train_loss", loss, prog_bar=True)
        self.log("train_acc", self.train_acc, prog_bar=True)
        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = nn.functional.cross_entropy(logits, y)
        self.val_acc(logits, y)

        self.log("val_loss", loss, prog_bar=True)
        self.log("val_acc", self.val_acc, prog_bar=True)

    def test_step(self, batch, batch_idx):
        x, y = batch

        # Measure inference time
        start_time = time.time()
        logits = self(x)
        inference_time = time.time() - start_time
        self.inference_times.append(inference_time)

        loss = nn.functional.cross_entropy(logits, y)

        # Calculate metrics
        self.test_acc(logits, y)
        self.test_f1(logits, y)
        self.test_auroc(logits, y)

        # Log metrics
        self.log("test_loss", loss)
        self.log("test_acc", self.test_acc)
        self.log("test_f1", self.test_f1)
        self.log("test_auroc", self.test_auroc)

        return {"loss": loss, "preds": logits, "targets": y}

    def on_test_epoch_end(self):
        # Calculate and log average inference time
        avg_inference_time = sum(self.inference_times) / len(self.inference_times)
        self.log("avg_inference_time", avg_inference_time)

        # Calculate model size
        model_size = sum(p.numel() for p in self.model.parameters())
        self.log("model_size", model_size)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
        scheduler = self.lr_scheduler(optimizer)

        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "interval": "epoch",
                "frequency": 1,
            },
        }


# Setup training with metrics monitoring
def train_and_evaluate_model(model, num_classes, train_loader, val_loader, test_loader, **kwargs):
    # Initialize model with metrics

    epochs = kwargs.get("epochs", 20)
    lr = kwargs.get("lr", 1e-3)
    lr_scheduler = kwargs.get("lr_scheduler", torch.optim.lr_scheduler.CosineAnnealingLR)
    weight_decay = kwargs.get("weight_decay", 0.0)

    lit_model = ClassificationLightningModel(
        model,
        num_classes=num_classes,
        lr=lr,
        lr_scheduler=lr_scheduler,
        weight_decay=weight_decay,
    )

    # Callbacks for monitoring
    timer = Timer()
    # TODO
    # gpu_stats = GPUStatsMonitor(
    #     memory_utilization=True,
    #     gpu_utilization=True,
    #     intra_step_time=True,
    #     fan_speed=True,
    #     temperature=True,
    # )

    # Initialize trainer with monitoring
    trainer = L.Trainer(
        max_epochs=epochs,
        accelerator="auto",
        devices=1,
        callbacks=[
            timer,
        ],  # TODO: gpu_stats
        enable_progress_bar=True,
        enable_model_summary=True,
    )

    # Train and test
    trainer.fit(lit_model, train_loader, val_loader)
    test_results = trainer.test(lit_model, test_loader)

    # Visualize test results in images
    L.visualize(test_results, save_path="test_results.png")  # TODO: fix this

    # Collect metrics
    metrics = {
        "test_accuracy": float(lit_model.test_acc.compute()),
        "training_time": timer.time_elapsed("train"),
        "model_size": sum(p.numel() for p in model.parameters()),
        "avg_inference_time": sum(lit_model.inference_times) / len(lit_model.inference_times),
        "test_f1": float(lit_model.test_f1.compute()),
        "test_auroc": float(lit_model.test_auroc.compute()),
    }

    return metrics

In [4]:
# Prepare data
data_config = DataConfig.base()

In [5]:
train_transform = transforms.Compose(
    [
        transforms.Resize(data_config.img_size),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

val_transform = transforms.Compose(
    [
        transforms.Resize(data_config.img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

In [6]:
trainset = datasets.CIFAR10(root="./data", train=True, download=True, transform=train_transform)
valset = datasets.CIFAR10(root="./data", train=True, download=True, transform=val_transform)

# Calculate split sizes
train_size = int(0.8 * len(trainset))
val_size = len(trainset) - train_size

# Generate indices for splitting
indices = list(range(len(trainset)))
np.random.shuffle(indices)
train_indices = indices[:train_size]
val_indices = indices[train_size:]

# Create subset datasets
train_data = Subset(trainset, train_indices)
val_data = Subset(valset, val_indices)


trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=data_config.batch_size, shuffle=True, num_workers=data_config.num_workers
)


testset = datasets.CIFAR10(root="./data", train=False, download=True, transform=val_transform)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=data_config.batch_size, shuffle=False, num_workers=data_config.num_workers
)

classes = ("plane", "car", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck")
train_loader = DataLoader(trainset, batch_size=data_config.batch_size, shuffle=True)
val_loader = DataLoader(valset, batch_size=data_config.batch_size)
test_loader = DataLoader(testset, batch_size=data_config.batch_size)

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [7]:
def compare_models(
    vit_model,
    cnn_model,
    num_classes,
    train_loader,
    val_loader,
    test_loader,
    vit_train_config: dict[str, any],
    cnn_train_config: dict[str, any],
    **kwargs
):

    print("Evaluating ViT Model...")
    vit_metrics = train_and_evaluate_model(
        vit_model, num_classes, train_loader, val_loader, test_loader, **vit_train_config
    )

    print("Evaluating CNN Model...")
    cnn_metrics = train_and_evaluate_model(
        cnn_model, num_classes, train_loader, val_loader, test_loader, **cnn_train_config
    )

    # Compare metrics
    comparison = {
        "Metric": [
            "Test Accuracy",
            "Training Time (s)",
            "Model Size",
            "Avg Inference Time (s)",
            "F1 Score",
            "AUROC",
        ],
        "ViT": [
            vit_metrics["test_accuracy"],
            vit_metrics["training_time"],
            vit_metrics["model_size"],
            vit_metrics["avg_inference_time"],
            vit_metrics["test_f1"],
            vit_metrics["test_auroc"],
        ],
        "CNN": [
            cnn_metrics["test_accuracy"],
            cnn_metrics["training_time"],
            cnn_metrics["model_size"],
            cnn_metrics["avg_inference_time"],
            cnn_metrics["test_f1"],
            cnn_metrics["test_auroc"],
        ],
    }

    # Print comparison table
    from tabulate import tabulate

    print("\nModel Comparison:")
    print(tabulate(comparison, headers="keys", tablefmt="grid"))

In [8]:
vit_config = ViTConfig.base()
vit_model = VisionTransformer(**asdict(vit_config))
cnn_model = resnet152()


compare_models(
    vit_model,
    cnn_model,
    data_config.num_classes,
    train_loader,
    val_loader,
    test_loader,
    asdict(TrainingConfig.base()),
    asdict(TrainingConfig.resnet152()),
)

GPU available: True (mps), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/Users/juyoungshin/Documents/code_repo/vit-assignment/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/logger_connector/logger_connector.py:75: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `lightning.pytorch` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | model      | VisionTransformer  | 85.8 M | train
1 | train_acc  | MulticlassAccuracy | 0      | train
2 | val_acc    | MulticlassAccuracy | 0      | train
3 | test_acc   | MulticlassAccuracy | 0

Evaluating ViT Model...
Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/Users/juyoungshin/Documents/code_repo/vit-assignment/.conda/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|          | 0/2 [00:00<?, ?it/s]


Detected KeyboardInterrupt, attempting graceful shutdown ...


NameError: name 'exit' is not defined