Image Classification

In [None]:
%matplotlib inline

import os
import sys
from pathlib import Path
from functools import partial

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torchvision.utils import make_grid

from models.resnet import ResNet, ResNetConfig, resnet18, resnet34, resnet50, resnet101, resnet152
from models.efficientnet import EfficientNet, EfficientNetConfig, efficientnet_b0, efficientnetv2_s

torch.manual_seed(0)

Choose dataset.

In [None]:
# available datasets: 'cifar10', 'cifar100', 'mnist', 'emnist'
_dataset = 'cifar10'

if _dataset == 'cifar10':
    dataset = datasets.CIFAR10
elif _dataset == 'cifar100':
    dataset = datasets.CIFAR100
elif _dataset == 'mnist':
    dataset = datasets.MNIST
elif _dataset == 'emnist':
    dataset = partial(datasets.EMNIST, split='letters')
else:
    raise ValueError('dataset only supports {cifar10|cifar100|mnist|emnist}')

Compute mean and std.

In [None]:
batch_size = 128
num_workers = 8

trainset = dataset(root='./data', train=True, transform=transforms.ToTensor(), download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)

classes = trainset.classes
num_classes = len(classes)

images, labels = next(iter(trainloader))

size = images.size(dim=2)
in_channels = images.size(dim=1)

s = torch.zeros(in_channels, dtype=torch.float64)
ss = torch.zeros(in_channels, dtype=torch.float64)
total = 0
for images, labels in trainloader:
    s += torch.sum(images, dim=(0, 2, 3))
    ss += torch.sum(torch.square(images), dim=(0, 2, 3))
    total += images.size(dim=0) * images.size(dim=2) * images.size(dim=3)

mean = torch.div(s, total)
std = torch.sqrt(torch.sub(torch.div(ss, total), torch.square(mean)))

print(mean)
print(std)

fig, ax = plt.subplots()
ax.imshow(torch.permute(make_grid(images[:4]), dims=(1, 2, 0)).numpy())
ax.axis('off')
plt.tight_layout()
print(' '.join(f'{classes[labels[j]]}' for j in range(4)))

Choose batch size, number of workers. Build dataloader.

In [None]:
batch_size = 128
num_workers = 8

augment = transforms.Compose([
    transforms.RandomCrop(size, padding=4, padding_mode='reflect'),
    transforms.RandomHorizontalFlip()
])
normalize = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)
])

trainset = dataset(root='./data', train=True, transform=transforms.Compose([augment, normalize]), download=True)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=num_workers, drop_last=False)

testset = dataset(root='./data', train=False, transform=normalize, download=True)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=num_workers, drop_last=False)

Choose device, model.

In [None]:
_device = 'cuda'
device = torch.device(_device)

model = resnet18(in_channels=in_channels, num_classes=num_classes, dropout=0.2)
model.to(device)

print(model)

Choose epochs, steps, lr, weight decay, optimizer, scheduler.

In [None]:
epochs = 500
steps = 50

lr = 1e-3
weight_decay = 3e-4
T_0 = 1
T_mult = 2

criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=T_0, T_mult=T_mult)

Make a checkpoint directory. Set file paths.

In [None]:
checkpoint_dir = './checkpoints'
Path(checkpoint_dir).mkdir(parents=True, exist_ok=True)

checkpoint = os.path.join(checkpoint_dir, 'model.pt')
log = os.path.join(checkpoint_dir, 'model.log')

Load model if you want.

In [None]:
load = False

if load:
    state = torch.load(checkpoint)
    model.load_state_dict(state['model'])
    optimizer.load_state_dict(state['optimizer'])
    scheduler.load_state_dict(state['scheduler'])

    load_epoch = state['epoch']
    best_accuracy = state['best_accuracy']
else:
    load_epoch = 1
    best_accuracy = 0.0

Redirect stdout to a log file if you want.

In [None]:
redirect = True

if redirect:
    f = open(log, 'w')
    stdout = sys.stdout
    sys.stdout = f

Train model. Save the best model every epoch.

In [None]:
for epoch in range(load_epoch - 1, epochs):
    running_loss = 0.0
    correct = 0
    total = 0
    model.train()
    for step, (images, labels) in enumerate(trainloader):
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        scheduler.step()

        running_loss += loss.item()
        correct += torch.sum(torch.argmax(outputs, dim=-1) == labels).item()
        total += labels.size(dim=0)

        if (step + 1) % steps == 0:
            print(f'[Epoch {epoch + 1:03d}] [Step {step + 1:04d}] Loss: {running_loss / (step + 1):.4f}, Accuracy: {correct / total * 100:.4f} %')

    with torch.no_grad():
        correct = 0
        total = 0
        model.eval()
        for images, labels in testloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            correct += torch.sum(torch.argmax(outputs, dim=-1) == labels).item()
            total += labels.size(dim=0)

    accuracy = correct / total * 100
    print(f'[Epoch {epoch + 1:03d}] Accuracy: {accuracy:.4f} %, Best accuracy: {best_accuracy:.4f} %')

    if accuracy > best_accuracy:
        best_accuracy = accuracy
        state = {
            'model': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'scheduler': scheduler.state_dict(),
            'epoch': epoch + 1,
            'best_accuracy': best_accuracy,
        }
        torch.save(state, './checkpoints/model.pt')
        print('New best accuracy, saved model.')

Restore stdout.

In [None]:
if redirect:
    sys.stdout = stdout
    f.close()

Test model.

In [None]:
state = torch.load(checkpoint)
model.load_state_dict(state['model'])

with torch.no_grad():
    correct = 0
    total = 0
    model.eval()
    for images, labels in testloader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        correct += torch.sum(torch.argmax(outputs, dim=-1) == labels).item()
        total += labels.size(dim=0)

accuracy = correct / total * 100
print(f'Accuracy: {accuracy:.4f} %')