In [None]:
# Import
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torch.amp import GradScaler
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from timm import create_model
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm
import pandas as pd


In [None]:
# Load dataset and processing data
# Define labels emotions
emotion_labels = {0: "Happiness", 1: "Surprise", 2: "Sadness", 3: "Anger", 4: "Disgust", 5: "Fear", 6: "Neutral"}

# Dataset class for RAF-DB custom dataset
class RAFDataset(Dataset):
    # Mapping lable emotion on RAF-DB
    EMOTIONS = {1: 1, 2: 5, 3: 4, 4: 0, 5: 2, 6: 3, 7: 6}  # RAF-DB labels to 0-6
    
    def __init__(self, root_dir, split_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        with open(split_file, 'r') as f:
            lines = f.readlines()[1:]  # Skip header
        for line in lines:
            parts = line.strip().split(',')
            if len(parts) != 2:
                print(f"Skipping malformed line: {line.strip()}")
                continue
            img_name, label = parts
            try:
                label_int = int(label)
                expected_subfolder = str(label_int)
                base_img_name = img_name
                img_path = os.path.join(root_dir, expected_subfolder, base_img_name)
                found = False
                if os.path.exists(img_path):
                    found = True
                else:
                    if '_aligned' in base_img_name:
                        alt_img_name = base_img_name.replace('_aligned', '')
                        img_path = os.path.join(root_dir, expected_subfolder, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            base_img_name = alt_img_name
                    elif not base_img_name.endswith('_aligned.jpg'):
                        alt_img_name = base_img_name.replace('.jpg', '_aligned.jpg')
                        img_path = os.path.join(root_dir, expected_subfolder, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            base_img_name = alt_img_name
                if not found:
                    for subfolder in range(1, 8):
                        subfolder_str = str(subfolder)
                        img_path = os.path.join(root_dir, subfolder_str, base_img_name)
                        if os.path.exists(img_path):
                            found = True
                            print(f"Found {base_img_name} in {root_dir}/{subfolder_str}/ (expected {root_dir}/{expected_subfolder}/)")
                            break
                        alt_img_name = base_img_name.replace('_aligned', '') if '_aligned' in base_img_name else base_img_name.replace('.jpg', '_aligned.jpg')
                        img_path = os.path.join(root_dir, subfolder_str, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            print(f"Found {alt_img_name} in {root_dir}/{subfolder_str}/ (expected {root_dir}/{expected_subfolder}/)")
                            base_img_name = alt_img_name
                            break
                if not found:
                    print(f"Warning: Image not found after searching all subfolders: {base_img_name}")
                    continue
                mapped_label = self.EMOTIONS[label_int]
                self.images.append(img_path)
                self.labels.append(mapped_label)
            except (ValueError, KeyError):
                print(f"Skipping invalid label in line: {line.strip()}")
                continue

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

# Data Loaders and split into train and test set data
train_dataset = RAFDataset(
    root_dir='/kaggle/input/DATASET/train',
    split_file='/kaggle/input/train_labels.csv',
    transform=train_transform
)
test_dataset = RAFDataset(
    root_dir='/kaggle/input/DATASET/test',
    split_file='/kaggle/input/test_labels.csv',
    transform=test_transform
)

# Print dataset sizes
print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

# Check class distribution in the training set
train_labels = [label for _, label in train_dataset]
train_class_counts = pd.Series(train_labels).value_counts().sort_index()
print("\nTraining set class distribution (Mapped labels 0-6 to emotion names):")
for mapped_label, count in train_class_counts.items():
    emotion_name = emotion_labels[mapped_label]
    print(f"Class {mapped_label} ({emotion_name}): {count} images")

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)

# Print shapes of one batch from each loader
train_image, train_label = next(iter(train_loader))
print(f"\nTrain batch: Image shape {train_image.shape}, Label shape {train_label.shape}")

test_image, test_label = next(iter(test_loader))
print(f"Test batch: Image shape {test_image.shape}, Label shape {test_label.shape}")

if len(train_dataset) == 0 or len(test_dataset) == 0:
    raise ValueError("One or both datasets are empty. Check the warnings above for missing images.")

In [None]:
# Define models
# MLCA Module (Lightweight)
class MLCA(nn.Module):
    def __init__(self, x1_dim, x2_dim, embed_dim=512, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.x1_proj = nn.Linear(x1_dim, embed_dim)
        self.x2_proj = nn.Linear(x2_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5

    def forward(self, x1, x2):
        x1 = self.x1_proj(x1)  # (B, N, embed_dim)
        x2 = self.x2_proj(x2)
        B, N, C = x1.shape

        q = self.q_proj(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k_proj(x2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v_proj(x2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        q = q.reshape(B * self.num_heads, N, C // self.num_heads)
        k = k.reshape(B * self.num_heads, N, C // self.num_heads)
        v = v.reshape(B * self.num_heads, N, C // self.num_heads)

        attn = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.bmm(attn, v)

        out = out.reshape(B, self.num_heads, N, C // self.num_heads).permute(0, 2, 1, 3).reshape(B, N, C)
        out = self.out_proj(out)
        return out

# SwinFace Model (Lightweight)
class SwinFace(nn.Module):
    def __init__(self, backbone_name='swin_base_patch4_window7_224', embed_dim=512, num_heads=4, num_classes=7, max_tokens=32):
        super().__init__()
        self.backbone = create_model(backbone_name, pretrained=True, features_only=True)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_tokens = max_tokens

        # C1 = 512 (features[-2]), C2 = 1024 (features[-1]) for swin_base
        self.mlca = MLCA(x1_dim=14, x2_dim=7, embed_dim=self.embed_dim, num_heads=self.num_heads)

        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        f1 = features[-2]  # (B, C1, H1, W1)
        f2 = features[-1]  # (B, C2, H2, W2)

        f1_flat = f1.flatten(2).transpose(1, 2)
        f2_flat = f2.flatten(2).transpose(1, 2)

        N = min(f1_flat.size(1), f2_flat.size(1), self.max_tokens)
        f1_flat = f1_flat[:, :N, :]
        f2_flat = f2_flat[:, :N, :]

        fused = self.mlca(f1_flat, f2_flat)
        pooled = fused.mean(dim=1)
        logits = self.classifier(pooled)
        return logits

In [None]:
# Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinFace().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
scaler = torch.amp.GradScaler()

In [None]:
# Checkpoint Management Setup
checkpoint_dir = "/kaggle/working/checkpoints"
os.makedirs(checkpoint_dir, exist_ok=True)
checkpoint_path = os.path.join(checkpoint_dir, "checkpoint.pth")

epochs = 20
train_losses = []
val_losses = []
train_accs = []
val_accs = []

# Early stopping variables
patience = 5
epochs_no_improve = 0
best_acc = 0.0
start_epoch = 0
early_stop = False

# Load checkpoint if it exists
if os.path.exists(checkpoint_path):
    checkpoint = torch.load(checkpoint_path, map_location=device)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    scaler.load_state_dict(checkpoint['scaler_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    best_acc = checkpoint['best_acc']
    train_losses = checkpoint.get('train_losses', [])
    val_losses = checkpoint.get('val_losses', [])
    train_accs = checkpoint.get('train_accs', [])
    val_accs = checkpoint.get('val_accs', [])
    print(f"Resumed training from epoch {start_epoch} with best accuracy {best_acc*100:.2f}%")

for epoch in range(start_epoch, epochs):
    model.train()
    epoch_train_losses = []
    train_preds, train_truths = [], []

    for batch_idx, (imgs, labels) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - Training"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print(f"NaN loss at epoch {epoch+1}, batch {batch_idx+1}. Skipping.")
            continue
            
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        epoch_train_losses.append(loss.item())
        train_preds.extend(outputs.argmax(dim=1).detach().cpu().numpy())
        train_truths.extend(labels.detach().cpu().numpy())

    avg_train_loss = np.mean(epoch_train_losses)
    train_acc = accuracy_score(train_truths, train_preds)
    train_losses.append(avg_train_loss)
    train_accs.append(train_acc)

    model.eval()
    epoch_val_losses = []
    val_preds, val_truths = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} - Validating"):
            imgs, labels = imgs.to(device), labels.to(device)
            with torch.amp.autocast('cuda'):
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                epoch_val_losses.append(loss.item())

            val_preds.extend(outputs.argmax(dim=1).detach().cpu().numpy())
            val_truths.extend(labels.detach().cpu().numpy())

    avg_val_loss = np.mean(epoch_val_losses)
    val_acc = accuracy_score(val_truths, val_preds)
    val_losses.append(avg_val_loss)
    val_accs.append(val_acc)

    print(f"Epoch [{epoch+1}/{epochs}], Train Loss: {avg_train_loss:.4f}, Train Acc: {train_acc:.2f}%, Val Loss: {avg_val_loss:.4f}, Val Acc: {val_acc:.2f}%")

    scheduler.step()

    # Save checkpoint
    checkpoint = {
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scaler_state_dict': scaler.state_dict(),
        'best_acc': best_acc,
        'train_losses': train_losses,
        'val_losses': val_losses,
        'train_accs': train_accs,
        'val_accs': val_accs
    }
    torch.save(checkpoint, checkpoint_path)

    # Update best model and early stopping
    if val_acc > best_acc:
        best_acc = val_acc
        epochs_no_improve = 0
        torch.save(model.state_dict(), "/kaggle/working/best_model.pth")
        print(f"Saved best model with Val Acc: {best_acc*100:.2f}%")
    else:
        epochs_no_improve += 1
        print(f"No improvement in Val Acc for {epochs_no_improve} epochs.")
        if epochs_no_improve >= patience:
            print("Early stopping triggered.")
            early_stop = True
            break

    if early_stop:
        break

In [None]:
# Save Model
torch.save(model.state_dict(), '/kaggle/working/swinface_model.pth')

In [None]:
# Plot Training Curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(epoch_train_losses)+1), epoch_train_losses, label='Train Loss')
plt.plot(range(1, len(epoch_val_losses)+1), epoch_val_losses, label='Val Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, len(train_accuracies)+1), train_accuracies, label='Train Acc')
plt.plot(range(1, len(val_accuracies)+1), val_accuracies, label='Val Acc')
plt.xlabel('Epoch')
plt.ylabel('Accuracy')
plt.legend()

plt.show()

In [None]:
# Load best model
model.load_state_dict(torch.load('/kaggle/working/best_swinface_model.pth'))

# Evaluate best model
model.eval()
val_preds = []
val_truths = []

with torch.no_grad():
    for imgs, labels in tqdm(test_loader, desc="Evaluating best model"):
        imgs, labels = imgs.to(device), labels.to(device)
        with torch.amp.autocast('cuda'):
            outputs = model(imgs)
            val_preds.extend(outputs.argmax(dim=1).detach().cpu().numpy())
            val_truths.extend(labels.detach().cpu().numpy())

# Confusion Matrix
cm = confusion_matrix(val_truths, val_preds)
emotion_labels = ['Happiness', 'Surprise', 'Sadness', 'Anger', 'Disgust', 'Fear', 'Neutral']

plt.figure(figsize=(10, 8))
sns.heatmap(cm, annot=True, fmt='d', xticklabels=emotion_labels, yticklabels=emotion_labels)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')
plt.show()

In [None]:
# Evaluate Model
model.eval()
correct = 0
total = 0
all_preds = []
all_labels = []

# Load the best model for evaluation
if os.path.exists("/kaggle/working/best_model.pth"):
    model.load_state_dict(torch.load("/kaggle/working/best_model.pth", map_location=device))
    print("Loaded best model for evaluation.")
else:
    print("Best model not found. Evaluating with the final model.")
    if os.path.exists("/kaggle/working/swinface_model.pth"):
        model.load_state_dict(torch.load("/kaggle/working/swinface_model.pth", map_location=device))
    else:
        print("No saved model found for evaluation.")

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Evaluating on Test Set"):
        images, labels = images.to(device), labels.to(device)
        with torch.amp.autocast('cuda'):
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())

if total > 0:
    test_acc = 100 * correct / total
    print(f"Test Accuracy: {test_acc:.2f}%")
else:
    print("Test set is empty or not loaded correctly. Cannot compute accuracy.")
    test_acc = 0.0

In [None]:
# Predict Emotion Function
def predict_emotion(img_path, model, transform, device):
    img = Image.open(img_path).convert('RGB')
    img = transform(img).unsqueeze(0).to(device)
    
    model.eval()
    with torch.no_grad():
        with torch.amp.autocast('cuda'):
            output = model(img)
            pred = output.argmax(dim=1).item()
    
    emotion_labels = ['Happiness', 'Surprise', 'Sadness', 'Anger', 'Disgust', 'Fear', 'Neutral']
    return emotion_labels[pred]

# Example usage
img_path = '/kaggle/input/raf-db-dataset/DATASET/test/4/test_0003_aligned.jpg'
predicted_emotion = predict_emotion(img_path, model, test_transform, device)
print(f'Predicted emotion: {predicted_emotion}')