## !! The final result should be only a runnable .py file !!

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import pickle
from torch.utils.data import random_split

# 0. Data Pre-processing

In [None]:
data_transforms = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),  # turn the graph to single color channel
    transforms.Resize((227, 227)), # resize to 227 * 227 because we use AlexNet
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485], std=[0.229])  # normalize
])

train_dataset = datasets.ImageFolder(
    '../dataset/train', transform=data_transforms)
# split training set to training set and validation set
# a random seed to ensure reproducibility of results.
torch.manual_seed(42)
train_size = int(0.85 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

test_dataset = datasets.ImageFolder('../dataset/test', transform=data_transforms)


train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True, num_workers=8, pin_memory=True)
val_loader = DataLoader(val_dataset, batch_size=1024,shuffle=False, num_workers=8, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=512, shuffle=False, num_workers=8, pin_memory=True)

print(len(train_loader), len(val_loader), len(test_loader))

# 1. Model

In [None]:
# todo: we should compare the optimal version with the previous ones
# version optimal
# reference: AlexNet
class EmotionCNN(nn.Module):
    def __init__(self, num_classes):
        super(EmotionCNN, self).__init__()
        self.features = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=96, kernel_size=11, stride=4, padding=2), 
            # out_channels is decided by # of filters
            # batch_size doesn't show here and is different from in_channels.
            nn.BatchNorm2d(96),
            nn.ReLU(inplace=True), # inplace: override
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(96, 256, kernel_size=5, padding=2),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),

            nn.Conv2d(256, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 384, kernel_size=3, padding=1),
            nn.BatchNorm2d(384),
            nn.ReLU(inplace=True),

            nn.Conv2d(384, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.ReLU(inplace=True),
            nn.MaxPool2d(kernel_size=3, stride=2),
            # output shape: (batch_size, channels = 256, height = 6, width = 6)
        )
        self.avgpool = nn.AdaptiveAvgPool2d((6, 6))
        self.flatten = nn.Flatten(1) 
        # flatten from channel, ex: [batch_size, channels(1), height, width] -> [batch_size, channels * height * width]
        self.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.BatchNorm1d(4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, num_classes),
        )

    def forward(self, x):
        x = self.features(x)
        x = self.avgpool(x)
        x = self.flatten(x)
        x = self.classifier(x)
        return x # the probability of 7 emotion class

# initialize model, loss-function and optimizer
model = EmotionCNN(num_classes=7)  # FER-2013 has 7 emotion class
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9)

In [None]:
# select device
if torch.cuda.is_available():
    device = torch.device("cuda:0")
    print("using cuda")
    # Enable cuDNN auto-tuner, may not work well on image with different dimensions
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    
#elif torch.backends.mps.is_available():
#    device = torch.device("mps")
#    print("using mac mps")
else:
    device = torch.device("cpu")
    print("using cpu")

In [None]:
class EarlyStop:
    def __init__(self, filename, stop_count, diff):
        self.save_file_name = 'ES_' + filename
        self.stop_count = stop_count
        self.counter = 0
        self.different = diff
        self.min_val_loss = float('inf')
        self.early_stop = False
        self.best = None
        
    def save_model(self, validation_loss, new_model):
        torch.save(new_model.state_dict(), self.save_file_name)
        self.min_val_loss = validation_loss
    
    def check_status(self, validation_loss, new_model):
        if self.best is None:
            self.best = validation_loss
            self.save_model(validation_loss, new_model)
        elif (- validation_loss) < (- self.best + self.different):
            self.counter += 1
            if self.counter >= self.stop_count:
                self.early_stop = True
        else:
            self.best = validation_loss
            self.save_model(validation_loss, new_model)
            self.counter = 0

In [None]:
# average loss / epoch
loss_history_per_epoch = []

# correct prediction / epoch
correct_prediction_pre_epoch = []

# accuracy / epoch
accuracy_per_epoch = []

# validation loss
val_loss_per_epoch = []

# validation accuracy
val_accuracy_per_epoch = []

In [None]:
# training model
num_epochs = 1000
model.to(device)

# early stopping variables
save_filename = 'model.pth'
stopping_count = 20
different = 0.001
early_stopping = EarlyStop(save_filename, stopping_count, different)

# progress bar
process = tqdm(range(num_epochs), bar_format='{l_bar}{bar:20}{r_bar}{bar:-20b}', colour='green', ascii='░▒█', unit='epoch')

for epoch in process:
    model.train()
    running_loss = 0.0
    accuracy = 0.0
    
    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        
        # forwarding get output
        outputs = model(inputs)
        
        # compute loss of output
        loss = criterion(outputs, labels)   
        
        # backward propagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # record training status
        running_loss += loss.item()
        prediction = outputs.argmax(dim=1)
        num_correct_prediction = (prediction == labels).sum().item()
        correct_prediction_pre_epoch.append(num_correct_prediction)
        accuracy += num_correct_prediction / inputs.shape[0]
    
    # save training status
    loss_history_per_epoch.append((running_loss / len(train_loader)))
    accuracy_per_epoch.append((accuracy / len(train_loader)))
    
    # training validation + early stopping
    val_loss = 0.0
    val_accuracy = 0.0
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        val_loss += loss.item()
        prediction = outputs.argmax(dim=1)
        num_correct_prediction = (prediction == labels).sum().item()
        accuracy = num_correct_prediction / inputs.shape[0]
        val_accuracy += accuracy
    val_loss_per_epoch.append((val_loss / len(val_loader)))
    val_accuracy_per_epoch.append((val_accuracy / len(val_loader)))
    
    early_stopping.check_status(val_loss, model)

    # display recently 5 average loss of epochs
    process.set_description(f"avg loss[-5:] = {loss_history_per_epoch[-5:]}; "
                            f"best loss = {early_stopping.min_val_loss}, val loss = {val_loss}; "
                            f"Stop Counter = {early_stopping.counter}/{stopping_count}")

    if early_stopping.early_stop:
        print('\nTrigger Early Stopping\n')
        break

In [None]:
# save the pth file
torch.save(model.state_dict(), 'MS_' + save_filename)

with open('loss_history.pkl', 'wb') as f:
    pickle.dump(loss_history_per_epoch, f)

with open('accuracy_history.pkl', 'wb') as f:
    pickle.dump(accuracy_per_epoch, f)

with open('val_loss_history.pkl', 'wb') as f:
    pickle.dump(val_loss_per_epoch, f)

with open('val_accuracy_history.pkl', 'wb') as f:
    pickle.dump(val_accuracy_per_epoch, f)

In [None]:
# evaluate model
model = EmotionCNN(num_classes=7)
def test_data(path):
    model.load_state_dict(torch.load(path))
    model.to(device)
    model.eval()
    
    correct = 0
    total = 0
    
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs.data, 1) # predicted is the emotion index
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    
    accuracy = 100 * correct / total
    print(f"Test Accuracy: {accuracy}%")

test_data('ES_' + save_filename) # early stop
test_data('MS_' + save_filename) # manual stop