In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import datasets, models, transforms
from torch.optim import Adam
import time
from torch.utils.data import Subset
from collections import defaultdict

In [None]:
# --- Config ---
data_dir = "../../fer2013"
num_classes = 7  # FER-2013 has 7 emotions
batch_size = 64
num_epochs = 15
learning_rate = 1e-4
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# --- Data transforms ---
train_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],  # ImageNet mean/std
                         [0.229, 0.224, 0.225])
])

val_transforms = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406],
                         [0.229, 0.224, 0.225])
])

In [3]:
# --- Datasets ---
train_dataset = datasets.ImageFolder(root=f"{data_dir}/train", transform=train_transforms)
val_dataset = datasets.ImageFolder(root=f"{data_dir}/test", transform=val_transforms)

# Limit to n samples per class
def limit_per_class(dataset, n=20):
    """Return a subset with up to n samples per class."""
    targets = [sample[1] for sample in dataset.samples]
    selected_idx = []
    class_counts = defaultdict(int)
    for idx, label in enumerate(targets):
        if class_counts[label] < n:
            selected_idx.append(idx)
            class_counts[label] += 1
    return Subset(dataset, selected_idx)

# --- Optional: limit dataset size for quick testing ---
limit_samples = True   # 🔹 change to True to enable limiting for testing
samples_per_class = 20  # 🔹 how many per class when enabled

if limit_samples:
    train_dataset = limit_per_class(train_dataset, samples_per_class)
    val_dataset = limit_per_class(val_dataset, samples_per_class)
    print(f"⚙️ Using only {samples_per_class} samples per class for quick testing.")
else:
    print("✅ Using full dataset.")

⚙️ Using only 20 samples per class for quick testing.


In [6]:
# Update dataloaders
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

### 💡 Why Replace the Classifier Head in EfficientNet-B0
EfficientNet-B0 was pretrained on ImageNet with 1000 object categories, so its final layer outputs 1000 logits. 

For FER-2013, we only have 7 emotion classes (`angry, disgust, fear, happy, sad, surprise, neutral`), so we replace the classifier with `Dropout(p=0.4)` and `Linear(in_features=1280, out_features=7)`. 

This way, we keep the pretrained **feature extractor** (which already detects useful visual features like edges and textures) and only retrain the **final output layer** to learn how to classify emotions instead of objects.

**Summary:**
| Part | Kept or Changed | Purpose |
|------|-----------------|----------|
| Convolutional "features" | ✅ Kept | Extracts general visual patterns |
| Classifier "head" | 🔁 Replaced | Matches 7 emotion classes |
| Dropout | 🔧 Tuned | Reduces overfitting on smaller datasets |

In [7]:
# --- Model ---
model = models.efficientnet_b0(pretrained=True)

# Freeze early layers (optional)
for param in model.features.parameters():
    param.requires_grad = False

# --- Replace classifier head ---

# Number of features coming from EfficientNet’s backbone (1280 for B0)
in_features = model.classifier[1].in_features
model.classifier = nn.Sequential(
    # Dropout helps reduce overfitting since FER-2013 is relatively small
    nn.Dropout(0.4),
    
    # num_classes = 7
    nn.Linear(in_features, num_classes)
)
model = model.to(device)



Downloading: "https://download.pytorch.org/models/efficientnet_b0_rwightman-7f5810bc.pth" to /Users/xinyizhu/.cache/torch/hub/checkpoints/efficientnet_b0_rwightman-7f5810bc.pth


100%|██████████| 20.5M/20.5M [00:00<00:00, 71.5MB/s]


In [8]:
# --- Loss & Optimizer ---
criterion = nn.CrossEntropyLoss()
optimizer = Adam(model.parameters(), lr=learning_rate)

In [9]:
# --- Training Loop ---
def train_model(model, criterion, optimizer, num_epochs=15):
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f"\nEpoch {epoch+1}/{num_epochs}")
        print("-" * 20)

        for phase in ['train', 'val']:
            if phase == 'train':
                model.train()
                loader = train_loader
            else:
                model.eval()
                loader = val_loader

            running_loss = 0.0
            running_corrects = 0

            for inputs, labels in loader:
                inputs, labels = inputs.to(device), labels.to(device)
                optimizer.zero_grad()

                with torch.set_grad_enabled(phase == 'train'):
                    outputs = model(inputs)
                    _, preds = torch.max(outputs, 1)
                    loss = criterion(outputs, labels)

                    if phase == 'train':
                        loss.backward()
                        optimizer.step()

                running_loss += loss.item() * inputs.size(0)
                running_corrects += torch.sum(preds == labels.data)

            epoch_loss = running_loss / len(loader.dataset)
            epoch_acc = running_corrects.double() / len(loader.dataset)

            print(f"{phase} Loss: {epoch_loss:.4f} Acc: {epoch_acc:.4f}")

            # Save best model
            if phase == 'val' and epoch_acc > best_acc:
                best_acc = epoch_acc
                torch.save(model.state_dict(), "best_efficientnetb0_fer2013.pth")

    print(f"\nBest Validation Accuracy: {best_acc:.4f}")
    return model

# --- Train ---
start = time.time()
model = train_model(model, criterion, optimizer, num_epochs)
print(f"Training complete in {(time.time() - start)/60:.2f} minutes.")


Epoch 1/15
--------------------
train Loss: 1.9832 Acc: 0.1643
val Loss: 1.9672 Acc: 0.1500

Epoch 2/15
--------------------
train Loss: 1.9473 Acc: 0.2000
val Loss: 1.9607 Acc: 0.1857

Epoch 3/15
--------------------
train Loss: 1.9665 Acc: 0.1500
val Loss: 1.9584 Acc: 0.1786

Epoch 4/15
--------------------
train Loss: 1.9544 Acc: 0.1357
val Loss: 1.9565 Acc: 0.1786

Epoch 5/15
--------------------
train Loss: 1.9507 Acc: 0.1643
val Loss: 1.9540 Acc: 0.1500

Epoch 6/15
--------------------
train Loss: 1.9355 Acc: 0.1786
val Loss: 1.9483 Acc: 0.1357

Epoch 7/15
--------------------
train Loss: 1.9241 Acc: 0.1714
val Loss: 1.9385 Acc: 0.1571

Epoch 8/15
--------------------
train Loss: 1.9423 Acc: 0.1786
val Loss: 1.9349 Acc: 0.1857

Epoch 9/15
--------------------
train Loss: 1.9695 Acc: 0.1286
val Loss: 1.9323 Acc: 0.1929

Epoch 10/15
--------------------
train Loss: 1.9548 Acc: 0.1643
val Loss: 1.9325 Acc: 0.2071

Epoch 11/15
--------------------
train Loss: 1.9629 Acc: 0.1357
val 