In [None]:
%pip install torch torchvision matplotlib scikit-learn tqdm

In [2]:
import os
import time
from pathlib import Path
import copy
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch
import torch.nn as nn
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from sklearn.metrics import confusion_matrix, classification_report
from PIL import Image

DATA_DIR = "Data"         # expects Data/<five different emotions folders>/
BATCH_SIZE = 32
NUM_EPOCHS = 12
LR = 1e-4
NUM_WORKERS = 4
MODEL_PATH = "best_emotion_resnet50.pth"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
VAL_SPLIT = 0.2

In [3]:
# Data transformations
train_transforms = transforms.Compose([
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [None]:
# Dataset split
full_dataset = datasets.ImageFolder(DATA_DIR, transform=train_transforms)
num_val = int(len(full_dataset) * VAL_SPLIT)
num_train = len(full_dataset) - num_val
train_ds, val_ds = random_split(full_dataset, [num_train, num_val])

# Update validation dataset transforms
val_ds.dataset.transform = val_transforms
train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS)

class_names = full_dataset.classes
num_classes = len(class_names)
print(f"Classes: {class_names}")
print(f"Number of training samples: {len(train_ds)}")
print(f"Number of validation samples: {len(val_ds)}")

Classes: ['Angry', 'Fear', 'Happy', 'Sad', 'Suprise']
Number of training samples: 47280
Number of validation samples: 11819
