In [1]:
import os

os.environ['OMP_NUM_THREADS'] = '2'
import sys
sys.path.append('..')
from utils.dataset import FerDataset
from utils.tools import init_logger

import numpy as np
from tqdm import tqdm_notebook as tqdm
import logging
import time

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
class RunningAverage(object):
    
    def __init__(self, window):
        self.window = window
        self.values = []
        self.mean = 0
        
    def update(self, value):
        self.values.append(value)
        if len(self.values) > self.window:
            self.mean += (value - self.values.pop(0)) / self.window
        else:
            self.mean = sum(self.values) / len(self.values)

    def __call__(self):
        return self.mean
    

class LeNet5(nn.Module):
    """ LeNet customized to FER."""
    
    def __init__(self):
        super(LeNet5, self).__init__()
        
        self.convnet = nn.Sequential(
            nn.Conv2d(1, 6, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(6, 16, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=(2, 2), stride=2),
            nn.Conv2d(16, 120, kernel_size=(5, 5)),
            nn.ReLU(),
            nn.Conv2d(120, 240, kernel_size=(5, 5)),
            nn.ReLU())

        self.fc = nn.Sequential(
            nn.Linear(240, 84),
            nn.ReLU(),
            nn.Linear(84, 10))
        
        self.reset_parameters()
        
    def reset_parameters(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.Linear):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
        

    def forward(self, x):
        x = self.convnet(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x    

In [3]:
model_dir = '../models/lenet'
if not os.path.exists(model_dir):
    os.makedirs(model_dir)
    
hparams = {}
hparams['n_epochs'] = 1
hparams['data'] = 'ferplus'
hparams['label'] = 'ferplus_votes'
hparams['batch_size'] = 24
hparams['lr'] = 1e-3
hparams['wd'] = 0
hparams['init'] = 'kaiming_he'
hparams['scheduler_patience'] = 10
hparams['scheduler_factor'] = 0.1

In [4]:
# Prepare dataloaders
data = hparams['data']
label = hparams['label']
batch_size = hparams['batch_size']

train_dataset = FerDataset(base_path='/Users/lennard/data/project/fer',
                           data=data, mode='train', label=label)
eval_dataset = FerDataset(base_path='/Users/lennard/data/project/fer',
                           data=data, mode='eval', label=label)

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=0, drop_last=True)
eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=False, num_workers=0)

# Prepare the network
model = LeNet5()
softmax = nn.Softmax(dim=-1)
log_softmax = nn.LogSoftmax(dim=-1)
loss_fn = nn.KLDivLoss(size_average=False)

optimizer = torch.optim.Adam(model.parameters(), lr=hparams['lr'], weight_decay=hparams['wd'])

scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, 'min', factor=hparams['scheduler_factor'],
    patience=hparams['scheduler_patience'])

hparams['eval_steps'] = len(train_dataloader)



In [None]:
log_path = os.path.join(model_dir, 'train_test.log')
if os.path.exists(log_path):
    os.remove(log_path)
init_logger(log_path, to_console=False)

logging.info('### Model ###\n' + model.__repr__())
logging.info('### Optimizer ###\n' + optimizer.__repr__())
logging.info('### Scheduler ###\n' + scheduler.__repr__())

hparams_str = " ; ".join("{}: {}".format(k, v) for k, v in hparams.items())
logging.info('### HParams ###\n' + hparams_str)

In [None]:
n_epochs = hparams['n_epochs']
eval_steps = hparams['eval_steps']
n_batches = len(train_dataloader)

checkpoint_dir = os.path.join(model_dir, 'checkpoints')
if not os.path.exists(checkpoint_dir):
    os.makedirs(checkpoint_dir)

step = 0
loss_avg = RunningAverage(window=20)

model.train()

with tqdm(total=n_epochs * n_batches) as t:
    train_samples_correct = 0
    train_samples = 0
    for epoch in range(n_epochs):
        for x_batch, y_batch in train_dataloader:
            
            
            # Forward pass
            logits = model(x_batch)
            log_probs = log_softmax(logits)
            loss = loss_fn(log_probs, y_batch)

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            # Logging
            train_samples += batch_size
            probs = softmax(logits)
            train_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
            acc = train_samples_correct / train_samples * 100
            loss_avg.update(loss.item())
            logging.info('[TRAIN] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, loss, acc))
            t.set_postfix(loss='{:05.3f}'.format(loss_avg()), acc='{:02.3f}'.format(acc))
            t.update()

            step += 1
            

            # Evaluate
            if step > 0 and step % eval_steps == 0:
                model.eval()
                eval_samples_correct = 0
                eval_loss = 0
                for x_batch, y_batch in eval_dataloader:
                    logits = model(x_batch)
                    probs = softmax(logits)
                    log_probs = torch.log(probs)
                    eval_loss += loss_fn(log_probs, y_batch)
                    eval_samples_correct += probs.argmax(dim=-1).eq(y_batch.argmax(dim=-1)).sum().item()
                    
                eval_loss /= len(eval_dataloader) 
                eval_acc = eval_samples_correct / len(eval_dataset)*100
                eval_summary = '[EVAL] Step: {} ; Loss: {:05.3f} ; Acc: {:02.3f}'.format(step, eval_loss, eval_acc)
                scheduler.step(eval_loss)
                logging.info(eval_summary)
                t.write(eval_summary)
                checkpoint = {'model': model.state_dict(),
                              'optimizer': optimizer.state_dict(),
                              'step': step,
                              'eval_loss': eval_loss,
                              'eval_acc': eval_acc,
                              'train_run_loss': loss_avg(),
                              'train_acc' : acc,
                              'hparams': hparams,
                              'scheduler': scheduler.state_dict()}
                filename = time.strftime("%Y%m%d-%H%M%S") + '.pth.tar'
                torch.save(checkpoint, os.path.join(checkpoint_dir, filename))
                model.train()

## with standard initialization

| lr   | train_loss   | train_acc  | eval_loss    | eval_acc    |
| ---- | -------------| -----------| -------------| ------------| 
| 1e-6 |      35.583  |     26.355 |     0.348    |     0.251   |
| 1e-6 |      34.847  |     27.292 |     0.349    |     0.364   |
| 1e-6 |      35.332  |     33.239 |     0.369    |     0.252   |
| 5e-6 |      27.606  |     34.576 |     0.299    |     0.390   |
| 5e-6 |      29.451  |     36.896 |     0.301    |     0.386   |
| 5e-6 |      29.301  |     35.816 |     0.299    |     0.398   |
| 1e-5 |      27.841  |     38.057 |     0.295    |     0.417   |
| 1e-5 |      28.357  |     36.383 |     0.295    |     0.419   |
| 1e-5 |      28.064  |     37.773 |     0.294    |     0.402   |
| 5e-5 |      25.665  |     41.961 |     0.266    |     0.496   |
| 5e-5 |      24.497  |     40.445 |     0.262    |     0.497   |
| 5e-5 |      25.369  |     41.859 |     0.264    |     0.506   |
| 1e-4 |      24.956  |     45.816 |     0.249    |     0.527   |
| 1e-4 |      22.715  |     47.353 |     0.242    |     0.540   |
| 1e-4 |      23.632  |     45.564 |     0.251    |     0.507   |
| 5e-4 |      19.524  |     53.284 |     0.205    |     0.592   |
| 5e-4 |      19.176  |     52.157 |     0.202    |     0.602   |
| 5e-4 |      20.239  |     51.593 |     0.208    |     0.594   |
| 1e-3 |      19.353  |     52.959 |     0.196    |     0.620   |
| 1e-3 |      18.125  |     53.323 |     0.197    |     0.607   |
| 1e-3 |      17.125  |     55.462 |     0.180    |     0.641   |
| 5e-3 |      29.174  |     35.162 |     0.302    |     0.373   |
| 5e-3 |      21.614  |     48.512 |     0.226    |     0.565   |
| 5e-3 |      28.610  |     35.203 |     0.301    |     0.373   |
| 1e-2 |      27.846  |     34.632 |     0.301    |     0.373   |
| 1e-2 |      28.591  |     35.032 |     0.299    |     0.373   |
| 1e-2 |      28.436  |     34.821 |     0.300    |     0.373   |
| 5e-2 |      29.182  |     34.111 |     0.301    |     0.373   |
| 5e-2 |      28.222  |     34.261 |     0.300    |     0.373   |
| 5e-2 |      28.773  |     34.167 |     0.301    |     0.373   |

## with kaiming_he initialization

| lr   | train_loss   | train_acc  | eval_loss    | eval_acc    |
| ---- | -------------| -----------| -------------| ------------|
| 1e-4 |      20.808  |     51.845 |     0.219    |     0.568   |
| 1e-4 |      19.747  |     51.029 |     0.222    |     0.564   |
| 1e-4 |      20.764  |     51.866 |     0.215    |     0.579   |
| 5e-4 |      17.359  |     56.243 |     0.186    |     0.645   |
| 5e-4 |      17.894  |     54.160 |     0.195    |     0.615   |
| 5e-4 |      17.719  |     55.651 |     0.183    |     0.642   |
| 1e-3 |      17.779  |     57.416 |     0.182    |     0.640   |
| 1e-3 |      17.412  |     57.871 |     0.174    |     0.651   |
| 1e-3 |      16.678  |     57.784 |     0.179    |     0.637   |
| 5e-3 |      20.817  |     50.070 |     0.237    |     0.557   |
| 5e-3 |      20.591  |     49.842 |     0.221    |     0.574   |
| 5e-3 |      24.766  |     50.525 |     0.238    |     0.529   |