# Data Exploration

This notebook explores the face recognition datasets used in the project:
- MS1MV2 training dataset
- Evaluation datasets (LFW, CFP-FP, AgeDB-30, etc.)

## Objectives
1. Load and visualize datasets
2. Understand data structure and statistics
3. Prepare data for training and evaluation


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
import yaml

from data.dataset import MS1MV2Dataset, LFWDataset
from data.dataloader import get_ms1mv2_dataloader, get_lfw_dataloader

# Load configuration
with open('../config.yaml', 'r') as f:
    config = yaml.safe_load(f)

print("Configuration loaded successfully")
print(f"Device: {'cuda' if torch.cuda.is_available() else 'cpu'}")


## 1. MS1MV2 Dataset Exploration

In [None]:
# Load MS1MV2 dataset
ms1mv2_path = config['data']['ms1mv2']['path']

try:
    dataset = MS1MV2Dataset(ms1mv2_path, is_training=True)
    print(f"MS1MV2 Dataset loaded successfully")
    print(f"Total samples: {len(dataset)}")
    print(f"Total identities: {len(dataset.identities)}")
    
    # Sample some images
    fig, axes = plt.subplots(2, 5, figsize=(15, 6))
    for i in range(10):
        img, label = dataset[i]
        axes[i//5, i%5].imshow(img.permute(1, 2, 0) * 0.5 + 0.5)  # Denormalize
        axes[i//5, i%5].set_title(f"Identity: {label}")
        axes[i//5, i%5].axis('off')
    plt.tight_layout()
    plt.show()
    
except Exception as e:
    print(f"Error loading MS1MV2 dataset: {e}")
    print("Please ensure the dataset is downloaded and path is correct in config.yaml")


## 2. Evaluation Datasets

In [None]:
# Explore LFW dataset
lfw_path = config['data']['evaluation']['lfw']
lfw_pairs_file = os.path.join(lfw_path, 'pairs.txt')

try:
    if os.path.exists(lfw_path):
        dataset = LFWDataset(lfw_path, lfw_pairs_file)
        print(f"LFW Dataset loaded successfully")
        print(f"Total pairs: {len(dataset)}")
        
        # Show some pairs
        fig, axes = plt.subplots(3, 4, figsize=(12, 9))
        for i in range(6):
            img1, img2, is_same = dataset[i]
            axes[i, 0].imshow(img1.permute(1, 2, 0) * 0.5 + 0.5)
            axes[i, 0].set_title("Image 1")
            axes[i, 0].axis('off')
            axes[i, 1].imshow(img2.permute(1, 2, 0) * 0.5 + 0.5)
            axes[i, 1].set_title(f"Image 2\nSame: {is_same}")
            axes[i, 1].axis('off')
        plt.tight_layout()
        plt.show()
    else:
        print(f"LFW dataset not found at {lfw_path}")
except Exception as e:
    print(f"Error loading LFW dataset: {e}")
