# Prosopo Training Notebook

Train a face embedding model from scratch using ArcFace loss.

**Target:** 99%+ accuracy on LFW benchmark

---

## Pipeline Overview
1. Mount Drive (checkpoint survival)
2. Download CASIA-WebFace from Kaggle (.rec format)
3. Unpack RecordIO ‚Üí raw images
4. Align faces with MTCNN ‚Üí 112√ó112
5. Train ResNet-50 + ArcFace
6. Evaluate on LFW

## 1. Setup & Mount Drive

‚ö†Ô∏è **CRITICAL:** Mount Drive FIRST to ensure checkpoints survive session disconnects.

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Create directories
import os
os.makedirs('/content/drive/MyDrive/prosopo/checkpoints', exist_ok=True)
os.makedirs('/content/data', exist_ok=True)
print('‚úÖ Drive mounted')

In [None]:
# Install dependencies
!pip install -q torch torchvision
!pip install -q albumentations facenet-pytorch scikit-image
!pip install -q tqdm scikit-learn opencv-python
!pip install -q mxnet  # For RecordIO unpacking
!pip install -q kaggle

print('‚úÖ Dependencies installed')

In [None]:
# Clone Prosopo repo and verify structure
!git clone https://github.com/InanXR/Prosopo.git /content/prosopo

# Verify the package exists
import os
expected_files = [
    '/content/prosopo/prosopo/__init__.py',
    '/content/prosopo/prosopo/models/arcface.py',
    '/content/prosopo/prosopo/training/trainer.py',
    '/content/prosopo/scripts/preprocess.py',
]

all_ok = True
for f in expected_files:
    if os.path.exists(f):
        print(f'‚úÖ {f}')
    else:
        print(f'‚ùå MISSING: {f}')
        all_ok = False

if all_ok:
    print('\n‚úÖ Prosopo repo verified!')
else:
    print('\n‚ùå Some files missing - check GitHub repo!')

## 2. Setup Kaggle API

‚ö†Ô∏è **DO NOT hardcode your API key!** Use Colab Secrets or upload kaggle.json.

In [None]:
# Option 1: Upload kaggle.json (RECOMMENDED)
# Go to kaggle.com -> Settings -> API -> Create New Token
# This downloads kaggle.json

from google.colab import files
import os

# Check if already configured
if not os.path.exists('/root/.kaggle/kaggle.json'):
    print('Upload your kaggle.json file:')
    uploaded = files.upload()
    
    !mkdir -p ~/.kaggle
    !mv kaggle.json ~/.kaggle/
    !chmod 600 ~/.kaggle/kaggle.json
    print('‚úÖ Kaggle configured')
else:
    print('‚úÖ Kaggle already configured')

In [None]:
# Option 2: Use Colab Secrets (Alternative)
# Uncomment if you stored your key in Colab's secret manager

# from google.colab import userdata
# import os
# 
# os.makedirs('/root/.kaggle', exist_ok=True)
# kaggle_json = f'''{{"username":"{userdata.get('KAGGLE_USERNAME')}","key":"{userdata.get('KAGGLE_KEY')}"}}'''
# with open('/root/.kaggle/kaggle.json', 'w') as f:
#     f.write(kaggle_json)
# !chmod 600 ~/.kaggle/kaggle.json
# print('‚úÖ Kaggle configured from secrets')

## 3. Download CASIA-WebFace from Kaggle

Dataset: `debarghamitraroy/casia-webface` (~2.73 GB)

In [None]:
# Download CASIA-WebFace dataset
!kaggle datasets download -d debarghamitraroy/casia-webface -p /content/data/

print('\n‚úÖ Download complete. Checking contents...')
!ls -lh /content/data/

In [None]:
# Unzip and verify structure
!unzip -q /content/data/casia-webface.zip -d /content/data/raw_rec

print('\nContents after unzip:')
!find /content/data/raw_rec -name "*.rec" -o -name "*.idx" | head -20

# Find the .rec file path
import glob
rec_files = glob.glob('/content/data/raw_rec/**/*.rec', recursive=True)
print(f'\nFound .rec files: {rec_files}')

## 4. Unpack RecordIO to Raw Images

The dataset comes in MXNet RecordIO format. We unpack it to folders.

‚è±Ô∏è **Time:** ~30 minutes for 490K images

In [None]:
import mxnet as mx
from mxnet import recordio
import cv2
import os
from tqdm import tqdm

def unpack_rec_file(rec_path, output_dir):
    """
    Unpack MXNet RecordIO file to image folders.
    """
    print(f"Unpacking {rec_path}...")
    
    idx_path = rec_path.replace('.rec', '.idx')
    if not os.path.exists(rec_path):
        raise FileNotFoundError(f"{rec_path} not found!")
    if not os.path.exists(idx_path):
        raise FileNotFoundError(f"{idx_path} not found! (Required alongside .rec)")
    
    # Open RecordIO
    imgrec = recordio.MXIndexedRecordIO(idx_path, rec_path, 'r')
    
    # Read header
    s = imgrec.read_idx(0)
    header, _ = recordio.unpack(s)
    
    if isinstance(header.label, float):
        num_images = int(header.label)
    else:
        num_images = int(header.label[0])
    
    print(f"Total images: {num_images:,}")
    os.makedirs(output_dir, exist_ok=True)
    
    success_count = 0
    error_count = 0
    
    for idx in tqdm(range(1, num_images + 1), desc="Unpacking", mininterval=1.0):
        try:
            s = imgrec.read_idx(idx)
            header, img_data = recordio.unpack(s)
            
            img = mx.image.imdecode(img_data).asnumpy()
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            
            if isinstance(header.label, float):
                label = int(header.label)
            else:
                label = int(header.label[0])
            
            folder_name = f"{label:07d}"
            save_dir = os.path.join(output_dir, folder_name)
            os.makedirs(save_dir, exist_ok=True)
            
            filename = f"{idx}.jpg"
            cv2.imwrite(os.path.join(save_dir, filename), img)
            success_count += 1
            
        except Exception:
            error_count += 1
            continue
    
    print(f"\n‚úÖ Unpacked {success_count:,} images ({error_count} errors)")
    print(f"   Output: {output_dir}")
    print(f"   Identities: {len(os.listdir(output_dir))}")

In [None]:
# Find and unpack the .rec file
import glob

rec_files = glob.glob('/content/data/raw_rec/**/*.rec', recursive=True)

if not rec_files:
    print("‚ùå No .rec file found! Checking directory structure...")
    !find /content/data -type f | head -30
else:
    rec_path = rec_files[0]
    print(f"Using: {rec_path}")
    unpack_rec_file(rec_path, '/content/data/raw_casia')

## 5. Align Faces with MTCNN

Detect faces and warp to canonical 112√ó112 pose.

‚è±Ô∏è **Time:** ~2-4 hours for 490K images

In [None]:
# Run alignment
!python /content/prosopo/scripts/preprocess.py \
    --input /content/data/raw_casia \
    --output /content/data/aligned_casia \
    --skip-existing

In [None]:
# Verify alignment results
import os

aligned_dir = '/content/data/aligned_casia'

if os.path.exists(aligned_dir):
    num_identities = len([d for d in os.listdir(aligned_dir) 
                          if os.path.isdir(os.path.join(aligned_dir, d))])
    
    total_images = sum(len(files) for _, _, files in os.walk(aligned_dir))
    
    print(f"‚úÖ Alignment complete!")
    print(f"   Identities: {num_identities:,}")
    print(f"   Total aligned images: {total_images:,}")
    
    # Check for class_indices.json
    if os.path.exists(f"{aligned_dir}/class_indices.json"):
        print(f"   ‚úÖ class_indices.json exists")
    else:
        print(f"   ‚ö†Ô∏è class_indices.json not found")
else:
    print("‚ùå Aligned directory not found!")

## 6. Download LFW for Evaluation

In [None]:
# Download LFW
!wget -q http://vis-www.cs.umass.edu/lfw/lfw.tgz -O /content/data/lfw.tgz
!tar -xzf /content/data/lfw.tgz -C /content/data/
!wget -q http://vis-www.cs.umass.edu/lfw/pairs.txt -O /content/data/pairs.txt

print('‚úÖ LFW downloaded')
!ls /content/data/lfw | head -5

## 7. Configure Training

In [None]:
import sys
sys.path.insert(0, '/content/prosopo')

# Test import
try:
    from prosopo.training import TrainingConfig, Trainer
    from prosopo.models import Prosopo
    print('‚úÖ Prosopo imports successful')
except ImportError as e:
    print(f'‚ùå Import failed: {e}')
    print('\nCheck that all files were pushed to GitHub!')

In [None]:
from prosopo.training import TrainingConfig

config = TrainingConfig(
    # Data paths
    data_root='/content/data/aligned_casia',
    class_indices_path='/content/data/aligned_casia/class_indices.json',
    lfw_root='/content/data/lfw',
    lfw_pairs_path='/content/data/pairs.txt',
    
    # Model
    backbone='resnet50',
    embedding_dim=512,
    pretrained=True,
    
    # ArcFace
    arcface_scale=64.0,
    arcface_margin=0.5,
    
    # Training
    batch_size=128,
    accumulation_steps=2,
    epochs=25,
    lr=0.1,
    num_workers=2,
    
    # Checkpointing
    checkpoint_dir='/content/drive/MyDrive/prosopo/checkpoints',
    save_every=1,
    val_epochs=[10, 15, 20, 25],
    
    # Resume (set path if session crashed)
    resume_from=None,  # e.g., '/content/drive/MyDrive/prosopo/checkpoints/epoch_10.pth'
)

print('‚úÖ Config ready')
print(f'   Batch size: {config.batch_size} x {config.accumulation_steps} = {config.batch_size * config.accumulation_steps} effective')

## 8. Train Model

‚è±Ô∏è **Expected time:** ~8-12 hours on T4 GPU

If session disconnects:
1. Re-run cells 1-6 (they're fast - data is cached)
2. Set `resume_from` to last checkpoint path
3. Re-run training

In [None]:
import torch
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
from prosopo.training import Trainer

trainer = Trainer(config)
trainer.train()

## 9. Final Evaluation

In [None]:
from prosopo.evaluation import evaluate_lfw

accuracy, threshold = evaluate_lfw(
    trainer.model,
    config.lfw_root,
    config.lfw_pairs_path,
)

print(f'\nüéØ LFW Accuracy: {accuracy:.2%}')
print(f'   Optimal threshold: {threshold:.3f}')

## 10. Export Model

In [None]:
import torch

final_path = '/content/drive/MyDrive/prosopo/prosopo_final.pth'
torch.save(trainer.model.state_dict(), final_path)
print(f'‚úÖ Model saved to: {final_path}')

In [None]:
from google.colab import files
files.download(final_path)