In [None]:
# import required packages
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torch.backends.cudnn as cudnn
import torchvision
import torchvision.transforms as transforms
import os
import argparse
from models import *
from torchsummary import summary
import matplotlib.pyplot as plt
import json
from IPython.display import FileLink, display
%matplotlib inline

In [None]:
# hyperparamters to adjust
model = "EfficientNetB0"
lr = 0.1
epochs = 100
batch_size = 512
optimizer = "SGD"
has_data_aug = 1

filename = model + "_" + str(lr) + "_" + str(epochs) + "_" + str(batch_size) + "_" + optimizer + "_" + str(has_data_aug)
# Example file format
print(filename)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)
best_acc = 0
start_epoch = 0

In [None]:
# Data
if has_data_aug:
    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
else:
    transform_train = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True, download=True, transform=transform_train)
trainloader = torch.utils.data.DataLoader(
    trainset, batch_size=batch_size, shuffle=True, num_workers=5)

testset = torchvision.datasets.CIFAR10(
    root='./data', train=False, download=True, transform=transform_test)
testloader = torch.utils.data.DataLoader(
    testset, batch_size=100, shuffle=False, num_workers=5)

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

In [None]:
# model
net = EfficientNetB0() 

net = net.to(device)
if device == 'cuda':
    net = torch.nn.DataParallel(net)
    cudnn.benchmark = True

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=lr,momentum=0.9, weight_decay=5e-4)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

In [None]:
# training
from tqdm import tqdm
def train(epoch):
    net.train()
    train_loss = 0
    correct = 0
    total = 0
    for (inputs, targets) in tqdm((trainloader)):
        inputs, targets = inputs.to(device), targets.to(device)
        optimizer.zero_grad()
        outputs = net(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

    return train_loss/len(trainloader), correct/total

In [None]:
# testing
def test(epoch):
    global best_acc
    net.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for  (inputs, targets) in tqdm(testloader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = net(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

    # Save checkpoint.
    acc = 100.*correct/total
    if acc > best_acc:
        print('Saving..')
        state = {
            'net': net.state_dict(),
            'acc': acc,
            'epoch': epoch,
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, './checkpoint/' + filename + '.pth')
        best_acc = acc

    return test_loss/len(testloader), correct/total

In [None]:
def plotLoss():
    plt.plot(range(epochs), history['train_loss'], '-', linewidth=3, label='Train Loss')
    plt.plot(range(epochs), history['test_loss'], '-', linewidth=3, label='Test Loss')
    plt.xlabel('epoch')
    plt.ylabel('loss')
    plt.grid(True)
    plt.legend()
    plt.show()

def plotAcc():
    plt.plot(range(epochs), history['train_acc'], '-', linewidth=3, label='Train Acc')
    plt.plot(range(epochs), history['test_acc'], '-', linewidth=3, label='Test Acc')
    plt.xlabel('epoch')
    plt.ylabel('accuracy')
    plt.grid(True)
    plt.legend()
    plt.show()

In [None]:
train_loss_history = []
test_loss_history = []
train_acc_history = []
test_acc_history = []

history = {
    'train_loss': [],
    'test_loss': [],
    'train_acc': [],
    'test_acc': []
}

for epoch in range(start_epoch, start_epoch+epochs):

    print('\nEpoch: %d' % epoch)
    train_loss, train_acc = train(epoch)

    print("Train \tLoss: %.3f | Acc: %.3f" % (train_loss, train_acc))

    test_loss, test_acc = test(epoch)
    # print('Test')
    print("Test \tLoss: %.3f | Acc: %.3f" % (test_loss, test_acc))

    history['train_loss'].append(train_loss)
    history['test_loss'].append(test_loss)
    history['train_acc'].append(train_acc)
    history['test_acc'].append(test_acc)

    scheduler.step()


In [None]:
torch.save(net.state_dict(), "weights/"+filename)
print(f"Model weights saved to: weights/{filename}")

In [None]:
plotAcc()

In [None]:
plotLoss()

In [None]:
import pickle

with open("results/history_"+filename, "wb") as f:
    pickle.dump(history, f)
print(f"Dictionary saved to results/history_"+filename)