In [None]:
import sys
sys.path.append('path/of/AI-Advanced-Course')

In [None]:
%matplotlib inline
from copy import deepcopy
from matplotlib import pyplot as plt

import imp
try:
    imp.find_module('jupyterplot')
    from jupyterplot import ProgressPlot
except ImportError:
    !pip install jupyterplot
    from jupyterplot import ProgressPlot

import torch
from torch import nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets as D
from torchvision import transforms as T

from utils import train_step, test_step
from utils import simulate_scheduler
from utils import BaselineModel
from utils import get_cifar10_dataset, make_dataloader
from utils import fetch_data, sample_random_data, show_images

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
num_epochs = 50
batch_size = 64
learning_rate = 0.01
momentum = 0.9
phases = ['train', 'test']

In [None]:
normal_dataset = get_cifar10_dataset()
random_augment_dataset = get_cifar10_dataset(random_crop=True)
loader = make_dataloader(normal_dataset, batch_size)
rloader = make_dataloader(random_augment_dataset, batch_size)

# 일반 이미지

In [None]:
images, target = fetch_data(normal_dataset['train'], [10, 100, 1000, 10000, 52])
titles = [normal_dataset['train'].classes[idx] for idx in target]
show_images(images.permute(0,2,3,1), titles)

# 무작위 확대 / 이동 한 이미지

In [None]:
images, target = fetch_data(random_augment_dataset['train'], [10, 100, 1000, 10000, 52])
titles = [normal_dataset['train'].classes[idx] for idx in target]
show_images(images.permute(0,2,3,1), titles)

In [None]:
base_cnn = BaselineModel()
nets = [
    BaselineModel().to(device)
    for _ in range(2)
]
_ = [net.load_state_dict(deepcopy(base_cnn.state_dict())) for net in nets]

optimizers = [
    torch.optim.SGD(net.parameters(), learning_rate, 0.9)
    for net in nets
]
scheds = [
    torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.9)
    for optimizer in optimizers
]
criterion = nn.CrossEntropyLoss()
conditions = ['without_augmentation', 'with_augmentation']
pp = ProgressPlot(
    plot_names=phases,
    line_names=conditions,
    x_lim=[0, num_epochs*len(loader['train'])],
    x_label='Iteration',
    y_lim=[[0, 3], [50, 100]]
)

accs = [0, 0]
for epoch in range(num_epochs):
    for nbatch, rbatch in zip(loader['train'], rloader['train']):
        losses = [
            train_step(net, *batch, optimizer, criterion, device)
            for net, optimizer, batch in zip(nets, optimizers, [nbatch, rbatch])
        ]
        pp.update([losses, accs])
    
    corrects = [0, 0]
    for inputs, target in loader['test']:
        outputs = [
            test_step(net, inputs, target, device=device)[0]
            for net in nets
        ]
        corrects = [
            (correct + (output.argmax(1).cpu() == target).sum()).item()
            for correct, output in zip(corrects, outputs)
        ]
        
    accs = [
        correct / len(normal_dataset['test']) * 100
        for correct in corrects
    ]
    
    print(f'Epoch: {epoch+1} accuracy ', end='')
    for cond, acc in zip(conditions, accs):
        print(f'{cond}: {acc:.2f}', end=' ')
    print()
    [sched.step() for sched in scheds]
pp.finalize()