In [1]:
import os
import sys
sys.path.append('..')
from utils.dataset import FerDataset
from utils.tools import init_logger

import numpy as np
from tqdm import tqdm_notebook as tqdm
import logging
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class RunningAverage(object):
    
    def __init__(self, window):
        self.window = window
        self.values = []
        self.mean = 0
        
    def update(self, value):
        self.values.append(value)
        if len(self.values) > self.window:
            self.mean += (value - self.values.pop(0)) / self.window
        else:
            self.mean = sum(self.values) / len(self.values)

    def __call__(self):
        return self.mean
    

class LeNet5(nn.Module):
    """ LeNet customized to FER."""
    
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.convnet = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(6, 16, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(16, 120, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.Conv2d(120, 240, kernel_size=(5, 5)),
            nn.ReLU())

        self.fc = nn.Sequential(
            nn.Linear(240, 84),
            nn.ReLU(),
            nn.Linear(84, 10))

    def forward(self, x):
        x = self.convnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x    

In [3]:
model_dir = '../models/lenet'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
hparams = {}
hparams['n_epochs'] = 10
hparams['data'] = 'ferplus'
hparams['label'] = 'ferplus_votes'
hparams['batch_size'] = 24

In [4]:
# Prepare dataloaders
data = hparams['data']
label = hparams['label']
batch_size = hparams['batch_size']

train_dataset = FerDataset(base_path='/Users/lennard/data/project/fer',
                           data=data, mode='train', label=label)
eval_dataset = FerDataset(base_path='/Users/lennard/data/project/fer',
                           data=data, mode='eval', label=label)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Prepare the network
model = LeNet5()
softmax = nn.Softmax(dim=-1)
log_softmax = nn.LogSoftmax(dim=-1)
loss_fn = nn.KLDivLoss(size_average=False)

optimizer = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.1)

hparams['eval_steps'] = len(train_dataloader)



In [5]:
log_path = os.path.join(model_dir, 'train.log')
if os.path.exists(log_path):
    os.remove(log_path)
init_logger(log_path, to_console=False)

logging.info('### Model ###\n' + model.__repr__())
logging.info('### Optimizer ###\n' + optimizer.__repr__())

hparams_str = " ; ".join("{}: {}".format(k, v) for k, v in hparams.items())
logging.info('### HParams ###\n' + hparams_str)

In [6]:
n_epochs = hparams['n_epochs']
eval_steps = hparams['eval_steps']
n_batches = len(train_dataloader)

checkpoint_dir = os.path.join(model_dir, 'checkpoints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

step = 0
loss_avg = RunningAverage(window=20)

with tqdm(total=n_epochs * n_batches) as t:
    train_samples_correct = 0
    train_samples = 0
    for epoch in range(n_epochs):
        for x_batch, y_batch in train_dataloader:
            
            # Forward pass
            logits = model(x_batch)
            log_probs = log_softmax(logits)
            loss = loss_fn(log_probs, y_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            train_samples += batch_size
            probs = softmax(logits)
            train_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
            acc = train_samples_correct / train_samples * 100
            loss_avg.update(loss.item())
            logging.info('[TRAIN] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, loss, acc))
            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), acc='{:02.3f}'.format(acc))
            t.update()

            step += 1
            

            # Evaluate
            if step > 0 and step % eval_steps == 0:
                model.eval()
                eval_samples_correct = 0
                eval_loss = 0
                for x_batch, y_batch in eval_dataloader:
                    logits = model(x_batch)
                    probs = softmax(logits)
                    log_probs = torch.log(probs)
                    eval_loss += loss_fn(log_probs, y_batch)
                    eval_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
                    
                eval_loss /= len(eval_dataloader) * 100
                eval_acc = eval_samples_correct / len(eval_dataset)
                eval_summary = '[EVAL] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, eval_loss, eval_acc)
                logging.info(eval_summary)
                t.write(eval_summary)
                checkpoint = {'model': model.state_dict(),
                              'optimizer': optimizer.state_dict(),
                              'step': step,
                              'eval_loss': eval_loss,
                              'eval_acc': eval_acc,
                              'train_run_loss': loss_avg(),
                              'train_acc' : acc,
                              'hparams': hparams}
                filename = time.strftime("%Y%m%d-%H%M%S") + '.pth.tar'
                torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
                model.train()

HBox(children=(IntProgress(value=0, max=11900), HTML(value='')))

[EVAL] Step: 1190 ; Loss: 0.239 ; Acc: 0.548
[EVAL] Step: 2380 ; Loss: 0.210 ; Acc: 0.588
[EVAL] Step: 3570 ; Loss: 0.193 ; Acc: 0.622
[EVAL] Step: 4760 ; Loss: 0.187 ; Acc: 0.642
[EVAL] Step: 5950 ; Loss: 0.177 ; Acc: 0.648
[EVAL] Step: 7140 ; Loss: 0.169 ; Acc: 0.666
[EVAL] Step: 8330 ; Loss: 0.163 ; Acc: 0.675
[EVAL] Step: 9520 ; Loss: 0.161 ; Acc: 0.689
[EVAL] Step: 10710 ; Loss: 0.155 ; Acc: 0.686
[EVAL] Step: 11900 ; Loss: 0.158 ; Acc: 0.697

