# CIFAR-10 classification

## Setup

In [1]:
# !pip install -r requirements.txt

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import os
import math

import hydra
import torch
import torchvision
import wandb

import pandas as pd
import seaborn as sn
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms

from dataclasses import dataclass
from omegaconf import DictConfig, OmegaConf
from pytorch_lightning import LightningModule, Trainer, seed_everything
from pytorch_lightning.callbacks import LearningRateMonitor, ModelCheckpoint
from pytorch_lightning.callbacks.progress import TQDMProgressBar
from pytorch_lightning.loggers import CSVLogger, WandbLogger
from pytorch_lightning.utilities.model_summary import ModelSummary

from torch.utils.data import random_split
from torch.optim.lr_scheduler import OneCycleLR, _LRScheduler, ExponentialLR, MultiStepLR, StepLR
from torch.optim.swa_utils import AveragedModel, update_bn
from torchmetrics.functional import accuracy

In [4]:
# Constants
@dataclass
class Config:
    seed: int = 69
    batch_size: int = 256 if torch.cuda.is_available() else 64
    n_workers: int = 8
    
    n_epochs = 30
    val_size = 5000
    
config = Config
seed_everything(config.seed)
torch.set_float32_matmul_precision('medium')

Global seed set to 69


## Data preparation

In [5]:
class DatasetFromSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform=None):
        self.subset = subset
        self.transform = transform

    def __getitem__(self, index):
        x, y = self.subset[index]
        if self.transform:
            x = self.transform(x)
        return x, y

    def __len__(self):
        return len(self.subset)
    
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [6]:
# full train dataset. Will be splitted for train/val
dataset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=None)

# Normalization constants
mean_values = dataset.data.mean(axis=(0,1,2)) / 255
std_values = dataset.data.std(axis=(0,1,2)) / 255

print(dataset.data.shape)
print(mean_values)
print(std_values)

Files already downloaded and verified
(50000, 32, 32, 3)
[0.49139968 0.48215841 0.44653091]
[0.24703223 0.24348513 0.26158784]


In [7]:
## Transforms

train_transform = transforms.Compose(
    [
     transforms.ToTensor(),
     torchvision.transforms.RandomCrop(32, padding=4),
     torchvision.transforms.RandomHorizontalFlip(),
     transforms.Normalize(mean_values, std_values),
    ]
)

val_transform = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize(mean_values, std_values),
    ]
)

test_transform = transforms.Compose(
    [
     transforms.ToTensor(),
     transforms.Normalize(mean_values, std_values),
    ]
    
)

In [8]:
## Datasets

trainset, valset = random_split(dataset, [len(dataset) - config.val_size, config.val_size])

trainset = DatasetFromSubset(trainset, transform=train_transform)

valset = DatasetFromSubset(valset, transform=val_transform)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)

Files already downloaded and verified


In [15]:
len(trainset), len(valset), len(testset)

(45000, 5000, 10000)

In [9]:
## DataLoaders

trainloader = torch.utils.data.DataLoader(trainset, batch_size=config.batch_size, pin_memory=True,
                                          shuffle=True, num_workers=config.n_workers)

valloader = torch.utils.data.DataLoader(valset, batch_size=config.batch_size, pin_memory=True,
                                        shuffle=False, num_workers=config.n_workers)

testloader = torch.utils.data.DataLoader(testset, batch_size=config.batch_size, pin_memory=True,
                                         shuffle=False, num_workers=config.n_workers)

## Modelling

In [10]:
def create_model():
    # Resnet model
    model = torchvision.models.resnet18(num_classes=10, pretrained=False)
    model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    model.maxpool = nn.Identity()
    
    # # FC model
    # class Net(nn.Module):
        # def __init__(self):
        #     super().__init__()
        #     self.fc1 = nn.Linear(3*32*32, 32*16)
        #     self.fc2 = nn.Linear(32*16, 32*4)
        #     self.fc3 = nn.Linear(32*4, 10)
            
        # def forward(self, x):
        #     x = torch.flatten(x, start_dim=1)
        #     x = F.relu(self.fc1(x))
        #     x = F.relu(self.fc2(x))
        #     x = self.fc3(x)
        #     return x

    # Conv model
    # class Net(nn.Module):
        # def __init__(self):
        #     super().__init__()
        #     self.conv1 = nn.Conv2d(3, 6, 5, bias=False)
        #     self.bn1 = nn.BatchNorm2d(6)
        #     self.pool = nn.MaxPool2d(2, 2)
        #     self.conv2 = nn.Conv2d(6, 16, 5, bias=False)
        #     self.bn2 = nn.BatchNorm2d(16)
        #     self.fc1 = nn.Linear(16 * 5 * 5, 120)
        #     self.dropout1 = torch.nn.Dropout(p=0.2, inplace=False)
        #     self.fc2 = nn.Linear(120, 84)
        #     self.dropout2 = torch.nn.Dropout(p=0.2, inplace=False)
        #     self.fc3 = nn.Linear(84, 10)

        # def forward(self, x):
        #     # x = self.pool(F.sigmoid(self.conv1(x)))
        #     # x = self.pool(F.sigmoid(self.conv2(x)))  
        #     x = self.pool(self.dropout1(F.relu(self.bn1(self.conv1(x)))))
        #     x = self.pool(self.dropout2(F.relu(self.bn2(self.conv2(x))))) 
        #     x = torch.flatten(x, 1)
        #     x = F.relu(self.fc1(x))
        #     x = F.relu(self.fc2(x))
        #     x = self.fc3(x)
        #     return x
    
    # model = Net()
    
    # # zero initialization
    # for m in model.modules():
    #     if isinstance(m, torch.nn.Linear):
    #         m.weight = torch.nn.init.zeros_(m.weight)
    #         m.bias = torch.nn.init.zeros_(m.bias)

    return model

### Lightning Module

In [11]:
class LitModule(LightningModule):
    def __init__(self, config):
        super().__init__()
        
        self.config = config
        self.save_hyperparameters()
        self.model = create_model()
        self.example_input_array = torch.zeros(2, 3, 32, 32)

    def forward(self, x):
        out = self.model(x)
        return F.log_softmax(out, dim=1)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        self.log("train_loss", loss)
        return loss

    def evaluate(self, batch, stage=None):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        preds = torch.argmax(logits, dim=1)
        acc = accuracy(preds, y, task='multiclass', num_classes=10)
        
        if stage:
            self.log(f"{stage}_loss", loss, prog_bar=True)
            self.log(f"{stage}_acc", acc, prog_bar=True)

    def validation_step(self, batch, batch_idx):
        self.evaluate(batch, "val")

    def test_step(self, batch, batch_idx):
        self.evaluate(batch, "test")

    def configure_optimizers(self):
        # optimizer = torch.optim.SGD(
        #     self.parameters(),
        #     lr=1e-3,
        #     momentum=0.9,
        #     weight_decay=5e-4,
        # )
        
        optimizer = torch.optim.Adam(
            self.parameters(),
            lr = 1e-3,
            weight_decay=5e-4,
        )
        
        steps_per_epoch = 45000 // config.batch_size + 1
        scheduler_dict = {
            "scheduler": OneCycleLR(
                optimizer,
                0.01,
                div_factor = 25,
                final_div_factor = 5e4,
                pct_start=0.2,
                epochs=self.trainer.max_epochs,
                steps_per_epoch=steps_per_epoch,
            ),
            "interval": "step",
            "name": "lr",
        }
        
        
        # scheduler_dict = {
        #     "scheduler": StepLR(
        #         optimizer,
        #         step_size=5,
        #         gamma=0.5),
        #     "interval": "epoch",
        #     "name": "lr",
        # }        
        
        return {"optimizer": optimizer, "lr_scheduler": scheduler_dict}

In [12]:
wandb.login()
wandb.init(
    project="img_classification_cifar10",
    name="resnet18_cyclelr",
)

[34m[1mwandb[0m: Currently logged in as: [33mwhatislove[0m. Use [1m`wandb login --relogin`[0m to force relogin


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.016668401450018185, max=1.0…

In [13]:
model = LitModule(config)

# ModelSummary(model, -1)



In [14]:
trainer = Trainer(
    max_epochs=config.n_epochs,
    accelerator="auto",
    devices=1,
    logger=[CSVLogger(save_dir="logs/"), WandbLogger()],
    callbacks=[
        LearningRateMonitor(logging_interval="step"),
        TQDMProgressBar(refresh_rate=1),
        ModelCheckpoint(dirpath='checkpoints/' + wandb.run.name, save_top_k=1,
                        filename='best', monitor="val_loss")
    ],
    # profiler="simple",
)

trainer.fit(model, 
            train_dataloaders=trainloader,
            val_dataloaders=valloader)

trainer.test(model, dataloaders=testloader)
wandb.finish()

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params | In sizes       | Out sizes
--------------------------------------------------------------
0 | model | ResNet | 11.2 M | [2, 3, 32, 32] | [2, 10]  
--------------------------------------------------------------
11.2 M    Trainable params
0         Non-trainable params
11.2 M    Total params
44.696    Total estimated model params size (MB)


Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Testing: 0it [00:00, ?it/s]

[34m[1mwandb[0m: [32m[41mERROR[0m Control-C detected -- Run data was not synced
