# Prosopo Training Notebook

**Streamlined execution for face embedding training.**

Prerequisites:
- `aligned_casia.zip` in Google Drive (`/MyDrive/prosopo/`)
- T4 GPU runtime enabled

In [None]:
# Cell 1: Mount & Setup
from google.colab import drive
drive.mount('/content/drive')

# Install dependencies
!pip install -q torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -q albumentations tqdm

import os
os.makedirs('/content/drive/MyDrive/prosopo/checkpoints', exist_ok=True)
print('‚úÖ Environment Ready')

In [None]:
# Cell 2: Restore the Asset (Zero Alignment Time)
import os
import zipfile

zip_path = '/content/drive/MyDrive/prosopo/aligned_casia.zip'
extract_path = '/content/data/aligned_casia'

if not os.path.exists(extract_path):
    print(f"üöÄ Detected Backup at {zip_path}")
    print("üìÇ Extracting aligned faces... (~2-3 mins)")
    
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall('/content/data/')
        
    print("‚úÖ Extraction Complete.")
else:
    print("‚úÖ Data already extracted.")

# Verify count
num_images = sum([len(files) for r, d, files in os.walk(extract_path)])
print(f"üì∏ Total Training Images: {num_images:,}")

In [None]:
# Cell 3: Clone The Brain
import sys

# Remove old repo if exists
!rm -rf /content/prosopo

# Clone fresh
!git clone https://github.com/InanXR/Prosopo.git /content/prosopo

# Add to Python path
sys.path.insert(0, '/content/prosopo')

# Verify imports
try:
    from prosopo.models import Prosopo
    from prosopo.data import CASIAWebFaceDataset
    from prosopo.training import Config
    print("‚úÖ Codebase Integrity Verified. Imports working.")
except ImportError as e:
    print(f"‚ùå Import Failed: {e}")

In [None]:
# Cell 4: The Training Loop
import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import transforms
from tqdm import tqdm
import os

from prosopo.models import Prosopo
from prosopo.data import CASIAWebFaceDataset
from prosopo.training import Config

# --- CONFIGURATION ---
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cfg = Config()
cfg.data_root = '/content/data/aligned_casia'
cfg.checkpoint_dir = '/content/drive/MyDrive/prosopo/checkpoints'

# Augmentations
train_transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

# Dataset
print("üìÇ Loading Dataset...")
dataset = CASIAWebFaceDataset(cfg.data_root, transform=train_transform)
train_loader = DataLoader(
    dataset, 
    batch_size=cfg.batch_size, 
    shuffle=True, 
    num_workers=cfg.num_workers, 
    pin_memory=True
)

NUM_CLASSES = dataset.num_classes
print(f"üéØ Classes: {NUM_CLASSES} | Samples: {len(dataset):,} | Device: {DEVICE}")

# Model
model = Prosopo(
    num_classes=NUM_CLASSES,
    embedding_dim=cfg.embedding_dim,
    arcface_scale=cfg.arcface_scale,
    arcface_margin=cfg.arcface_margin,
).to(DEVICE)

# Optimizer
optimizer = optim.SGD(
    model.parameters(), 
    lr=cfg.lr, 
    momentum=cfg.momentum, 
    weight_decay=cfg.weight_decay
)
scheduler = optim.lr_scheduler.MultiStepLR(
    optimizer, 
    milestones=cfg.lr_milestones, 
    gamma=cfg.lr_gamma
)

criterion = torch.nn.CrossEntropyLoss()

# --- RESUME CHECK ---
start_epoch = 0
resume_path = f"{cfg.checkpoint_dir}/latest.pth"

if os.path.exists(resume_path):
    print("üîÑ Resuming from checkpoint...")
    checkpoint = torch.load(resume_path)
    model.load_state_dict(checkpoint['model_state_dict'])
    optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
    start_epoch = checkpoint['epoch'] + 1
    print(f"   Resuming from epoch {start_epoch}")

# --- TRAINING LOOP ---
print("\nüî• TRAINING STARTED")
print(f"   Epochs: {start_epoch} ‚Üí {cfg.epochs}")
print(f"   Batch Size: {cfg.batch_size}")
print(f"   Learning Rate: {cfg.lr}")
print("-" * 50)

for epoch in range(start_epoch, cfg.epochs):
    model.train()
    running_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{cfg.epochs}")
    for images, labels in pbar:
        images, labels = images.to(DEVICE), labels.to(DEVICE)
        
        optimizer.zero_grad()
        
        # Forward (Returns ArcFace Logits)
        outputs = model(images, labels)
        
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
        
    scheduler.step()
    avg_loss = running_loss / len(train_loader)
    print(f"üìä Epoch {epoch+1} complete. Avg Loss: {avg_loss:.4f}")
    
    # Save Checkpoint
    save_path = f"{cfg.checkpoint_dir}/epoch_{epoch+1}.pth"
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, save_path)
    
    # Save 'latest' for easy resume
    torch.save({
        'epoch': epoch,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'loss': avg_loss,
    }, resume_path)
    
    print(f"üíæ Checkpoint saved: {save_path}")

print("\n‚úÖ TRAINING COMPLETE!")

In [None]:
# Cell 5: Export Final Model
import torch

final_path = '/content/drive/MyDrive/prosopo/prosopo_final.pth'
torch.save(model.state_dict(), final_path)
print(f'‚úÖ Final model saved to: {final_path}')

# Download locally
from google.colab import files
files.download(final_path)