In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import random_split, DataLoader

import kagglehub
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [None]:
path = kagglehub.dataset_download("msambare/fer2013")
print("Path to dataset files:", path)

test = path + "/test"
train = path + "/train"

In [None]:
# class data 
dataset = ImageFolder(root=train)
class_counts = {}

for _, label in dataset.samples:
    class_name = dataset.classes[label]
    class_counts[class_name] = class_counts.get(class_name, 0) + 1

for class_name, count in class_counts.items():
    print(f"{class_name}: {count}")

In [None]:
class_indices = np.unique(dataset.targets)
class_weights = compute_class_weight(class_weight='balanced', classes=class_indices, y=dataset.targets)
class_weights_tensor = torch.tensor(class_weights, dtype=torch.float)
# 48 x 48

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

transform = transforms.Compose([
    transforms.Resize((48, 48)), # Resize images to 48x48
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(5),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
    transforms.ToTensor(),
])

In [None]:
class CNN(nn.Module):
    def __init__(self, num_classes):
        super(CNN, self).__init__()
        self.network = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, padding=1),  # (48x48) → (48x48)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # → (24x24)

            nn.Conv2d(32, 64, kernel_size=3, padding=1), # → (24x24)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # → (12x12)

            nn.Conv2d(64, 128, kernel_size=3, padding=1),# → (12x12)
            nn.ReLU(),
            nn.MaxPool2d(2),                             # → (6x6)

            nn.Flatten(),                                # 128 * 6 * 6 = 4608
            nn.Linear(128 * 6 * 6, 256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )

    def forward(self, x):
        return self.network(x)

In [None]:
train_dataset = datasets.ImageFolder(train, transform=transform)
test_dataset  = datasets.ImageFolder(test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=32, shuffle=False)

num_classes = len(train_dataset.classes)

model = CNN(num_classes=num_classes).to(device)

criterion = nn.CrossEntropyLoss(weight=class_weights_tensor.to(device))
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3)

In [None]:
num_epochs = 30

for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    train_loss = total_loss

    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total
    avg_val_loss = val_loss / len(test_loader)

    scheduler.step(avg_val_loss) 

    print(f"Epoch {epoch+1}/{num_epochs} | Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f} | LR: {optimizer.param_groups[0]['lr']:.6f}")


In [None]:
torch.save(model.state_dict(), "models/model.pth")
print("✅ Training complete. Model saved.")

In [None]:
def evaluate_model(model, test_loader):
    model.eval()
    correct, total, total_loss = 0, 0, 0
    criterion = nn.CrossEntropyLoss()

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            total_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    acc = correct / total
    print(f"Test Loss: {total_loss:.4f}, Test Accuracy: {acc:.4f}")

evaluate_model(model, test_loader)

In [None]:
def plot_confusion_matrix(model, test_loader, class_names):
    all_preds = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues")
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

plot_confusion_matrix(model, test_loader, train_dataset.classes)


In [None]:
def visualize_predictions(model, test_loader, class_names, num_images=8):
    model.eval()
    images_shown = 0
    plt.figure(figsize=(15, 8))

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            for i in range(inputs.size(0)):
                if images_shown == num_images:
                    break
                img = inputs[i].cpu().permute(1, 2, 0).numpy()
                img = std * img + mean  # Unnormalize
                img = np.clip(img, 0, 1)

                plt.subplot(2, num_images // 2, images_shown + 1)
                plt.imshow(img)
                plt.title(f"True: {class_names[labels[i]]}\nPred: {class_names[predicted[i]]}")
                plt.axis("off")
                images_shown += 1

            if images_shown == num_images:
                break
    plt.tight_layout()
    plt.show()

visualize_predictions(model, test_loader, train_dataset.classes)


In [None]:
def visualize_feature_space(model, loader, method='tsne', max_samples=500):
    model.eval()
    features, labels = [], []

    with torch.no_grad():
        for inputs, lbls in loader:
            inputs = inputs.to(device)
            x = inputs
            for layer in model.network:
                x = layer(x)
                if isinstance(layer, nn.Flatten):
                    break
            features.append(x.view(x.size(0), -1).cpu())
            labels.extend(lbls.numpy())
            if len(labels) >= max_samples:
                break

    features = torch.cat(features)[:max_samples].numpy()
    labels = np.array(labels[:max_samples])

    if method == 'tsne':
        reducer = TSNE(n_components=2, perplexity=30)
    else:
        reducer = PCA(n_components=2)

    reduced = reducer.fit_transform(features)

    plt.figure(figsize=(10, 6))
    scatter = plt.scatter(reduced[:, 0], reduced[:, 1], c=labels, cmap='tab10', alpha=0.7)
    plt.legend(handles=scatter.legend_elements()[0], labels=train_dataset.classes)
    plt.title(f'{method.upper()} Visualization of Feature Space')
    plt.show()

visualize_feature_space(model, test_loader, method='tsne')