# Per-Subject Reenactment Finetuning

This notebook performs intensive finetuning on a **single identity** for maximum face swap quality.

**Purpose:** Overfit the model to one specific person for perfect quality on that individual.

**When to use:**
- You want the best possible quality for a specific person
- You have at least 2-10 images of that person
- You're willing to sacrifice generalization for quality

**Contents:**
1. Load dataset and select target identity
2. Create single-identity dataset with augmentation
3. Setup per-subject training
4. Train with aggressive overfitting
5. Evaluate results

## 1. Setup & Imports

In [None]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchvision.transforms as T
import numpy as np
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

from fsgan.utils.utils import load_model
from fsgan.utils.img_utils import bgr2tensor, create_pyramid, tensor2bgr
from fsgan.utils.landmarks_utils import filter_landmarks
from fsgan.criterions.vgg_loss import VGGLoss
from fsgan.utils.obj_factory import obj_factory

import dataloader

ROOT = Path('.')
WEIGHTS_DIR = ROOT / 'fsgan' / 'weights'
OUT_DIR = ROOT / 'outputs'
OUT_DIR.mkdir(exist_ok=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Load Dataset

Load your dataset first to see available identities.

In [None]:
# Dataset configuration
IMAGE_SIZE = 256
NB_IMAGES = 10
DATASET_PATH = "../data/Face-Swap-M2-Dataset/dataset/smaller"

# Load dataset
train_dataset, test_dataset, nb_classes = dataloader.make_dataset(
    DATASET_PATH, 
    NB_IMAGES, 
    IMAGE_SIZE, 
    0.8, 
    crop_faces=False
)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of identities: {nb_classes}")

## 3. Analyze Available Identities

See which identities are available and how many images each has.

In [None]:
def analyze_dataset_identities(dataset):
    """Analyze dataset to find all identities and their image counts."""
    label_to_indices = {}
    
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        if isinstance(label, torch.Tensor):
            label = label.item()
        if label not in label_to_indices:
            label_to_indices[label] = []
        label_to_indices[label].append(idx)
    
    return label_to_indices

print("="*60)
print("ANALYZING DATASET FOR PER-SUBJECT FINETUNING")
print("="*60)

train_identities = analyze_dataset_identities(train_dataset)
print(f"\nTRAIN SET: {len(train_identities)} identities")
print("-"*40)
for label in sorted(train_identities.keys()):
    count = len(train_identities[label])
    status = "✓" if count >= 2 else "✗ (need more)"
    print(f"  Label {label}: {count} images {status}")

# Find best candidates
sorted_by_count = sorted(train_identities.items(), key=lambda x: len(x[1]), reverse=True)
print(f"\n→ Recommended: Label {sorted_by_count[0][0]} ({len(sorted_by_count[0][1])} images)")

## 4. Configuration

**Set the identity to finetune on here:**

In [None]:
# ============================================================
# SELECT WHICH IDENTITY TO FINETUNE ON
# ============================================================
SUBJECT_LABEL = 0  # <-- CHANGE THIS to the identity you want

# Training hyperparameters - aggressive for overfitting
SUBJECT_EPOCHS = 500           # Many epochs to overfit
SUBJECT_LR = 1e-4              # Learning rate
SUBJECT_BATCH_SIZE = 2         # Small batch
SUBJECT_MIN_IMAGES = 2         # Minimum images needed

# Loss weights - emphasize pixel accuracy
SUBJECT_WEIGHT_PIXEL = 10.0    # Higher pixel weight
SUBJECT_WEIGHT_PERC = 1.0      # Perceptual weight
SUBJECT_SAVE_EVERY = 50        # Save checkpoints frequently

# Data augmentation
USE_AUGMENTATION = True

print("Per-subject finetuning configuration:")
print(f"  Target identity: {SUBJECT_LABEL}")
print(f"  Epochs: {SUBJECT_EPOCHS}")
print(f"  Learning rate: {SUBJECT_LR}")
print(f"  Batch size: {SUBJECT_BATCH_SIZE}")
print(f"  Augmentation: {USE_AUGMENTATION}")

## 5. Create Single-Identity Dataset

Extract one identity and apply heavy augmentation for few-shot learning.

In [None]:
class SingleIdentityDataset(Dataset):
    """
    Dataset for per-subject finetuning using a single identity.
    Creates pairs and supports heavy data augmentation.
    """
    def __init__(self, base_dataset, target_label, augment=True, num_augmentations=5):
        self.base_dataset = base_dataset
        self.target_label = target_label
        self.augment = augment
        self.num_augmentations = num_augmentations if augment else 1
        
        # Find all indices for this identity
        self.original_indices = []
        for idx in range(len(base_dataset)):
            _, label = base_dataset[idx]
            if isinstance(label, torch.Tensor):
                label = label.item()
            if label == target_label:
                self.original_indices.append(idx)
        
        if len(self.original_indices) == 0:
            raise ValueError(f"No images found for label {target_label}")
        
        # Augmentation transforms
        self.augment_transform = T.Compose([
            T.RandomHorizontalFlip(p=0.5),
            T.RandomAffine(
                degrees=15,
                translate=(0.1, 0.1),
                scale=(0.9, 1.1),
            ),
            T.ColorJitter(
                brightness=0.2,
                contrast=0.2,
                saturation=0.1,
                hue=0.05
            ),
        ])
        
        print(f"SingleIdentityDataset for label {target_label}:")
        print(f"  Original images: {len(self.original_indices)}")
        print(f"  With augmentation: {len(self)} pairs")
    
    def __len__(self):
        return len(self.original_indices) * self.num_augmentations
    
    def _get_image(self, idx):
        img, _ = self.base_dataset[idx]
        return img
    
    def _apply_augmentation(self, img_tensor):
        if not self.augment:
            return img_tensor
        
        # Convert to [0, 1] for transforms
        img_01 = (img_tensor + 1) / 2
        img_01 = torch.clamp(img_01, 0, 1)
        
        # Apply augmentation
        img_aug = self.augment_transform(img_01)
        
        # Back to [-1, 1]
        return img_aug * 2 - 1
    
    def __getitem__(self, idx):
        base_idx = idx % len(self.original_indices)
        real_idx = self.original_indices[base_idx]
        
        # Source: original or augmented
        src_img = self._get_image(real_idx)
        if self.augment and idx >= len(self.original_indices):
            src_img = self._apply_augmentation(src_img)
        
        # Target: random OTHER image from same person
        tgt_base_idx = base_idx
        while tgt_base_idx == base_idx and len(self.original_indices) > 1:
            tgt_base_idx = np.random.randint(len(self.original_indices))
        tgt_real_idx = self.original_indices[tgt_base_idx]
        tgt_img = self._get_image(tgt_real_idx)
        
        # Augment target too
        if self.augment:
            tgt_img = self._apply_augmentation(tgt_img)
        
        return src_img, tgt_img

# Create dataset
if SUBJECT_LABEL in train_identities:
    num_images = len(train_identities[SUBJECT_LABEL])
    print(f"\nCreating per-subject dataset for label {SUBJECT_LABEL}...")
    print(f"  Found {num_images} images")
    
    if num_images < SUBJECT_MIN_IMAGES:
        print(f"  ⚠ Warning: Only {num_images} images (minimum: {SUBJECT_MIN_IMAGES})")
    
    # Adjust batch size and augmentation
    effective_batch_size = min(SUBJECT_BATCH_SIZE, num_images)
    
    if num_images <= 5:
        num_aug = 20  # Heavy augmentation
    elif num_images <= 10:
        num_aug = 10
    else:
        num_aug = 5
    
    subject_dataset = SingleIdentityDataset(
        train_dataset,
        target_label=SUBJECT_LABEL,
        augment=USE_AUGMENTATION,
        num_augmentations=num_aug
    )
    
    subject_loader = DataLoader(
        subject_dataset,
        batch_size=effective_batch_size,
        shuffle=True,
        num_workers=0,
        drop_last=True
    )
    
    print(f"\n✓ Dataset ready:")
    print(f"  Augmented pairs: {len(subject_dataset)}")
    print(f"  Batch size: {effective_batch_size}")
    print(f"  Batches per epoch: {len(subject_loader)}")
else:
    print(f"\n✗ Label {SUBJECT_LABEL} not found!")
    print(f"  Available: {list(train_identities.keys())}")

## 6. Load Models and Setup Training

In [None]:
print("Loading pretrained models...")

# Load reenactment generator
reenact_w = WEIGHTS_DIR / 'nfv_msrunet_256_1_2_reenactment_v2.1.pth'
Gr_subject, ckpt = load_model(str(reenact_w), 'reenactment', device=device, return_checkpoint=True)
Gr_subject.train()
print(f"✓ Reenactment generator")

# Load landmarks model (frozen)
lms_w = WEIGHTS_DIR / 'hr18_wflw_landmarks.pth'
L_subject, _ = load_model(str(lms_w), 'landmarks', device=device, return_checkpoint=True)
L_subject.eval()
for p in L_subject.parameters():
    p.requires_grad = False
print("✓ Landmarks model (frozen)")

# Get pyramid levels
n_levels_subject = getattr(Gr_subject, 'n_local_enhancers', 1) + 1
print(f"Pyramid levels: {n_levels_subject}")

# Optimizer
optimizer_subject = optim.Adam(
    Gr_subject.parameters(), 
    lr=SUBJECT_LR, 
    betas=(0.5, 0.999)
)
print(f"✓ Optimizer: Adam, LR={SUBJECT_LR}")

# Loss functions
criterion_pixel = nn.L1Loss().to(device)
try:
    vgg_id_path = str(WEIGHTS_DIR / 'vggface2_vgg19_256_1_2_id.pth')
    criterion_id = VGGLoss(vgg_id_path).to(device)
    criterion_id.eval()
    print("✓ VGG identity loss")
except Exception as e:
    print(f"⚠ Could not load VGG: {e}")
    criterion_id = None

try:
    vgg_attr_path = str(WEIGHTS_DIR / 'celeba_vgg19_256_2_0_28_attr.pth')
    criterion_attr = VGGLoss(vgg_attr_path).to(device)
    criterion_attr.eval()
    print("✓ VGG attribute loss")
except Exception:
    criterion_attr = None

# ImageNet normalization
imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

## 7. Helper Functions

In [None]:
def prepare_reenactment_input(src_batch, tgt_batch, L_model, n_pyramid_levels, device):
    # Ensure [-1, 1] range
    if src_batch.min() >= 0:
        src_normalized = src_batch * 2 - 1
    else:
        src_normalized = src_batch
    
    if tgt_batch.min() >= 0:
        tgt_normalized = tgt_batch * 2 - 1
    else:
        tgt_normalized = tgt_batch
    
    # Prepare target for landmarks
    tgt_01 = (tgt_normalized + 1) / 2
    tgt_for_lms = (tgt_01 - imagenet_mean) / imagenet_std
    
    # Get landmarks
    with torch.no_grad():
        tgt_landmarks = L_model(tgt_for_lms)
        tgt_landmarks = filter_landmarks(tgt_landmarks)
    
    # Build pyramid
    src_pyd = create_pyramid(src_normalized, n_pyramid_levels)
    
    # Build input list
    input_list = []
    for p in range(len(src_pyd)):
        pyd_h, pyd_w = src_pyd[p].shape[2:]
        context = F.interpolate(tgt_landmarks, size=(pyd_h, pyd_w), mode='bilinear', align_corners=False)
        context = filter_landmarks(context)
        inp = torch.cat((src_pyd[p], context), dim=1)
        input_list.append(inp)
    
    return input_list, tgt_normalized

print("✓ Helper functions defined")

## 8. Per-Subject Training Loop

In [None]:
def per_subject_finetune(model, dataloader, L_model, optimizer, 
                         epochs=300, save_dir='models/per_subject',
                         weight_pixel=10.0, weight_perc=1.0):
    """
    Per-subject finetuning for maximum quality.
    Intentionally overfits to a single person.
    """
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)
    
    model.train()
    L_model.eval()
    
    history = {
        'loss': [], 'loss_pixel': [], 'loss_id': [], 
        'loss_attr': [], 'loss_stepwise': []
    }
    
    n_levels = getattr(model, 'n_local_enhancers', 1) + 1
    
    print("="*70)
    print("PER-SUBJECT FINETUNING (Aggressive Overfitting)")
    print("="*70)
    print(f"Epochs: {epochs}")
    print(f"Batches per epoch: {len(dataloader)}")
    print(f"Weight pixel: {weight_pixel}, Weight perceptual: {weight_perc}")
    print("="*70)
    
    best_loss = float('inf')
    
    for epoch in range(epochs):
        epoch_losses = {'total': 0, 'pixel': 0, 'id': 0, 'attr': 0, 'stepwise': 0}
        n_batches = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for src_img, tgt_img in pbar:
            src_img = src_img.to(device)
            tgt_img = tgt_img.to(device)
            
            # Prepare inputs
            input_list, tgt_normalized = prepare_reenactment_input(
                src_img, tgt_img, L_model, n_levels, device
            )
            
            # Forward pass
            output = model(input_list)
            if isinstance(output, (list, tuple)):
                pred = output[-1]
                # Stepwise consistency
                loss_stepwise = torch.tensor(0.0, device=device)
                for i, out_i in enumerate(output[:-1]):
                    tgt_down = F.interpolate(tgt_normalized, size=out_i.shape[2:], 
                                            mode='bilinear', align_corners=False)
                    loss_stepwise += criterion_pixel(out_i, tgt_down)
                loss_stepwise /= len(output)
            else:
                pred = output
                loss_stepwise = torch.tensor(0.0, device=device)
            
            # Losses
            loss_pixel = criterion_pixel(pred, tgt_normalized)
            
            if criterion_id is not None:
                loss_id = criterion_id(pred, tgt_normalized)
            else:
                loss_id = torch.tensor(0.0, device=device)
            
            if criterion_attr is not None:
                loss_attr = criterion_attr(pred, tgt_normalized)
            else:
                loss_attr = torch.tensor(0.0, device=device)
            
            # Combined loss
            loss_total = (
                weight_pixel * loss_pixel + 
                weight_perc * (loss_id + loss_attr) +
                weight_pixel * 0.5 * loss_stepwise
            )
            
            # Backward
            optimizer.zero_grad()
            loss_total.backward()
            optimizer.step()
            
            # Track
            epoch_losses['total'] += loss_total.item()
            epoch_losses['pixel'] += loss_pixel.item()
            epoch_losses['id'] += loss_id.item() if criterion_id else 0
            epoch_losses['attr'] += loss_attr.item() if criterion_attr else 0
            epoch_losses['stepwise'] += loss_stepwise.item()
            n_batches += 1
            
            pbar.set_postfix({'loss': f"{loss_total.item():.4f}"})
        
        # Epoch summary
        avg_loss = epoch_losses['total'] / n_batches
        avg_pixel = epoch_losses['pixel'] / n_batches
        avg_id = epoch_losses['id'] / n_batches
        avg_attr = epoch_losses['attr'] / n_batches
        avg_stepwise = epoch_losses['stepwise'] / n_batches
        
        history['loss'].append(avg_loss)
        history['loss_pixel'].append(avg_pixel)
        history['loss_id'].append(avg_id)
        history['loss_attr'].append(avg_attr)
        history['loss_stepwise'].append(avg_stepwise)
        
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Pixel={avg_pixel:.4f}")
        
        # Save best
        if avg_loss < best_loss:
            best_loss = avg_loss
            best_path = save_path / 'per_subject_best.pth'
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'arch': ckpt.get('arch', 'unknown'),
                'history': history,
                'best_loss': best_loss,
            }, best_path)
            print(f"  → Best model saved: {best_path}")
        
        # Regular checkpoints
        if (epoch + 1) % SUBJECT_SAVE_EVERY == 0 or epoch == epochs - 1:
            ckpt_path = save_path / f'per_subject_epoch{epoch+1}.pth'
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'arch': ckpt.get('arch', 'unknown'),
                'history': history,
            }, ckpt_path)
            print(f"  → Checkpoint: {ckpt_path}")
    
    print(f"\nBest loss: {best_loss:.4f}")
    return history

print("✓ Training function defined")

## 9. Run Per-Subject Training

In [None]:
if 'subject_loader' in dir() and subject_loader is not None:
    print(f"Starting per-subject finetuning with {len(subject_dataset)} pairs...")
    
    subject_history = per_subject_finetune(
        model=Gr_subject,
        dataloader=subject_loader,
        L_model=L_subject,
        optimizer=optimizer_subject,
        epochs=SUBJECT_EPOCHS,
        save_dir=str(OUT_DIR / 'per_subject_models'),
        weight_pixel=SUBJECT_WEIGHT_PIXEL,
        weight_perc=SUBJECT_WEIGHT_PERC
    )
else:
    print("Cannot run training: subject_loader not available")

## 10. Visualize Training Progress

In [None]:
%matplotlib inline

if 'subject_history' in dir() and subject_history:
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    
    axes[0, 0].plot(subject_history['loss'], 'b-', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Total Loss')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].grid(True, alpha=0.3)
    
    axes[0, 1].plot(subject_history['loss_pixel'], 'g-', linewidth=2)
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('Pixel L1 Loss')
    axes[0, 1].set_title('Pixel Loss')
    axes[0, 1].grid(True, alpha=0.3)
    
    axes[1, 0].plot(subject_history['loss_id'], 'r-', linewidth=2, label='Identity')
    axes[1, 0].plot(subject_history['loss_attr'], 'm-', linewidth=2, label='Attribute')
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('Perceptual Loss')
    axes[1, 0].set_title('Perceptual Losses')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    axes[1, 1].plot(subject_history['loss_stepwise'], 'c-', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('Stepwise Loss')
    axes[1, 1].set_title('Stepwise Consistency')
    axes[1, 1].grid(True, alpha=0.3)
    
    plt.suptitle('Per-Subject Training Progress', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(str(OUT_DIR / 'per_subject_training_curves.png'), dpi=150)
    plt.show()
    
    print(f"\nFinal losses:")
    print(f"  Total: {subject_history['loss'][-1]:.4f}")
    print(f"  Pixel: {subject_history['loss_pixel'][-1]:.4f}")

## 11. Test Per-Subject Model

In [None]:
def test_per_subject_model(model, dataloader, L_model, num_samples=6):
    model.eval()
    n_levels = getattr(model, 'n_local_enhancers', 1) + 1
    
    all_src, all_tgt, all_pred = [], [], []
    
    with torch.no_grad():
        for src_batch, tgt_batch in dataloader:
            if len(all_src) >= num_samples:
                break
                
            src_batch = src_batch.to(device)
            tgt_batch = tgt_batch.to(device)
            
            input_list, tgt_normalized = prepare_reenactment_input(
                src_batch, tgt_batch, L_model, n_levels, device
            )
            
            output = model(input_list)
            pred = output[-1] if isinstance(output, (list, tuple)) else output
            
            for i in range(min(src_batch.shape[0], num_samples - len(all_src))):
                all_src.append(src_batch[i].cpu())
                all_tgt.append(tgt_batch[i].cpu())
                all_pred.append(pred[i].cpu())
    
    def to_np(t):
        return ((t.numpy().transpose(1, 2, 0) + 1) / 2 * 255).clip(0, 255).astype('uint8')
    
    src_np = [to_np(t) for t in all_src]
    tgt_np = [to_np(t) for t in all_tgt]
    pred_np = [to_np(t) for t in all_pred]
    
    # Visualization
    n = len(src_np)
    fig, axes = plt.subplots(n, 3, figsize=(12, 4*n))
    if n == 1:
        axes = axes.reshape(1, -1)
    
    total_psnr = 0
    for i in range(n):
        axes[i, 0].imshow(src_np[i])
        axes[i, 0].set_title('Source', fontsize=11)
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(tgt_np[i])
        axes[i, 1].set_title('Target (GT)', fontsize=11)
        axes[i, 1].axis('off')
        
        mse = np.mean((tgt_np[i].astype(float) - pred_np[i].astype(float)) ** 2)
        psnr = 10 * np.log10(255**2 / (mse + 1e-10))
        total_psnr += psnr
        
        axes[i, 2].imshow(pred_np[i])
        axes[i, 2].set_title(f'Prediction (PSNR: {psnr:.1f} dB)', fontsize=11)
        axes[i, 2].axis('off')
    
    plt.suptitle('Per-Subject Model Results', fontsize=14, fontweight='bold')
    plt.tight_layout()
    plt.savefig(str(OUT_DIR / 'per_subject_test_results.png'), dpi=150)
    plt.show()
    
    avg_psnr = total_psnr / n
    print(f"\nAverage PSNR: {avg_psnr:.2f} dB")
    
    model.train()
    return avg_psnr

# Run test
if 'Gr_subject' in dir() and 'subject_loader' in dir():
    test_psnr = test_per_subject_model(
        Gr_subject,
        subject_loader,
        L_subject,
        num_samples=6
    )

## Summary

This notebook:
- ✓ Extracted a single identity from your dataset
- ✓ Applied heavy data augmentation for few-shot learning
- ✓ Trained with aggressive overfitting (intentional!)
- ✓ Achieved maximum quality for that specific person

**The best model is saved as `per_subject_best.pth`**

**Use this model when:**
- You need the highest quality for a specific person
- Source and target are both this person
- You have similar poses to training data