## Data Augmentation For MNIST Dataset

This experiment follows Section 5.4 in the [paper](http://proceedings.mlr.press/v108/lorraine20a.html).

In [5]:
import sys
sys.path.append("..")

import torch
import torch.nn as nn
from torch.optim import Adam

from network import MLP
from model import UNetAugmentHyperOptModel
from hyper_opt import FixedPointHyperOptimizer
from example.data_utils import InfiniteDataLoader, load_mnist

In [6]:
# some setup parameters
batch_size = 128
lr = 0.01
base_optimizer = 'SGD'
hyper_lr = 0.01 if base_optimizer == 'RMSprop' else 0.01
device = "cuda:5"

device = torch.device(device)


In [7]:
# get data
train_loader, val_loader, test_loader = load_mnist(batch_size=batch_size)
train_iter = InfiniteDataLoader(train_loader, device)
val_iter = InfiniteDataLoader(val_loader, device)

Create model and a hyper model which wraps the main model.
Here,
- main model is just a multi-layer neural network
- hyper model is a Unet $f_\lambda(\mathbf{x}, \epsilon)$

In [8]:
# main model
model = MLP(num_layers=5, input_shape=(28, 28))
criterion = nn.CrossEntropyLoss()
# hyper model
h_model = UNetAugmentHyperOptModel(model, criterion, in_channels=1, num_classes=1, padding=True)
h_model.to(device=device)  # h_model contains model so that no need to load model to cuda

# optimizer for neural network
nn_optimizer = Adam(h_model.parameters, lr=lr)
hyper_optimizer = FixedPointHyperOptimizer(
    h_model.parameters,
    h_model.hyper_parameters,
    base_optimizer=base_optimizer,
    default=dict(lr=hyper_lr, momentum=0.9),
    use_gauss_newton=True,
    stochastic=True)
hyper_optimizer.set_kwargs(inner_lr=lr, K=10)

In [9]:
# evaluate function on test set
def evaluate():
    model.eval()
    with torch.no_grad():
        total_loss, correct = 0., 0
        for x, y in test_loader:
            x, y = x.to(device), y.to(device)
            logit = model(x)
            loss = criterion(logit, y)
            total_loss += loss.item()
            pred = torch.argmax(logit, dim=1).float()
            correct += (y == pred).float().sum().item()
    model.train()
    acc = float(correct) / len(test_loader.dataset)
    return acc, total_loss

In [10]:
def train(n_epochs=20, n_warmup_epochs=1, log_freq=20):

    T = 10
    counter = 0

    # a closure computing train loss. Data will be drawn stochastically
    def train_loss_func():
        x, y = train_iter.next_batch()
        train_loss, train_logit = h_model.train_loss(x, y)
        return train_loss, train_logit
    
    while train_iter.epoch_elapsed <= n_epochs:
        for _ in range(T):
            model.train()
            train_x, train_y = train_iter.next_batch()
            train_loss, _ = h_model.train_loss(train_x, train_y)
            nn_optimizer.zero_grad()
            train_loss.backward()
            nn_optimizer.step()
        
        val_x, val_y = val_iter.next_batch()
        val_loss = h_model.validation_loss(val_x, val_y)

        if train_iter.epoch_elapsed > n_warmup_epochs:
            hyper_optimizer.step(train_loss_func, val_loss, verbose=False)
        

        # print 
        if counter % 10 == 0 and counter > 0:
            eval_acc, eval_loss = evaluate()
            train_loss, val_loss = train_loss.item(), val_loss.item()
            print(f"Iter {counter:5d} | train loss {train_loss:5.2f} | val loss {val_loss:5.2f} | test loss {eval_loss:5.2f} | test acc {eval_acc:2.4f}")

        # counter increment
        counter += 1

In [11]:
train()

Iter    10 | train loss  0.30 | val loss  0.18 | test loss 24.30 | test acc 0.9115
Iter    20 | train loss  0.21 | val loss  0.20 | test loss 19.72 | test acc 0.9302
Iter    30 | train loss  0.30 | val loss  0.09 | test loss 17.47 | test acc 0.9360
Iter    40 | train loss  0.18 | val loss  0.16 | test loss 16.01 | test acc 0.9475
Iter    50 | train loss  0.10 | val loss  0.16 | test loss 13.38 | test acc 0.9541
Iter    60 | train loss  0.17 | val loss  0.17 | test loss 14.97 | test acc 0.9461
Iter    70 | train loss  0.41 | val loss  0.20 | test loss 14.37 | test acc 0.9494
Iter    80 | train loss  0.20 | val loss  0.13 | test loss 13.20 | test acc 0.9566
Iter    90 | train loss  0.16 | val loss  0.15 | test loss 17.56 | test acc 0.9437
Iter   100 | train loss  0.20 | val loss  0.13 | test loss 13.40 | test acc 0.9524
Iter   110 | train loss  0.27 | val loss  0.36 | test loss 13.68 | test acc 0.9570
Iter   120 | train loss  0.18 | val loss  0.20 | test loss 14.83 | test acc 0.9520
Iter

Let's train without augmentation


In [12]:
no_aug_model = MLP(num_layers=5, input_shape=(28,28))
no_aug_model.to(device)

criterion = nn.CrossEntropyLoss()
optimizer = Adam(list(no_aug_model.parameters()), lr=lr)

counter = 0
for i in range(15):
    for x, y in train_loader:
        x, y = x.to(device), y.to(device)
        logit = no_aug_model(x)
        train_loss = criterion(logit, y)
        optimizer.zero_grad()
        train_loss.backward()
        optimizer.step()
    
        if counter % 20 == 0 and counter > 0:
            correct = 0
            with torch.no_grad():
                for x, y in test_loader:
                    x, y = x.to(device), y.to(device)
                    logit = no_aug_model(x)
                    pred = torch.argmax(logit, dim=1).float()
                    correct += (y == pred).float().sum().item()
            acc = float(correct) / len(test_loader.dataset)
            print(f"Iter {counter:5d} | train loss {train_loss.item():5.2f} | test acc: {acc: 2.4f}")
    
        counter += 1


Iter    20 | train loss  0.54 | test acc:  0.7888
Iter    40 | train loss  0.49 | test acc:  0.8765
Iter    60 | train loss  0.42 | test acc:  0.9067
Iter    80 | train loss  0.39 | test acc:  0.9157
Iter   100 | train loss  0.30 | test acc:  0.9104
Iter   120 | train loss  0.31 | test acc:  0.9202
Iter   140 | train loss  0.25 | test acc:  0.9218
Iter   160 | train loss  0.26 | test acc:  0.9280
Iter   180 | train loss  0.17 | test acc:  0.9257
Iter   200 | train loss  0.27 | test acc:  0.9363
Iter   220 | train loss  0.28 | test acc:  0.9343
Iter   240 | train loss  0.20 | test acc:  0.9346
Iter   260 | train loss  0.17 | test acc:  0.9418
Iter   280 | train loss  0.17 | test acc:  0.9365
Iter   300 | train loss  0.25 | test acc:  0.9358
Iter   320 | train loss  0.21 | test acc:  0.9359
Iter   340 | train loss  0.29 | test acc:  0.9504
Iter   360 | train loss  0.26 | test acc:  0.9415
Iter   380 | train loss  0.29 | test acc:  0.9390
Iter   400 | train loss  0.25 | test acc:  0.9436
