In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

import torchvision
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
import numpy as np
import os
import sys

from tqdm import tqdm

from resnet import ResNet18

In [25]:
if torch.cuda.is_available() == True:
    device = torch.device('cuda:0')
    print(torch.cuda.get_device_name())
else:
    device = torch.device('cpu')
device

GeForce RTX 2070 SUPER


device(type='cuda', index=0)

In [24]:
batch_size = 128
classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

def load_data():    
    transform_train = transforms.Compose([transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), 
                                    transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])
    transform_test = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010))])

    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    train_loader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True)

    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    test_loader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False)

    data_loader = {'train': train_loader, 'test': test_loader}
    return data_loader

def load_iter(data_loader, data_type):
    bar_format = '{bar:30} {n_fmt}/{total_fmt} [{elapsed}<{remaining} {rate_fmt}] {desc}'
    
    if data_type == 'train':
        train_loader = data_loader['train']
        train_iter = tqdm(enumerate(train_loader), total=len(train_loader), unit_scale=batch_size, bar_format=bar_format)
        return train_iter
    elif data_type == 'test':
        test_loader = data_loader['test']
        test_iter = tqdm(enumerate(test_loader), total=len(test_loader), unit_scale=batch_size, bar_format=bar_format)
        return test_iter
    else:
        print('Data Error!!!')

In [4]:
def imshow(img):
    img[0] = img[0] * 0.2023 + 0.4914
    img[1] = img[1] * 0.1994 + 0.4822
    img[2] = img[2] * 0.2010 + 0.4465
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

In [20]:
def train(model):
    model.train()
    train_loss = 0
    total = 0
    correct = 0
    step = 0
    train_iter = load_iter(data_loader, 'train')
    for i, (batch, label) in train_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        
        optimizer.zero_grad()
        loss = loss_function(output, label)
        loss.backward()
        optimizer.step()
        
        train_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        train_iter.set_description(f'[{acc:.2f}% ({correct}/{total})]', True)

In [19]:
def test(model):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    acc = 0.
    test_iter = load_iter(data_loader, 'test')

    for i, (batch, label) in test_iter:
        batch, label = batch.to(device), label.to(device)
        output = model(batch)
        loss = loss_function(output, label)

        test_loss += loss.item()
        _, predicted = output.max(1)
        total += label.size(0)
        correct += predicted.eq(label).sum().item()

        acc = 100. * correct / total
        test_iter.set_description(f'[{acc:.2f}%({correct}/{total})]', True)
    return acc

In [7]:
def save_model(epoch, acc, optimizer):
    global best_acc
    if acc > best_acc:
        state = {
            'model': model.state_dict(),
            'acc': acc,
            'epoch': epoch,
            'optimizer': optimizer.state_dict()
        }
        if not os.path.isdir('models'):
            os.mkdir('models')
        torch.save(state, f'./models/model_{epoch}.pth')
        best_acc = acc
        print('Saving Model...')

def load_model(name):
    state_dict = torch.load(f'./models/{name}.pth', map_location=device)
    model = ResNet18()
    model.to(device)
    model.load_state_dict(state_dict['model'])
    optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
#     optimizer.load_state_dict(state_dict['optimizer'])
    return model, optimizer

In [8]:
def decaying_learning_rate(optimizer, epoch):
    lr = 1e-1
    if epoch >= 100:
        lr /= 10
    if epoch >= 150:
        lr /= 10
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr

In [21]:
# Training
model = ResNet18()
model.to(device)
loss_function = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=1e-1, momentum=0.9, weight_decay=1e-4)
best_acc = 0
data_loader = load_data()
for epoch in range(1, 3):
    print(f'Epoch {epoch}')
    train(model)
    test_acc = test(model)
    save_model(epoch, test_acc, optimizer)
    decaying_learning_rate(optimizer, epoch)

Files already downloaded and verified
Files already downloaded and verified


                               0/50048 [00:00<? ?it/s] 

Epoch 1


██████████████████████████████ 50048/50048 [01:23<00:00 597.45it/s] [35.65% (17824/50000)]: 
██████████████████████████████ 10112/10112 [00:05<00:00 1834.13it/s] [45.24%(4524/10000)]: 
                               128/50048 [00:00<01:08 727.11it/s] [39.84% (51/128)]: 

Saving Model...
Epoch 2


██████████████████████████████ 50048/50048 [01:24<00:00 594.72it/s] [55.29% (27644/50000)]: 
██████████████████████████████ 10112/10112 [00:05<00:00 1896.76it/s] [60.49%(6049/10000)]: 


Saving Model...


In [23]:
# Validation
model, optimizer = load_model('baseline')
loss_function = nn.CrossEntropyLoss()
test(model)

██████████████████████████████ 10112/10112 [00:03<00:00 2980.21it/s] [94.53%(9453/10000)]: 


94.53