In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Subset
from torchvision import transforms, datasets, models
from retinaface import RetinaFace
import cv2
import numpy as np
from PIL import Image
import os
from glob import glob
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
import pickle

# ============================================================================
# STEP 1: ONE-TIME PREPROCESSING - Run this ONCE and save results
# ============================================================================

def preprocess_and_save_faces(data_path, output_path, target_size=(48, 48)):
    """
    Detect faces once and save cropped faces to disk.
    This should be run ONCE before training.
    """
    os.makedirs(output_path, exist_ok=True)
    
    for split in ['train', 'test']:
        split_path = os.path.join(data_path, split)
        output_split_path = os.path.join(output_path, split)
        
        # Get all emotion classes
        classes = os.listdir(split_path)
        
        for emotion_class in classes:
            class_path = os.path.join(split_path, emotion_class)
            output_class_path = os.path.join(output_split_path, emotion_class)
            os.makedirs(output_class_path, exist_ok=True)
            
            image_files = glob(os.path.join(class_path, "*.png"))
            
            print(f"Processing {split}/{emotion_class}: {len(image_files)} images")
            
            for img_path in tqdm(image_files, desc=f"{split}/{emotion_class}"):
                img = Image.open(img_path).convert("RGB")
                img_np = np.array(img)
                
                # Detect face
                faces = RetinaFace.detect_faces(img_np)
                
                if isinstance(faces, dict) and len(faces) > 0:
                    key = list(faces.keys())[0]
                    x1, y1, x2, y2 = faces[key]["facial_area"]
                    crop = img_np[y1:y2, x1:x2]
                    crop = cv2.resize(crop, target_size)
                else:
                    # No face detected - use center crop as fallback
                    h, w = img_np.shape[:2]
                    center_y, center_x = h // 2, w // 2
                    half_size = min(h, w) // 2
                    crop = img_np[
                        max(0, center_y - half_size):min(h, center_y + half_size),
                        max(0, center_x - half_size):min(w, center_x + half_size)
                    ]
                    crop = cv2.resize(crop, target_size)
                
                # Save cropped face
                crop_pil = Image.fromarray(crop)
                output_file = os.path.join(output_class_path, os.path.basename(img_path))
                crop_pil.save(output_file)
    
    print("\nPreprocessing complete! Cropped faces saved to:", output_path)


def extract_and_save_features(preprocessed_path, output_file, device='cuda'):
    """
    Extract features from all preprocessed faces ONCE and save to disk.
    This should be run ONCE after preprocessing faces.
    """
    # Load pre-trained feature extractor
    feature_extractor = models.efficientnet_b0(
        weights=models.EfficientNet_B0_Weights.IMAGENET1K_V1
    )
    feature_extractor = nn.Sequential(*list(feature_extractor.children())[:-1])  # Remove classifier
    feature_extractor.eval()
    feature_extractor.to(device)
    
    # Transformation for feature extraction
    transform = transforms.Compose([
        transforms.Resize((224, 224)),  # EfficientNet input size
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
    
    features_dict = {'train': [], 'test': []}
    
    for split in ['train', 'test']:
        dataset = datasets.ImageFolder(
            os.path.join(preprocessed_path, split),
            transform=transform
        )
        dataloader = DataLoader(dataset, batch_size=32, shuffle=False, num_workers=4)
        
        split_features = []
        split_labels = []
        
        print(f"\nExtracting features for {split} set...")
        with torch.no_grad():
            for images, labels in tqdm(dataloader):
                images = images.to(device)
                features = feature_extractor(images)
                features = features.squeeze(-1).squeeze(-1)  # Remove spatial dimensions
                
                split_features.append(features.cpu())
                split_labels.append(labels)
        
        features_dict[split] = {
            'features': torch.cat(split_features),
            'labels': torch.cat(split_labels),
            'classes': dataset.classes
        }
    
    # Save features
    torch.save(features_dict, output_file)
    print(f"\nFeatures saved to: {output_file}")
    print(f"Feature dimension: {features_dict['train']['features'].shape[1]}")


# ============================================================================
# STEP 2: FAST TRAINING DATASET - Uses pre-extracted features
# ============================================================================

class PreExtractedFeatureDataset(Dataset):
    """
    Fast dataset that loads pre-extracted features from memory/disk.
    """
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels
    
    def __len__(self):
        return len(self.labels)
    
    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# ============================================================================
# STEP 3: TRAINING FUNCTIONS
# ============================================================================

class FER_EfficientNetClassifier(nn.Module):
    def __init__(self, feature_dim, num_classes=7):
        super().__init__()
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(256, num_classes)
        )
    
    def forward(self, x):
        return self.classifier(x)


def train_model(model, train_loader, val_loader, criterion, optimizer, 
                num_epochs, device='cuda', patience=5):
    """
    Fast training function with early stopping.
    """
    model.to(device)
    best_val_acc = 0.0
    patience_counter = 0
    
    for epoch in range(num_epochs):
        print(f"\n--- Epoch {epoch+1}/{num_epochs} ---")
        
        # Training phase
        model.train()
        train_loss = 0.0
        train_correct = 0
        train_total = 0
        
        train_bar = tqdm(train_loader, desc='Training')
        for features, labels in train_bar:
            features, labels = features.to(device), labels.to(device)
            
            optimizer.zero_grad()
            outputs = model(features)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item() * features.size(0)
            _, predicted = torch.max(outputs, 1)
            train_correct += (predicted == labels).sum().item()
            train_total += labels.size(0)
            
            train_bar.set_postfix({
                'Loss': f'{loss.item():.4f}',
                'Acc': f'{train_correct/train_total:.4f}'
            })
        
        train_loss /= train_total
        train_acc = train_correct / train_total
        
        # Validation phase
        model.eval()
        val_loss = 0.0
        val_correct = 0
        val_total = 0
        
        with torch.no_grad():
            for features, labels in tqdm(val_loader, desc='Validation'):
                features, labels = features.to(device), labels.to(device)
                outputs = model(features)
                loss = criterion(outputs, labels)
                
                val_loss += loss.item() * features.size(0)
                _, predicted = torch.max(outputs, 1)
                val_correct += (predicted == labels).sum().item()
                val_total += labels.size(0)
        
        val_loss /= val_total
        val_acc = val_correct / val_total
        
        print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc*100:.2f}%")
        print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc*100:.2f}%")
        
        # Early stopping
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            patience_counter = 0
            torch.save(model.state_dict(), 'best_model.pth')
            print(f"✓ New best model saved! (Val Acc: {val_acc*100:.2f}%)")
        else:
            patience_counter += 1
            if patience_counter >= patience:
                print(f"\nEarly stopping triggered after {epoch+1} epochs")
                break
    
    return model


# ============================================================================
# STEP 4: MAIN TRAINING SCRIPT
# ============================================================================

def main():
    # Paths
    DATA_PATH = "C:/adam/AMIT_Diploma/grad_project/archive (1)"
    PREPROCESSED_PATH = "C:/adam/AMIT_Diploma/grad_project/preprocessed_faces"
    FEATURES_FILE = "C:/adam/AMIT_Diploma/grad_project/extracted_features.pt"
    
    # ========================================================================
    # OPTION A: First time setup (run once)
    # ========================================================================
    # Uncomment these lines on first run:
    print("Step 1: Preprocessing faces...")
    preprocess_and_save_faces(DATA_PATH, PREPROCESSED_PATH)
    # 
    print("\nStep 2: Extracting features...")
    extract_and_save_features(PREPROCESSED_PATH, FEATURES_FILE)
    
    # ========================================================================
    # OPTION B: Fast training (run every time after preprocessing)
    # ========================================================================
    print("Loading pre-extracted features...")
    features_dict = torch.load(FEATURES_FILE)
    
    # Get feature dimension
    feature_dim = features_dict['train']['features'].shape[1]
    print(f"Feature dimension: {feature_dim}")
    
    # Split train into train/val
    train_features = features_dict['train']['features']
    train_labels = features_dict['train']['labels']
    
    # Stratified split
    splitter = StratifiedShuffleSplit(n_splits=1, test_size=0.1, random_state=42)
    train_idx, val_idx = next(splitter.split(
        np.arange(len(train_labels)),
        train_labels.numpy()
    ))
    
    # Create datasets
    train_dataset = PreExtractedFeatureDataset(
        train_features[train_idx],
        train_labels[train_idx]
    )
    val_dataset = PreExtractedFeatureDataset(
        train_features[val_idx],
        train_labels[val_idx]
    )
    test_dataset = PreExtractedFeatureDataset(
        features_dict['test']['features'],
        features_dict['test']['labels']
    )
    
    # Create dataloaders
    train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)
    test_loader = DataLoader(test_dataset, batch_size=64, shuffle=False)
    
    print(f"Train samples: {len(train_dataset)}")
    print(f"Val samples: {len(val_dataset)}")
    print(f"Test samples: {len(test_dataset)}")
    
    # Initialize model
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    model = FER_EfficientNetClassifier(feature_dim=feature_dim, num_classes=7)
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001)
    
    # Train
    print("\nStarting training...")
    trained_model = train_model(
        model, train_loader, val_loader, criterion, optimizer,
        num_epochs=50, device=device, patience=5
    )
    
    print("\nTraining complete!")


if __name__ == "__main__":
    main()

Step 1: Preprocessing faces...
Processing train/angry: 3995 images


train/angry:   0%|          | 0/3995 [00:00<?, ?it/s]

train/angry:  32%|███▏      | 1284/3995 [49:13<1:43:56,  2.30s/it]


KeyboardInterrupt: 