In [None]:
import torch
from torch import autocast, nn, optim
from torchvision import datasets, models, transforms

In [None]:
AUTOCAST_FLAG = True
COMPILE_FLAG = True
num_workers = 2
epochs = 5
batch_size = 100
eval_batch_size = 1000

In [None]:
root = '~/.pytorch/datasets/'
device = torch.device(f'cuda:{torch.cuda.device_count() - 1}' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}, Type: {device.type}')

In [None]:
mean = torch.tensor([129.3, 124.1, 112.4]) / 255
std = torch.tensor([68.2, 65.4, 70.4]) / 255

transform = {
    'train': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std),
        transforms.RandomCrop(32, padding=4, padding_mode='constant'),
        transforms.RandomHorizontalFlip()
    ]),
    'eval': transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(mean, std)
    ])
}

dataset = {
    'train': datasets.CIFAR100(root=root, train=True, download=True, transform=transform['train']),
    'test': datasets.CIFAR100(root=root, train=False, download=True, transform=transform['eval'])
}

dataloader = {
    'train': torch.utils.data.DataLoader(
        dataset['train'], batch_size=batch_size, shuffle=True, num_workers=num_workers
    ),
    'test': torch.utils.data.DataLoader(
        dataset['test'], batch_size=eval_batch_size, num_workers=num_workers
    )
}

In [None]:
model = models.resnet18()
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Sequential()
model.fc = nn.Linear(in_features=512, out_features=100, bias=True)
model = model.to(device)

In [None]:
scaler = torch.cuda.amp.GradScaler(enabled=True if device.type=='cuda' and AUTOCAST_FLAG else False)
# compile_mode: 'default', 'reduce-overhead', 'max-autotune'
model = torch.compile(model, mode='default', fullgraph=True, disable=not COMPILE_FLAG)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=25, gamma=0.2)

In [None]:
def train_step(model, dataset, dataloader, AUTOCAST_FLAG=False):
    record_loss, record_acc = 0, 0
    model.train()
    for i, data in enumerate(dataloader):
        # load data
        inputs = data[0].to(device)
        labels = data[1].to(device)
        # compute
        optimizer.zero_grad()
        with autocast(device.type, enabled=AUTOCAST_FLAG):
            outputs = model(inputs)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        # record
        predict_labels = torch.max(outputs, dim=1).indices
        record_loss += loss.item()
        record_acc += torch.sum(labels==predict_labels).item()
    record_loss /= len(dataloader)
    record_acc /= len(dataset)
    return record_loss, record_acc

def eval_step(model, dataset, dataloader, AUTOCAST_FLAG=False):
    record_loss, record_acc = 0, 0
    model.eval()
    with torch.no_grad():
        for i, data in enumerate(dataloader):
            # load data
            inputs = data[0].to(device)
            labels = data[1].to(device)
            # compute
            with autocast(device.type, enabled=AUTOCAST_FLAG):
                outputs = model(inputs)
                loss = criterion(outputs, labels)
            # record
            predict_labels = torch.max(outputs, dim=1).indices
            record_loss += loss.item()
            record_acc += torch.sum(labels==predict_labels).item()
    record_loss /= len(dataloader)
    record_acc /= len(dataset)
    return record_loss, record_acc

In [None]:
def timed(fn):
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)
    start.record()
    result = fn()
    end.record()
    torch.cuda.synchronize()
    return result, start.elapsed_time(end) / 1000

In [None]:
for epoch in range(epochs):
    # train
    result, time_cost = timed(
        lambda: train_step(model, dataset['train'], dataloader['train'], AUTOCAST_FLAG)
    )
    train_loss, train_acc = result
    # eval
    test_loss, test_acc = eval_step(model, dataset['test'], dataloader['test'], AUTOCAST_FLAG)
    # print results
    print('----')
    print(f'epoch {epoch}')
    print(f'AUTOCAST: {AUTOCAST_FLAG}, COMPILE: {COMPILE_FLAG}')
    print(f'time_cost: {time_cost}')
    print(f'batch_size: {batch_size}')
    print(f'learning_rate: {scheduler.get_last_lr()}')
    print(f'train_loss: {train_loss}, train_acc: {train_acc}')
    print(f'test_loss: {test_loss}, test_acc: {test_acc}')
    print('----')
    # scheduler
    scheduler.step()