# Model Loading and Testing

This notebook demonstrates how to load pretrained face recognition models:
- InsightFace models (buffalo_l)
- iResNet models

## Objectives
1. Load pretrained models
2. Test model inference
3. Extract face embeddings


In [None]:
import sys
import os
sys.path.append(os.path.join(os.path.dirname(os.getcwd()), 'src'))

import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml

from utils.model_loader import load_insightface_model, load_model_from_config
from data.dataset import MS1MV2Dataset
from torch.utils.data import DataLoader

# Load configuration
with open('../config.yaml', 'r') as f:
    config = yaml.safe_load(f)

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")


## 1. Load InsightFace Model (buffalo_l)

In [None]:
# Load InsightFace model
try:
    if config['model'].get('use_insightface', False):
        model_name = config['model'].get('insightface_model', 'buffalo_l')
        print(f"Loading InsightFace model: {model_name}")
        
        # Load using InsightFace
        app = load_insightface_model(model_name)
        print("InsightFace model loaded successfully")
        
        # Note: InsightFace models work differently - they use FaceAnalysis app
        # For our purposes, we'll create a wrapper or use iResNet models
        print("\nNote: For linearization, we recommend using iResNet models")
        print("Set use_insightface: false in config.yaml to use iResNet")
    else:
        print("InsightFace not configured. Using iResNet models instead.")
except Exception as e:
    print(f"Error loading InsightFace model: {e}")
    print("Falling back to iResNet models")


## 2. Load iResNet Model

In [None]:
# Load model from config
try:
    model = load_model_from_config(config)
    model = model.to(device)
    model.eval()
    print("Model loaded successfully")
    print(f"Model type: {type(model)}")
    
    # Test with dummy input
    dummy_input = torch.randn(1, 3, 112, 112).to(device)
    with torch.no_grad():
        embeddings = model.extract_features(dummy_input)
        print(f"Embedding shape: {embeddings.shape}")
        print(f"Embedding norm: {embeddings.norm(dim=1).item():.4f}")
        
except Exception as e:
    print(f"Error loading model: {e}")
    print("Please check model configuration in config.yaml")


## 3. Test on Sample Images

In [None]:
# Load sample images from dataset
try:
    ms1mv2_path = config['data']['ms1mv2']['path']
    dataset = MS1MV2Dataset(ms1mv2_path, is_training=False)
    dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
    
    # Get a batch
    images, labels = next(iter(dataloader))
    images = images.to(device)
    
    # Extract embeddings
    with torch.no_grad():
        embeddings = model.extract_features(images)
    
    # Visualize
    fig, axes = plt.subplots(2, 4, figsize=(16, 8))
    for i in range(4):
        axes[0, i].imshow(images[i].cpu().permute(1, 2, 0) * 0.5 + 0.5)
        axes[0, i].set_title(f"Identity: {labels[i].item()}")
        axes[0, i].axis('off')
        
        # Show embedding (first 10 dimensions)
        emb_plot = embeddings[i].cpu().numpy()[:10]
        axes[1, i].bar(range(10), emb_plot)
        axes[1, i].set_title(f"Embedding (first 10 dims)")
        axes[1, i].set_ylim(-1, 1)
    
    plt.tight_layout()
    plt.show()
    
    # Compute similarity between images
    embeddings_norm = torch.nn.functional.normalize(embeddings, p=2, dim=1)
    similarity_matrix = embeddings_norm @ embeddings_norm.T
    print("\nSimilarity matrix:")
    print(similarity_matrix.cpu().numpy())
    
except Exception as e:
    print(f"Error testing model: {e}")
