In [4]:
import torchvision.transforms as transforms
import torchvision
from pathlib import Path
import torch
from tqdm import tqdm
import torch.nn as nn
import os
import sys
sys.path.insert(0, os.path.abspath(os.path.join(os.getcwd(), '../')))
from optimizers.AdaFisher import AdaFisher
from models.resnet_cifar import ResNet18

In [None]:
H_bar_dir = Path("H_bar_resnet").expanduser()
S_dir = Path("S_resnet").expanduser()
if not H_bar_dir.exists():
        print(f"Info: Data dir {H_bar_dir} does not exist, building")
        H_bar_dir.mkdir(exist_ok=True, parents=True)
if not S_dir.exists():
        print(f"Info: Data dir {S_dir} does not exist, building")
        S_dir.mkdir(exist_ok=True, parents=True)

In [None]:
def get_network(type):
    if type == "resnet18":
        net = ResNet18()
    else:
        raise NotImplementedError
    return net

def get_data(batch_size):
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.2023, 0.1994, 0.2010)),
    ])

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465),
                                (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(
        root='./data', train=True, download=True,
        transform=transform_train)
    train_loader = torch.utils.data.DataLoader(
        trainset, batch_size=batch_size,
        shuffle=True,
        num_workers=4)

    testset = torchvision.datasets.CIFAR10(
        root='./data', train=False,
        download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(
        testset, batch_size=100, shuffle=False,
        num_workers=4)
    return train_loader, test_loader

In [None]:
def train_epoch(epoch, model, dataloader, optimizer, criterion, device):
    model.train()
    train_loss, correct, total = 0, 0, 0
    steps = 0
    for inputs, targets in tqdm(dataloader, desc="Training"):
        inputs = inputs.to(device); targets = targets.to(device)
        optimizer.zero_grad()
        output_train = model(inputs)
        loss_train = criterion(output_train, targets)
        if optimizer == "AdaHessian":
            loss_train.backward(create_graph=True)
        else:
            loss_train.backward()
        optimizer.step()

        train_loss += loss_train.item()
        _, predicted = output_train.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        steps += 1

    train_loss_epoch = train_loss / total
    train_accuracy_epoch = correct / total
    tqdm.write(f"== [TRAIN] Epoch: {epoch}, Loss: {train_loss_epoch:.3f}, Accuracy: {train_accuracy_epoch:.3f} ==>")
    return train_loss_epoch, train_accuracy_epoch
def valid_epoch(epoch, model, dataloader, criterion, device):
    model.eval()
    test_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for inputs, targets in tqdm(dataloader, desc="Validation"):
            inputs = inputs.to(device); targets = targets.to(device)
            output_test = model(inputs)
            loss_test = criterion(output_test, targets)
            test_loss += loss_test.item()
            _, predicted = output_test.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_loss_epoch = test_loss / total
    test_accuracy_epoch = correct / total
    tqdm.write(f"== [VALID] Epoch: {epoch}, Loss: {test_loss_epoch:.3f}, Accuracy: {test_accuracy_epoch:.3f} ==>")
    return test_loss_epoch, test_accuracy_epoch

In [None]:
def train_model(num_epochs, model, train_dataloader, test_dataloader, criterion, optimizer, device):
    train_loss, test_loss, train_accuracy, test_accuracy = [], [], [], []
    for epoch in range(num_epochs):
        train_loss_epoch, train_accuracy_epoch = train_epoch(epoch, model, train_dataloader, optimizer, criterion, device)
        train_loss.append(train_loss_epoch); train_accuracy.append(train_accuracy_epoch)
        test_loss_epoch, test_accuracy_epoch = valid_epoch(epoch, model, test_dataloader, criterion, device)
        test_loss.append(test_loss_epoch); test_accuracy.append(test_accuracy_epoch)

In [None]:
def main(num_epochs):
    if torch.cuda.is_available():
        device = 'cuda'
    elif torch.backends.mps.is_available():
        device = 'mps' 
    else:
        device = 'cpu'
        print("Warning: CPU will be slow when running")
    model = get_network(type="resnet18").to(device)
    train_dataloader, test_dataloader = get_data(batch_size=256)
    optimizer = AdaFisher(model, lr=0.001, gammas=[0.92, 0.008], Lambda=1e-3, weight_decay=5e-4)
    criterion = nn.CrossEntropyLoss()
    train_model(num_epochs, model, train_dataloader, test_dataloader, criterion, optimizer, device)

In [None]:
EPOCHS = 50
main(EPOCHS)