In [1]:
from google.colab import drive
drive.mount('/content/gdrive')
%cd gdrive/My Drive/Colab Notebooks/mixup

ModuleNotFoundError: No module named 'google.colab'

In [16]:
import torch
import torch.nn as nn
from torch.nn.functional import one_hot
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import os

from tqdm import tqdm

from utils import progress_bar
from resnet import ResNet18

In [17]:
if torch.cuda.is_available() == True:
    device = torch.device('cuda:0')
    print(torch.cuda.get_device_name())
else:
    device = torch.device('cpu')

In [18]:
transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
train_loader = torch.utils.data.DataLoader(trainset, batch_size=128, shuffle=True)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
test_loader = torch.utils.data.DataLoader(testset, batch_size=128, shuffle=False)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

Files already downloaded and verified
Files already downloaded and verified


In [5]:
def imshow(img):
    img[0] = img[0] * 0.2023 + 0.4914
    img[1] = img[1] * 0.1994 + 0.4822
    img[2] = img[2] * 0.2010 + 0.4465
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [6]:
def mixup(batch, label, alpha=1.0):
    lam = np.random.beta(alpha, alpha)
    index = torch.randperm(batch.size()[0])
    new_batch = batch[index]
    new_label = label[index]
    mixup_batch = lam * batch + (1 - lam) * new_batch
    return mixup_batch, (label, new_label), lam

def mix_label(label, new_label, lam, output, loss_function):
    loss = loss_function(output, label)
    new_loss = loss_function(output, new_label)
    mix_loss = lam * loss + (1 - lam) * new_loss
    return mix_loss

In [7]:
def train(model, mix=False):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    for i, (batch, label) in enumerate(train_loader):
        batch, label = batch.to(device), label.to(device)
        
        if mix == True:
            batch, (label, new_label), lam = mixup(batch, label)

        output = model(batch)
        optimizer.zero_grad()
        loss = loss_function(output, label) if mix == False else mix_label(label, new_label, lam, output, loss_function)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        progress_bar(i, len(train_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(train_loss/(i + 1), acc, correct, total))

In [8]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for i, (batch, label) in enumerate(test_loader):
            batch, label = batch.to(device), label.to(device)
            output = model(batch)
            loss = loss_function(output, label)

            test_loss += loss.item()
            _, predicted = output.max(1)
            total += label.size(0)
            correct += predicted.eq(label).sum().item()
            
            acc = 100. * correct / total
            progress_bar(i, len(test_loader), 'Loss: %.3f | Acc: %.3f%% (%d/%d)'%(test_loss/(i + 1), acc, correct, total))
    return acc

In [30]:
def save_model(epoch, acc):
    global best_acc
    if acc > best_acc:
        print('Saving..')
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch
        }
        if not os.path.isdir('checkpoint'):
            os.mkdir('checkpoint')
        torch.save(state, f'./checkpoint/model_{epoch}.pth')
        best_acc = acc

def load_model(name):
    model = ResNet18()
    model = torch.load(f'./checkpoint/{name}.pth', map_location=device)
    return model['model']

In [10]:
def decaying_learning_rate(optimizer, epoch):
    lr = 1e-1
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [19]:
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
best_acc = 0
for epoch in range(200):
    print(epoch)
    train(model, True)
    test_acc = test(model)
    save_model(epoch, test_acc)
    decaying_learning_rate(optimizer, epoch)

0
 [>.............................]  Step: 5s811ms | Tot: 5s811ms | Loss: 2.361 | Acc: 14.062% (36/256) 2/391 

KeyboardInterrupt: 

In [31]:
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
model.load_state_dict(load_model('mixup'))
test(model)



96.06