In [33]:
import torch, torchvision
from torch import nn
from torchvision import transforms

from PIL import Image, ImageOps

import numpy as np

import matplotlib.pyplot as plt

In [50]:
import copy

In [34]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
print(DEVICE)

cuda


In [35]:
import os

###################################################################
# This data preprocess does not assume 48x48 grayscale images.
# An image passed through it will be turned gray and scaled to 
#  48x48 pixels then turned into a pytorch tensor.
###################################################################

label_map = {
    'angry': 0,
    'disgust': 1,
    'fear': 2,
    'happy': 3,
    'neutral': 4,
    'sad': 5,
    'surprise': 6
}

predict_map = {
    0: 'angry',
    1: 'disgust',
    2: 'fear',
    3: 'happy',
    4: 'neutral',
    5: 'sad',
    6: 'surprise'
}

directory = './data/train'

width = 48

file_count = 0
for dir in os.listdir(directory):
    curr_dir = os.path.join(directory, dir)
    file_count += len(os.listdir(curr_dir))

train_labels = torch.zeros((file_count, 1, 7))
train_data = torch.zeros((file_count, 1, width, width))

curr_data_point1 = 0
last = 0
print("Extracting Train Data")
convert_tensor = transforms.ToTensor()
for dir in os.listdir(directory):
    curr_dir = os.path.join(directory, dir)
    for filename in os.listdir(curr_dir):
      f = os.path.join(curr_dir, filename)
      if os.path.isfile(f):
          img = ImageOps.grayscale(Image.open(f)).resize((width, width))
          train_data[curr_data_point1][0] = convert_tensor(img)
          train_labels[curr_data_point1][0][label_map[dir]] = 1
          curr_data_point1 += 1
    print(f"Emotion: {dir}, File Count: {curr_data_point1-last}")
    last = curr_data_point1

print()
print("Extracting Test Data")

directory = './data/test'

file_count = 0
for dir in os.listdir(directory):
    curr_dir = os.path.join(directory, dir)
    file_count += len(os.listdir(curr_dir))

test_labels = torch.zeros((file_count, 1, 7))
test_data = torch.zeros((file_count, 1, width, width))
curr_data_point2 = 0
last = 0
for dir in os.listdir(directory):
    curr_dir = os.path.join(directory, dir)
    for filename in os.listdir(curr_dir):
      f = os.path.join(curr_dir, filename)
      if os.path.isfile(f):
          img = ImageOps.grayscale(Image.open(f)).resize((width, width))
          test_data[curr_data_point2][0] = convert_tensor(img)
          test_labels[curr_data_point2][0][label_map[dir]] = 1
          curr_data_point2 += 1
    print(f"Emotion: {dir}, File Count: {curr_data_point2-last}")
    last = curr_data_point2

Extracting Train Data
Emotion: angry, File Count: 3995
Emotion: disgust, File Count: 436
Emotion: fear, File Count: 4097
Emotion: happy, File Count: 7215
Emotion: neutral, File Count: 4965
Emotion: sad, File Count: 4830
Emotion: surprise, File Count: 3171

Extracting Test Data
Emotion: angry, File Count: 958
Emotion: disgust, File Count: 111
Emotion: fear, File Count: 1024
Emotion: happy, File Count: 1774
Emotion: neutral, File Count: 1233
Emotion: sad, File Count: 1247
Emotion: surprise, File Count: 831


In [94]:
class Block(nn.Module):
    def __init__(self, channel_in, pass_on, channel_out, device):
        super().__init__()
        self.conv_1x1 = nn.Sequential(
            nn.Conv2d(channel_in, channel_out["1x1"], kernel_size=1),
            nn.BatchNorm2d(channel_out["1x1"]),
            nn.PReLU()
        )
    
        # 3x3 branch, we padding 1 in the 3x3 convolution layer to keep same size of image
        self.conv_3x3 = nn.Sequential(
            nn.Conv2d(channel_in, pass_on["3x3"], kernel_size=1),
            nn.BatchNorm2d(pass_on["3x3"]),
            nn.PReLU(),
            nn.Conv2d(pass_on["3x3"], channel_out["3x3"], kernel_size=3, padding=1),
            nn.BatchNorm2d(channel_out["3x3"]),
            nn.PReLU()
        )
        
        # 5x5 branch, we padding 2 in the 5x5 convolution layer to keep same size of image
        self.conv_5x5 = nn.Sequential(
            nn.Conv2d(channel_in, pass_on["5x5"], kernel_size=1),
            nn.BatchNorm2d(pass_on["5x5"]),
            nn.PReLU(),
            nn.Conv2d(pass_on["5x5"], channel_out["5x5"], kernel_size=5, padding=2),
            nn.BatchNorm2d(channel_out["5x5"]),
            nn.PReLU()
        ) 
        # Max pooling branch
        self.max_pool = nn.Sequential(
            nn.MaxPool2d(kernel_size=3, padding=1, stride=1),
            nn.Conv2d(channel_in, channel_out["max"], kernel_size=1),
            nn.BatchNorm2d(channel_out["max"]),
            nn.PReLU()
        )

        self._initialize_weights()

    def forward(self, x):
        return torch.cat(
            [
                self.conv_1x1(x), self.conv_3x3(x),
                self.conv_5x5(x), self.max_pool(x)
            ], dim=1 # concatenate along channels
        )

    def _initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='leaky_relu')
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.Linear):
                nn.init.normal_(m.weight, 0, 0.1)
                nn.init.constant_(m.bias, 0)

In [97]:
class LeNET(nn.Module):
    def __init__(self, n_classes, epochs, device):
        super().__init__()

        self.seq = nn.Sequential(
            # input layer
            nn.Conv2d(1, 64, kernel_size=7, padding=3),
            nn.PReLU(),
            nn.LocalResponseNorm(128),
            nn.Conv2d(64, 112, kernel_size=1),
            nn.PReLU(),
            nn.Conv2d(112, 196, kernel_size=3, padding=1),
            nn.PReLU(),
            nn.LocalResponseNorm(128),
            # pass through blocks
            Block(
                196, 
                pass_on={"3x3": 96, "5x5": 16}, 
                channel_out={"1x1": 64, "3x3": 128, "5x5": 32, "max": 32},
                device=device
            ),
            Block(
                256, 
                pass_on={"3x3": 128, "5x5": 32}, 
                channel_out={"1x1": 128, "3x3": 192, "5x5": 96, "max": 64},
                device=device
            ),
            # reduce dimensions
            nn.MaxPool2d(3, stride=2, padding=1),  
            # pass through blocks
            Block(
                480, 
                pass_on={"3x3": 96, "5x5": 16}, 
                channel_out={"1x1": 192, "3x3": 208, "5x5": 48, "max": 64}, 
                device=device
            ),
            Block(
                512, 
                pass_on={"3x3": 112, "5x5": 24}, 
                channel_out={"1x1": 176, "3x3": 224, "5x5": 64, "max": 64}, 
                device=device
            ),
            # reduce dimensions
            nn.MaxPool2d(3, stride=2, padding=1),
            # pass through last blocks
            Block(
                528, 
                pass_on={"3x3": 160, "5x5": 32}, 
                channel_out={"1x1": 256, "3x3": 320, "5x5": 128, "max": 128},
                device=device
            ),
            # pool
            nn.AdaptiveAvgPool2d((1, 1)),
            # classification head
            nn.Dropout(0.4),
            nn.Flatten(),
            nn.Linear(832, n_classes),
            nn.Softmax(1)
        )

        self.optimizer = torch.optim.Adam(self.parameters(), lr=0.001)
        self.criterion = nn.CrossEntropyLoss()

        self.epochs = epochs

        self.device = device

        self.augment = transform = transforms.Compose(
            [
                transforms.RandomHorizontalFlip(p=0.5),
                transforms.RandomCrop(32, padding=4),
            ]
        )


    def forward(self, X):
        return self.seq(X)

    def train(self, X_train, y_train, X_test, y_test):
        batch_size = 50
        epoch_losses = []
        validation_losses = []

        state_dic = None
        min_val_loss = 9e15
        break_count = 0
        max_break_count = 10

        for epoch in range(self.epochs):
            train_loss = 0
            indices = torch.randperm(X_train.shape[0])
            X_train = X_train[indices]
            y_train = y_train[indices]
            
            for i in range(0, len(X_train), batch_size):
                x = self.augment(X_train[i:i+batch_size]).to(self.device)
                y = y_train[i:i+batch_size,0,:].to(self.device)
                
                self.optimizer.zero_grad()

                pred = self.forward(x)

                loss = self.criterion(pred, y) 
                train_loss += loss.item()

                loss.backward()
                self.optimizer.step()

            epoch_losses.append(train_loss)

            acc = 0
            valid_loss = 0
            with torch.no_grad():
                for i in range(0, len(X_test), batch_size):
                    x = X_test[i:i+batch_size].to(self.device)
                    y = y_test[i:i+batch_size,0,:].to(self.device)
                    
                    preds = self.forward(x)

                    valid_loss += self.criterion(preds, y).item()

                    acc += torch.sum(
                        torch.argmax(preds, dim=1) == torch.argmax(y, dim=1)
                    )
                
            validation_losses.append(valid_loss)

            if (valid_loss < min_val_loss):
                min_val_loss = valid_loss
                break_count = 0
                state_dict = copy.deepcopy(self.state_dict())
            
            print(f"[{epoch + 1}]")
            print(f"   Training loss:       {train_loss}")
            print(f"   Validation Acc:      {acc / X_test.shape[0]}")
            print(f"   Validation Loss:     {valid_loss}")
            print(f"   Min Validation Loss: {min_val_loss}")

            if (valid_loss > min_val_loss):
                break_count += 1
                if (break_count >= max_break_count):
                    self.load_state_dict(state_dict)
                    print(f"Validation loss not improved in {break_count} epochs.")
                    print(f"Ending Training Early.")
                    break
        
        return epoch_losses, validation_losses

    def predict(self, x):
        pass

In [99]:
model = LeNET(7, 10, DEVICE).to(DEVICE)

e_loss, v_loss = model.train(train_data, train_labels, test_data, test_labels)

[1]
   Training loss:       1079.1714961528778
   Validation Acc:      0.20325997471809387
   Validation Loss:     277.0945565700531
   Min Validation Loss: 277.0945565700531
[2]
   Training loss:       1043.564003109932
   Validation Acc:      0.22373922169208527
   Validation Loss:     276.0327534675598
   Min Validation Loss: 276.0327534675598
[3]
   Training loss:       1018.0593881607056
   Validation Acc:      0.21955977380275726
   Validation Loss:     276.81475353240967
   Min Validation Loss: 276.0327534675598
[4]
   Training loss:       1000.4273828268051
   Validation Acc:      0.22875453531742096
   Validation Loss:     275.6122786998749
   Min Validation Loss: 275.6122786998749
[5]
   Training loss:       985.389434337616
   Validation Acc:      0.22722208499908447
   Validation Loss:     275.90408170223236
   Min Validation Loss: 275.6122786998749
[6]
   Training loss:       975.9644130468369
   Validation Acc:      0.23223739862442017
   Validation Loss:     275.54132699