In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import timm
from tqdm import tqdm
from PIL import Image
import os

# =====================================================================
# 1. Ensure images are always RGB
# =====================================================================
def rgb_converter(img):
    return img.convert("RGB")

# =====================================================================
# 2. Corrected Model Definition
# =====================================================================
class CustomSwinTransformer(nn.Module):
    def __init__(self, pretrained=True, num_classes=7):
        super(CustomSwinTransformer, self).__init__()
        self.backbone = timm.create_model(
            'swin_base_patch4_window7_224',
            pretrained=pretrained,
            num_classes=0  # remove original classifier
        )
        self.classifier = nn.Sequential(
            nn.Linear(self.backbone.num_features, 512),
            nn.ReLU(),
            nn.Dropout(p=0.6),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        x = self.backbone(x)
        return self.classifier(x)

# =====================================================================
# 3. Training and Validation Functions
# =====================================================================
def train_one_epoch(model, dataloader, optimizer, criterion, device, desc="Training"):
    model.train()
    running_loss, correct, total = 0.0, 0, 0
    for inputs, labels in tqdm(dataloader, desc=desc):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        _, preds = torch.max(outputs, 1)
        correct += torch.sum(preds == labels)
        total += labels.size(0)
    return running_loss / total, correct.double() / total

def validate_model(model, dataloader, criterion, device, desc="Validating"):
    model.eval()
    running_loss, correct, total = 0.0, 0, 0
    with torch.no_grad():
        for inputs, labels in tqdm(dataloader, desc=desc):
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            running_loss += loss.item() * inputs.size(0)
            _, preds = torch.max(outputs, 1)
            correct += torch.sum(preds == labels)
            total += labels.size(0)
    return running_loss / total, correct.double() / total

# =====================================================================
# 4. Main (for Jupyter)
# =====================================================================

# Replace with your actual dataset paths
FANE_TRAIN_DATA_PATH = "/Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/traintestsplit/train"
FANE_VAL_DATA_PATH = "/Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/traintestsplit/val"
MODEL_SAVE_PATH = "/Users/sanskarparab/CC Emotion Detection /Facial-Expression-Recognition-FER-for-Mental-Health-Detection-/Models/Swin_FANE_Best_Model.pth"

NUM_CLASSES = 7
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"✅ Using device: {DEVICE}")

transform = transforms.Compose([
    rgb_converter,
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

train_dataset = datasets.ImageFolder(root=FANE_TRAIN_DATA_PATH, transform=transform)
val_dataset = datasets.ImageFolder(root=FANE_VAL_DATA_PATH, transform=transform)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"Found {len(train_dataset)} training and {len(val_dataset)} validation images.")
print("Class Mapping:", train_dataset.class_to_idx)

# Create Model
model = CustomSwinTransformer(pretrained=True, num_classes=NUM_CLASSES).to(DEVICE)
criterion = nn.CrossEntropyLoss()

# --- Phase 1: Train Classifier Head ---
for param in model.backbone.parameters():
    param.requires_grad = False

optimizer = optim.AdamW(model.classifier.parameters(), lr=5e-4)
best_val_acc = 0.0

print("\n--- PHASE 1: Training Classifier Head ---")
for epoch in range(5):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, desc=f"Head Train {epoch+1}/5")
    val_loss, val_acc = validate_model(model, val_loader, criterion, DEVICE, desc=f"Head Val {epoch+1}/5")
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")

# --- Phase 2: Fine-tune Whole Model ---
for param in model.parameters():
    param.requires_grad = True

optimizer = optim.AdamW(model.parameters(), lr=1e-5)

print("\n--- PHASE 2: Fine-Tuning Entire Model ---")
for epoch in range(5, 15):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, DEVICE, desc=f"Full Train {epoch+1}/15")
    val_loss, val_acc = validate_model(model, val_loader, criterion, DEVICE, desc=f"Full Val {epoch+1}/15")
    print(f"Epoch {epoch+1}: Train Acc={train_acc:.4f}, Val Acc={val_acc:.4f}")
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        torch.save(model.state_dict(), MODEL_SAVE_PATH)
        print(f"⭐ Saved new best model with Validation Accuracy: {best_val_acc:.4f}")


✅ Using device: cpu
Found 13563 training and 2160 validation images.
Class Mapping: {'Angry': 0, 'Disgust': 1, 'Fear': 2, 'Happy': 3, 'Neutral': 4, 'Sad': 5, 'Surprise': 6}


'(ReadTimeoutError("HTTPSConnectionPool(host='huggingface.co', port=443): Read timed out. (read timeout=10)"), '(Request ID: ba0ea447-ac06-4d07-8612-0a5c4725bb0b)')' thrown while requesting HEAD https://huggingface.co/timm/swin_base_patch4_window7_224.ms_in22k_ft_in1k/resolve/main/model.safetensors
Retrying in 1s [Retry 1/5].



--- PHASE 1: Training Classifier Head ---


Head Train 1/5: 100%|███████████████████████████████████████████████████████████████████████| 424/424 [06:16<00:00,  1.12it/s]
Head Val 1/5: 100%|███████████████████████████████████████████████████████████████████████████| 68/68 [01:00<00:00,  1.12it/s]


Epoch 1: Train Acc=0.4750, Val Acc=0.5662


Head Train 2/5: 100%|███████████████████████████████████████████████████████████████████████| 424/424 [06:22<00:00,  1.11it/s]
Head Val 2/5: 100%|███████████████████████████████████████████████████████████████████████████| 68/68 [01:03<00:00,  1.07it/s]


Epoch 2: Train Acc=0.5430, Val Acc=0.6153


Head Train 3/5: 100%|███████████████████████████████████████████████████████████████████████| 424/424 [06:41<00:00,  1.06it/s]
Head Val 3/5: 100%|███████████████████████████████████████████████████████████████████████████| 68/68 [01:05<00:00,  1.04it/s]


Epoch 3: Train Acc=0.5685, Val Acc=0.6384


Head Train 4/5: 100%|███████████████████████████████████████████████████████████████████████| 424/424 [06:24<00:00,  1.10it/s]
Head Val 4/5: 100%|███████████████████████████████████████████████████████████████████████████| 68/68 [00:58<00:00,  1.15it/s]


Epoch 4: Train Acc=0.5884, Val Acc=0.6306


Head Train 5/5: 100%|███████████████████████████████████████████████████████████████████████| 424/424 [06:28<00:00,  1.09it/s]
Head Val 5/5: 100%|███████████████████████████████████████████████████████████████████████████| 68/68 [01:02<00:00,  1.09it/s]


Epoch 5: Train Acc=0.6063, Val Acc=0.6745

--- PHASE 2: Fine-Tuning Entire Model ---


Full Train 6/15: 100%|██████████████████████████████████████████████████████████████████████| 424/424 [23:03<00:00,  3.26s/it]
Full Val 6/15: 100%|██████████████████████████████████████████████████████████████████████████| 68/68 [01:09<00:00,  1.02s/it]


Epoch 6: Train Acc=0.6575, Val Acc=0.7338
⭐ Saved new best model with Validation Accuracy: 0.7338


Full Train 7/15: 100%|██████████████████████████████████████████████████████████████████████| 424/424 [22:14<00:00,  3.15s/it]
Full Val 7/15: 100%|██████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.17it/s]


Epoch 7: Train Acc=0.7096, Val Acc=0.7833
⭐ Saved new best model with Validation Accuracy: 0.7833


Full Train 8/15: 100%|██████████████████████████████████████████████████████████████████████| 424/424 [20:34<00:00,  2.91s/it]
Full Val 8/15: 100%|██████████████████████████████████████████████████████████████████████████| 68/68 [00:56<00:00,  1.20it/s]


Epoch 8: Train Acc=0.7508, Val Acc=0.8250
⭐ Saved new best model with Validation Accuracy: 0.8250


Full Train 9/15: 100%|██████████████████████████████████████████████████████████████████████| 424/424 [20:48<00:00,  2.94s/it]
Full Val 9/15: 100%|██████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.18it/s]


Epoch 9: Train Acc=0.7832, Val Acc=0.8560
⭐ Saved new best model with Validation Accuracy: 0.8560


Full Train 10/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [21:02<00:00,  2.98s/it]
Full Val 10/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.18it/s]


Epoch 10: Train Acc=0.8195, Val Acc=0.8704
⭐ Saved new best model with Validation Accuracy: 0.8704


Full Train 11/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [21:02<00:00,  2.98s/it]
Full Val 11/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.18it/s]


Epoch 11: Train Acc=0.8382, Val Acc=0.8847
⭐ Saved new best model with Validation Accuracy: 0.8847


Full Train 12/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [22:55<00:00,  3.24s/it]
Full Val 12/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [01:05<00:00,  1.04it/s]


Epoch 12: Train Acc=0.8567, Val Acc=0.8898
⭐ Saved new best model with Validation Accuracy: 0.8898


Full Train 13/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [22:00<00:00,  3.12s/it]
Full Val 13/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.17it/s]


Epoch 13: Train Acc=0.8771, Val Acc=0.9060
⭐ Saved new best model with Validation Accuracy: 0.9060


Full Train 14/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [21:27<00:00,  3.04s/it]
Full Val 14/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [00:58<00:00,  1.16it/s]


Epoch 14: Train Acc=0.8900, Val Acc=0.8958


Full Train 15/15: 100%|█████████████████████████████████████████████████████████████████████| 424/424 [21:29<00:00,  3.04s/it]
Full Val 15/15: 100%|█████████████████████████████████████████████████████████████████████████| 68/68 [00:57<00:00,  1.18it/s]


Epoch 15: Train Acc=0.8952, Val Acc=0.9106
⭐ Saved new best model with Validation Accuracy: 0.9106


In [2]:
pip install timm


Collecting timm
  Downloading timm-1.0.21-py3-none-any.whl.metadata (62 kB)
Downloading timm-1.0.21-py3-none-any.whl (2.5 MB)
[2K   [38;2;114;156;31m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m22.6 MB/s[0m  [33m0:00:00[0m
[?25hInstalling collected packages: timm
Successfully installed timm-1.0.21
Note: you may need to restart the kernel to use updated packages.
