In [3]:
# ==============================================================================
#  SWIN + FANE TEST SCRIPT (Final Version)
#  Evaluates trained model on test set and reports final accuracy
# ==============================================================================

import torch
import torch.nn as nn
from torchvision import datasets, transforms
import timm
from torch.utils.data import DataLoader
from PIL import Image
from tqdm import tqdm
import os
import warnings

# Suppress tqdm notebook warning (optional)
warnings.filterwarnings("ignore", category=UserWarning)

# ==============================================================================
# 1. CONFIGURATION
# ==============================================================================

TEST_DATA_DIR = "/Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/traintestsplit/test"
MODEL_PATH = "/Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/Models/Swin_FANE_Best_Model.pth"

NUM_CLASSES = 7
BATCH_SIZE = 32
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

print(f"✅ Using device: {DEVICE}")

# ==============================================================================
# 2. MODEL DEFINITION (Ensure exact match with training)
# ==============================================================================

class CustomSwinTransformer(nn.Module):
    def __init__(self, pretrained=True, num_classes=7):  # ✅ FIXED double underscores
        super(CustomSwinTransformer, self).__init__()
        # Backbone (same as used during training)
        self.backbone = timm.create_model('swin_base_patch4_window7_224',
                                          pretrained=pretrained,
                                          num_classes=0)
        # Custom classification head
        self.classifier = nn.Sequential(
            nn.Linear(self.backbone.num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.6),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        return self.classifier(x)

# ==============================================================================
# 3. DATA TRANSFORMS AND LOADER
# ==============================================================================

def rgb_converter(img):
    """Convert to 3-channel RGB (handles grayscale inputs)."""
    return img.convert("RGB")

transform = transforms.Compose([
    rgb_converter,
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

test_dataset = datasets.ImageFolder(root=TEST_DATA_DIR, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

print(f"📁 Found {len(test_dataset)} test images.")
print(f"🔤 Class Mapping: {test_dataset.class_to_idx}")

# ==============================================================================
# 4. LOAD MODEL AND WEIGHTS
# ==============================================================================

model = CustomSwinTransformer(pretrained=False, num_classes=NUM_CLASSES).to(DEVICE)

try:
    model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
    print(f"✅ Model weights loaded from: {MODEL_PATH}")
except FileNotFoundError:
    print(f"❌ ERROR: Model file not found at {MODEL_PATH}")
    exit()
except RuntimeError as e:
    print(f"❌ ERROR loading weights: {e}")
    print("Ensure the CustomSwinTransformer definition matches your training script.")
    exit()

model.eval()

# ==============================================================================
# 5. MODEL EVALUATION
# ==============================================================================

correct = 0
total = 0

print("\n🧪 Starting evaluation on Test Set...\n")

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing", unit="batch"):
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

accuracy = 100 * correct / total
print("\n=======================================================")
print(f"📊 FINAL TEST SET ACCURACY: {accuracy:.2f}%")
print("=======================================================\n")


✅ Using device: cpu
📁 Found 4616 test images.
🔤 Class Mapping: {'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}
✅ Model weights loaded from: /Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/Models/Swin_FANE_Best_Model.pth

🧪 Starting evaluation on Test Set...



Testing: 100%|███████████████████████████████████████████████████████████████████████████| 145/145 [02:04<00:00,  1.16batch/s]


📊 FINAL TEST SET ACCURACY: 90.77%






In [4]:
from sklearn.metrics import classification_report, confusion_matrix
import numpy as np

# Assuming you already have model outputs, labels, or can quickly re-run for preds:
y_true = []
y_pred = []

# Collect predictions and true labels
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        outputs = model(images)
        _, preds = torch.max(outputs, 1)
        y_true.extend(labels.cpu().numpy())
        y_pred.extend(preds.cpu().numpy())

# Generate the report
print("\n📄 Classification Report:")
print(classification_report(
    y_true,
    y_pred,
    target_names=list(test_dataset.class_to_idx.keys()),
    digits=4
))



📄 Classification Report:
              precision    recall  f1-score   support

       Angry     0.9302    0.9335    0.9318       571
     Disgust     0.9171    0.8624    0.8889       436
        Fear     0.8879    0.9233    0.9052       926
       Happy     0.9697    0.9791    0.9744       621
     Neutral     0.8932    0.8947    0.8940       589
         Sad     0.8690    0.8869    0.8778       875
    Surprise     0.9196    0.8612    0.8895       598

    accuracy                         0.9077      4616
   macro avg     0.9124    0.9059    0.9088      4616
weighted avg     0.9081    0.9077    0.9076      4616

