# Prosopo Training (Self-Contained)

All code inline. No external imports. Just run.

In [None]:
# Cell 1: Setup
from google.colab import drive
drive.mount('/content/drive')

!pip install -q torch torchvision tqdm

import os
os.makedirs('/content/drive/MyDrive/prosopo/checkpoints', exist_ok=True)
print('âœ… Setup Complete')

In [None]:
# Cell 2: Extract Data
import zipfile
import os

zip_path = '/content/drive/MyDrive/prosopo/aligned_casia.zip'
extract_path = '/content/data/aligned_casia'

if not os.path.exists(extract_path):
    print('ðŸ“‚ Extracting...')
    with zipfile.ZipFile(zip_path, 'r') as z:
        z.extractall('/content/data/')
    print('âœ… Done')
else:
    print('âœ… Already extracted')

print(f'ðŸ“¸ Images: {sum(len(f) for _,_,f in os.walk(extract_path)):,}')

In [None]:
# Cell 3: Define Everything Inline
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, models
from PIL import Image
import os
import glob
import math
from tqdm import tqdm

# === DATASET ===
class FaceDataset(Dataset):
    def __init__(self, root, transform=None):
        self.root = root
        self.transform = transform
        self.images = glob.glob(os.path.join(root, '*', '*.jpg'))
        self.classes = sorted([d for d in os.listdir(root) if os.path.isdir(os.path.join(root, d))])
        self.class_to_idx = {c: i for i, c in enumerate(self.classes)}
        
    def __len__(self):
        return len(self.images)
    
    def __getitem__(self, idx):
        path = self.images[idx]
        label = self.class_to_idx[os.path.basename(os.path.dirname(path))]
        img = Image.open(path).convert('RGB')
        if self.transform:
            img = self.transform(img)
        return img, label

# === ARCFACE HEAD ===
class ArcFace(nn.Module):
    def __init__(self, in_features, num_classes, s=64.0, m=0.5):
        super().__init__()
        self.s = s
        self.m = m
        self.weight = nn.Parameter(torch.FloatTensor(num_classes, in_features))
        nn.init.xavier_uniform_(self.weight)
        self.cos_m = math.cos(m)
        self.sin_m = math.sin(m)
        self.th = math.cos(math.pi - m)
        self.mm = math.sin(math.pi - m) * m

    def forward(self, x, label):
        cosine = F.linear(F.normalize(x), F.normalize(self.weight))
        sine = torch.sqrt(1.0 - cosine.pow(2).clamp(0, 1))
        phi = cosine * self.cos_m - sine * self.sin_m
        phi = torch.where(cosine > self.th, phi, cosine - self.mm)
        one_hot = F.one_hot(label, self.weight.size(0)).float()
        output = (one_hot * phi) + ((1.0 - one_hot) * cosine)
        return output * self.s

# === MODEL ===
class Prosopo(nn.Module):
    def __init__(self, num_classes, embed_dim=512):
        super().__init__()
        resnet = models.resnet50(weights='IMAGENET1K_V1')
        self.backbone = nn.Sequential(*list(resnet.children())[:-1])
        self.bn1 = nn.BatchNorm1d(2048)
        self.fc = nn.Linear(2048, embed_dim)
        self.bn2 = nn.BatchNorm1d(embed_dim)
        self.head = ArcFace(embed_dim, num_classes)
        
    def forward(self, x, label=None):
        x = self.backbone(x).flatten(1)
        x = self.bn1(x)
        x = self.fc(x)
        x = self.bn2(x)
        if label is not None:
            return self.head(x, label)
        return F.normalize(x)

print('âœ… Classes defined')

In [None]:
# Cell 4: Train
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
DATA_ROOT = '/content/data/aligned_casia'
CKPT_DIR = '/content/drive/MyDrive/prosopo/checkpoints'
BATCH_SIZE = 128
EPOCHS = 25
LR = 0.1

transform = transforms.Compose([
    transforms.Resize((112, 112)),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

print('ðŸ“‚ Loading data...')
dataset = FaceDataset(DATA_ROOT, transform)
loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, pin_memory=True)
NUM_CLASSES = len(dataset.classes)
print(f'ðŸŽ¯ Classes: {NUM_CLASSES} | Samples: {len(dataset):,} | Device: {DEVICE}')

model = Prosopo(NUM_CLASSES).to(DEVICE)
optimizer = optim.SGD(model.parameters(), lr=LR, momentum=0.9, weight_decay=5e-4)
scheduler = optim.lr_scheduler.MultiStepLR(optimizer, [10, 18, 22], 0.1)
criterion = nn.CrossEntropyLoss()

# Resume
start_epoch = 0
resume = f'{CKPT_DIR}/latest.pth'
if os.path.exists(resume):
    ckpt = torch.load(resume)
    model.load_state_dict(ckpt['model'])
    optimizer.load_state_dict(ckpt['optim'])
    start_epoch = ckpt['epoch'] + 1
    print(f'ðŸ”„ Resuming from epoch {start_epoch}')

print(f'\nðŸ”¥ Training epochs {start_epoch} â†’ {EPOCHS}')
for epoch in range(start_epoch, EPOCHS):
    model.train()
    total_loss = 0
    pbar = tqdm(loader, desc=f'Epoch {epoch+1}/{EPOCHS}')
    for imgs, labels in pbar:
        imgs, labels = imgs.to(DEVICE), labels.to(DEVICE)
        optimizer.zero_grad()
        out = model(imgs, labels)
        loss = criterion(out, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
        pbar.set_postfix(loss=f'{loss.item():.4f}')
    scheduler.step()
    
    # Save
    torch.save({'epoch': epoch, 'model': model.state_dict(), 'optim': optimizer.state_dict()},
               f'{CKPT_DIR}/epoch_{epoch+1}.pth')
    torch.save({'epoch': epoch, 'model': model.state_dict(), 'optim': optimizer.state_dict()}, resume)
    print(f'ðŸ’¾ Epoch {epoch+1} saved. Avg loss: {total_loss/len(loader):.4f}')

print('\nâœ… DONE!')

In [None]:
# Cell 5: Export
final = '/content/drive/MyDrive/prosopo/prosopo_final.pth'
torch.save(model.state_dict(), final)
print(f'âœ… Saved to {final}')