In [None]:
# Imports
import warnings
import logging
import numpy as np
import pandas as pd
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torch
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from transformers import ViTModel
from sklearn.metrics import classification_report, confusion_matrix, ConfusionMatrixDisplay
from torch.optim.lr_scheduler import ReduceLROnPlateau

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
logging.getLogger('tensorflow').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

In [None]:
# Config
MODEL_NAME = "google/vit-base-patch16-224"
DATA_ROOT = "../emotions_data"
MODEL_SAVE_PATH = "saved_model/vit_emotion_head.pth"
EMOTION_CLASSES = ['Angry', 'Disgust', 'Fear', 'Happy', 'Sad', 'Surprise', 'Neutral']

BATCH_SIZE = 64
EPOCHS = 20
LR = 1e-4
PATIENCE = 3

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Custom Dataset
class FERImageDataset(Dataset):
    def __init__(self, root_dir, emotion_classes, transform=None):
        self.transform = transform
        self.samples, self.labels = [], []
        for idx, cls in enumerate(emotion_classes):
            class_dir = os.path.join(root_dir, cls)
            for img_name in os.listdir(class_dir):
                if img_name.endswith(('.png', '.jpg', '.jpeg')):
                    self.samples.append(os.path.join(class_dir, img_name))
                    self.labels.append(idx)

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path = self.samples[idx]
        img = Image.open(img_path).convert("RGB").resize((224, 224))
        if self.transform:
            img = self.transform(img)
        return img, self.labels[idx]

In [None]:
# Transforms
transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# Datasets and Loaders
train_dataset = FERImageDataset(os.path.join(DATA_ROOT, "Training"), EMOTION_CLASSES, transform)
val_dataset = FERImageDataset(os.path.join(DATA_ROOT, "PublicTest"), EMOTION_CLASSES, transform)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
# Model
class EmotionClassifier(nn.Module):
    def __init__(self, vit, num_classes=7, dropout=0.4):
        super().__init__()
        self.vit = vit
        self.dropout = nn.Dropout(dropout)
        self.emotion_head = nn.Sequential(
            nn.Linear(vit.config.hidden_size, 512),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.vit(pixel_values=x).last_hidden_state[:, 0]
        x = self.dropout(x)
        return self.emotion_head(x)

In [None]:
# Load ViT
vit = ViTModel.from_pretrained(MODEL_NAME)
for param in vit.parameters():
    param.requires_grad = False
for name, param in vit.named_parameters():
    if any(f"encoder.layer.{i}" in name for i in [10, 11]):
        param.requires_grad = True

# Class balancing
class_counts = [train_dataset.labels.count(i) for i in range(7)]
class_weights = 1.0 / torch.tensor(class_counts, dtype=torch.float)
class_weights = class_weights / class_weights.sum() * len(EMOTION_CLASSES)

# Loss, optimizer, scheduler
model = EmotionClassifier(vit).to(device)
criterion = nn.CrossEntropyLoss(weight=class_weights.to(device))
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR, weight_decay=0.01)
scheduler = ReduceLROnPlateau(optimizer, mode='max', patience=2, factor=0.5, verbose=True)

In [None]:
# Training loop
best_val_acc = 0
patience_counter = 0
train_losses, val_accuracies = [], []
min_delta = 0.1

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    correct = 0

    for imgs, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        imgs, labels = imgs.to(device), labels.to(device)
        outputs = model(imgs)
        loss = criterion(outputs, labels)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    train_loss = total_loss / len(train_loader)
    acc = 100 * correct / len(train_dataset)
    train_losses.append(train_loss)
    print(f"Train Loss: {train_loss:.4f}, Accuracy: {acc:.2f}%")

    # Validation
    model.eval()
    val_correct = 0
    with torch.no_grad():
        for imgs, labels in val_loader:
            imgs, labels = imgs.to(device), labels.to(device)
            outputs = model(imgs)
            val_correct += (outputs.argmax(1) == labels).sum().item()

    val_acc = 100 * val_correct / len(val_dataset)
    val_accuracies.append(val_acc)
    print(f"Validation Accuracy: {val_acc:.2f}%")

    scheduler.step(val_acc)

    if val_acc > best_val_acc + min_delta:
        best_val_acc = val_acc
        patience_counter = 0
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"Saved improved model at Epoch {epoch+1}")
    else:
        patience_counter += 1
        print(f"No significant improvement. Patience: {patience_counter}/{PATIENCE}")
        if patience_counter >= PATIENCE:
            print("Early stopping triggered due to validation accuracy plateau.")
            break

In [None]:
# Plot
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train Loss')
plt.plot(val_accuracies, label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Metric')
plt.title('Training Loss and Validation Accuracy')
plt.legend()
plt.grid()
plt.show()

In [None]:
# Evaluation
model.load_state_dict(torch.load(MODEL_SAVE_PATH))
model.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for imgs, labels in val_loader:
        imgs = imgs.to(device)
        outputs = model(imgs)
        preds = outputs.argmax(1).cpu().numpy()
        all_preds.extend(preds)
        all_labels.extend(labels.numpy())

print("Classification Report:")
print(classification_report(all_labels, all_preds, target_names=EMOTION_CLASSES))

In [None]:
# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds, normalize='true')
disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=EMOTION_CLASSES)
disp.plot(xticks_rotation=45, cmap='Blues')
plt.title("Confusion Matrix")
plt.grid(False)
plt.show()

In [None]:
# Prediction samples
def show_predictions(model, dataset, num=6):
    model.eval()
    fig, axes = plt.subplots(1, num, figsize=(15, 3))
    for i in range(num):
        img, label = dataset[i]
        with torch.no_grad():
            pred = model(img.unsqueeze(0).to(device)).argmax(1).item()
        axes[i].imshow(img.permute(1, 2, 0).numpy() * 0.5 + 0.5)
        axes[i].set_title(f"GT: {EMOTION_CLASSES[label]}\nPred: {EMOTION_CLASSES[pred]}",
                          color='green' if label == pred else 'red')
        axes[i].axis("off")
    plt.tight_layout()
    plt.show()

show_predictions(model, val_dataset)

print(f"Training complete. Best model saved to {MODEL_SAVE_PATH}")