In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from src.dataloader import FER2013Dataset
from src.const import device, batch_size
import copy

In [None]:
from model.arch1 import CNN as CNN1
from model.arch2 import CNN as CNN2
from model.arch3 import CNN as CNN3

In [3]:
fer2013_csv_path = "fer2013.csv"

In [4]:
train_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

val_transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = FER2013Dataset(csv_file=fer2013_csv_path, usage='Training', transform=train_transform)
val_dataset = FER2013Dataset(csv_file=fer2013_csv_path, usage='PublicTest', transform=val_transform)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

Transforming data: 100%|██████████| 28709/28709 [00:19<00:00, 1442.71it/s]


Dataset 'Training' class distribution:
Class 0: 3995 samples
Class 1: 436 samples
Class 2: 4097 samples
Class 3: 7215 samples
Class 4: 4830 samples
Class 5: 3171 samples
Class 6: 4965 samples
Total: 28709 samples


Transforming data: 100%|██████████| 3589/3589 [00:02<00:00, 1559.94it/s]

Dataset 'PublicTest' class distribution:
Class 0: 467 samples
Class 1: 56 samples
Class 2: 496 samples
Class 3: 895 samples
Class 4: 653 samples
Class 5: 415 samples
Class 6: 607 samples
Total: 3589 samples





In [None]:
# Note: Dataset is imbalanced --> use weighted loss base on class distribution
class_counts = []
distribution = train_dataset.get_distribution()
for i in range(len(distribution)):
    class_counts.append(distribution[i])

class_counts = torch.tensor(class_counts, dtype=torch.float)

weights = 1.0 / class_counts
weights = weights / weights.sum() * len(class_counts)  # Normalize to keep loss scale stable

def train(model, num_epochs=25, patience=5):
    criterion = nn.CrossEntropyLoss(weight=weights.to(device))
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    best_model_wts = copy.deepcopy(model.state_dict())
    best_acc = 0.0
    epochs_no_improve = 0

    for epoch in range(num_epochs):
        model.train()
        running_loss = 0.0
        for inputs, labels in train_loader:
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        avg_loss = running_loss / len(train_loader)

        model.eval()
        correct = 0
        total = 0
        with torch.no_grad():
            for inputs, labels in val_loader:
                outputs = model(inputs)
                _, predicted = torch.max(outputs.data, 1)
                total += labels.size(0)
                correct += (predicted == labels).sum().item()
        val_acc = 100 * correct / total

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {avg_loss:.4f}, Val Accuracy: {val_acc:.2f}%")

        # Early stopping and choose the best model
        if val_acc > best_acc:
            best_acc = val_acc
            best_model_wts = copy.deepcopy(model.state_dict())
            epochs_no_improve = 0
            torch.save(model.state_dict(), 'best_model.pth')
        else:
            epochs_no_improve += 1
            if epochs_no_improve >= patience:
                print("Early stopping triggered.")
                break

    # Load best model weights
    model.load_state_dict(best_model_wts)


Dataset 'Training' class distribution:
Class 0: 3995 samples
Class 1: 436 samples
Class 2: 4097 samples
Class 3: 7215 samples
Class 4: 4830 samples
Class 5: 3171 samples
Class 6: 4965 samples
Total: 28709 samples


In [16]:
model1 = CNN1(num_classes=7).to(device)
model2 = CNN2(num_classes=7).to(device)
model3 = CNN3(num_classes=7).to(device)

In [17]:
train(model1, num_epochs=25)

Epoch [1/25], Loss: 1.8331, Val Accuracy: 30.01%
Epoch [2/25], Loss: 1.6723, Val Accuracy: 32.85%
Epoch [3/25], Loss: 1.5759, Val Accuracy: 42.41%
Epoch [4/25], Loss: 1.4960, Val Accuracy: 44.36%
Epoch [5/25], Loss: 1.4371, Val Accuracy: 46.14%
Epoch [6/25], Loss: 1.3854, Val Accuracy: 45.39%
Epoch [7/25], Loss: 1.3313, Val Accuracy: 45.33%
Epoch [8/25], Loss: 1.2945, Val Accuracy: 47.90%
Epoch [9/25], Loss: 1.2579, Val Accuracy: 47.95%
Epoch [10/25], Loss: 1.2326, Val Accuracy: 48.48%
Epoch [11/25], Loss: 1.1935, Val Accuracy: 48.37%
Epoch [12/25], Loss: 1.1724, Val Accuracy: 49.18%
Epoch [13/25], Loss: 1.1522, Val Accuracy: 48.54%
Epoch [14/25], Loss: 1.1175, Val Accuracy: 49.32%
Epoch [15/25], Loss: 1.1107, Val Accuracy: 49.18%
Epoch [16/25], Loss: 1.1033, Val Accuracy: 49.51%
Epoch [17/25], Loss: 1.0693, Val Accuracy: 50.26%
Epoch [18/25], Loss: 1.0442, Val Accuracy: 50.40%
Epoch [19/25], Loss: 1.0403, Val Accuracy: 50.35%
Epoch [20/25], Loss: 1.0250, Val Accuracy: 50.40%
Epoch [21

In [19]:
train(model2, num_epochs=25)

Epoch [1/25], Loss: 1.2482, Val Accuracy: 52.05%
Epoch [2/25], Loss: 1.1477, Val Accuracy: 54.58%
Epoch [3/25], Loss: 1.0366, Val Accuracy: 55.08%
Epoch [4/25], Loss: 0.9520, Val Accuracy: 56.37%
Epoch [5/25], Loss: 0.8589, Val Accuracy: 58.48%
Epoch [6/25], Loss: 0.7577, Val Accuracy: 58.46%
Epoch [7/25], Loss: 0.6752, Val Accuracy: 59.65%
Epoch [8/25], Loss: 0.5835, Val Accuracy: 59.40%
Epoch [9/25], Loss: 0.5178, Val Accuracy: 58.04%
Epoch [10/25], Loss: 0.4439, Val Accuracy: 59.60%
Epoch [11/25], Loss: 0.3856, Val Accuracy: 60.16%
Epoch [12/25], Loss: 0.3212, Val Accuracy: 58.43%
Epoch [13/25], Loss: 0.2795, Val Accuracy: 59.57%
Epoch [14/25], Loss: 0.2789, Val Accuracy: 59.71%
Epoch [15/25], Loss: 0.2408, Val Accuracy: 59.74%
Epoch [16/25], Loss: 0.2143, Val Accuracy: 58.71%
Early stopping triggered.


In [20]:
train(model3, num_epochs=25)

Epoch [1/25], Loss: 1.8932, Val Accuracy: 26.41%
Epoch [2/25], Loss: 1.6627, Val Accuracy: 34.91%
Epoch [3/25], Loss: 1.4437, Val Accuracy: 45.70%
Epoch [4/25], Loss: 1.2988, Val Accuracy: 48.29%
Epoch [5/25], Loss: 1.1773, Val Accuracy: 52.13%
Epoch [6/25], Loss: 1.0699, Val Accuracy: 55.45%
Epoch [7/25], Loss: 0.9799, Val Accuracy: 57.12%
Epoch [8/25], Loss: 0.8971, Val Accuracy: 58.46%
Epoch [9/25], Loss: 0.7890, Val Accuracy: 59.71%
Epoch [10/25], Loss: 0.7385, Val Accuracy: 60.35%
Epoch [11/25], Loss: 0.6467, Val Accuracy: 61.33%
Epoch [12/25], Loss: 0.5576, Val Accuracy: 59.77%
Epoch [13/25], Loss: 0.4561, Val Accuracy: 60.24%
Epoch [14/25], Loss: 0.3631, Val Accuracy: 59.85%
Epoch [15/25], Loss: 0.2745, Val Accuracy: 59.68%
Epoch [16/25], Loss: 0.2056, Val Accuracy: 58.60%
Early stopping triggered.
