# ResNet50 with Pytorch Lightning

## Data

In [1]:
import torch
import pytorch_lightning as pl
from torchvision.datasets import CIFAR10
from torchvision import transforms
import torch.utils.data as data


DATASET_PATH = "../data/"

def get_data_loader(args):
    train_dataset = CIFAR10(root=DATASET_PATH, train=True, download=True)
    DATA_MEANS = (train_dataset.data / 255.0).mean(axis=(0, 1, 2))
    DATA_STD = (train_dataset.data / 255.0).std(axis=(0, 1, 2))
    print("Data mean", DATA_MEANS)
    print("Data std", DATA_STD)

    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
    ])

    train_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=train_transform, download=True)
    val_dataset = CIFAR10(root=DATASET_PATH, train=True, transform=test_transform, download=True)
    train_set, _ = torch.utils.data.random_split(train_dataset, [45000, 5000])
    _, val_set = torch.utils.data.random_split(val_dataset, [45000, 5000])

    test_set = CIFAR10(root=DATASET_PATH, train=False, transform=test_transform, download=True)

    train_loader = data.DataLoader(train_set, batch_size=args.batch_size, shuffle=True, drop_last=False, pin_memory=True, num_workers=8,persistent_workers=True)
    val_loader = data.DataLoader(val_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=8)
    test_loader = data.DataLoader(test_set, batch_size=args.batch_size, shuffle=False, drop_last=False, num_workers=8,persistent_workers=True)

    return train_loader, val_loader, test_loader

## Model

In [None]:
import torch
import torch.nn as nn
from torchvision.models import resnet50
import torch.optim as optim
from transformers import get_linear_schedule_with_warmup


class ResNet50(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = resnet50(pretrained=True, progress=True)
        self.classifier = nn.Sequential(
            nn.Linear(1000, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
        for param in self.model.parameters():
            param.requires_grad = False
    
    def forward(self, imgs):
        return self.classifier(self.model(imgs))


class CIFARModule(pl.LightningModule):
    def __init__(self, args) -> None:
        super().__init__()
        self.save_hyperparameters()
        self.args = args
        self.model = ResNet50()
        self.loss = nn.CrossEntropyLoss()
        self.example_input_array = torch.zeros((1, 3, 32, 32), dtype=torch.float32)

    def forward(self, imgs):
        return self.model(imgs)

    def configure_optimizers(self):
        optimizer = None
        scheduler = None
        if self.args.optimizer_name == "Adamw":
            optimizer = optim.AdamW(self.parameters(), lr=self.args.lr)
        elif self.args.optimizer_name == "SGD":
            optimizer = optim.SGD(self.parameters(), lr=self.args.lr, momentum=0.9)
        
        if self.args.scheduler_name == "lr_schedule":
            scheduler = get_linear_schedule_with_warmup(
                optimizer=optimizer, num_warmup_steps=self.args.warmup_step,
                num_training_steps=self.args.total_steps)

        if optimizer and scheduler:
            return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
        elif optimizer:
            return [optimizer]

    # Training Step is called for each batch
    def training_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        self.log("acc/train", acc, on_step=True)
        self.log("loss/tain", loss, on_step=True)
        return loss

    # Validation Step is called for each batch
    def validation_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        self.log("acc/val", acc, on_step=True)
        self.log("loss/val", loss, on_step=True)

    def test_step(self, batch, batch_idx):
        imgs, labels = batch
        preds = self.model(imgs)
        loss = self.loss(preds, labels)
        acc = (preds.argmax(dim=-1) == labels).float().mean()
        self.log("acc/test", acc, on_step=True)
        self.log("loss/test", loss, on_step=True)


In [6]:
from datetime import datetime
from dataclasses import dataclass
from pytorch_lightning.callbacks import LearningRateMonitor
pl.seed_everything(42)
dt = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
save_dir = f"lightning_logs/{dt}"

@dataclass
class Args:
    batch_size: int = 64
    lr: float = 1e-3
    optimizer_name: str = "Adamw"
    scheduler_name: str = "lr_schedule"
    warmup_step: int = 500
    total_steps: int = 10000

args = Args()
train_loader, val_loader, test_loader = get_data_loader(args)
model = CIFARModule(args)
lr_monitor = LearningRateMonitor(logging_interval="step")
trainer = pl.Trainer(
    max_epochs=1,
    callbacks=[lr_monitor]
)

Seed set to 42


Files already downloaded and verified
Data mean [0.49139968 0.48215841 0.44653091]
Data std [0.24703223 0.24348513 0.26158784]
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [7]:
trainer.fit(model, train_loader, val_loader)

You are using a CUDA device ('NVIDIA GeForce RTX 4060 Laptop GPU') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type             | Params | Mode  | In sizes       | Out sizes
--------------------------------------------------------------------------------
0 | model | ResNet50         | 26.1 M | train | [1, 3, 32, 32] | [1, 10]  
1 | loss  | CrossEntropyLoss | 0      | train | ?              | ?        
--------------------------------------------------------------------------------
517 K     Trainable params
25.6 M    Non-trainable params
26.1 M    Total params
104.299   Total estimated model params size (MB)
157       Modules in train mode
0         Modules in eval mode


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

c:\Users\runze\.conda\envs\llm\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'val_dataloader' to speed up the dataloader worker initialization.
c:\Users\runze\.conda\envs\llm\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'train_dataloader' to speed up the dataloader worker initialization.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=1` reached.


In [8]:
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
c:\Users\runze\.conda\envs\llm\Lib\site-packages\pytorch_lightning\trainer\connectors\data_connector.py:419: Consider setting `persistent_workers=True` in 'test_dataloader' to speed up the dataloader worker initialization.


Testing: |          | 0/? [00:00<?, ?it/s]

────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     test_acc_epoch         0.7652000188827515
     test_loss_epoch        0.6833807826042175
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'test_acc_epoch': 0.7652000188827515, 'test_loss_epoch': 0.6833807826042175}]