# Prosopo Training Notebook

Train a face embedding model from scratch using ArcFace loss.

**Target:** 99%+ accuracy on LFW benchmark

---

## Pipeline Overview
1. Mount Drive (checkpoint survival)
2. Download CASIA-WebFace from Kaggle (.rec format)
3. Unpack RecordIO ‚Üí raw images
4. Align faces with MTCNN ‚Üí 112√ó112
5. Train ResNet-50 + ArcFace
6. Evaluate on LFW

## 1. Setup & Mount Drive

‚ö†Ô∏è **CRITICAL:** Mount Drive FIRST to ensure checkpoints survive session disconnects.

In [None]:
# Mount Google Drive for checkpoint persistence
from google.colab import drive
drive.mount('/content/drive')

# Create directories
import os
os.makedirs('/content/drive/MyDrive/prosopo/checkpoints', exist_ok=True)
os.makedirs('/content/data', exist_ok=True)
print('‚úÖ Drive mounted')

In [None]:
# Install dependencies
!pip install -q torch torchvision
!pip install -q albumentations facenet-pytorch scikit-image
!pip install -q tqdm scikit-learn opencv-python
!pip install -q mxnet  # For RecordIO unpacking
!pip install -q kaggle

# Clone Prosopo repo
!git clone https://github.com/InanXR/Prosopo.git /content/prosopo

print('‚úÖ Dependencies installed')

## 2. Download CASIA-WebFace from Kaggle

Dataset: `debarghamitraroy/casia-webface` (~2.73 GB)

In [None]:
# Setup Kaggle API credentials
import os

# Set your Kaggle API token
os.environ['KAGGLE_API_TOKEN'] = 'KGAT_1200fe88c38a44a77c8879998f9413ac'

# Alternative: Create kaggle.json
!mkdir -p ~/.kaggle
kaggle_json = '{"username":"YOUR_KAGGLE_USERNAME","key":"YOUR_KAGGLE_KEY"}'
# Uncomment and fill in if token method doesn't work:
# with open('/root/.kaggle/kaggle.json', 'w') as f:
#     f.write(kaggle_json)
# !chmod 600 ~/.kaggle/kaggle.json

print('‚úÖ Kaggle credentials configured')

In [None]:
# Download CASIA-WebFace dataset
!kaggle datasets download -d debarghamitraroy/casia-webface -p /content/data/
!unzip -q /content/data/casia-webface.zip -d /content/data/raw_rec

# Check what we got
!ls -lh /content/data/raw_rec/
print('‚úÖ Dataset downloaded')

## 3. Unpack RecordIO to Raw Images

The dataset comes in MXNet RecordIO format. We unpack it to folders.

In [None]:
import mxnet as mx
from mxnet import recordio
import cv2
import os
from tqdm import tqdm

def unpack_rec_file(rec_path, output_dir):
    """
    Unpack MXNet RecordIO file to image folders.
    
    Creates structure:
        output_dir/
            0000001/
                1.jpg
                2.jpg
            0000002/
                ...
    """
    print(f"Unpacking {rec_path} to {output_dir}...")
    
    idx_path = rec_path.replace('.rec', '.idx')
    if not os.path.exists(rec_path):
        print(f"‚ùå Error: {rec_path} not found!")
        # Try to find .rec file
        import glob
        rec_files = glob.glob('/content/data/raw_rec/**/*.rec', recursive=True)
        print(f"Found .rec files: {rec_files}")
        return
    
    # Open RecordIO
    imgrec = recordio.MXIndexedRecordIO(idx_path, rec_path, 'r')
    
    # Read header to get total count
    s = imgrec.read_idx(0)
    header, _ = recordio.unpack(s)
    
    # header.label[0] contains the number of images
    if isinstance(header.label, float):
        num_images = int(header.label)
    else:
        num_images = int(header.label[0])
    
    print(f"Total images to unpack: {num_images}")
    os.makedirs(output_dir, exist_ok=True)
    
    # Unpack each image
    for idx in tqdm(range(1, num_images + 1), desc="Unpacking"):
        try:
            s = imgrec.read_idx(idx)
            header, img_data = recordio.unpack(s)
            
            # Decode image
            img = mx.image.imdecode(img_data).asnumpy()
            
            # Convert RGB (MXNet) to BGR (OpenCV)
            img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
            
            # Get label (identity)
            if isinstance(header.label, float):
                label = int(header.label)
            else:
                label = int(header.label[0])
            
            # Create identity folder
            folder_name = f"{label:07d}"
            save_dir = os.path.join(output_dir, folder_name)
            os.makedirs(save_dir, exist_ok=True)
            
            # Save image
            filename = f"{idx}.jpg"
            cv2.imwrite(os.path.join(save_dir, filename), img)
            
        except Exception as e:
            if idx % 10000 == 0:
                print(f"Warning at {idx}: {e}")
            continue
    
    print(f"\n‚úÖ Unpacked to {output_dir}")
    print(f"   Identities: {len(os.listdir(output_dir))}")

In [None]:
# Find and unpack the .rec file
import glob

rec_files = glob.glob('/content/data/raw_rec/**/*.rec', recursive=True)
print(f"Found .rec files: {rec_files}")

if rec_files:
    rec_path = rec_files[0]
    unpack_rec_file(rec_path, '/content/data/raw_casia')
else:
    print("‚ùå No .rec file found! Check the download.")

## 4. Align Faces with MTCNN

Detect faces and warp to canonical 112√ó112 pose.

In [None]:
# Run alignment (this takes ~2-4 hours for full dataset)
import sys
sys.path.insert(0, '/content/prosopo')

!python /content/prosopo/scripts/preprocess.py \
    --input /content/data/raw_casia \
    --output /content/data/aligned_casia \
    --skip-existing

In [None]:
# Verify alignment results
import os

aligned_dir = '/content/data/aligned_casia'
num_identities = len([d for d in os.listdir(aligned_dir) if os.path.isdir(os.path.join(aligned_dir, d))])

total_images = sum(
    len(files) for _, _, files in os.walk(aligned_dir)
)

print(f"‚úÖ Alignment complete!")
print(f"   Identities: {num_identities}")
print(f"   Total aligned images: {total_images}")

## 5. Download LFW for Evaluation

In [None]:
# Download LFW
!wget -q http://vis-www.cs.umass.edu/lfw/lfw.tgz -O /content/data/lfw.tgz
!tar -xzf /content/data/lfw.tgz -C /content/data/

# Download pairs.txt
!wget -q http://vis-www.cs.umass.edu/lfw/pairs.txt -O /content/data/pairs.txt

print('‚úÖ LFW downloaded')

## 6. Configure Training

In [None]:
import sys
sys.path.insert(0, '/content/prosopo')

from prosopo.training import TrainingConfig, Trainer

# Training configuration
config = TrainingConfig(
    # Data paths
    data_root='/content/data/aligned_casia',
    class_indices_path='/content/data/aligned_casia/class_indices.json',
    lfw_root='/content/data/lfw',
    lfw_pairs_path='/content/data/pairs.txt',
    
    # Model
    backbone='resnet50',
    embedding_dim=512,
    pretrained=True,
    
    # ArcFace
    arcface_scale=64.0,
    arcface_margin=0.5,
    
    # Training
    batch_size=128,
    accumulation_steps=2,
    epochs=25,
    lr=0.1,
    num_workers=2,
    
    # Checkpointing (to Drive!)
    checkpoint_dir='/content/drive/MyDrive/prosopo/checkpoints',
    save_every=1,
    
    # Validation epochs
    val_epochs=[10, 15, 20, 25],
    
    # Resume from checkpoint (set path if resuming after disconnect)
    resume_from=None,
)

print('‚úÖ Config ready')
print(f'   Effective batch size: {config.batch_size * config.accumulation_steps}')

## 7. Train Model

‚è±Ô∏è **Expected time:** ~8-12 hours on T4 GPU

If session disconnects, change `resume_from` to the last checkpoint path and re-run.

In [None]:
# Check GPU
import torch
print(f'GPU: {torch.cuda.get_device_name(0)}')
print(f'VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB')

In [None]:
# Initialize trainer and start training
trainer = Trainer(config)
trainer.train()

## 8. Final Evaluation

In [None]:
from prosopo.evaluation import evaluate_lfw

accuracy, threshold = evaluate_lfw(
    trainer.model,
    config.lfw_root,
    config.lfw_pairs_path,
)

print(f'\nüéØ LFW Accuracy: {accuracy:.2%}')
print(f'   Optimal threshold: {threshold:.3f}')

## 9. Export Model

In [None]:
# Save final model to Drive
import torch

final_path = '/content/drive/MyDrive/prosopo/prosopo_final.pth'
torch.save(trainer.model.state_dict(), final_path)

print(f'‚úÖ Model saved to: {final_path}')

In [None]:
# Download to local machine
from google.colab import files
files.download(final_path)