# **Real-Time Facial Expression Detection using Few-Shot Learning**

Shri Hari S - PES1UG22AM154

Venkat Subramanian - PES1UG22AM188

Vishwanath Sridhar - PES1UG22AM194

Vismaya Vadana - PES1UG22AM195

Kaggle Data Source to Import FER2013 Dataset

In [None]:
# IMPORTANT: RUN THIS CELL IN ORDER TO IMPORT YOUR KAGGLE DATA SOURCES,
# THEN FEEL FREE TO DELETE THIS CELL.
# NOTE: THIS NOTEBOOK ENVIRONMENT DIFFERS FROM KAGGLE'S PYTHON
# ENVIRONMENT SO THERE MAY BE MISSING LIBRARIES USED BY YOUR
# NOTEBOOK.
import kagglehub
msambare_fer2013_path = kagglehub.dataset_download('msambare/fer2013')

print('Data source import complete.')


Downloading from https://www.kaggle.com/api/v1/datasets/download/msambare/fer2013?dataset_version_number=1...


100%|██████████| 60.3M/60.3M [00:00<00:00, 76.9MB/s]

Extracting files...





Data source import complete.


In [None]:
import os

# List the contents of the downloaded dataset directory
dataset_dir = msambare_fer2013_path
print("Contents of the dataset directory:")
print(os.listdir(dataset_dir))

# If you want to load the images, specify the path to the train and test folders
train_dir = os.path.join(dataset_dir, 'train')
test_dir = os.path.join(dataset_dir, 'test')

# List the emotion subdirectories in the training and testing sets
print("Train directory contents:")
print(os.listdir(train_dir))

print("Test directory contents:")
print(os.listdir(test_dir))


Contents of the dataset directory:
['train', 'test']
Train directory contents:
['sad', 'angry', 'fear', 'disgust', 'surprise', 'neutral', 'happy']
Test directory contents:
['sad', 'angry', 'fear', 'disgust', 'surprise', 'neutral', 'happy']


**Installing EasyFSL**

The `easyfsl` library simplifies the implementation of few-shot learning pipelines by providing ready-to-use modules for tasks like Prototypical Networks, Matching Networks, and Relation Networks. It includes utilities for dataset management, task sampling, and model evaluation in few-shot learning scenarios.


Import Statements

In [1]:
import os
import torch
from torch import nn, optim
from torch.utils.data import DataLoader, Subset
from torchvision import transforms
from torchvision.datasets import ImageFolder
from torchvision.models import resnet18
from tqdm import tqdm
from easyfsl.samplers import TaskSampler
from easyfsl.utils import plot_images, sliding_average
import cv2
import torch.nn.functional as F
from PIL import Image
import random
import numpy as np

train_dir = 'train'
test_dir = 'test'

Defining the constants for training the Meta-Learning Model

In [2]:
# Define constants
image_size = 48  # Change this to match the FER2013 image size
N_WAY = 5  # Number of classes in a task
N_SHOT = 10  # Number of images per class in the support set
N_QUERY = 15  # Number of images per class in the query set
N_EVALUATION_TASKS = 100
N_TRAINING_EPISODES = 60000

Data Augmentation and Preprocessing

In [3]:
# First, let's correct the transforms
train_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.RandomHorizontalFlip(),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2),
    transforms.RandomAffine(degrees=0, translate=(0.1, 0.1)),
    transforms.ToTensor(),  # Convert to tensor first
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])  # Then normalize
])

test_transforms = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

In [4]:
# Load the datasets
train_set = ImageFolder(root=train_dir, transform=train_transforms)
test_set = ImageFolder(root=test_dir, transform=test_transforms)

Creating Task Samplers for Training and Testing

In [5]:

train_set.get_labels = lambda: [instance[1] for instance in train_set.samples]
train_sampler = TaskSampler(
    train_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_TRAINING_EPISODES
)

test_set.get_labels = lambda: [instance[1] for instance in test_set.samples]
test_sampler = TaskSampler(
    test_set, n_way=N_WAY, n_shot=N_SHOT, n_query=N_QUERY, n_tasks=N_EVALUATION_TASKS
)

In [6]:
# Create data loaders
train_loader = DataLoader(
    train_set,
    batch_sampler=train_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=train_sampler.episodic_collate_fn,
)

test_loader = DataLoader(
    test_set,
    batch_sampler=test_sampler,
    num_workers=2,
    pin_memory=True,
    collate_fn=test_sampler.episodic_collate_fn,
)

**Prototypical Networks Architecture**


Implements Prototypical Networks with a backbone for feature extraction and a projection head for dimensionality reduction, computing class prototypes and classification scores based on Euclidean distances.

In [7]:
class PrototypicalNetworks(nn.Module):
    def __init__(self, backbone: nn.Module):
        super(PrototypicalNetworks, self).__init__()
        self.backbone = backbone
        self.projection = nn.Sequential(
            nn.Linear(512, 256),  # Assuming ResNet18's output is 512
            nn.ReLU(),
            nn.Linear(256, 128)
        )

    def forward(self, support_images, support_labels, query_images):
        z_support = self.projection(self.backbone(support_images))
        z_query = self.projection(self.backbone(query_images))

        # Infer the number of different classes from the labels of the support set
        n_way = len(torch.unique(support_labels))
        # Prototype i is the mean of all instances of features corresponding to labels == i
        z_proto = torch.cat(
            [z_support[support_labels == label].mean(0).unsqueeze(0) for label in range(n_way)]
        )

        # Compute the euclidean distance from queries to prototypes
        dists = torch.cdist(z_query, z_proto)

        # Transform distances into classification scores
        scores = -dists
        return scores

**Model Initialization**

Initialize the Prototypical Networks model using a ResNet18 backbone pretrained on ImageNet, with the fully connected layer replaced by a flattening operation.

In [8]:
# Initialize the model
convolutional_network = resnet18(weights='IMAGENET1K_V1')
convolutional_network.fc = nn.Flatten()
model = PrototypicalNetworks(convolutional_network)


**Evaluation Function**

Defines a task-based evaluation process for Prototypical Networks, calculating accuracy across multiple tasks using support and query sets.


In [None]:
# Evaluation function
def evaluate_on_one_task(support_images: torch.Tensor, support_labels: torch.Tensor,
                         query_images: torch.Tensor, query_labels: torch.Tensor) -> (int, int):
    return (
        torch.max(model(support_images.cuda(), support_labels.cuda(), query_images.cuda()).detach(), 1)[1]
        == query_labels.cuda()
    ).sum().item(), len(query_labels)

def evaluate(data_loader: DataLoader):
    total_predictions = 0
    correct_predictions = 0
    model.eval()
    with torch.no_grad():
        for support_images, support_labels, query_images, query_labels, _ in tqdm(data_loader):
            correct, total = evaluate_on_one_task(support_images, support_labels, query_images, query_labels)
            total_predictions += total
            correct_predictions += correct

    print(f"Model tested on {len(data_loader)} tasks. Accuracy: {(100 * correct_predictions / total_predictions):.2f}%")

# Evaluate the model
evaluate(test_loader)

100%|██████████| 100/100 [00:11<00:00,  8.83it/s]

Model tested on 100 tasks. Accuracy: 29.97%





**Training and Validation Loop**

Implements the training loop for Prototypical Networks with early stopping, a cosine annealing learning rate scheduler, and periodic validation for model performance monitoring and checkpoint saving.


In [None]:
# Training setup
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def fit(support_images: torch.Tensor, support_labels: torch.Tensor,
        query_images: torch.Tensor, query_labels: torch.Tensor) -> float:
    optimizer.zero_grad()
    classification_scores = model(support_images.cuda(), support_labels.cuda(), query_images.cuda())
    loss = criterion(classification_scores, query_labels.cuda())
    loss.backward()
    optimizer.step()
    return loss.item()


# Now let's set up the complete training loop with validation
# Initialize variables for early stopping
best_acc = 0
patience = 5
patience_counter = 0

# Create scheduler
optimizer = optim.AdamW(model.parameters(), lr=0.0003, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer,
                                                T_max=N_TRAINING_EPISODES,
                                                eta_min=1e-6)

# Training loop
log_update_frequency = 10
all_loss = []
model.train()

with tqdm(enumerate(train_loader), total=len(train_loader)) as tqdm_train:
    for episode_index, (support_images, support_labels, query_images, query_labels, _) in tqdm_train:
        # Training step
        loss_value = fit(support_images, support_labels, query_images, query_labels)
        all_loss.append(loss_value)

        # Update learning rate
        scheduler.step()

        # Update progress bar
        if episode_index % log_update_frequency == 0:
            tqdm_train.set_postfix(
                loss=sliding_average(all_loss, log_update_frequency),
                lr=scheduler.get_last_lr()[0]
            )

        # Validation step
        if episode_index % 1000 == 0:
            model.eval()
            current_acc = 0
            n_tasks = 0

            # Compute validation accuracy
            with torch.no_grad():
                for val_support_images, val_support_labels, val_query_images, val_query_labels, _ in test_loader:
                    correct, total = evaluate_on_one_task(
                        val_support_images,
                        val_support_labels,
                        val_query_images,
                        val_query_labels
                    )
                    current_acc += correct
                    n_tasks += total

            current_acc = (100 * current_acc / n_tasks)
            print(f"\nValidation accuracy at episode {episode_index}: {current_acc:.2f}%")

            # Early stopping check
            if current_acc > best_acc:
                best_acc = current_acc
                torch.save({
                    'episode': episode_index,
                    'model_state_dict': model.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'accuracy': best_acc,
                }, 'best_model.pth')
                patience_counter = 0
                print(f"New best accuracy: {best_acc:.2f}%")
            else:
                patience_counter += 1

            if patience_counter >= patience:
                print("Early stopping triggered!")
                break

            model.train()

# Load best model and evaluate
checkpoint = torch.load('best_model.pth')
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\nFinal evaluation with best model (acc: {checkpoint['accuracy']:.2f}%):")

  0%|          | 0/60000 [00:00<?, ?it/s, loss=1.61, lr=0.0003]


Validation accuracy at episode 0: 27.87%


  0%|          | 2/60000 [00:08<55:42:13,  3.34s/it, loss=1.61, lr=0.0003] 

New best accuracy: 27.87%


  2%|▏         | 1000/60000 [02:36<2:12:20,  7.43it/s, loss=0.997, lr=0.0003]


Validation accuracy at episode 1000: 57.36%


  2%|▏         | 1002/60000 [02:44<22:56:50,  1.40s/it, loss=0.997, lr=0.0003]

New best accuracy: 57.36%


  3%|▎         | 1999/60000 [05:11<2:03:38,  7.82it/s, loss=0.885, lr=0.000299]


Validation accuracy at episode 2000: 61.25%


  3%|▎         | 2002/60000 [05:18<16:56:29,  1.05s/it, loss=0.885, lr=0.000299]

New best accuracy: 61.25%


  5%|▌         | 3000/60000 [07:46<2:02:24,  7.76it/s, loss=0.774, lr=0.000298]


Validation accuracy at episode 3000: 65.59%


  5%|▌         | 3002/60000 [07:53<18:25:22,  1.16s/it, loss=0.774, lr=0.000298]

New best accuracy: 65.59%


  7%|▋         | 3999/60000 [10:22<1:57:09,  7.97it/s, loss=0.677, lr=0.000297]


Validation accuracy at episode 4000: 67.27%


  7%|▋         | 4002/60000 [10:30<15:42:17,  1.01s/it, loss=0.677, lr=0.000297]

New best accuracy: 67.27%


  8%|▊         | 4999/60000 [12:59<2:01:07,  7.57it/s, loss=0.763, lr=0.000295]


Validation accuracy at episode 5000: 68.20%


  8%|▊         | 5002/60000 [13:06<18:33:48,  1.22s/it, loss=0.763, lr=0.000295]

New best accuracy: 68.20%


 10%|█         | 6002/60000 [15:45<17:01:18,  1.13s/it, loss=0.748, lr=0.000293]


Validation accuracy at episode 6000: 67.49%


 12%|█▏        | 6999/60000 [18:14<1:55:28,  7.65it/s, loss=0.612, lr=0.00029]


Validation accuracy at episode 7000: 68.31%


 12%|█▏        | 7002/60000 [18:23<19:18:54,  1.31s/it, loss=0.612, lr=0.00029]

New best accuracy: 68.31%


 13%|█▎        | 8002/60000 [21:00<19:05:51,  1.32s/it, loss=0.598, lr=0.000287]


Validation accuracy at episode 8000: 68.31%


 15%|█▌        | 9000/60000 [23:29<1:47:41,  7.89it/s, loss=0.573, lr=0.000284]


Validation accuracy at episode 9000: 68.77%


 15%|█▌        | 9002/60000 [23:38<18:21:01,  1.30s/it, loss=0.573, lr=0.000284]

New best accuracy: 68.77%


 17%|█▋        | 10003/60000 [26:14<15:28:36,  1.11s/it, loss=0.536, lr=0.00028]


Validation accuracy at episode 10000: 68.13%


 18%|█▊        | 11003/60000 [28:52<13:51:34,  1.02s/it, loss=0.492, lr=0.000276]


Validation accuracy at episode 11000: 67.69%


 20%|█▉        | 11999/60000 [31:22<1:41:15,  7.90it/s, loss=0.512, lr=0.000271]


Validation accuracy at episode 12000: 69.37%


 20%|██        | 12002/60000 [31:31<16:24:08,  1.23s/it, loss=0.512, lr=0.000271]

New best accuracy: 69.37%


 22%|██▏       | 13003/60000 [34:09<15:05:34,  1.16s/it, loss=0.433, lr=0.000267]


Validation accuracy at episode 13000: 68.23%


 23%|██▎       | 13999/60000 [36:39<1:34:11,  8.14it/s, loss=0.414, lr=0.000262]


Validation accuracy at episode 14000: 69.49%


 23%|██▎       | 14003/60000 [36:48<13:35:17,  1.06s/it, loss=0.414, lr=0.000262]

New best accuracy: 69.49%


 25%|██▍       | 14999/60000 [39:19<2:18:23,  5.42it/s, loss=0.463, lr=0.000256]


Validation accuracy at episode 15000: 69.75%


 25%|██▌       | 15002/60000 [39:28<14:56:14,  1.20s/it, loss=0.463, lr=0.000256]

New best accuracy: 69.75%


 27%|██▋       | 16003/60000 [42:05<13:42:12,  1.12s/it, loss=0.427, lr=0.000251]


Validation accuracy at episode 16000: 67.99%


 28%|██▊       | 17002/60000 [44:43<17:47:14,  1.49s/it, loss=0.336, lr=0.000245]


Validation accuracy at episode 17000: 69.57%


 30%|██▉       | 17999/60000 [47:12<1:39:04,  7.07it/s, loss=0.328, lr=0.000238]


Validation accuracy at episode 18000: 70.09%


 30%|███       | 18002/60000 [47:21<13:57:35,  1.20s/it, loss=0.328, lr=0.000238]

New best accuracy: 70.09%


 32%|███▏      | 19003/60000 [49:58<11:09:17,  1.02it/s, loss=0.281, lr=0.000232]


Validation accuracy at episode 19000: 69.77%


 33%|███▎      | 20003/60000 [52:33<12:36:11,  1.13s/it, loss=0.288, lr=0.000225]


Validation accuracy at episode 20000: 67.75%


 35%|███▌      | 21003/60000 [55:10<14:33:15,  1.34s/it, loss=0.225, lr=0.000218]


Validation accuracy at episode 21000: 68.96%


 37%|███▋      | 22001/60000 [57:46<14:36:14,  1.38s/it, loss=0.278, lr=0.000211]


Validation accuracy at episode 22000: 70.03%


 38%|███▊      | 23000/60000 [1:00:13<1:19:51,  7.72it/s, loss=0.201, lr=0.000204]


Validation accuracy at episode 23000: 71.29%


 38%|███▊      | 23002/60000 [1:00:21<13:34:27,  1.32s/it, loss=0.201, lr=0.000204]

New best accuracy: 71.29%


 40%|████      | 24002/60000 [1:02:58<11:17:27,  1.13s/it, loss=0.148, lr=0.000197]


Validation accuracy at episode 24000: 69.89%


 42%|████▏     | 25002/60000 [1:05:32<9:49:51,  1.01s/it, loss=0.236, lr=0.000189] 


Validation accuracy at episode 25000: 69.83%


 43%|████▎     | 26002/60000 [1:08:06<8:04:11,  1.17it/s, loss=0.168, lr=0.000182]


Validation accuracy at episode 26000: 70.65%


 45%|████▌     | 27003/60000 [1:10:42<8:03:46,  1.14it/s, loss=0.144, lr=0.000174] 


Validation accuracy at episode 27000: 69.29%


 47%|████▋     | 28000/60000 [1:13:17<1:23:45,  6.37it/s, loss=0.137, lr=0.000166]


Validation accuracy at episode 28000: 69.69%
Early stopping triggered!

Final evaluation with best model (acc: 71.29%):



  checkpoint = torch.load('best_model.pth')


In [None]:
# Evaluate the model
evaluate(test_loader)


100%|██████████| 100/100 [00:08<00:00, 11.93it/s]

Model tested on 100 tasks. Accuracy: 69.23%





**Model Loading and Final Evaluation**

Dynamically sets the device (CPU or CUDA), loads the best model checkpoint, and prepares for final evaluation with the stored accuracy.


In [9]:
# Dynamically set the device to CPU or CUDA if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load best model and evaluate
checkpoint = torch.load('model.pth', map_location=device)  # Map to the appropriate device
model.load_state_dict(checkpoint['model_state_dict'])
print(f"\nFinal evaluation with best model (acc: {checkpoint['accuracy']:.2f}%):")


Final evaluation with best model (acc: 71.29%):


  checkpoint = torch.load('model.pth', map_location=device)  # Map to the appropriate device


**Live Face Recognition with Prototypical Networks**

Implements a real-time face recognition system using a webcam, leveraging Prototypical Networks to classify faces based on embeddings and prototypes with smoothed predictions for robust inference.


In [24]:
# Load pre-trained face detector (Haar Cascade or MTCNN)
face_cascade = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml')

# Preprocessing for the query image (face detection)
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Ensure resizing to match the input size used during training
    transforms.ToTensor(),  # Convert to tensor
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])  # Standard normalization
])


# Create a balanced support set with a max of 100 images per class
class_indices = {}
for idx, (_, label) in enumerate(train_set.imgs):
    if label not in class_indices:
        class_indices[label] = []
    class_indices[label].append(idx)

# Limit to a maximum of 100 images per class
balanced_indices = []
for label, indices in class_indices.items():
    balanced_indices.extend(random.sample(indices, min(100, len(indices))))

balanced_train_set = Subset(train_set, balanced_indices)
train_loader = DataLoader(balanced_train_set, batch_size=8, shuffle=True)

# Initialize variables to store the support set embeddings and prototypes
support_embeddings = None
prototypes = None

# Start video capture from the webcam
video_capture = cv2.VideoCapture(0)

if not video_capture.isOpened():
    print("Error: Could not access webcam.")
    exit()

# Variable to keep track of predictions over time (smoothing)
previous_predictions = []
prediction_buffer_size = 10

# Smooth prediction using a moving average or majority vote
def smooth_prediction(predictions, buffer_size):
    return np.bincount(predictions[-buffer_size:]).argmax()

while True:
    # Capture frame-by-frame
    ret, frame = video_capture.read()
    if not ret:
        print("Error: Could not read frame.")
        break

    # Convert frame to grayscale for the detector
    gray_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY)

    # Detect faces
    faces = face_cascade.detectMultiScale(gray_frame, scaleFactor=1.1, minNeighbors=5, minSize=(30, 30))

    # If faces are detected
    for (x, y, w, h) in faces:
        face = frame[y:y + h, x:x + w]

        # Convert numpy.ndarray (OpenCV) to PIL.Image
        face_image = Image.fromarray(cv2.cvtColor(face, cv2.COLOR_BGR2RGB))

        # Apply transformation
        face_tensor = transform(face_image).unsqueeze(0).to(device)  # Add batch dimension and move to device

        # Get a small batch of images as the support set (only once per video loop)
        if support_embeddings is None:  # Initialize the support set embeddings and prototypes
            support_images, support_labels = next(iter(train_loader))  # Get the first batch
            support_images, support_labels = support_images.to(device), support_labels.to(device)

            # Extract embeddings for the support set
            with torch.no_grad():
                support_embeddings = model.backbone(support_images)  # Backbone returns the feature embeddings
                support_embeddings = model.projection(support_embeddings)

                # Calculate prototypes for the support set (mean of support embeddings per class)
                prototypes = []
                for label in torch.unique(support_labels):
                    class_embeddings = support_embeddings[support_labels == label]
                    prototypes.append(class_embeddings.mean(0).unsqueeze(0))
                prototypes = torch.cat(prototypes)

        # Forward pass through the model for query image (live face)
        with torch.no_grad():
            query_embedding = model.backbone(face_tensor)
            query_embedding = model.projection(query_embedding)

            # Calculate distances between query embedding and support prototypes (e.g., Euclidean distance)
            distances = F.pairwise_distance(query_embedding, prototypes)
            predicted_class = distances.argmin().item()

            # Append the current prediction to previous_predictions for smoothing
            previous_predictions.append(predicted_class)

            # Use the moving average of the last 10 predictions for smoothing
            smoothed_prediction = smooth_prediction(previous_predictions, prediction_buffer_size)

            # Display the smoothed predicted class on the live video
            cv2.putText(frame, f"Predicted Class: {smoothed_prediction}", (x, y - 10),
                        cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 2)

        # Draw rectangle around the detected face
        cv2.rectangle(frame, (x, y), (x + w, y + h), (255, 0, 0), 2)

    # Display the resulting frame
    cv2.imshow('Video - Face Detection with Model Inference', frame)

    # Break the loop if 'q' is pressed
    if cv2.waitKey(1) & 0xFF == ord('q'):
        break

# Release the capture and close all OpenCV windows
video_capture.release()
cv2.destroyAllWindows()
