# Imports

In [None]:
import torch, torch.nn as nn, torch.optim as optim, torch.nn.functional as F
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, WeightedRandomSampler, random_split
from torchvision.datasets import ImageFolder
from torch.optim.lr_scheduler import ReduceLROnPlateau
from collections import Counter

import numpy as np, matplotlib.pyplot as plt, seaborn as sns
from sklearn.utils.class_weight import compute_class_weight
from sklearn.metrics import confusion_matrix
from sklearn.manifold import TSNE
from sklearn.decomposition import PCA

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(device)

# Dataset

In [None]:
path = "/home/jit/.cache/kagglehub/datasets/msambare/fer2013/versions/1"
train = path + "/train"
test = path + "/test"

# Data Preprocessing

In [None]:
transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.Grayscale(num_output_channels=1),
    transforms.RandomHorizontalFlip(),
    transforms.RandomVerticalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

In [None]:
# Data loaders
train_dataset = datasets.ImageFolder(train, transform=transform)
test_dataset  = datasets.ImageFolder(test, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader  = DataLoader(test_dataset, batch_size=64, shuffle=False)

In [None]:
labels = [label for _, label in train_dataset.samples]
class_counts = Counter(labels)

print(class_counts)

In [None]:
weights = [0] * len(class_counts)

for cls_idx, count in class_counts.items():
    weights[cls_idx] = 1.0 / count

class_weights = torch.FloatTensor(weights).to(device)

print(class_weights)

# Model Layer Setup

In [None]:
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, padding=1)  # input_shape=(1,48,48)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(64)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.dropout1 = nn.Dropout(0.25)
        
        self.conv3 = nn.Conv2d(64, 128, kernel_size=5, padding=2)  # padding=2 for 'same' with kernel=5
        self.bn2 = nn.BatchNorm2d(128)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.dropout2 = nn.Dropout(0.25)
        
        self.conv4 = nn.Conv2d(128, 512, kernel_size=3, padding=1)
        self.bn3 = nn.BatchNorm2d(512)
        self.pool3 = nn.MaxPool2d(2, 2)
        self.dropout3 = nn.Dropout(0.25)
        
        self.conv5 = nn.Conv2d(512, 512, kernel_size=3, padding=1)
        self.bn4 = nn.BatchNorm2d(512)
        self.pool4 = nn.MaxPool2d(2, 2)
        self.dropout4 = nn.Dropout(0.25)
        
        self.flatten_dim = 512 * 3 * 3  # After 4 maxpools on 48x48 input: 48->24->12->6->3
        
        self.fc1 = nn.Linear(self.flatten_dim, 256)
        self.bn5 = nn.BatchNorm1d(256)
        self.dropout5 = nn.Dropout(0.25)
        
        self.fc2 = nn.Linear(256, 512)
        self.bn6 = nn.BatchNorm1d(512)
        self.dropout6 = nn.Dropout(0.25)
        
        self.fc3 = nn.Linear(512, 7)
        
    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = F.relu(self.conv2(x))
        x = self.bn1(x)
        x = self.pool1(x)
        x = self.dropout1(x)
        
        x = F.relu(self.conv3(x))
        x = self.bn2(x)
        x = self.pool2(x)
        x = self.dropout2(x)
        
        x = F.relu(self.conv4(x))
        x = self.bn3(x)
        x = self.pool3(x)
        x = self.dropout3(x)
        
        x = F.relu(self.conv5(x))
        x = self.bn4(x)
        x = self.pool4(x)
        x = self.dropout4(x)
        
        x = x.view(-1, self.flatten_dim)  # Flatten
        
        x = F.relu(self.fc1(x))
        x = self.bn5(x)
        x = self.dropout5(x)
        
        x = F.relu(self.fc2(x))
        x = self.bn6(x)
        x = self.dropout6(x)
        
        x = self.fc3(x)
        return F.log_softmax(x, dim=1) 

# Optimizer Setup

In [None]:
model = CNN().to(device)


# Loss function and optimizer
criterion = nn.CrossEntropyLoss(weight=class_weights)
optimizer = optim.Adam(model.parameters(), lr=0.001)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=3, min_lr=1e-6)

# Count parameters
total_params = 0
trainable_params = 0

for name, param in model.named_parameters():
    if param.requires_grad:
        trainable_params += param.numel()
    total_params += param.numel()

print(f"Total parameters: {total_params}")
print(f"Trainable parameters: {trainable_params}")

# Training Loop

In [None]:
num_epochs = 50

for epoch in range(num_epochs):
    model.train()
    total_loss, correct, total = 0, 0, 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    train_acc = correct / total
    train_loss = total_loss

    model.eval()
    val_loss, val_correct, val_total = 0, 0, 0
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            val_loss += loss.item()

            _, predicted = torch.max(outputs, 1)
            val_correct += (predicted == labels).sum().item()
            val_total += labels.size(0)

    val_acc = val_correct / val_total
    avg_val_loss = val_loss / len(test_loader)

    scheduler.step(avg_val_loss) 

    print(f"Epoch {epoch+1:>2}/{num_epochs} | Train Loss: {train_loss:.4f} | LR: {scheduler.optimizer.param_groups[0]['lr']:.6f} | Train Acc: {train_acc:.4f} | Val Loss: {avg_val_loss:.4f} | Val Acc: {val_acc:.4f}")


In [None]:
torch.save(model.state_dict(), "models/model.pth")
print("✅ Training complete. Model saved.")

# Test Confusion Matrix

In [None]:
def plot_confusion_matrix(model, test_loader, class_names):
    all_preds = []
    all_labels = []

    model.eval()
    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.to(device)
            outputs = model(inputs)
            _, predicted = torch.max(outputs, 1)

            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.numpy())

    cm = confusion_matrix(all_labels, all_preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', xticklabels=class_names, yticklabels=class_names, cmap="Blues")
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

plot_confusion_matrix(model, test_loader, train_dataset.classes)


In [None]:
import torch
from PIL import Image
from torchvision import transforms

# Example transform (must match your training)
transform = transforms.Compose([
    transforms.Resize((48, 48)),
    transforms.Grayscale(num_output_channels=1),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

# Load model
model = CNN().to(device)
model.load_state_dict(torch.load("models/model.pth", map_location=device)) 
model.eval()

# Load image
img_path = "extracted_faces/example_face_1.jpg"
img = Image.open(img_path)
img = transform(img).unsqueeze(0).to(device)

# Predict
with torch.no_grad():
    output = model(img)
    _, predicted = torch.max(output, 1)
    print(f"Predicted class index: {predicted.item()}")
    print(f"Output {output}")
    print(f"Predicted class: {train_dataset.classes[predicted.item()]}")
