## !! The final result should be only a runnable .py file !!

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
from torch.utils.data import random_split
from models import model_1 as m
from training_early_stop import EarlyStop
import utility

# 0. Data Pre-processing

In [2]:
data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # turn the graph to single color channel
    transforms.Resize((227, 227)), # resize to 227 * 227 because we use AlexNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])  # normalize
])

train_dataset = datasets.ImageFolder(
    '../dataset/train', transform=data_transforms)
# split training set to training set and validation set
# a random seed to ensure reproducibility of results.
torch.manual_seed(42)
train_size = int(0.85 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

test_dataset = datasets.ImageFolder('../dataset/test', transform=data_transforms)


train_loader = DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1024,shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=8, pin_memory=False)

# 1. Model

In [3]:
# initialize model, loss-function and optimizer
model = m.EmotionCNN(num_classes=7)  # FER-2013 has 7 emotion class
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

In [4]:
# select device
device = utility.select_devices()

using CUDA + cudnn


In [5]:
# average loss / epoch
loss_history_per_epoch = []
# correct prediction / epoch
correct_prediction_pre_epoch = []
# accuracy / epoch
accuracy_per_epoch = []
# validation loss
val_loss_per_epoch = []
# validation accuracy
val_accuracy_per_epoch = []

In [None]:
# training model
num_epochs = 1000
model.to(device)
model.train()

# early stopping variables
stopping_count = 100
different = 0.001
interval = 5
counter = 0
is_always = False
is_exe = False
early_stopping = EarlyStop(m.pth_save_path, stopping_count, different)

# progress bar
process = tqdm(range(num_epochs), bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', colour='green', ascii='░▒█', unit='epoch')

for epoch in process:
    running_loss = 0.0
    accuracy = 0.0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        # forwarding get output
        outputs = model(inputs)
        # compute loss of output
        loss = criterion(outputs, labels)
        # backward propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        # record training status
        running_loss += loss.item()
        prediction = outputs.argmax(dim=1)
        num_correct_prediction = (prediction == labels).sum().item()
        correct_prediction_pre_epoch.append(num_correct_prediction)
        accuracy += num_correct_prediction / inputs.shape[0]
    # save training status
    loss_history_per_epoch.append((running_loss / len(train_loader)))
    accuracy_per_epoch.append((accuracy / len(train_loader)))

    # training validation + early stopping
    if is_always or is_exe or (epoch!=0 and epoch%(interval-1)==0):
        val_loss = 0.0
        val_accuracy = 0.0

        if epoch%(interval-1)==0:
            early_stopping.counter = 0
            early_stopping.best_loss = None
            is_exe = True

        counter += 1

        if counter >= stopping_count:
            counter = 0
            is_exe = False

        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()
            prediction = outputs.argmax(dim=1)
            num_correct_prediction = (prediction == labels).sum().item()
            accuracy = num_correct_prediction / inputs.shape[0]
            val_accuracy += accuracy
        val_loss_per_epoch.append((val_loss / len(val_loader)))
        val_accuracy_per_epoch.append((val_accuracy / len(val_loader)))

        early_stopping.check_status(model, val_loss)

        # display recently 5 average loss of epochs
        process.set_description(f"avg loss[-5:] = {loss_history_per_epoch[-5:]}\t"
                                f"accuracy[-5:] = {accuracy_per_epoch[-5:]}\t"
                                f"best loss = {early_stopping.min_val_loss}, val loss = {val_loss}\t"
                                f"val accuracy[-5] = {val_accuracy_per_epoch[-5:]}\t"
                                f"Stop Counter = {early_stopping.counter}/{stopping_count}\t")
    else:
        process.set_description(f"avg loss[-5:] = {loss_history_per_epoch[-5:]}\t"
                                f"accuracy[-5:] = {accuracy_per_epoch[-5:]}\t")

    if early_stopping.early_stop:
        print('\nTrigger Early Stopping\n')
        break

In [None]:
# save the pth file
torch.save(model.state_dict(), m.pth_manual_save_path)

utility.save_pickle_files(loss_history_per_epoch, m.record_save_path + '/loss_history.pkl')
utility.save_pickle_files(accuracy_per_epoch, m.record_save_path + '/accuracy_history.pkl')
utility.save_pickle_files(val_loss_per_epoch, m.record_save_path + '/val_loss_history.pkl')
utility.save_pickle_files(val_accuracy_per_epoch, m.record_save_path + '/val_accuracy_history.pkl.pkl')

In [None]:
# evaluate model
model = m.EmotionCNN(num_classes=7)
utility.model_validation(model, device, test_loader, m.pth_save_path)
utility.model_validation(model, device, test_loader, m.pth_manual_save_path)