In [1]:
#!/usr/bin/env python3
"""
Optimized FER Training Pipeline with:
 - Fast in-memory data loading (pre-cached images)
 - Stratified train/val split
 - Modern PyTorch APIs
 - Efficient data augmentation
 - Early stopping and learning rate scheduling
"""


'\nOptimized FER Training Pipeline with:\n - Fast in-memory data loading (pre-cached images)\n - Stratified train/val split\n - Modern PyTorch APIs\n - Efficient data augmentation\n - Early stopping and learning rate scheduling\n'

In [2]:
import os
import random
from pathlib import Path
from collections import Counter
from typing import Tuple, List

import numpy as np
from PIL import Image

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader

from torchvision import transforms, models
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm

In [3]:




# ============================================================
#                  1. Data Collection
# ============================================================

def gather_image_paths_and_labels(root_dir: str) -> Tuple[List[str], List[int], List[str]]:
    """
    Collect all images from emotion class subdirectories.
    Handles both structures:
      - root/emotion/*.jpg
      - root/train/emotion/*.jpg and root/test/emotion/*.jpg
    
    Returns:
        image_paths: List of file paths
        labels: List of integer labels
        class_names: List of class names (sorted)
    """
    root = Path(root_dir)
    if not root.exists():
        raise FileNotFoundError(f"Dataset root not found: {root_dir}")

    # Collect all images with their parent directory (emotion class)
    image_data = []
    exts = {".jpg", ".jpeg", ".png", ".bmp"}
    
    for img_path in root.rglob("*"):
        if img_path.suffix.lower() in exts:
            # Get emotion class (parent directory name)
            emotion_class = img_path.parent.name
            # Skip if parent is 'train' or 'test' - go up one more level
            if emotion_class in ['train', 'test']:
                emotion_class = img_path.parent.parent.name
            
            image_data.append((str(img_path), emotion_class))
    
    if not image_data:
        raise ValueError(f"No images found under {root_dir}")
    
    # Create class mapping
    unique_classes = sorted(set(emotion for _, emotion in image_data))
    class_to_idx = {name: idx for idx, name in enumerate(unique_classes)}
    
    # Convert to lists
    image_paths = [path for path, _ in image_data]
    labels = [class_to_idx[emotion] for _, emotion in image_data]
    
    return image_paths, labels, unique_classes

In [4]:



# ============================================================
#                  2. Pre-cached Dataset
# ============================================================

class PreCachedImageDataset(Dataset):
    """
    Ultra-fast dataset that stores pre-processed tensors in memory.
    For training: applies minimal random augmentation on cached tensors.
    For validation: returns tensors directly (no augmentation).
    """
    def __init__(self, image_paths: List[str], labels: List[int], 
                 transform=None, cache_images: bool = True, img_size: int = 224,
                 is_train: bool = True):
        self.labels = torch.tensor(labels, dtype=torch.long)  # Pre-convert labels to tensor
        self.transform = transform
        self.is_train = is_train
        self.cached_tensors = []
        
        if cache_images:
            print(f"Pre-loading and processing {len(image_paths)} images...")
            
            # Pre-process transform (applied once during caching)
            imagenet_mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
            imagenet_std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
            
            resize_transform = transforms.Compose([
                transforms.Resize(img_size),
                transforms.ToTensor(),
            ])
            
            for path in tqdm(image_paths, desc="Caching"):
                try:
                    img = Image.open(path).convert("RGB")
                    img_tensor = resize_transform(img)
                    
                    # For validation, pre-normalize to avoid doing it on every access
                    if not is_train:
                        img_tensor = (img_tensor - imagenet_mean) / imagenet_std
                    
                    self.cached_tensors.append(img_tensor)
                except Exception as e:
                    print(f"Warning: Failed to load {path}: {e}")
                    blank = torch.zeros(3, img_size, img_size)
                    if not is_train:
                        blank = (blank - imagenet_mean) / imagenet_std
                    self.cached_tensors.append(blank)
        else:
            self.cached_tensors = None
            self.image_paths = image_paths

    def __len__(self):
        return len(self.labels)

    def __getitem__(self, idx):
        if self.cached_tensors is not None:
            img = self.cached_tensors[idx]
            
            # For training, apply augmentation
            if self.is_train and self.transform:
                img = self.transform(img)
            # For validation, tensor is already normalized, return as-is
            
        else:
            img = Image.open(self.image_paths[idx]).convert("RGB")
            if self.transform:
                img = self.transform(img)
        
        return img, self.labels[idx]


In [5]:


# ============================================================
#                  3. Transforms
# ============================================================

def get_augmentation_transforms() -> Tuple[transforms.Compose, transforms.Compose]:
    """
    Returns lightweight augmentation transforms for pre-cached tensors.
    Applied on GPU-ready tensors for speed.
    """
    imagenet_mean = [0.485, 0.456, 0.406]
    imagenet_std = [0.229, 0.224, 0.225]
    
    # Train: augmentation on tensors (fast)
    train_transform = transforms.Compose([
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.Normalize(imagenet_mean, imagenet_std),
        transforms.RandomErasing(p=0.1, scale=(0.02, 0.1)),  # Applied after normalize
    ])
    
    # Val: just normalize
    val_transform = transforms.Compose([
        transforms.Normalize(imagenet_mean, imagenet_std),
    ])
    
    return train_transform, val_transform


In [6]:


# ============================================================
#                  4. Model Builder
# ============================================================

def build_model(model_name: str = 'resnet18', num_classes: int = 7, 
                dropout_rate: float = 0.5, freeze_backbone: bool = False) -> nn.Module:
    """
    Build ResNet-based classifier with modern PyTorch API.
    """
    if model_name == 'resnet18':
        weights = models.ResNet18_Weights.IMAGENET1K_V1
        model = models.resnet18(weights=weights)
    elif model_name == 'resnet50':
        weights = models.ResNet50_Weights.IMAGENET1K_V1
        model = models.resnet50(weights=weights)
    else:
        raise ValueError(f"Unsupported model: {model_name}")

    # Replace classifier head
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Dropout(dropout_rate / 2),  # Light dropout before first layer
        nn.Linear(in_features, 512),
        nn.BatchNorm1d(512),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout_rate),
        nn.Linear(512, 256),
        nn.BatchNorm1d(256),
        nn.ReLU(inplace=True),
        nn.Dropout(dropout_rate),
        nn.Linear(256, num_classes)
    )

    # Freeze backbone if requested
    if freeze_backbone:
        for name, param in model.named_parameters():
            if "fc" not in name:
                param.requires_grad = False

    return model


In [7]:


# ============================================================
#                  5. Training Functions
# ============================================================

def train_one_epoch(model: nn.Module, loader: DataLoader, criterion: nn.Module,
                   optimizer: optim.Optimizer, device: torch.device, epoch: int) -> Tuple[float, float]:
    """Train for one epoch with optimized data loading"""
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Train]", leave=False)

    for images, labels in pbar:
        images = images.to(device, non_blocking=True)
        labels = labels.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)  # Faster than zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        preds = outputs.argmax(dim=1)
        correct += (preds == labels).sum().item()
        total += images.size(0)

        # Update progress bar less frequently for speed
        if total % (images.size(0) * 5) == 0:  # Update every 5 batches
            avg_loss = running_loss / total
            avg_acc = 100.0 * correct / total
            pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    avg_loss = running_loss / total
    avg_acc = 100.0 * correct / total
    return avg_loss, avg_acc

    return avg_loss, avg_acc


def evaluate(model: nn.Module, loader: DataLoader, criterion: nn.Module,
            device: torch.device, epoch: int) -> Tuple[float, float]:
    """Evaluate on validation set"""
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    pbar = tqdm(loader, desc=f"Epoch {epoch} [Val]", leave=False)

    with torch.no_grad():
        for images, labels in pbar:
            images = images.to(device, non_blocking=True)
            labels = labels.to(device, non_blocking=True)

            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            preds = outputs.argmax(dim=1)
            correct += (preds == labels).sum().item()
            total += images.size(0)

            # Update progress bar
            avg_loss = running_loss / total
            avg_acc = 100.0 * correct / total
            pbar.set_postfix({"loss": f"{avg_loss:.4f}", "acc": f"{avg_acc:.2f}%"})

    return avg_loss, avg_acc

In [8]:



# ============================================================
#                  6. Main Training Pipeline
# ============================================================

def main():
    # ============= Configuration =============
    CONFIG = {
        'data_root': r"C:\Users\be724\Downloads\archive (3)",
        'model_name': 'resnet18',  # or 'resnet50'
        'img_size': 112,
        'batch_size': 64,  # Large batch size works great with cached data
        'epochs': 30,
        'lr': 1e-3,  # Higher initial LR (will be reduced by scheduler)
        'weight_decay': 1e-4,
        'dropout': 0.5,
        'freeze_backbone': False,
        'val_size': 0.15,
        'random_seed': 42,
        'num_workers': 0,  # Must be 0 for cached tensor datasets
        'patience': 7,
        'use_class_weights': True,
        'save_dir': './checkpoints',
        'cache_images': True,  # Pre-load images into RAM
    }

    # ============= Setup =============
    random.seed(CONFIG['random_seed'])
    np.random.seed(CONFIG['random_seed'])
    torch.manual_seed(CONFIG['random_seed'])
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(CONFIG['random_seed'])
        torch.backends.cudnn.benchmark = True  # Optimize for fixed input size

    os.makedirs(CONFIG['save_dir'], exist_ok=True)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # ============= Load Data =============
    print("\n" + "="*60)
    print("Loading dataset...")
    print("="*60)
    
    image_paths, labels, class_names = gather_image_paths_and_labels(CONFIG['data_root'])
    num_classes = len(class_names)
    
    print(f"\nDataset Statistics:")
    print(f"  Total images: {len(image_paths)}")
    print(f"  Classes ({num_classes}): {class_names}")
    
    # Class distribution
    label_counts = Counter(labels)
    print(f"\n  Class distribution:")
    for cls_idx, cls_name in enumerate(class_names):
        print(f"    {cls_name}: {label_counts[cls_idx]}")

    # ============= Stratified Split =============
    print(f"\nCreating stratified split (val_size={CONFIG['val_size']})...")
    
    sss = StratifiedShuffleSplit(n_splits=1, test_size=CONFIG['val_size'],
                                random_state=CONFIG['random_seed'])
    indices = np.arange(len(labels))
    train_idx, val_idx = next(sss.split(indices, labels))

    train_paths = [image_paths[i] for i in train_idx]
    train_labels_list = [labels[i] for i in train_idx]
    val_paths = [image_paths[i] for i in val_idx]
    val_labels_list = [labels[i] for i in val_idx]

    print(f"  Train samples: {len(train_paths)}")
    print(f"  Val samples: {len(val_paths)}")

    # ============= Create Datasets =============
    train_tf, val_tf = get_augmentation_transforms()
    
    train_dataset = PreCachedImageDataset(
        train_paths, train_labels_list, 
        transform=train_tf, 
        cache_images=CONFIG['cache_images'],
        img_size=CONFIG['img_size'],
        is_train=True
    )
    val_dataset = PreCachedImageDataset(
        val_paths, val_labels_list,
        transform=None,  # Validation is pre-normalized
        cache_images=CONFIG['cache_images'],
        img_size=CONFIG['img_size'],
        is_train=False
    )

    train_loader = DataLoader(
        train_dataset, 
        batch_size=CONFIG['batch_size'],
        shuffle=True,
        num_workers=0,  # CRITICAL: Use 0 workers with cached data to avoid pickle overhead
        pin_memory=True if torch.cuda.is_available() else False,
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=CONFIG['batch_size'],
        shuffle=False,
        num_workers=0,  # CRITICAL: Use 0 workers with cached data
        pin_memory=True if torch.cuda.is_available() else False,
    )

    # ============= Build Model =============
    print(f"\n{'='*60}")
    print("Building model...")
    print("="*60)
    
    model = build_model(
        model_name=CONFIG['model_name'],
        num_classes=num_classes,
        dropout_rate=CONFIG['dropout'],
        freeze_backbone=CONFIG['freeze_backbone']
    ).to(device)
    
    print(f"  Model: {CONFIG['model_name']}")
    print(f"  Freeze backbone: {CONFIG['freeze_backbone']}")
    print(f"  Dropout: {CONFIG['dropout']}")
    print(f"  Trainable params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

    # ============= Loss & Optimizer =============
    if CONFIG['use_class_weights']:
        train_counts = np.bincount(train_labels_list, minlength=num_classes)
        class_weights = 1.0 / (train_counts + 1e-6)
        class_weights = class_weights * (len(train_labels_list) / class_weights.sum())
        class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)
        criterion = nn.CrossEntropyLoss(weight=class_weights)
        print(f"\n  Using class weights: {class_weights.cpu().numpy()}")
    else:
        criterion = nn.CrossEntropyLoss()

    optimizer = optim.AdamW(
        filter(lambda p: p.requires_grad, model.parameters()),
        lr=CONFIG['lr'],
        weight_decay=CONFIG['weight_decay']
    )
    
    scheduler = ReduceLROnPlateau(
        optimizer, mode='min', factor=0.5, patience=3
    )

    # ============= Training Loop =============
    print(f"\n{'='*60}")
    print("Starting training...")
    print("="*60 + "\n")
    
    best_val_acc = 0.0
    best_epoch = 0
    patience_counter = 0

    for epoch in range(1, CONFIG['epochs'] + 1):
        train_loss, train_acc = train_one_epoch(
            model, train_loader, criterion, optimizer, device, epoch
        )
        val_loss, val_acc = evaluate(
            model, val_loader, criterion, device, epoch
        )

        print(f"\n[Epoch {epoch}/{CONFIG['epochs']}]")
        print(f"  Train -> Loss: {train_loss:.4f}  Acc: {train_acc:.2f}%")
        print(f"  Val   -> Loss: {val_loss:.4f}  Acc: {val_acc:.2f}%")

        scheduler.step(val_loss)

        # Save best model
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_epoch = epoch
            patience_counter = 0
            
            best_path = os.path.join(CONFIG['save_dir'], "best_model.pth")
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'val_acc': val_acc,
                'val_loss': val_loss,
                'class_names': class_names,
                'config': CONFIG
            }, best_path)
            print(f"  ✓ New best model saved! (Val Acc: {val_acc:.2f}%)")
        else:
            patience_counter += 1
            print(f"  No improvement ({patience_counter}/{CONFIG['patience']})")
            
            if patience_counter >= CONFIG['patience']:
                print(f"\n{'='*60}")
                print(f"Early stopping triggered (no improvement for {CONFIG['patience']} epochs)")
                print("="*60)
                break

    # ============= Training Complete =============
    print(f"\n{'='*60}")
    print("Training Complete!")
    print("="*60)
    print(f"  Best Val Acc: {best_val_acc:.2f}% (Epoch {best_epoch})")
    print(f"  Model saved to: {CONFIG['save_dir']}/best_model.pth")


if __name__ == "__main__":
    main()

Using device: cpu

Loading dataset...

Dataset Statistics:
  Total images: 35887
  Classes (7): ['angry', 'disgusted', 'fearful', 'happy', 'neutral', 'sad', 'surprised']

  Class distribution:
    angry: 4953
    disgusted: 547
    fearful: 5121
    happy: 8989
    neutral: 6198
    sad: 6077
    surprised: 4002

Creating stratified split (val_size=0.15)...
  Train samples: 30503
  Val samples: 5384
Pre-loading and processing 30503 images...


Caching: 100%|██████████| 30503/30503 [00:40<00:00, 755.21it/s] 


Pre-loading and processing 5384 images...


Caching: 100%|██████████| 5384/5384 [00:11<00:00, 469.07it/s]



Building model...
  Model: resnet18
  Freeze backbone: False
  Dropout: 0.5
  Trainable params: 11,573,831

  Using class weights: [ 2114.7783 19146.703   2045.306   1165.3425  1690.0564  1723.7594
  2617.0537]

Starting training...



                                                                                           


[Epoch 1/30]
  Train -> Loss: 1.6421  Acc: 37.63%
  Val   -> Loss: 1.4565  Acc: 39.67%
  ✓ New best model saved! (Val Acc: 39.67%)


                                                                                           


[Epoch 2/30]
  Train -> Loss: 1.3898  Acc: 47.78%
  Val   -> Loss: 1.3160  Acc: 52.14%
  ✓ New best model saved! (Val Acc: 52.14%)


                                                                                           


[Epoch 3/30]
  Train -> Loss: 1.2659  Acc: 52.27%
  Val   -> Loss: 1.2822  Acc: 54.62%
  ✓ New best model saved! (Val Acc: 54.62%)


                                                                                           


[Epoch 4/30]
  Train -> Loss: 1.2051  Acc: 54.13%
  Val   -> Loss: 1.2436  Acc: 55.55%
  ✓ New best model saved! (Val Acc: 55.55%)


                                                                                           


[Epoch 5/30]
  Train -> Loss: 1.1437  Acc: 56.50%
  Val   -> Loss: 1.1804  Acc: 55.20%
  No improvement (1/7)


                                                                                           


[Epoch 6/30]
  Train -> Loss: 1.0583  Acc: 59.37%
  Val   -> Loss: 1.1982  Acc: 57.56%
  ✓ New best model saved! (Val Acc: 57.56%)


                                                                                           


[Epoch 7/30]
  Train -> Loss: 1.0505  Acc: 60.10%
  Val   -> Loss: 1.0719  Acc: 57.80%
  ✓ New best model saved! (Val Acc: 57.80%)


                                                                                           


[Epoch 8/30]
  Train -> Loss: 0.9787  Acc: 62.21%
  Val   -> Loss: 1.1655  Acc: 58.08%
  ✓ New best model saved! (Val Acc: 58.08%)


                                                                                           


[Epoch 9/30]
  Train -> Loss: 0.9496  Acc: 63.68%
  Val   -> Loss: 1.0744  Acc: 61.70%
  ✓ New best model saved! (Val Acc: 61.70%)


                                                                                            


[Epoch 10/30]
  Train -> Loss: 0.8839  Acc: 66.56%
  Val   -> Loss: 1.1152  Acc: 63.32%
  ✓ New best model saved! (Val Acc: 63.32%)


                                                                                            


[Epoch 11/30]
  Train -> Loss: 0.8308  Acc: 67.56%
  Val   -> Loss: 1.1231  Acc: 61.79%
  No improvement (1/7)


                                                                                            


[Epoch 12/30]
  Train -> Loss: 0.7007  Acc: 73.10%
  Val   -> Loss: 1.0381  Acc: 64.71%
  ✓ New best model saved! (Val Acc: 64.71%)


                                                                                            


[Epoch 13/30]
  Train -> Loss: 0.6287  Acc: 75.92%
  Val   -> Loss: 1.0434  Acc: 63.87%
  No improvement (1/7)


                                                                                            


[Epoch 14/30]
  Train -> Loss: 0.5546  Acc: 78.95%
  Val   -> Loss: 1.1433  Acc: 63.06%
  No improvement (2/7)


                                                                                            


[Epoch 15/30]
  Train -> Loss: 0.4769  Acc: 82.20%
  Val   -> Loss: 1.1997  Acc: 63.47%
  No improvement (3/7)


                                                                                            


[Epoch 16/30]
  Train -> Loss: 0.4327  Acc: 84.20%
  Val   -> Loss: 1.2726  Acc: 64.28%
  No improvement (4/7)


                                                                                            


[Epoch 17/30]
  Train -> Loss: 0.3190  Acc: 88.75%
  Val   -> Loss: 1.2956  Acc: 65.21%
  ✓ New best model saved! (Val Acc: 65.21%)


                                                                                            


[Epoch 18/30]
  Train -> Loss: 0.2405  Acc: 91.77%
  Val   -> Loss: 1.4284  Acc: 65.62%
  ✓ New best model saved! (Val Acc: 65.62%)


                                                                                            


[Epoch 19/30]
  Train -> Loss: 0.1973  Acc: 93.41%
  Val   -> Loss: 1.4902  Acc: 65.32%
  No improvement (1/7)


                                                                                                    


[Epoch 20/30]
  Train -> Loss: 0.1693  Acc: 94.44%
  Val   -> Loss: 1.6254  Acc: 64.62%
  No improvement (2/7)


                                                                                                


[Epoch 21/30]
  Train -> Loss: 0.1258  Acc: 95.90%
  Val   -> Loss: 1.6347  Acc: 65.42%
  No improvement (3/7)


                                                                                             


[Epoch 22/30]
  Train -> Loss: 0.1099  Acc: 96.63%
  Val   -> Loss: 1.6134  Acc: 65.64%
  ✓ New best model saved! (Val Acc: 65.64%)


                                                                                              


[Epoch 23/30]
  Train -> Loss: 0.0971  Acc: 97.04%
  Val   -> Loss: 1.7265  Acc: 65.38%
  No improvement (1/7)


                                                                                                    


[Epoch 24/30]
  Train -> Loss: 0.0825  Acc: 97.54%
  Val   -> Loss: 1.7945  Acc: 65.27%
  No improvement (2/7)


                                                                                            


[Epoch 25/30]
  Train -> Loss: 0.0643  Acc: 98.14%
  Val   -> Loss: 1.8197  Acc: 65.73%
  ✓ New best model saved! (Val Acc: 65.73%)


                                                                                            


[Epoch 26/30]
  Train -> Loss: 0.0587  Acc: 98.27%
  Val   -> Loss: 1.8484  Acc: 65.81%
  ✓ New best model saved! (Val Acc: 65.81%)


                                                                                            


[Epoch 27/30]
  Train -> Loss: 0.0558  Acc: 98.34%
  Val   -> Loss: 1.8597  Acc: 65.82%
  ✓ New best model saved! (Val Acc: 65.82%)


                                                                                            


[Epoch 28/30]
  Train -> Loss: 0.0504  Acc: 98.49%
  Val   -> Loss: 1.9368  Acc: 65.86%
  ✓ New best model saved! (Val Acc: 65.86%)


                                                                                            


[Epoch 29/30]
  Train -> Loss: 0.0466  Acc: 98.67%
  Val   -> Loss: 1.8468  Acc: 66.10%
  ✓ New best model saved! (Val Acc: 66.10%)


                                                                                            


[Epoch 30/30]
  Train -> Loss: 0.0419  Acc: 98.77%
  Val   -> Loss: 1.9034  Acc: 66.08%
  No improvement (1/7)

Training Complete!
  Best Val Acc: 66.10% (Epoch 29)
  Model saved to: ./checkpoints/best_model.pth
