# Prosopo Training Notebook

Train a face embedding model from scratch using ArcFace loss.

**Target:** 99%+ accuracy on LFW benchmark

## 1. Setup & Mount Drive

‚ö†Ô∏è **CRITICAL:** Mount Drive FIRST to ensure checkpoints survive session disconnects.

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Create checkpoint directory
import os
os.makedirs('/content/drive/MyDrive/prosopo/checkpoints', exist_ok=True)
print('‚úÖ Drive mounted and checkpoint directory created')

In [None]:
# Install dependencies
!pip install -q torch torchvision
!pip install -q albumentations facenet-pytorch scikit-image
!pip install -q tqdm scikit-learn

# Clone Prosopo repo
!git clone https://github.com/YOUR_USERNAME/prosopo.git /content/prosopo
%cd /content/prosopo

print('‚úÖ Dependencies installed')

## 2. Download Data

Downloads to `/content` (fast local SSD), NOT to Drive (slow).

In [None]:
# Download CASIA-WebFace (aligned version)
# NOTE: Replace with actual download link for aligned dataset
!gdown YOUR_ALIGNED_CASIA_LINK -O /content/casia_aligned.zip
!unzip -q /content/casia_aligned.zip -d /content/data/

print('‚úÖ CASIA-WebFace downloaded')

In [None]:
# Download LFW for evaluation
!wget -q http://vis-www.cs.umass.edu/lfw/lfw.tgz -O /content/lfw.tgz
!tar -xzf /content/lfw.tgz -C /content/data/

# Download pairs.txt
!wget -q http://vis-www.cs.umass.edu/lfw/pairs.txt -O /content/data/pairs.txt

print('‚úÖ LFW downloaded')

## 3. Configure Training

In [None]:
import sys
sys.path.insert(0, '/content/prosopo')

from prosopo.training import TrainingConfig, Trainer

# Training configuration
config = TrainingConfig(
    # Data paths
    data_root='/content/data/CASIA-WebFace-aligned',
    lfw_root='/content/data/lfw',
    lfw_pairs_path='/content/data/pairs.txt',
    
    # Model
    backbone='resnet50',
    embedding_dim=512,
    pretrained=True,
    
    # ArcFace
    arcface_scale=64.0,
    arcface_margin=0.5,
    
    # Training
    batch_size=128,
    accumulation_steps=2,
    epochs=25,
    lr=0.1,
    num_workers=2,
    
    # Checkpointing (to Drive!)
    checkpoint_dir='/content/drive/MyDrive/prosopo/checkpoints',
    save_every=1,
    
    # Validation epochs
    val_epochs=[10, 15, 20, 25],
    
    # Resume from checkpoint (set epoch number if resuming)
    resume_from=None,  # e.g., '/content/drive/MyDrive/prosopo/checkpoints/epoch_10.pth'
)

print('‚úÖ Config ready')
print(f'   Effective batch size: {config.batch_size * config.accumulation_steps}')

## 4. Train Model

‚è±Ô∏è **Expected time:** ~8-12 hours on T4 GPU

In [None]:
# Check GPU
import torch
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Initialize trainer
trainer = Trainer(config)

# Start training
trainer.train()

## 5. Evaluate on LFW

In [None]:
from prosopo.evaluation import evaluate_lfw

accuracy, threshold = evaluate_lfw(
    trainer.model,
    config.lfw_root,
    config.lfw_pairs_path,
)

print(f'\nüéØ LFW Accuracy: {accuracy:.2%}')
print(f'   Optimal threshold: {threshold:.3f}')

## 6. Export Model

In [None]:
# Save final model to Drive
import torch

final_path = '/content/drive/MyDrive/prosopo/prosopo_final.pth'
torch.save(trainer.model.state_dict(), final_path)

print(f'‚úÖ Model saved to: {final_path}')

In [None]:
# Download to local machine
from google.colab import files
files.download(final_path)