In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import insightface
from insightface.app import FaceAnalysis
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from retinaface import RetinaFace
import cv2
import numpy as np
from PIL import Image
import os
from glob import glob
from sklearn.model_selection import StratifiedShuffleSplit
from tqdm import tqdm
from torchvision.models import efficientnet_b0


# ============================================================
#                1. RetinaFace Feature Extractor
# ============================================================

class RetinaFaceFeatureExtractor:
    def __init__(self):
        from insightface.app import FaceAnalysis

        self.detector = FaceAnalysis(
            name="buffalo_l",
            providers=['CPUExecutionProvider'],
            root="C:/adam/AMIT_Diploma/grad_project/insightface_models"
        )
        self.detector.prepare(ctx_id=0, det_size=(640, 640))

    def detect_and_extract(self, image):
        """Returns (feature_vector, cropped_face_image_or_None)"""

        # Use InsightFace's get method instead of predict
        faces = self.detector.get(image)

        if len(faces) == 0:
            return None, None

        face = faces[0]
        
        # Get bounding box and ensure it's within image bounds
        bbox = face.bbox.astype(int)
        h, w = image.shape[:2]
        
        x1 = max(0, bbox[0])
        y1 = max(0, bbox[1])
        x2 = min(w, bbox[2])
        y2 = min(h, bbox[3])
        
        # Check if bbox is valid
        if x2 <= x1 or y2 <= y1:
            return None, None
        
        # Crop face
        face_img = image[y1:y2, x1:x2]
        
        # Double check the cropped image is not empty
        if face_img.size == 0:
            return None, None

        # Resize input face to detector backbone input size (224 is ok)
        face_resized = cv2.resize(face_img, (224, 224))
        face_rgb = cv2.cvtColor(face_resized, cv2.COLOR_BGR2RGB)
        face_tensor = torch.tensor(face_rgb).permute(2, 0, 1).float() / 255.
        face_tensor = face_tensor.unsqueeze(0)

        # Extract features using the face embedding (512-dim)
        # Instead of trying to access backbone, use the embedding
        features = torch.tensor(face.embedding)

        return features, face_img  # feature_vector, cropped face


# ============================================================
#                2. Pre-extracted Feature Dataset 
# ============================================================

class PreExtractedFeatureDataset(Dataset):
    """Dataset that works with pre-extracted features"""
    def __init__(self, features, labels):
        self.features = features
        self.labels = labels

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


# ============================================================
#            3. EfficientNet Classifier (Feature Input)
# ============================================================

class EfficientNetFeatureClassifier(nn.Module):
    def __init__(self, feature_dim=512, num_classes=7):
        super().__init__()

        # Load efficientnet_b0
        base = efficientnet_b0(weights='IMAGENET1K_V1')

        # Replace the entire feature extractor
        # LN → ReLU → Dropout → classification
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 1024),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.4),
            nn.Linear(512, num_classes)
        )

    def forward(self, x):
        return self.classifier(x)


# ============================================================
#                      4. Train / Val Functions
# ============================================================

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct = 0, 0

    for features, labels in tqdm(loader, desc="Training", leave=False):
        features, labels = features.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(features)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (outputs.argmax(1) == labels).sum().item()

    return total_loss / len(loader), correct / len(loader.dataset)


def validate(model, loader, criterion, device):
    model.eval()
    total_loss, correct = 0, 0

    with torch.no_grad():
        for features, labels in tqdm(loader, desc="Validating", leave=False):
            features, labels = features.to(device), labels.to(device)
            outputs = model(features)
            loss = criterion(outputs, labels)

            total_loss += loss.item()
            correct += (outputs.argmax(1) == labels).sum().item()

    return total_loss / len(loader), correct / len(loader.dataset)


# ============================================================
#                  5. Feature Pre-extraction
# ============================================================

def extract_all_features(image_paths, labels, extractor):
    """Pre-extract features from all images once"""
    features_list = []
    valid_labels = []
    skipped = 0
    
    print("Pre-extracting features from all images...")
    for img_path, label in tqdm(zip(image_paths, labels), total=len(image_paths)):
        image = cv2.imread(img_path)
        
        # Check if image was read successfully
        if image is None:
            skipped += 1
            continue
        
        features, face = extractor.detect_and_extract(image)
        
        if features is None:
            # If detection fails, use zero-vector
            features = torch.zeros(512)
        
        features_list.append(features.float())
        valid_labels.append(label)
    
    if skipped > 0:
        print(f"Skipped {skipped} corrupted/unreadable images")
    
    # Stack all features into a single tensor
    features_tensor = torch.stack(features_list)
    labels_tensor = torch.tensor(valid_labels, dtype=torch.long)
    
    return features_tensor, labels_tensor


# ============================================================
#                        6. Main Script
# ============================================================

def main():
    data_path = "C:/adam/AMIT_Diploma/grad_project/archive (1)/train"

    image_paths = glob(os.path.join(data_path, "*/*.png"))
    labels = [os.path.basename(os.path.dirname(p)) for p in image_paths]

    # Convert class names → integer IDs
    class_to_idx = {c: i for i, c in enumerate(sorted(set(labels)))}
    labels = [class_to_idx[c] for c in labels]

    # Stratified split
    sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=42)
    train_idx, val_idx = next(sss.split(image_paths, labels))

    train_paths = [image_paths[i] for i in train_idx]
    val_paths = [image_paths[i] for i in val_idx]

    train_labels = [labels[i] for i in train_idx]
    val_labels = [labels[i] for i in val_idx]

    # Initialize extractor once
    print("Initializing RetinaFace extractor...")
    extractor = RetinaFaceFeatureExtractor()

    # Pre-extract all features
    print("\nExtracting training features...")
    train_features, train_labels_tensor = extract_all_features(train_paths, train_labels, extractor)
    
    print("\nExtracting validation features...")
    val_features, val_labels_tensor = extract_all_features(val_paths, val_labels, extractor)

    # Create datasets with pre-extracted features
    train_dataset = PreExtractedFeatureDataset(train_features, train_labels_tensor)
    val_dataset = PreExtractedFeatureDataset(val_features, val_labels_tensor)

    # DataLoaders (can use more workers now since no heavy processing in __getitem__)
    train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, num_workers=0)

    # Model
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    model = EfficientNetFeatureClassifier(feature_dim=512).to(device)

    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    # Training
    print("\nStarting training...")
    for epoch in range(15):
        print(f"\nEpoch {epoch+1}/15")

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)

        print(f"Train Loss {train_loss:.4f}  |  Acc {train_acc:.4f}")
        print(f"Val   Loss {val_loss:.4f}  |  Acc {val_acc:.4f}")

    # Save the trained model
    torch.save(model.state_dict(), "fer_model_final.pth")
    print("\nModel saved as 'fer_model_final.pth'")


if __name__ == "__main__":
    main()

Initializing RetinaFace extractor...
download_path: C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l
Downloading C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l.zip from https://github.com/deepinsight/insightface/releases/download/v0.7/buffalo_l.zip...


100%|██████████| 281857/281857 [00:52<00:00, 5347.75KB/s]


Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l\1k3d68.onnx landmark_3d_68 ['None', 3, 192, 192] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l\2d106det.onnx landmark_2d_106 ['None', 3, 192, 192] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l\det_10g.onnx detection [1, 3, '?', '?'] 127.5 128.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model: C:/adam/AMIT_Diploma/grad_project/insightface_models\models\buffalo_l\genderage.onnx genderage ['None', 3, 96, 96] 0.0 1.0
Applied providers: ['CPUExecutionProvider'], with options: {'CPUExecutionProvider': {}}
find model

100%|██████████| 22967/22967 [29:42<00:00, 12.89it/s]  



Extracting validation features...
Pre-extracting features from all images...


100%|██████████| 5742/5742 [07:24<00:00, 12.91it/s]



Using device: cpu

Starting training...

Epoch 1/15


                                                            

Train Loss 1.8195  |  Acc 0.2527
Val   Loss 1.8119  |  Acc 0.2534

Epoch 2/15


                                                               

Train Loss 1.8041  |  Acc 0.2587
Val   Loss 1.8142  |  Acc 0.2532

Epoch 3/15


                                                            

Train Loss 1.7927  |  Acc 0.2625
Val   Loss 1.8355  |  Acc 0.2546

Epoch 4/15


                                                               

Train Loss 1.7933  |  Acc 0.2632
Val   Loss 1.8459  |  Acc 0.2539

Epoch 5/15


                                                               

Train Loss 1.7892  |  Acc 0.2638
Val   Loss 1.8564  |  Acc 0.2536

Epoch 6/15


                                                            

Train Loss 1.7902  |  Acc 0.2642
Val   Loss 1.8680  |  Acc 0.2539

Epoch 7/15


                                                               

Train Loss 1.7885  |  Acc 0.2646
Val   Loss 1.8887  |  Acc 0.2544

Epoch 8/15


                                                            

Train Loss 1.7851  |  Acc 0.2650
Val   Loss 1.8864  |  Acc 0.2539

Epoch 9/15


                                                               

Train Loss 1.7826  |  Acc 0.2657
Val   Loss 1.8966  |  Acc 0.2541

Epoch 10/15


                                                               

Train Loss 1.7839  |  Acc 0.2657
Val   Loss 1.9101  |  Acc 0.2536

Epoch 11/15


                                                               

Train Loss 1.7877  |  Acc 0.2653
Val   Loss 1.9036  |  Acc 0.2541

Epoch 12/15


                                                               

Train Loss 1.7862  |  Acc 0.2652
Val   Loss 1.9270  |  Acc 0.2534

Epoch 13/15


                                                               

Train Loss 1.7866  |  Acc 0.2654
Val   Loss 1.9319  |  Acc 0.2541

Epoch 14/15


                                                            

Train Loss 1.7911  |  Acc 0.2651
Val   Loss 1.9397  |  Acc 0.2532

Epoch 15/15


                                                            

Train Loss 1.7826  |  Acc 0.2658
Val   Loss 1.9482  |  Acc 0.2539

Model saved as 'fer_model_final.pth'
