## cifar training workflow (lite version)

This is a small training workflow with minimum amount of functions required, for debugging purposes: I need to reach the standard 90+ accuracy on the test set

In [None]:
# IMPORTS

import os
import sys

from typing import Any, Callable, List, Optional, Type, Union, Tuple, Dict
from torch import Tensor

import os
import sys
import torch
from torch import nn, optim, mps
from torch.utils.data import DataLoader, Dataset
import torchvision
from torchvision import transforms

import numpy as np
import matplotlib.pyplot as plt
import copy
from tqdm import tqdm

from models import resnet
from data import get_dataset
from visualization import show_cifar_images
from training import (
    AccuracyMetric,
    get_confusion_matrix,
    parse_loss,
    parse_scheduler,
    parse_optimizer,
    get_dataset_classes,
)
from utils import (
    bayes_eval,
    bayes_forward,
    count_FLOPS,
    count_parameters,
    calculate_storage_in_mb,
)

In [None]:
device = torch.device("cpu")
if torch.backends.mps.is_available():
    device = torch.device("mps")
if torch.cuda.is_available():
    device = torch.device("cuda:0")

print(f'Using Device {device}')

In [None]:
classes = get_dataset_classes("cifar10")

transform_train = transforms.Compose(
    [
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

transform_test = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ]
)

trainset = torchvision.datasets.CIFAR10(
    root="../data", train=True, download=True, transform=transform_train
)

testset = torchvision.datasets.CIFAR10(
    root="../data", train=False, download=True, transform=transform_test
)

trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=128, shuffle=True, num_workers=2
)

testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=2
)

In [None]:
# Define the resnet and attach predictor
net = torchvision.models.resnet18(torchvision.models.ResNet18_Weights.IMAGENET1K_V1)
net.fc = nn.Linear(in_features=512, out_features=10)

net.to(device)
print('')

In [None]:
# Define the optimizer, loss func, scheduler
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.1, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=200)

In [None]:
# A helper progressbar
def progress_bar(current, total, msg=None):
    global last_time, begin_time
    if current == 0:
        begin_time = time.time()  # Reset for new bar.

    cur_len = int(TOTAL_BAR_LENGTH * current / total)
    rest_len = int(TOTAL_BAR_LENGTH - cur_len) - 1

    sys.stdout.write(" [")
    for i in range(cur_len):
        sys.stdout.write("=")
    sys.stdout.write(">")
    for i in range(rest_len):
        sys.stdout.write(".")
    sys.stdout.write("]")

    cur_time = time.time()
    step_time = cur_time - last_time
    last_time = cur_time
    tot_time = cur_time - begin_time

    L = []
    L.append("  Step: %s" % format_time(step_time))
    L.append(" | Tot: %s" % format_time(tot_time))
    if msg:
        L.append(" | " + msg)

    msg = "".join(L)
    sys.stdout.write(msg)
    for i in range(term_width - int(TOTAL_BAR_LENGTH) - len(msg) - 3):
        sys.stdout.write(" ")

    # Go back to the center of the bar.
    for i in range(term_width - int(TOTAL_BAR_LENGTH / 2) + 2):
        sys.stdout.write("\b")
    sys.stdout.write(" %d/%d " % (current + 1, total))

    if current < total - 1:
        sys.stdout.write("\r")
    else:
        sys.stdout.write("\n")
    sys.stdout.flush()

In [None]:
def train(epoch):
    print("\nEpoch: %d" % epoch)
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for inputs, targets in tqdm(testloader):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    print(
        "Acc: %.3f%% (%d/%d)" % (100.0 * correct / total, correct, total),
    )

In [None]:
# evaluate model
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, targets in tqdm(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

        print(
            "Acc: %.3f%% (%d/%d)" % (100.0 * correct / total, correct, total),
        )
    # Save checkpoint.
    acc = 100.0 * correct / total

    return acc

In [None]:
history = []
start_epoch = 0
for epoch in range(start_epoch, start_epoch+200):
    train(epoch)
    test_acc = test(epoch)
    history.append(test_acc)
    scheduler.step()