In [1]:
import os
os.environ['OMP_NUM_THREADS'] = '4'
import sys
sys.path.append("..")
from utils.dataset import FerDataset
from utils.resnet import *
from utils.tools import init_logger, RunningAverage

import numpy as np
from tqdm import tqdm_notebook as tqdm
import logging
import time

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
model_dir = '../models/resnet'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
hparams = {}
hparams['n_epochs'] = 1
hparams['data'] = 'ferplus'
hparams['label'] = 'ferplus_votes'
hparams['batch_size'] = 24
hparams['wd'] = 0
hparams['lr'] = 1e-3
hparams['adaptive'] = False

In [3]:
# Prepare dataloaders
data = hparams['data']
label = hparams['label']
batch_size = hparams['batch_size']

train_dataset = FerDataset(base_path='../fer',
                           data=data, mode='train', label=label)
eval_dataset = FerDataset(base_path='../fer',
                           data=data, mode='eval', label=label)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Prepare the network
model = ResNet(BasicBlock, [2, 2, 2, 2], num_classes=train_dataset.n_classes, adaptive=hparams['adaptive'])
softmax = nn.Softmax(dim=-1)
log_softmax = nn.LogSoftmax(dim=-1)
loss_fn = nn.KLDivLoss(size_average=False)

def criterion(logits, labels):
    return loss_fn(log_softmax(logits), labels)
    
optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'], weight_decay=hparams['wd'])

hparams['eval_steps'] = len(train_dataloader)



In [4]:
log_path = os.path.join(model_dir, 'resnet_train_04.log')
if os.path.exists(log_path):
    os.remove(log_path)
init_logger(log_path, to_console=False)

logging.info('### Model ###\n' + model.__repr__())
logging.info('### Optimizer ###\n' + optimizer.__repr__())

hparams_str = " ; ".join("{}: {}".format(k, v) for k, v in hparams.items())
logging.info('### HParams ###\n' + hparams_str)

In [None]:
n_epochs = hparams['n_epochs']
eval_steps = hparams['eval_steps']
n_batches = len(train_dataloader)
wd = hparams['wd']

checkpoint_dir = os.path.join(model_dir, 'checkpoints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

step = 0
loss_avg = RunningAverage(window=20)
model.train()
with tqdm(total=n_epochs * n_batches) as t:
    train_samples_correct = 0
    train_samples = 0
    for epoch in range(n_epochs):
        for x_batch, y_batch in train_dataloader:
            
            # Forward pass
            logits = model(x_batch)
            log_probs = log_softmax(logits)
            loss = loss_fn(log_probs, y_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            train_samples += batch_size
            probs = softmax(logits)
            train_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
            acc = train_samples_correct / train_samples * 100
            loss_avg.update(loss.item())
            logging.info('[TRAIN] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, loss, acc))
            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), acc='{:02.3f}'.format(acc))
            t.update()

            step += 1
            

            # Evaluate
            if step > 0 and step % eval_steps == 0:
                model.eval()
                eval_samples_correct = 0
                eval_loss = 0
                for x_batch, y_batch in eval_dataloader:
                    logits = model(x_batch)
                    probs = softmax(logits)
                    log_probs = torch.log(probs)
                    eval_loss += loss_fn(log_probs, y_batch)
                    eval_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
                    
                eval_loss /= len(eval_dataloader)
                eval_acc = eval_samples_correct / len(eval_dataset) * 100
                eval_summary = '[EVAL] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, eval_loss, eval_acc)
                logging.info(eval_summary)
                t.write(eval_summary)
                checkpoint = {'model': model.state_dict(),
                              'optimizer': optimizer.state_dict(),
                              'step': step,
                              'eval_loss': eval_loss,
                              'eval_acc': eval_acc,
                              'train_run_loss': loss_avg(),
                              'train_acc' : acc,
                              'hparams': hparams}
                filename = time.strftime("%Y%m%d-%H%M%S") + '.pth.tar'
                torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
                model.train()

HBox(children=(IntProgress(value=0, max=1190), HTML(value='')))


| lr   | train_acc   | eval_acc | train_loss | eval_loss |
| ---- | ----------- | -------- | ---------- | --------- | 
| 1e-6 |      26.800 |   31.400 |     58.826 |    59.300 |
| 1e-5 |      35.550 |   36.700 |     30.607 |    36.700 |
| 1e-4 |      47.349 |   56.900 |     20.955 |    22.300 |
| 1e-3 |      xx.xxx |   xx.xxx |     xx.xxx |    xx.xxx |
| 1e-2 |      xx.xxx |   xx.xxx |     xx.xxx |    xx.xxx |
| 1e-1 |      xx.xxx |   xx.xxx |     xx.xxx |    xx.xxx |