In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '4'
import sys
sys.path.append("..")
from utils.dataset import FerDataset
from utils.resnet import *
from utils.tools import init_logger, RunningAverage

import numpy as np
from tqdm import tqdm_notebook as tqdm
import logging
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
model_dir = '../models/resnet'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)

In [None]:
i = 0
hparams = {}
hparams['n_epochs'] = 10
hparams['data'] = 'ferplus'
hparams['label'] = 'ferplus_votes'
hparams['batch_size'] = 24
hparams['wd'] = 0
hparams['lr'] = 1e-3
hparams['adaptive'] = False
hparams['batchnorm'] = False
hparams['scheduler_patience'] = 5
hparams['scheduler_factor'] = 0.5

# Prepare dataloaders
data = hparams['data']
label = hparams['label']
batch_size = hparams['batch_size']

train_dataset = FerDataset(base_path='../fer',
                           data=data, mode='train', label=label)
eval_dataset = FerDataset(base_path='../fer',
                           data=data, mode='eval', label=label)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Prepare the network
model = ResNet(num_classes=train_dataset.n_classes, adaptive=hparams['adaptive'], batchnorm=hparams['batchnorm'])
softmax = nn.Softmax(dim=-1)
log_softmax = nn.LogSoftmax(dim=-1)
loss_fn = nn.KLDivLoss(size_average=False)

def criterion(logits, labels):
    return loss_fn(log_softmax(logits), labels)

optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'], weight_decay=hparams['wd'])
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=hparams['scheduler_factor'],
    patience=hparams['scheduler_patience'])


hparams['eval_steps'] = len(train_dataloader)


log_path = os.path.join(model_dir, 'small_resnet_10epoch_lr{}_wd{}_{}.log'.format(hparams['lr'], hparams['wd'], i))
init_logger(log_path, to_console=False)

logging.info('### Model ###\n' + model.__repr__())
logging.info('### Optimizer ###\n' + optimizer.__repr__())

hparams_str = " ; ".join("{}: {}".format(k, v) for k, v in hparams.items())
logging.info('### HParams ###\n' + hparams_str)


n_epochs = hparams['n_epochs']
eval_steps = hparams['eval_steps']
n_batches = len(train_dataloader)
wd = hparams['wd']

checkpoint_dir = os.path.join(model_dir, 'checkpoints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

step = 0
loss_avg = RunningAverage(window=20)
model.train()
with tqdm(total=n_epochs * n_batches) as t:
    train_samples_correct = 0
    train_samples = 0
    for epoch in range(n_epochs):
        for x_batch, y_batch in train_dataloader:

            # Forward pass
            logits = model(x_batch)
            log_probs = log_softmax(logits)
            loss = loss_fn(log_probs, y_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            train_samples += batch_size
            probs = softmax(logits)
            train_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
            acc = train_samples_correct / train_samples * 100
            loss_avg.update(loss.item())
            logging.info('[TRAIN] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, loss, acc))
            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), acc='{:02.3f}'.format(acc))
            t.update()

            step += 1


            # Evaluate
            if step > 0 and step % eval_steps == 0:
                model.eval()
                eval_samples_correct = 0
                eval_loss = 0
                for x_batch, y_batch in eval_dataloader:
                    logits = model(x_batch)
                    probs = softmax(logits)
                    log_probs = torch.log(probs)
                    eval_loss += loss_fn(log_probs, y_batch)
                    eval_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()

                eval_loss /= len(eval_dataloader)
                eval_acc = eval_samples_correct / len(eval_dataset) * 100
                eval_summary = '[EVAL] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, eval_loss, eval_acc)
                scheduler.step(eval_loss)
                logging.info(eval_summary)
                t.write(eval_summary)
                checkpoint = {'model': model.state_dict(),
                              'optimizer': optimizer.state_dict(),
                              'step': step,
                              'eval_loss': eval_loss,
                              'eval_acc': eval_acc,
                              'train_run_loss': loss_avg(),
                              'train_acc' : acc,
                              'hparams': hparams}
                filename = time.strftime("%Y%m%d-%H%M%S") + '.pth.tar'
                torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
                model.train()



HBox(children=(IntProgress(value=0, max=11900), HTML(value='')))

[EVAL] Step: 1190 ; Loss: 21.594 ; Acc: 58.815
[EVAL] Step: 2380 ; Loss: 17.289 ; Acc: 66.806
[EVAL] Step: 3570 ; Loss: 16.208 ; Acc: 68.064
[EVAL] Step: 4760 ; Loss: 15.657 ; Acc: 68.259
[EVAL] Step: 5950 ; Loss: 15.543 ; Acc: 68.930



| lr   | wd   | train_acc | eval_acc | train_loss | eval_loss |
| ---- | ---- | --------- | -------- | ---------- | --------- | 
| 1e-6 |    0 |    26.800 |   31.400 |     58.826 |    59.300 |
| 1e-5 |    0 |    35.550 |   36.700 |     30.607 |    36.700 |
| 1e-4 |    0 |    47.349 |   56.900 |     20.955 |    22.300 |
| 1e-3 |    0 |    51.852 |   43.532 |     18.456 |    34.191 |
| 1e-3 |    0 |    51.341 |   63.621 |     17.933 |    17.973 |


| lr   | wd   | train_acc | eval_acc | train_loss | eval_loss |
| ---- | ---- | --------- | -------- | ---------- | --------- |
| 1e-4 |    0 |    45.872 |   50.210 |     21.797 |    25.170 |
| 1e-4 |    0 |    45.221 |   55.407 |     22.171 |    22.802 |
| 1e-4 |    0 |    46.278 |   34.395 |     21.933 |    36.593 |
| 1e-4 | 1e-4 |    45.077 |   55.909 |     21.930 |    21.193 |
| 1e-4 | 1e-4 |    45.882 |   51.299 |     23.002 |    26.200 |
| 1e-4 | 1e-4 |    45.998 |   56.077 |     21.235 |    22.152 |
| 1e-3 |    0 |    48.697 |   58.871 |     20.307 |    19.524 |
| 1e-3 |    0 |    51.320 |   **61.665** |     19.448 |    **18.989** |
| 1e-3 |    0 |    48.127 |   55.658 |     20.774 |    22.751 |
| 1e-3 | 1e-4 |    51.870 |   59.681 |     19.578 |    20.818 |
| 1e-3 | 1e-4 |    **52.903** |   57.949 |     **18.403** |    21.709 |
| 1e-3 | 1e-4 |    50.571 |   61.106 |     19.436 |    19.615 |
| 1e-2 |    0 |    31.999 |   40.291 |     28.803 |     00inf |
| 1e-2 |    0 |    33.232 |   34.674 |     33.566 |     00inf |
| 1e-2 |    0 |    34.489 |   40.542 |     28.645 |    28.988 |
| 1e-2 | 1e-4 |    30.686 |   35.736 |     29.858 |    32.322 |
| 1e-2 | 1e-4 |    31.268 |   35.317 |     32.325 |    31.525 |
| 1e-2 | 1e-4 |    31.628 |   35.233 |     29.007 |    30.953 |