In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import autocast, GradScaler
from torchvision import transforms
from PIL import Image
import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns
from timm import create_model
import torch.optim.lr_scheduler as lr_scheduler
from tqdm import tqdm

In [None]:
# Dataset class for RAF-DB
class RAFDataset(Dataset):
    EMOTIONS = {1: 1, 2: 5, 3: 4, 4: 0, 5: 2, 6: 3, 7: 6}  # RAF-DB labels to 0-6
                
    def __init__(self, root_dir, split_file, transform=None):
        self.root_dir = root_dir
        self.transform = transform
        self.images = []
        self.labels = []
        with open(split_file, 'r') as f:
            lines = f.readlines()[1:]  # Skip header
        for line in lines:
            parts = line.strip().split(',')
            if len(parts) != 2:
                print(f"Skipping malformed line: {line.strip()}")
                continue
            img_name, label = parts
            try:
                label_int = int(label)
                expected_subfolder = str(label_int)
                base_img_name = img_name
                img_path = os.path.join(root_dir, expected_subfolder, base_img_name)
                found = False
                if os.path.exists(img_path):
                    found = True
                else:
                    if '_aligned' in base_img_name:
                        alt_img_name = base_img_name.replace('_aligned', '')
                        img_path = os.path.join(root_dir, expected_subfolder, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            base_img_name = alt_img_name
                    elif not base_img_name.endswith('_aligned.jpg'):
                        alt_img_name = base_img_name.replace('.jpg', '_aligned.jpg')
                        img_path = os.path.join(root_dir, expected_subfolder, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            base_img_name = alt_img_name
                if not found:
                    for subfolder in range(1, 8):
                        subfolder_str = str(subfolder)
                        img_path = os.path.join(root_dir, subfolder_str, base_img_name)
                        if os.path.exists(img_path):
                            found = True
                            print(f"Found {base_img_name} in {root_dir}/{subfolder_str}/ (expected {root_dir}/{expected_subfolder}/)")
                            break
                        alt_img_name = base_img_name.replace('_aligned', '') if '_aligned' in base_img_name else base_img_name.replace('.jpg', '_aligned.jpg')
                        img_path = os.path.join(root_dir, subfolder_str, alt_img_name)
                        if os.path.exists(img_path):
                            found = True
                            print(f"Found {alt_img_name} in {root_dir}/{subfolder_str}/ (expected {root_dir}/{expected_subfolder}/)")
                            base_img_name = alt_img_name
                            break
                if not found:
                    print(f"Warning: Image not found after searching all subfolders: {base_img_name}")
                    continue
                mapped_label = self.EMOTIONS[label_int]
                self.images.append(img_path)
                self.labels.append(mapped_label)
            except (ValueError, KeyError):
                print(f"Skipping invalid label in line: {line.strip()}")
                continue

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = Image.open(self.images[idx]).convert('RGB')
        label = self.labels[idx]
        if self.transform:
            img = self.transform(img)
        return img, label

In [None]:
# MLCA Module (Lightweight)
class MLCA(nn.Module):
    def __init__(self, x1_dim, x2_dim, embed_dim=512, num_heads=4):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.x1_proj = nn.Linear(x1_dim, embed_dim)
        self.x2_proj = nn.Linear(x2_dim, embed_dim)
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        self.scale = (embed_dim // num_heads) ** -0.5

    def forward(self, x1, x2):
        x1 = self.x1_proj(x1)  # (B, N, embed_dim)
        x2 = self.x2_proj(x2)
        B, N, C = x1.shape

        q = self.q_proj(x1).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        k = self.k_proj(x2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
        v = self.v_proj(x2).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        q = q.reshape(B * self.num_heads, N, C // self.num_heads)
        k = k.reshape(B * self.num_heads, N, C // self.num_heads)
        v = v.reshape(B * self.num_heads, N, C // self.num_heads)

        attn = torch.bmm(q, k.transpose(1, 2)) * self.scale
        attn = attn.softmax(dim=-1)
        out = torch.bmm(attn, v)

        out = out.reshape(B, self.num_heads, N, C // self.num_heads).permute(0, 2, 1, 3).reshape(B, N, C)
        out = self.out_proj(out)
        return out



# SwinFace Model (Lightweight)
class SwinFace(nn.Module):
    def __init__(self, backbone_name='swin_base_patch4_window7_224', embed_dim=512, num_heads=4, num_classes=7, max_tokens=32):
        super().__init__()
        self.backbone = create_model(backbone_name, pretrained=True, features_only=True)
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.max_tokens = max_tokens

        self.classifier = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, num_classes)
        )

    def forward(self, x):
        features = self.backbone(x)
        f1 = features[-2]  # (B, C1, H1, W1)
        f2 = features[-1]  # (B, C2, H2, W2)

        B, C1, H1, W1 = f1.shape
        B, C2, H2, W2 = f2.shape

        f1_flat = f1.flatten(2).transpose(1, 2)  # (B, H1*W1, C1)
        f2_flat = f2.flatten(2).transpose(1, 2)  # (B, H2*W2, C2)

        N = min(f1_flat.size(1), f2_flat.size(1), self.max_tokens)
        f1_flat = f1_flat[:, :N, :]
        f2_flat = f2_flat[:, :N, :]

        if not hasattr(self, 'mlca'):
            self.mlca = MLCA(
                x1_dim=C1,
                x2_dim=C2,
                embed_dim=self.embed_dim,
                num_heads=self.num_heads
            ).to(x.device)

        fused = self.mlca(f1_flat, f2_flat)
        pooled = fused.mean(dim=1)
        logits = self.classifier(pooled)
        return logits


In [None]:
# Transforms
train_transform = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

test_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [None]:
# Data Loaders
train_dataset = RAFDataset(
    root_dir='/kaggle/input/DATASET/train',
    split_file='/kaggle/input/train_labels.csv',
    transform=train_transform
)
test_dataset = RAFDataset(
    root_dir='/kaggle/input/DATASET/test',
    split_file='/kaggle/input/test_labels.csv',
    transform=test_transform
)

print(f"Number of training samples: {len(train_dataset)}")
print(f"Number of test samples: {len(test_dataset)}")

if len(train_dataset) == 0 or len(test_dataset) == 0:
    raise ValueError("One or both datasets are empty. Check the warnings above for missing images.")

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=2)  # Reduced batch size to 16
test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False, num_workers=2)  # Reduced batch size to 16

In [None]:
from torch.amp import GradScaler

# Model Setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = SwinFace().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5, weight_decay=1e-4)
scheduler = lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)
scaler = torch.amp.GradScaler()

In [None]:
from sklearn.metrics import accuracy_score

epochs = 20
epoch_train_losses = []
epoch_val_losses = []
train_accuracies = []
val_accuracies = []

for epoch in range(epochs):
    model.train()
    train_losses = []
    train_preds, train_truths = [], []

    for batch_idx, (imgs, labels) in tqdm(enumerate(train_loader), total=len(train_loader), desc=f"Epoch {epoch+1}/{epochs} - Training"):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()

        with torch.cuda.amp.autocast():
            outputs = model(imgs)
            loss = criterion(outputs, labels)

        if torch.isnan(loss):
            print(f"NaN loss at epoch {epoch+1}, batch {batch_idx+1}. Skipping.")
            continue

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        train_losses.append(loss.item())
        train_preds.extend(outputs.argmax(dim=1).detach().cpu().numpy())
        train_truths.extend(labels.detach().cpu().numpy())

    avg_train_loss = np.mean(train_losses)
    train_acc = accuracy_score(train_truths, train_preds)
    epoch_train_losses.append(avg_train_loss)
    train_accuracies.append(train_acc)

    model.eval()
    val_losses = []
    val_preds, val_truths = [], []

    with torch.no_grad():
        for imgs, labels in tqdm(test_loader, desc=f"Epoch {epoch+1}/{epochs} - Validating"):
            imgs, labels = imgs.to(device), labels.to(device)
            with torch.cuda.amp.autocast():
                outputs = model(imgs)
                loss = criterion(outputs, labels)
                val_losses.append(loss.item())

            val_preds.extend(outputs.argmax(dim=1).detach().cpu().numpy())
            val_truths.extend(labels.detach().cpu().numpy())

    avg_val_loss = np.mean(val_losses)
    val_acc = accuracy_score(val_truths, val_preds)
    epoch_val_losses.append(avg_val_loss)
    val_accuracies.append(val_acc)

    print(f"Epoch {epoch+1}/{epochs}")
    print(f"Train   - Loss: {avg_train_loss:.4f}, Acc: {train_acc:.4f}")
    print(f"Val     - Loss: {avg_val_loss:.4f}, Acc: {val_acc:.4f}")
    
    scheduler.step()


In [None]:
# Save Model
torch.save(model.state_dict(), '/kaggle/working/swinface_model.pth')

In [None]:
# Plotting Functions
def plot_training_progress(epoch_train_losses, val_accuracies):
    epochs = range(1, len(epoch_train_losses) + 1)
    fig, ax1 = plt.subplots()
    ax1.set_xlabel('Epochs')
    ax1.set_ylabel('Training Loss', color='tab:blue')
    ax1.plot(epochs, epoch_train_losses, color='tab:blue', label='Training Loss')
    ax1.tick_params(axis='y', labelcolor='tab:blue')
    ax2 = ax1.twinx()
    ax2.set_ylabel('Validation Accuracy', color='tab:green')
    ax2.plot(epochs, val_accuracies, color='tab:green', label='Validation Accuracy')
    ax2.tick_params(axis='y', labelcolor='tab:green')
    fig.tight_layout()
    plt.title('Training Loss and Validation Accuracy over Epochs')
    plt.show()

def plot_confusion_matrix(truths, preds):
    cm = confusion_matrix(truths, preds)
    plt.figure(figsize=(8, 6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Happy', 'Surprise', 'Sad', 'Anger', 'Disgust', 'Fear', 'Neutral'],
                yticklabels=['Happy', 'Surprise', 'Sad', 'Anger', 'Disgust', 'Fear', 'Neutral'])
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.title('Confusion Matrix')
    plt.show()

In [None]:
# Plot plot_training_progresss
plot_training_progress(epoch_train_losses, val_accuracies)
# Plot plot_confusion_matrix
plot_confusion_matrix(truths, preds)