# 01 - Data Exploration

This notebook explores the FER-2013 dataset for facial emotion recognition.

## Contents
1. Load and inspect dataset
2. Visualize label distribution
3. Display sample images
4. Analyze image statistics

In [None]:
import sys
sys.path.append('..')

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path

from src.data import FERDataset, get_test_transforms
from src.utils.visualization import plot_label_distribution

%matplotlib inline
plt.style.use('seaborn-v0_8-whitegrid')

## 1. Load Dataset

In [None]:
# Load training dataset
DATA_DIR = '../data'

train_dataset = FERDataset(DATA_DIR, split='train', transform=get_test_transforms())
test_dataset = FERDataset(DATA_DIR, split='test', transform=get_test_transforms())

print(f'Training samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

## 2. Label Distribution

In [None]:
# Get label distribution
train_dist = train_dataset.get_label_distribution()
print('Training label distribution:')
for emotion, count in train_dist.items():
    print(f'  {emotion}: {count}')

# Plot distribution
plot_label_distribution(train_dist, title='FER-2013 Training Set Distribution')

## 3. Sample Images

In [None]:
from src.data.dataset import EMOTION_LABELS

# Display sample images for each emotion
fig, axes = plt.subplots(2, 7, figsize=(16, 5))

for emotion_idx in range(7):
    # Find samples for this emotion
    indices = [i for i, label in enumerate(train_dataset.labels) if label == emotion_idx]
    
    for row in range(2):
        idx = indices[row]
        image, label = train_dataset[idx]
        
        ax = axes[row, emotion_idx]
        ax.imshow(image.squeeze(), cmap='gray')
        if row == 0:
            ax.set_title(EMOTION_LABELS[emotion_idx])
        ax.axis('off')

plt.suptitle('Sample Images for Each Emotion', fontsize=14)
plt.tight_layout()
plt.show()

## 4. Image Statistics

In [None]:
# Calculate mean and std
import torch
from torch.utils.data import DataLoader

loader = DataLoader(train_dataset, batch_size=100, shuffle=False)

mean = 0.
std = 0.
n_samples = 0

for images, _ in loader:
    batch_samples = images.size(0)
    images = images.view(batch_samples, images.size(1), -1)
    mean += images.mean(2).sum(0)
    std += images.std(2).sum(0)
    n_samples += batch_samples

mean /= n_samples
std /= n_samples

print(f'Dataset mean: {mean.item():.4f}')
print(f'Dataset std: {std.item():.4f}')