In [None]:
import torch
import torchvision
from torch import nn
import torch.nn.functional as F
from torch.utils.data import TensorDataset, DataLoader
import pandas as pd
import numpy as np
from tqdm import tqdm

In [None]:
from google.colab import drive
drive.mount('/content/gdrive')

Mounted at /content/gdrive


In [None]:
device = torch.device('cuda')

In [None]:
df = pd.read_csv('/content/gdrive/My Drive/fer2013.csv')

df['pixels'] = df['pixels'].apply(lambda x: np.fromstring(x, sep=' ', dtype='float32'))


train_df = df[df['Usage'] == 'Training']
test_df = df[df['Usage'] != 'Training']

X_train = torch.from_numpy(np.vstack(train_df['pixels'].values / 255)).view(-1, 48, 48).unsqueeze(1)
X_test = torch.from_numpy(np.vstack(test_df['pixels'].values / 255)).view(-1, 48, 48).unsqueeze(1)

y_train = torch.from_numpy(train_df['emotion'].values)
y_test = torch.from_numpy(test_df['emotion'].values)

y_train = F.one_hot(y_train, num_classes=7).float()
y_test = F.one_hot(y_test, num_classes=7).float()

In [None]:
train_dataset = TensorDataset(X_train, y_train)
test_dataset = TensorDataset(X_test, y_test)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)

In [None]:
for img, label in train_loader:
    print(img.shape, label.shape)
    break

torch.Size([32, 1, 48, 48]) torch.Size([32, 7])


In [None]:
classes = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

In [None]:
model = nn.Sequential(
    nn.Conv2d(in_channels=1, out_channels=64, kernel_size=3, padding=1),
    nn.BatchNorm2d(64),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, padding=1),
    nn.BatchNorm2d(128),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, padding=1),
    nn.BatchNorm2d(256),
    nn.ReLU(),
    nn.MaxPool2d(kernel_size=2, stride=2),

    nn.Flatten(),

    nn.Linear(256 * 6 * 6, 1024),
    nn.BatchNorm1d(1024),
    nn.ReLU(),
    nn.Dropout(0.5),

    nn.Linear(1024, 512),
    nn.BatchNorm1d(512),
    nn.ReLU(),
    nn.Dropout(0.5),

    nn.Linear(512, 7),
    nn.Softmax(dim=1),
)

model.to(device)

Sequential(
  (0): Conv2d(1, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (2): ReLU()
  (3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (4): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (5): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (6): ReLU()
  (7): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (8): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (9): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (10): ReLU()
  (11): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (12): Flatten(start_dim=1, end_dim=-1)
  (13): Linear(in_features=9216, out_features=1024, bias=True)
  (14): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (15): ReLU()
  (16

In [None]:
optimizer = torch.optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=1e-4)
loss_fn = nn.CrossEntropyLoss()
n_epochs = 30
best_val_acc = 0

for epoch in range(n_epochs):
    model.train()
    correct = 0
    total = 0
    for images, labels in tqdm(train_loader):
        images, labels = images.to(device), labels.to(device)

        pred = model(images)
        loss = loss_fn(pred, labels)
        loss.backward()

        optimizer.step()
        optimizer.zero_grad()

        correct += (pred.argmax(1) == labels.argmax(1)).sum()
        total += labels.size(0)
    train_acc = correct / total

    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in tqdm(test_loader):
            images, labels = images.to(device), labels.to(device)

            pred = model(images)
            loss = loss_fn(pred, labels)

            correct += (pred.argmax(1) == labels.argmax(1)).sum()
            total += labels.size(0)
    test_acc = correct / total

    print(f'epoch [{epoch + 1}/{n_epochs}], train acc: {train_acc * 100:.1f}%, test acc: {test_acc * 100:.1f}%')

    if test_acc >= best_val_acc:
        best_val_acc = test_acc
        torch.save(model, 'emotion_model.pth')
        print('best model saved')

100%|██████████| 898/898 [00:09<00:00, 95.56it/s] 
100%|██████████| 225/225 [00:00<00:00, 405.84it/s]


epoch [1/30], train acc: 39.0%, test acc: 44.9%
best model saved


100%|██████████| 898/898 [00:08<00:00, 111.71it/s]
100%|██████████| 225/225 [00:00<00:00, 414.80it/s]


epoch [2/30], train acc: 47.5%, test acc: 49.1%
best model saved


100%|██████████| 898/898 [00:07<00:00, 113.29it/s]
100%|██████████| 225/225 [00:00<00:00, 406.30it/s]


epoch [3/30], train acc: 50.7%, test acc: 51.3%
best model saved


100%|██████████| 898/898 [00:07<00:00, 113.64it/s]
100%|██████████| 225/225 [00:00<00:00, 389.65it/s]


epoch [4/30], train acc: 53.5%, test acc: 50.9%


100%|██████████| 898/898 [00:07<00:00, 114.49it/s]
100%|██████████| 225/225 [00:00<00:00, 407.48it/s]


epoch [5/30], train acc: 55.5%, test acc: 53.0%
best model saved


100%|██████████| 898/898 [00:08<00:00, 106.22it/s]
100%|██████████| 225/225 [00:00<00:00, 405.98it/s]


epoch [6/30], train acc: 57.2%, test acc: 53.7%
best model saved


100%|██████████| 898/898 [00:07<00:00, 113.23it/s]
100%|██████████| 225/225 [00:00<00:00, 385.33it/s]


epoch [7/30], train acc: 59.4%, test acc: 53.8%
best model saved


100%|██████████| 898/898 [00:07<00:00, 112.36it/s]
100%|██████████| 225/225 [00:00<00:00, 407.89it/s]


epoch [8/30], train acc: 60.4%, test acc: 54.4%
best model saved


100%|██████████| 898/898 [00:08<00:00, 111.50it/s]
100%|██████████| 225/225 [00:00<00:00, 405.76it/s]


epoch [9/30], train acc: 61.8%, test acc: 55.9%
best model saved


100%|██████████| 898/898 [00:08<00:00, 106.70it/s]
100%|██████████| 225/225 [00:00<00:00, 380.27it/s]


epoch [10/30], train acc: 63.8%, test acc: 54.5%


100%|██████████| 898/898 [00:08<00:00, 110.81it/s]
100%|██████████| 225/225 [00:00<00:00, 400.89it/s]


epoch [11/30], train acc: 65.0%, test acc: 55.7%


100%|██████████| 898/898 [00:08<00:00, 110.00it/s]
100%|██████████| 225/225 [00:00<00:00, 396.47it/s]


epoch [12/30], train acc: 66.1%, test acc: 56.1%
best model saved


100%|██████████| 898/898 [00:08<00:00, 109.30it/s]
100%|██████████| 225/225 [00:00<00:00, 226.81it/s]


epoch [13/30], train acc: 67.3%, test acc: 56.6%
best model saved


100%|██████████| 898/898 [00:08<00:00, 108.66it/s]
100%|██████████| 225/225 [00:00<00:00, 393.67it/s]


epoch [14/30], train acc: 68.2%, test acc: 56.8%
best model saved


100%|██████████| 898/898 [00:08<00:00, 107.95it/s]
100%|██████████| 225/225 [00:00<00:00, 394.02it/s]


epoch [15/30], train acc: 69.7%, test acc: 56.4%


100%|██████████| 898/898 [00:08<00:00, 110.17it/s]
100%|██████████| 225/225 [00:00<00:00, 378.41it/s]


epoch [16/30], train acc: 70.8%, test acc: 58.5%
best model saved


100%|██████████| 898/898 [00:08<00:00, 107.26it/s]
100%|██████████| 225/225 [00:00<00:00, 400.05it/s]


epoch [17/30], train acc: 72.1%, test acc: 57.8%


100%|██████████| 898/898 [00:08<00:00, 109.69it/s]
100%|██████████| 225/225 [00:00<00:00, 396.75it/s]


epoch [18/30], train acc: 72.6%, test acc: 56.9%


100%|██████████| 898/898 [00:08<00:00, 111.26it/s]
100%|██████████| 225/225 [00:00<00:00, 399.63it/s]


epoch [19/30], train acc: 74.0%, test acc: 58.2%


100%|██████████| 898/898 [00:08<00:00, 109.60it/s]
100%|██████████| 225/225 [00:00<00:00, 398.00it/s]


epoch [20/30], train acc: 74.8%, test acc: 58.5%


100%|██████████| 898/898 [00:08<00:00, 107.23it/s]
100%|██████████| 225/225 [00:00<00:00, 381.25it/s]


epoch [21/30], train acc: 76.0%, test acc: 56.8%


100%|██████████| 898/898 [00:08<00:00, 106.86it/s]
100%|██████████| 225/225 [00:00<00:00, 397.85it/s]


epoch [22/30], train acc: 76.2%, test acc: 58.6%
best model saved


100%|██████████| 898/898 [00:08<00:00, 109.35it/s]
100%|██████████| 225/225 [00:00<00:00, 396.32it/s]


epoch [23/30], train acc: 77.3%, test acc: 58.3%


100%|██████████| 898/898 [00:09<00:00, 92.49it/s]
100%|██████████| 225/225 [00:00<00:00, 267.85it/s]


epoch [24/30], train acc: 78.2%, test acc: 58.5%


100%|██████████| 898/898 [00:08<00:00, 110.38it/s]
100%|██████████| 225/225 [00:00<00:00, 398.41it/s]


epoch [25/30], train acc: 79.0%, test acc: 58.5%


100%|██████████| 898/898 [00:08<00:00, 109.55it/s]
100%|██████████| 225/225 [00:00<00:00, 396.11it/s]


epoch [26/30], train acc: 79.7%, test acc: 58.6%
best model saved


100%|██████████| 898/898 [00:08<00:00, 109.65it/s]
100%|██████████| 225/225 [00:00<00:00, 382.14it/s]


epoch [27/30], train acc: 80.3%, test acc: 57.8%


100%|██████████| 898/898 [00:08<00:00, 105.71it/s]
100%|██████████| 225/225 [00:00<00:00, 397.74it/s]


epoch [28/30], train acc: 80.7%, test acc: 58.4%


100%|██████████| 898/898 [00:08<00:00, 109.34it/s]
100%|██████████| 225/225 [00:00<00:00, 395.81it/s]


epoch [29/30], train acc: 81.4%, test acc: 59.1%
best model saved


100%|██████████| 898/898 [00:08<00:00, 109.98it/s]
100%|██████████| 225/225 [00:00<00:00, 371.91it/s]

epoch [30/30], train acc: 82.3%, test acc: 58.2%



