# Multi-Subject Reenactment Finetuning


## 1. Setup & Imports

In [None]:
import sys
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import numpy as np
import cv2
from pathlib import Path
from tqdm import tqdm
import matplotlib.pyplot as plt

from fsgan.utils.utils import load_model
from fsgan.utils.img_utils import bgr2tensor, create_pyramid, tensor2bgr
from fsgan.utils.landmarks_utils import LandmarksHeatMapDecoder, filter_landmarks
from fsgan.criterions.vgg_loss import VGGLoss
from fsgan.notebook_helpers.reenact_preprocess import run_full_pipeline
from fsgan.utils.obj_factory import obj_factory

import dataloader

ROOT = Path('.')
WEIGHTS_DIR = ROOT / 'fsgan' / 'weights'
OUT_DIR = ROOT / 'outputs'
OUT_DIR.mkdir(exist_ok=True)

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## 2. Configuration

In [None]:
IMAGE_SIZE = 256
NB_IMAGES = 10
DATASET_PATH = "../data/Face-Swap-M2-Dataset/dataset/smaller"


FINETUNE_EPOCHS = 100
FINETUNE_LR = 1e-5
FINETUNE_BATCH_SIZE = 2 
GRADIENT_ACCUMULATION_STEPS = 2  # Accumulate gradients to simulate batch_size=4
SAVE_EVERY = 200


WEIGHT_PIXEL = 0.1
WEIGHT_PERCEPTUAL = 1.0
WEIGHT_REC = 1.0

print("Configuration:")
print(f"  Dataset: {DATASET_PATH}")
print(f"  Image size: {IMAGE_SIZE}")
print(f"  Images per person: {NB_IMAGES}")
print(f"  Epochs: {FINETUNE_EPOCHS}")
print(f"  Batch size: {FINETUNE_BATCH_SIZE}")
print(f"  Gradient accumulation: {GRADIENT_ACCUMULATION_STEPS} (effective batch: {FINETUNE_BATCH_SIZE * GRADIENT_ACCUMULATION_STEPS})")
print(f"  Learning rate: {FINETUNE_LR}")

## 3. Load Dataset

In [None]:

train_dataset, test_dataset, nb_classes = dataloader.make_dataset(
    DATASET_PATH, 
    NB_IMAGES, 
    IMAGE_SIZE, 
    0.8, 
    crop_faces=False
)

print(f"Train samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Number of identities: {nb_classes}")

class_counts = np.bincount([label for _, label in train_dataset])
for cls, count in enumerate(class_counts):
    print(f"  Identity {cls}: {count} images")

## 4. Create Paired Dataset for Reenactment

For reenactment training, we need pairs of images from the same person with differents poses

In [None]:
class ReenactmentPairDataset(Dataset):
    """
    Creates pairs of images from the same person for reenactment training.
    """
    def __init__(self, base_dataset, resolution=256):
        self.base_dataset = base_dataset
        self.resolution = resolution
        
        # Group samples by label (person identity)
        self.label_to_indices = {}
        for idx in range(len(base_dataset)):
            _, label = base_dataset[idx]
            if isinstance(label, torch.Tensor):
                label = label.item()
            if label not in self.label_to_indices:
                self.label_to_indices[label] = []
            self.label_to_indices[label].append(idx)
        
        # Filter labels with at least 2 samples
        self.valid_labels = [l for l, indices in self.label_to_indices.items() if len(indices) >= 2]
        
        # Create list of valid indices
        self.valid_indices = []
        for label in self.valid_labels:
            self.valid_indices.extend(self.label_to_indices[label])
        
        print(f"ReenactmentPairDataset: {len(self.valid_labels)} identities with 2+ images")
        print(f"Total valid samples: {len(self.valid_indices)}")
    
    def __len__(self):
        return len(self.valid_indices)
    
    def __getitem__(self, idx):
        # Get source sample
        src_idx = self.valid_indices[idx]
        src_img, src_label = self.base_dataset[src_idx]
        if isinstance(src_label, torch.Tensor):
            src_label = src_label.item()
        
        # Get a different sample from the same person
        same_person_indices = self.label_to_indices[src_label]
        tgt_idx = src_idx
        while tgt_idx == src_idx:
            tgt_idx = same_person_indices[np.random.randint(len(same_person_indices))]
        
        tgt_img, _ = self.base_dataset[tgt_idx]
        
        return src_img, tgt_img, src_label

# Create paired datasets
finetune_dataset = ReenactmentPairDataset(train_dataset, resolution=IMAGE_SIZE)
finetune_loader = DataLoader(
    finetune_dataset, 
    batch_size=FINETUNE_BATCH_SIZE, 
    shuffle=True,
    num_workers=4,
    drop_last=True
)

print(f"\nFinetune dataloader: {len(finetune_loader)} batches")

## 5. Load Pretrained Models

In [None]:
print("Loading pretrained models...")

# Load reenactment generator
reenact_w = WEIGHTS_DIR / 'nfv_msrunet_256_1_2_reenactment_v2.1.pth'
Gr_finetune, ckpt = load_model(str(reenact_w), 'reenactment', device=device, return_checkpoint=True)
Gr_finetune.train()
print(f"Loaded Reenactment generator: {ckpt.get('arch', 'unknown')}")

# Load landmarks model (frozen)
lms_w = WEIGHTS_DIR / 'hr18_wflw_landmarks.pth'
L_frozen, _ = load_model(str(lms_w), 'landmarks', device=device, return_checkpoint=True)
L_frozen.eval()
for param in L_frozen.parameters():
    param.requires_grad = False
print("Loaded Landmarks model (frozen)")

n_local = getattr(Gr_finetune, 'n_local_enhancers', None)
if n_local is None and hasattr(Gr_finetune, 'module'):
    n_local = getattr(Gr_finetune.module, 'n_local_enhancers', None)
n_local = n_local if n_local is not None else 1
n_levels = n_local + 1
print(f"Pyramid levels: {n_levels}")

# Normalisation for landmark model
imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)

## 6. Setup Loss Functions

In [None]:
print("Initializing loss functions...")


criterion_pixel = nn.L1Loss().to(device)


try:
    vgg_id_path = str(WEIGHTS_DIR / 'vggface2_vgg19_256_1_2_id.pth')
    criterion_id = VGGLoss(vgg_id_path).to(device)
    criterion_id.eval()
    print("loaded VGG identity loss")
except Exception as e:
    print(f"Could not load VGG identity loss: {e}")
    criterion_id = None

try:
    vgg_attr_path = str(WEIGHTS_DIR / 'celeba_vgg19_256_2_0_28_attr.pth')
    criterion_attr = VGGLoss(vgg_attr_path).to(device)
    criterion_attr.eval()
    print("VGG attribute loss loaded")
except Exception as e:
    print(f"Could not load VGG attribute loss: {e}")
    criterion_attr = None


optimizer_G = optim.Adam(Gr_finetune.parameters(), lr=FINETUNE_LR, betas=(0.5, 0.999))
print(f"Optimizer: Adam, LR={FINETUNE_LR}")

## 7. Helper Functions

In [None]:
def prepare_reenactment_input(src_batch, tgt_batch, L_model, n_pyramid_levels, device):
    """
    Prepare source image pyramid with target landmarks.
    """

    if src_batch.min() >= 0:
        src_normalized = src_batch * 2 - 1
    else:
        src_normalized = src_batch
    
    if tgt_batch.min() >= 0:
        tgt_normalized = tgt_batch * 2 - 1
    else:
        tgt_normalized = tgt_batch
    

    tgt_01 = (tgt_normalized + 1) / 2
    tgt_for_lms = (tgt_01 - imagenet_mean) / imagenet_std
    

    with torch.no_grad():
        tgt_landmarks = L_model(tgt_for_lms)
        tgt_landmarks = filter_landmarks(tgt_landmarks)

        del tgt_01, tgt_for_lms
    

    src_pyd = create_pyramid(src_normalized, n_pyramid_levels)
    

    input_list = []
    for p in range(len(src_pyd)):
        pyd_h, pyd_w = src_pyd[p].shape[2:]
        context = F.interpolate(tgt_landmarks, size=(pyd_h, pyd_w), mode='bilinear', align_corners=False)
        context = filter_landmarks(context)
        inp = torch.cat((src_pyd[p], context), dim=1)
        input_list.append(inp)
    

    del src_pyd, tgt_landmarks, context
    
    return input_list, tgt_normalized

print("Helper functions defined")

## 8. Training Loop

In [None]:
def finetune_reenactment(model, dataloader, optimizer, epochs, save_dir, resume_from=None):
    """
    Finetune the reenactment generator on paired face data.
    
    Args:
        resume_from: Path to checkpoint to resume from, or None to start fresh
    """
    save_path = Path(save_dir)
    save_path.mkdir(parents=True, exist_ok=True)
    
    model.train()
    history = {'loss': [], 'loss_pixel': [], 'loss_id': [], 'loss_attr': []}
    start_epoch = 0
    

    if resume_from is not None:
        print(f"Resuming from checkpoint: {resume_from}")
        resume_ckpt = torch.load(resume_from, map_location=device)
        model.load_state_dict(resume_ckpt['state_dict'])
        optimizer.load_state_dict(resume_ckpt['optimizer'])
        start_epoch = resume_ckpt.get('epoch', 0)
        if 'history' in resume_ckpt:
            history = resume_ckpt['history']
        print(f"✓ Resumed from epoch {start_epoch}")
        print(f"  Previous best loss: {history['loss'][-1] if history['loss'] else 'N/A'}")
    
    print("="*60)
    print("STARTING FINETUNING" if start_epoch == 0 else "RESUMING FINETUNING")
    print("="*60)
    print(f"Epochs: {start_epoch + 1} to {epochs}")
    print(f"Batches per epoch: {len(dataloader)}")
    print(f"Save directory: {save_path}")
    print(f"Gradient accumulation steps: {GRADIENT_ACCUMULATION_STEPS}")
    print("="*60)
    
    for epoch in range(start_epoch, epochs):
        epoch_losses = {'total': 0, 'pixel': 0, 'id': 0, 'attr': 0}
        n_batches = 0
        accumulation_step = 0
        
        pbar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{epochs}")
        
        for batch_idx, (src_img, tgt_img, labels) in enumerate(pbar):
            src_img = src_img.to(device)
            tgt_img = tgt_img.to(device)

            

            input_list, tgt_normalized = prepare_reenactment_input(
                src_img, tgt_img, L_frozen, n_levels, device
            )
            

            output = model(input_list)
            pred = output[-1] if isinstance(output, (list, tuple)) else output
            

            loss_pixel = criterion_pixel(pred, tgt_normalized)
            
            if criterion_id is not None:
                loss_id = criterion_id(pred, tgt_normalized)
            else:
                loss_id = torch.tensor(0.0, device=device)
            
            if criterion_attr is not None:
                loss_attr = criterion_attr(pred, tgt_normalized)
            else:
                loss_attr = torch.tensor(0.0, device=device)
            

            loss_rec = WEIGHT_PIXEL * loss_pixel + 0.5 * loss_id + 0.5 * loss_attr
            loss_total = WEIGHT_REC * loss_rec / GRADIENT_ACCUMULATION_STEPS
            

            loss_total.backward()
            

            total_val = loss_total.item() * GRADIENT_ACCUMULATION_STEPS
            pixel_val = loss_pixel.item()
            id_val = loss_id.item() if criterion_id else 0
            attr_val = loss_attr.item() if criterion_attr else 0
            
            epoch_losses['total'] += total_val
            epoch_losses['pixel'] += pixel_val
            epoch_losses['id'] += id_val
            epoch_losses['attr'] += attr_val
            n_batches += 1
            
            del input_list, output, pred, loss_pixel, loss_id, loss_attr, loss_rec, loss_total
            del src_img, tgt_img, tgt_normalized
            
            accumulation_step += 1
            if accumulation_step % GRADIENT_ACCUMULATION_STEPS == 0:
                optimizer.step()
                optimizer.zero_grad()
                accumulation_step = 0
            
            pbar.set_postfix({
                'loss': f"{total_val:.4f}",
                'pix': f"{pixel_val:.4f}"
            })
            
            if (n_batches % 5) == 0:
                torch.cuda.empty_cache()
        
        if accumulation_step > 0:
            optimizer.step()
            optimizer.zero_grad()
        
        avg_loss = epoch_losses['total'] / n_batches
        avg_pixel = epoch_losses['pixel'] / n_batches
        avg_id = epoch_losses['id'] / n_batches
        avg_attr = epoch_losses['attr'] / n_batches
        
        history['loss'].append(avg_loss)
        history['loss_pixel'].append(avg_pixel)
        history['loss_id'].append(avg_id)
        history['loss_attr'].append(avg_attr)
        
        print(f"Epoch {epoch+1}: Loss={avg_loss:.4f}, Pixel={avg_pixel:.4f}, ID={avg_id:.4f}")
        
        # Save
        if (epoch + 1) % SAVE_EVERY == 0 or epoch == epochs - 1:
            ckpt_path = save_path / f'reenact_finetuned_epoch{epoch+1}.pth'
            torch.save({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'arch': ckpt.get('arch', 'unknown'),
                'history': history,
            }, ckpt_path)
            print(f"  → Saved checkpoint: {ckpt_path}")
    
    print("\n" + "="*60)
    print("FINETUNING COMPLETE")
    print("="*60)
    
    return history

print("Training function defined")

## 9. Run Finetuning

In [None]:
file_name = "reenact_finetuned_epoch100.pth"
finetuned_dir = OUT_DIR / 'finetuned_models'
ckpts = [finetuned_dir / file_name] if (finetuned_dir / file_name).exists() else []

if not ckpts:
    print("No finetuned checkpoints found!")
else:
    latest_ckpt = ckpts[-1]
    print(f"Loading: {latest_ckpt.name}")
    
    ckpt_data = torch.load(str(latest_ckpt), map_location=device)
    Gr_finetune.load_state_dict(ckpt_data['state_dict'])
    print(f"Loaded finetuned model from epoch {ckpt_data.get('epoch', '?')}")

### 9.1 Configure Resume (Optional)
If you have a checkpoint from a previous training run, you can resume from it instead of starting from scratch.

In [None]:

RESUME_FROM =  'outputs/finetuned_models/reenact_finetuned_epoch300.pth'


if RESUME_FROM and Path(RESUME_FROM).exists():
    print(f"Will resume from: {RESUME_FROM}")
else:
    if RESUME_FROM:
        print(f"Checkpoint not found: {RESUME_FROM}")
        print("  Starting fresh training instead")
    RESUME_FROM = None
    print("Starting fresh training")

In [None]:
allocated_memory = torch.cuda.memory_allocated()
print(f"CUDA memory allocated: {allocated_memory / (1024 ** 2):.2f} MB")

In [None]:
import os

torch.cuda.empty_cache()
print("CUDA cache emptied.")

torch.cuda.reset_peak_memory_stats()
print("GPU memory stats reset.")

os.system('nvidia-smi --gpu-reset')
print("GPU reset command executed.")

In [None]:
allocated_memory = torch.cuda.memory_allocated()
print(f"CUDA memory allocated: {allocated_memory / (1024 ** 2):.2f} MB")

### 9.2 Execute Training

In [None]:

history = finetune_reenactment(
    model=Gr_finetune,
    dataloader=finetune_loader,
    optimizer=optimizer_G,
    epochs=1500,
    save_dir=finetuned_dir,
    resume_from=RESUME_FROM
)

**How Resume Works:**
- If `RESUME_FROM=None`: Starts fresh training from epoch 0
- If `RESUME_FROM` points to a checkpoint: Loads model weights, optimizer state, and continues from that epoch
- Training history is preserved and extended
- New checkpoints will continue the epoch numbering (e.g., if resuming from epoch 50, next save is epoch 60, 70, etc.)

## 10. Visualize Training Progress

In [None]:
%matplotlib inline

if 'history' in dir() and history:
    fig, axes = plt.subplots(1, 3, figsize=(15, 4))
    
    axes[0].plot(history['loss'], 'b-', linewidth=2)
    axes[0].set_xlabel('Epoch')
    axes[0].set_ylabel('Total Loss')
    axes[0].set_title('Total Loss')
    axes[0].grid(True, alpha=0.3)
    
    axes[1].plot(history['loss_pixel'], 'g-', linewidth=2)
    axes[1].set_xlabel('Epoch')
    axes[1].set_ylabel('Pixel L1 Loss')
    axes[1].set_title('Pixel Loss')
    axes[1].grid(True, alpha=0.3)
    
    axes[2].plot(history['loss_id'], 'r-', linewidth=2, label='Identity')
    axes[2].plot(history['loss_attr'], 'm-', linewidth=2, label='Attribute')
    axes[2].set_xlabel('Epoch')
    axes[2].set_ylabel('Perceptual Loss')
    axes[2].set_title('Perceptual Losses')
    axes[2].legend()
    axes[2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(str(OUT_DIR / 'finetuning_curves.png'), dpi=150)
    plt.show()
    
    print(f"Final losses - Total: {history['loss'][-1]:.4f}, Pixel: {history['loss_pixel'][-1]:.4f}")

## 11. Test Finetuned Model with Full Pipeline

In [None]:
import random

def analyze_dataset_identities(dataset):
    """Analyze dataset to find all identities."""
    label_to_indices = {}
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        if isinstance(label, torch.Tensor):
            label = label.item()
        if label not in label_to_indices:
            label_to_indices[label] = []
        label_to_indices[label].append(idx)
    return label_to_indices

NUM_SWAPS = 5
FULL_SWAP_DIR = OUT_DIR / 'finetuned_full_swaps'
FULL_SWAP_DIR.mkdir(exist_ok=True, parents=True)


print("Loading finetuned model...")
finetuned_dir = OUT_DIR / 'finetuned_models'
ckpts = sorted(finetuned_dir.glob('*.pth'))

if not ckpts:
    print("No finetuned checkpoints found!")
else:
    latest_ckpt = ckpts[-1]
    print(f"Loading: {latest_ckpt.name}")
    
    ckpt_data = torch.load(str(latest_ckpt), map_location=device)
    arch = ckpt_data.get('arch', 'res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)')
    Gr_finetuned = obj_factory(arch).to(device)
    Gr_finetuned.load_state_dict(ckpt_data['state_dict'])
    Gr_finetuned.eval()
    
    print(f"✓ Loaded epoch {ckpt_data.get('epoch', '?')}")
    
    # Get available identities
    train_identities = analyze_dataset_identities(train_dataset)
    available_labels = list(train_identities.keys())
    
    if len(available_labels) < 2:
        print("Need at least 2 identities!")
    else:
        print(f"\nPerforming {NUM_SWAPS} full pipeline face swaps...")
        
        all_sources = []
        all_targets = []
        all_results = []
        swap_info = []
        
        for i in range(NUM_SWAPS):
            src_label, tgt_label = random.sample(available_labels, 2)
            src_idx = random.choice(train_identities[src_label])
            tgt_idx = random.choice(train_identities[tgt_label])
            
            src_img, _ = train_dataset[src_idx]
            tgt_img, _ = train_dataset[tgt_idx]
            
            # Save temp images
            src_path = FULL_SWAP_DIR / f'temp_src_{i}.png'
            tgt_path = FULL_SWAP_DIR / f'temp_tgt_{i}.png'
            out_path = FULL_SWAP_DIR / f'full_swap_{i}_id{src_label}_pose{tgt_label}.png'
            
            src_np = ((src_img.permute(1, 2, 0).numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
            tgt_np = ((tgt_img.permute(1, 2, 0).numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
            
            cv2.imwrite(str(src_path), cv2.cvtColor(src_np, cv2.COLOR_RGB2BGR))
            cv2.imwrite(str(tgt_path), cv2.cvtColor(tgt_np, cv2.COLOR_RGB2BGR))
            
            # Run full pipeline with FINETUNED MODEL
            try:
                result_bgr, intermediates, src_crop, tgt_crop = run_full_pipeline(
                    str(src_path), 
                    str(tgt_path), 
                    out_path=str(out_path),
                    reenact=True, 
                    use_detector=True, 
                    device=device,
                    crop_scale=1.2, 
                    resolution=256,
                    G_model=Gr_finetuned
                )
                
                result_rgb = cv2.cvtColor(result_bgr, cv2.COLOR_BGR2RGB)
                
                all_sources.append(src_np)
                all_targets.append(tgt_np)
                all_results.append(result_rgb)
                swap_info.append(f"ID {src_label} → Pose {tgt_label}")
                
                print(f"  {i+1}/{NUM_SWAPS}: ID {src_label} → Pose {tgt_label} ✓")
                
            except Exception as e:
                print(f"  {i+1}/{NUM_SWAPS}: Error - {e}")
        
        print(f"\n✓ Done! Results in: {FULL_SWAP_DIR}")
        
        # Visualize results
        if all_results:
            n_swaps = len(all_results)
            fig, axes = plt.subplots(n_swaps, 3, figsize=(12, 4*n_swaps))
            
            if n_swaps == 1:
                axes = axes.reshape(1, -1)
            
            for i in range(n_swaps):
                axes[i, 0].imshow(all_sources[i])
                axes[i, 0].set_title(f'Source\n{swap_info[i].split("→")[0].strip()}', fontsize=11)
                axes[i, 0].axis('off')
                
                axes[i, 1].imshow(all_targets[i])
                axes[i, 1].set_title(f'Target\n{swap_info[i].split("→")[1].strip()}', fontsize=11)
                axes[i, 1].axis('off')
                
                axes[i, 2].imshow(all_results[i])
                axes[i, 2].set_title(f'Result\n{swap_info[i]}', fontsize=11)
                axes[i, 2].axis('off')
            
            plt.suptitle('Finetuned Model - Face Swap Results', fontsize=14, fontweight='bold')
            plt.tight_layout()
            
            viz_path = FULL_SWAP_DIR / 'all_swaps_visualization.png'
            plt.savefig(str(viz_path), dpi=150, bbox_inches='tight')
            print(f"\nSaved visualization: {viz_path}")
            
            plt.show()

## 12. Compare Base vs Finetuned Models

Compare face swap results between the pretrained base model and finetuned model.

In [None]:
# Load FaceNet for identity similarity measurement
print("\n4. Loading FaceNet for identity verification...")

try:
    from facenet_pytorch import InceptionResnetV1
    
    facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)
    
    for param in facenet.parameters():
        param.requires_grad = False
    
    print("   ✓ FaceNet loaded successfully")
    
    from torchvision import transforms
    facenet_transform = transforms.Compose([
        transforms.ToPILImage(),
        transforms.Resize((160, 160)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    def get_facenet_embedding(img_rgb):
        """Extract FaceNet embedding from RGB image (HWC numpy array)."""
        img_tensor = facenet_transform(img_rgb).unsqueeze(0).to(device)
        with torch.no_grad():
            embedding = facenet(img_tensor)
        return embedding
    
    def cosine_similarity(emb1, emb2):
        """Compute cosine similarity between two embeddings."""
        return F.cosine_similarity(emb1, emb2).item()
    
    print("   ✓ FaceNet helper functions defined")
    FACENET_AVAILABLE = True
    
except ImportError:
    print("   ⚠ facenet-pytorch not installed. Run: pip install facenet-pytorch")
    print("   Skipping identity similarity metrics...")
    FACENET_AVAILABLE = False
except Exception as e:
    print(f"   ⚠ Error loading FaceNet: {e}")
    print("   Skipping identity similarity metrics...")
    FACENET_AVAILABLE = False

print("="*60)

In [None]:
import random


NUM_COMPARISONS = 10
COMPARISON_DIR = OUT_DIR / 'model_comparison'
COMPARISON_DIR.mkdir(exist_ok=True, parents=True)

print("="*60)
print("MODEL COMPARISON: Base vs Finetuned")
print("="*60)

# Load base pretrained model
print("\n1. Loading BASE pretrained model...")
reenact_base_w = WEIGHTS_DIR / 'nfv_msrunet_256_1_2_reenactment_v2.1.pth'
Gr_base, ckpt_base = load_model(str(reenact_base_w), 'reenactment', device=device, return_checkpoint=True)
Gr_base.eval()
print(f"Base model loaded: {ckpt_base.get('arch', 'unknown')}")

# Load finetuned model
print("\n2. Loading FINETUNED model...")
finetuned_checkpoint = OUT_DIR / 'finetuned_models' / 'reenact_finetuned_epoch500.pth'

if not finetuned_checkpoint.exists():
    print(f"Checkpoint not found: {finetuned_checkpoint}")
    print("Looking for any available checkpoint...")
    available_ckpts = sorted((OUT_DIR / 'finetuned_models').glob('*.pth'))
    if available_ckpts:
        finetuned_checkpoint = available_ckpts[-1]
        print(f"   Using: {finetuned_checkpoint.name}")
    else:
        raise FileNotFoundError("No finetuned checkpoints found!")

ckpt_finetuned = torch.load(str(finetuned_checkpoint), map_location=device)
arch = ckpt_finetuned.get('arch', 'res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)')
Gr_finetuned_cmp = obj_factory(arch).to(device)
Gr_finetuned_cmp.load_state_dict(ckpt_finetuned['state_dict'])
Gr_finetuned_cmp.eval()
print(f"Finetuned model loaded from epoch {ckpt_finetuned.get('epoch', '?')}")

# Get available identities from dataset
def analyze_dataset_identities(dataset):
    """Analyze dataset to find all identities."""
    label_to_indices = {}
    for idx in range(len(dataset)):
        _, label = dataset[idx]
        if isinstance(label, torch.Tensor):
            label = label.item()
        if label not in label_to_indices:
            label_to_indices[label] = []
        label_to_indices[label].append(idx)
    return label_to_indices

train_identities = analyze_dataset_identities(train_dataset)
available_labels = list(train_identities.keys())

print(f"\n3. Dataset info:")
print(f"Available identities: {len(available_labels)}")
print(f"Will perform {NUM_COMPARISONS} swaps")
print("="*60)

In [None]:

print("\nPerforming face swaps with both models...\n")

all_sources = []
all_targets = []
all_base_results = []
all_finetuned_results = []
swap_info = []

# Metrics storage
metrics = {
    'base_similarity': [],
    'finetuned_similarity': [],
    'improvement': []           # Similarity deltas
}

for i in range(NUM_COMPARISONS):
    src_label, tgt_label = random.sample(available_labels, 2)
    src_idx = random.choice(train_identities[src_label])
    tgt_idx = random.choice(train_identities[tgt_label])
    
    src_img, _ = train_dataset[src_idx]
    tgt_img, _ = train_dataset[tgt_idx]
    
    src_path = COMPARISON_DIR / f'temp_src_{i}.png'
    tgt_path = COMPARISON_DIR / f'temp_tgt_{i}.png'
    
    src_np = ((src_img.permute(1, 2, 0).numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
    tgt_np = ((tgt_img.permute(1, 2, 0).numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
    
    cv2.imwrite(str(src_path), cv2.cvtColor(src_np, cv2.COLOR_RGB2BGR))
    cv2.imwrite(str(tgt_path), cv2.cvtColor(tgt_np, cv2.COLOR_RGB2BGR))
    
    print(f"[{i+1}/{NUM_COMPARISONS}] Identity {src_label} -> Pose {tgt_label}")
    

    if FACENET_AVAILABLE:
        src_embedding = get_facenet_embedding(src_np)
    

    try:
        base_out_path = COMPARISON_DIR / f'swap_{i}_base_id{src_label}_pose{tgt_label}.png'
        result_base_bgr, _, _, _ = run_full_pipeline(
            str(src_path), 
            str(tgt_path), 
            out_path=str(base_out_path),
            reenact=True, 
            use_detector=True, 
            device=device,
            crop_scale=1.2, 
            resolution=256,
            G_model=Gr_base
        )
        result_base_rgb = cv2.cvtColor(result_base_bgr, cv2.COLOR_BGR2RGB)
        

        if FACENET_AVAILABLE:
            base_embedding = get_facenet_embedding(result_base_rgb)
            base_sim = cosine_similarity(src_embedding, base_embedding)
            metrics['base_similarity'].append(base_sim)
            print(f"  ✓ Base model (similarity: {base_sim:.4f})")
        else:
            print(f"  ✓ Base model")
            
    except Exception as e:
        print(f"  ✗ Base model failed: {e}")
        result_base_rgb = np.zeros_like(src_np)
        if FACENET_AVAILABLE:
            metrics['base_similarity'].append(0.0)
    

    try:
        finetuned_out_path = COMPARISON_DIR / f'swap_{i}_finetuned_id{src_label}_pose{tgt_label}.png'
        result_finetuned_bgr, _, _, _ = run_full_pipeline(
            str(src_path), 
            str(tgt_path), 
            out_path=str(finetuned_out_path),
            reenact=True, 
            use_detector=True, 
            device=device,
            crop_scale=1.2, 
            resolution=256,
            G_model=Gr_finetuned_cmp
        )
        result_finetuned_rgb = cv2.cvtColor(result_finetuned_bgr, cv2.COLOR_BGR2RGB)
        
        # Compute identity similarity
        if FACENET_AVAILABLE:
            finetuned_embedding = get_facenet_embedding(result_finetuned_rgb)
            finetuned_sim = cosine_similarity(src_embedding, finetuned_embedding)
            metrics['finetuned_similarity'].append(finetuned_sim)
            
            # Calculate improvement
            improvement = finetuned_sim - metrics['base_similarity'][-1]
            metrics['improvement'].append(improvement)
            
            improvement_str = f"+{improvement:.4f}" if improvement > 0 else f"{improvement:.4f}"
            print(f"Finetuned model (similarity: {finetuned_sim:.4f}, Δ: {improvement_str})")
        else:
            print(f"Finetuned model")
            
    except Exception as e:
        print(f"  ✗ Finetuned model failed: {e}")
        result_finetuned_rgb = np.zeros_like(src_np)
        if FACENET_AVAILABLE:
            metrics['finetuned_similarity'].append(0.0)
            metrics['improvement'].append(0.0)
    
    all_sources.append(src_np)
    all_targets.append(tgt_np)
    all_base_results.append(result_base_rgb)
    all_finetuned_results.append(result_finetuned_rgb)
    swap_info.append(f"ID {src_label} → Pose {tgt_label}")
    print()

print(f"Completed {NUM_COMPARISONS} comparisons")
print(f"Results saved to: {COMPARISON_DIR}")

# Print summary statistics
if FACENET_AVAILABLE and metrics['base_similarity']:
    print("\n" + "="*60)
    print("IDENTITY SIMILARITY METRICS (FaceNet Cosine Similarity)")
    print("="*60)
    print(f"Base Model Average:      {np.mean(metrics['base_similarity']):.4f} ± {np.std(metrics['base_similarity']):.4f}")
    print(f"Finetuned Model Average: {np.mean(metrics['finetuned_similarity']):.4f} ± {np.std(metrics['finetuned_similarity']):.4f}")
    print(f"Average Improvement:     {np.mean(metrics['improvement']):.4f}")
    print(f"Improvements > 0:        {sum(1 for x in metrics['improvement'] if x > 0)}/{len(metrics['improvement'])} swaps")
    print("="*60)

In [None]:
%matplotlib inline

if FACENET_AVAILABLE and metrics['base_similarity']:
    
    
    fig, axes = plt.subplots(1, 3, figsize=(18, 5))
    
    x = np.arange(NUM_COMPARISONS)
    width = 0.35
    
    axes[0].bar(x - width/2, metrics['base_similarity'], width, label='Base Model', color='steelblue', alpha=0.8)
    axes[0].bar(x + width/2, metrics['finetuned_similarity'], width, label='Finetuned Model', color='seagreen', alpha=0.8)
    axes[0].set_xlabel('Swap Index', fontsize=12)
    axes[0].set_ylabel('Cosine Similarity', fontsize=12)
    axes[0].set_title('Identity Preservation per Swap', fontsize=14, fontweight='bold')
    axes[0].set_xticks(x)
    axes[0].legend()
    axes[0].grid(True, alpha=0.3, axis='y')
    axes[0].set_ylim([0, 1])
    
    colors = ['green' if x > 0 else 'red' for x in metrics['improvement']]
    axes[1].bar(x, metrics['improvement'], color=colors, alpha=0.7)
    axes[1].axhline(y=0, color='black', linestyle='--', linewidth=1)
    axes[1].set_xlabel('Swap Index', fontsize=12)
    axes[1].set_ylabel('Similarity Improvement (delta)', fontsize=12)
    axes[1].set_title('Finetuned vs Base Improvement', fontsize=14, fontweight='bold')
    axes[1].set_xticks(x)
    axes[1].grid(True, alpha=0.3, axis='y')
    
    box_data = [metrics['base_similarity'], metrics['finetuned_similarity']]
    bp = axes[2].boxplot(box_data, labels=['Base', 'Finetuned'], patch_artist=True,
                          boxprops=dict(facecolor='lightblue', alpha=0.7),
                          medianprops=dict(color='red', linewidth=2))
    axes[2].set_ylabel('Cosine Similarity', fontsize=12)
    axes[2].set_title('Overall Distribution', fontsize=14, fontweight='bold')
    axes[2].grid(True, alpha=0.3, axis='y')
    axes[2].set_ylim([0, 1])
    
    plt.tight_layout()
    
    metrics_plot_path = COMPARISON_DIR / 'identity_similarity_metrics.png'
    plt.savefig(str(metrics_plot_path), dpi=150, bbox_inches='tight')
    print(f"\n✓ Saved metrics visualization: {metrics_plot_path}")
    
    plt.show()

In [None]:
%matplotlib inline

fig, axes = plt.subplots(NUM_COMPARISONS, 4, figsize=(16, 4*NUM_COMPARISONS))

if NUM_COMPARISONS == 1:
    axes = axes.reshape(1, -1)

for i in range(NUM_COMPARISONS):

    axes[i, 0].imshow(all_sources[i])
    axes[i, 0].set_title(f'Source\n{swap_info[i].split("→")[0].strip()}', fontsize=12, fontweight='bold')
    axes[i, 0].axis('off')
    

    axes[i, 1].imshow(all_targets[i])
    axes[i, 1].set_title(f'Target\n{swap_info[i].split("→")[1].strip()}', fontsize=12, fontweight='bold')
    axes[i, 1].axis('off')
    

    axes[i, 2].imshow(all_base_results[i])
    if FACENET_AVAILABLE and i < len(metrics['base_similarity']):
        sim_score = metrics['base_similarity'][i]
        title = f'Base Model\n{swap_info[i]}\nSimilarity: {sim_score:.3f}'
    else:
        title = f'Base Model\n{swap_info[i]}'
    axes[i, 2].set_title(title, fontsize=11, color='blue')
    axes[i, 2].axis('off')
    

    axes[i, 3].imshow(all_finetuned_results[i])
    if FACENET_AVAILABLE and i < len(metrics['finetuned_similarity']):
        sim_score = metrics['finetuned_similarity'][i]
        improvement = metrics['improvement'][i]
        improvement_str = f"+{improvement:.3f}" if improvement > 0 else f"{improvement:.3f}"
        title = f'Finetuned Model\n{swap_info[i]}\nSimilarity: {sim_score:.3f} (Δ{improvement_str})'
        title_color = 'darkgreen' if improvement > 0 else 'darkorange'
    else:
        title = f'Finetuned Model\n{swap_info[i]}'
        title_color = 'green'
    axes[i, 3].set_title(title, fontsize=11, color=title_color)
    axes[i, 3].axis('off')

plt.suptitle('Model Comparison: Base vs Finetuned Face Swaps with Identity Metrics', 
             fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()


viz_path = COMPARISON_DIR / 'comparison_visualization.png'
plt.savefig(str(viz_path), dpi=150, bbox_inches='tight')
print(f"\n✓ Saved comparison visualization: {viz_path}")

plt.show()

### Analysis

In [None]:


import random

NUM_REENACT_TESTS = 5
REENACT_COMPARISON_DIR = OUT_DIR / 'reenactment_only_comparison'
REENACT_COMPARISON_DIR.mkdir(exist_ok=True, parents=True)

print("="*60)
print("REENACTMENT-ONLY COMPARISON: Base vs Finetuned")
print("="*60)


if 'Gr_base' not in dir():
    print("\n1. Loading BASE reenactment model...")
    reenact_base_w = WEIGHTS_DIR / 'nfv_msrunet_256_1_2_reenactment_v2.1.pth'
    Gr_base, ckpt_base = load_model(str(reenact_base_w), 'reenactment', device=device, return_checkpoint=True)
    Gr_base.eval()
    print(f"   ✓ Base model loaded: {ckpt_base.get('arch', 'unknown')}")

if 'Gr_finetuned_cmp' not in dir():
    print("\n2. Loading FINETUNED reenactment model...")
    finetuned_checkpoint = OUT_DIR / 'finetuned_models' / 'reenact_finetuned_epoch500.pth'
    
    if not finetuned_checkpoint.exists():
        print(f"Checkpoint not found: {finetuned_checkpoint}")
        print("   Looking for any available checkpoint...")
        available_ckpts = sorted((OUT_DIR / 'finetuned_models').glob('*.pth'))
        if available_ckpts:
            finetuned_checkpoint = available_ckpts[-1]
            print(f"   Using: {finetuned_checkpoint.name}")
        else:
            raise FileNotFoundError("No finetuned checkpoints found!")
    
    ckpt_finetuned = torch.load(str(finetuned_checkpoint), map_location=device)
    arch = ckpt_finetuned.get('arch', 'res_unet.MultiScaleResUNet(in_nc=101,out_nc=3)')
    Gr_finetuned_cmp = obj_factory(arch).to(device)
    Gr_finetuned_cmp.load_state_dict(ckpt_finetuned['state_dict'])
    Gr_finetuned_cmp.eval()
    print(f"Finetuned model loaded from epoch {ckpt_finetuned.get('epoch', '?')}")


if 'L_frozen' not in dir():
    print("\n3. Loading landmarks model...")
    lms_w = WEIGHTS_DIR / 'hr18_wflw_landmarks.pth'
    L_frozen, _ = load_model(str(lms_w), 'landmarks', device=device, return_checkpoint=True)
    L_frozen.eval()
    for param in L_frozen.parameters():
        param.requires_grad = False
    print("Landmarks model loaded")

print(f"\n4. Dataset: {len(available_labels)} identities available")
print(f"Will perform {NUM_REENACT_TESTS} cross-person reenactment tests")
print("="*60)


reenact_sources = []
reenact_targets = []
reenact_base_outputs = []
reenact_finetuned_outputs = []
reenact_info = []


reenact_metrics = {
    'base_similarity': [],
    'finetuned_similarity': [],
    'improvement': [],
    'pixel_l1_base': [],
    'pixel_l1_finetuned': []
}

print("\nPerforming direct reenactment tests (cross-person)...\n")

with torch.no_grad():
    for i in range(NUM_REENACT_TESTS):
        src_label, tgt_label = random.sample(available_labels, 2)
        
        src_idx = random.choice(train_identities[src_label])
        tgt_idx = random.choice(train_identities[tgt_label])
        
        src_img, _ = train_dataset[src_idx]
        tgt_img, _ = train_dataset[tgt_idx]
        
        src_batch = src_img.unsqueeze(0).to(device)
        tgt_batch = tgt_img.unsqueeze(0).to(device)
        
        print(f"[{i+1}/{NUM_REENACT_TESTS}] Source ID {src_label} -> Target ID {tgt_label} (cross-person)")
        
        input_list, tgt_normalized = prepare_reenactment_input(
            src_batch, tgt_batch, L_frozen, n_levels, device
        )
        

        src_np = ((src_img.permute(1, 2, 0).cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
        tgt_np = ((tgt_img.permute(1, 2, 0).cpu().numpy() + 1) / 2 * 255).clip(0, 255).astype('uint8')
        
        if FACENET_AVAILABLE:
            src_embedding = get_facenet_embedding(src_np)
        

        output_base = Gr_base(input_list)
        pred_base = output_base[-1] if isinstance(output_base, (list, tuple)) else output_base
        pred_base_np = tensor2bgr(pred_base[0])
        pred_base_rgb = cv2.cvtColor(pred_base_np, cv2.COLOR_BGR2RGB)
        

        if FACENET_AVAILABLE:
            base_emb = get_facenet_embedding(pred_base_rgb)
            base_sim = cosine_similarity(src_embedding, base_emb)
            reenact_metrics['base_similarity'].append(base_sim)
        

        pixel_l1_base = F.l1_loss(pred_base, tgt_normalized).item()
        reenact_metrics['pixel_l1_base'].append(pixel_l1_base)
        

        output_finetuned = Gr_finetuned_cmp(input_list)
        pred_finetuned = output_finetuned[-1] if isinstance(output_finetuned, (list, tuple)) else output_finetuned
        pred_finetuned_np = tensor2bgr(pred_finetuned[0])
        pred_finetuned_rgb = cv2.cvtColor(pred_finetuned_np, cv2.COLOR_BGR2RGB)
        
        if FACENET_AVAILABLE:
            finetuned_emb = get_facenet_embedding(pred_finetuned_rgb)
            finetuned_sim = cosine_similarity(src_embedding, finetuned_emb)
            reenact_metrics['finetuned_similarity'].append(finetuned_sim)
            
            improvement = finetuned_sim - base_sim
            reenact_metrics['improvement'].append(improvement)
            
            improvement_str = f"+{improvement:.4f}" if improvement > 0 else f"{improvement:.4f}"
            print(f"  Identity similarity - Base: {base_sim:.4f}, Finetuned: {finetuned_sim:.4f}, Δ: {improvement_str}")
        
        pixel_l1_finetuned = F.l1_loss(pred_finetuned, tgt_normalized).item()
        reenact_metrics['pixel_l1_finetuned'].append(pixel_l1_finetuned)
        
        pixel_improvement = pixel_l1_base - pixel_l1_finetuned
        pixel_improvement_str = f"-{abs(pixel_improvement):.4f}" if pixel_improvement > 0 else f"+{abs(pixel_improvement):.4f}"
        print(f"  Pixel L1 loss     - Base: {pixel_l1_base:.4f}, Finetuned: {pixel_l1_finetuned:.4f}, Δ: {pixel_improvement_str}")
        
        reenact_sources.append(src_np)
        reenact_targets.append(tgt_np)
        reenact_base_outputs.append(pred_base_rgb)
        reenact_finetuned_outputs.append(pred_finetuned_rgb)
        reenact_info.append(f"ID {src_label} → ID {tgt_label}")
        
        cv2.imwrite(str(REENACT_COMPARISON_DIR / f'reenact_{i}_source_id{src_label}.png'), 
                    cv2.cvtColor(src_np, cv2.COLOR_RGB2BGR))
        cv2.imwrite(str(REENACT_COMPARISON_DIR / f'reenact_{i}_target_id{tgt_label}.png'), 
                    cv2.cvtColor(tgt_np, cv2.COLOR_RGB2BGR))
        cv2.imwrite(str(REENACT_COMPARISON_DIR / f'reenact_{i}_base_id{src_label}_to_id{tgt_label}.png'), 
                    pred_base_np)
        cv2.imwrite(str(REENACT_COMPARISON_DIR / f'reenact_{i}_finetuned_id{src_label}_to_id{tgt_label}.png'), 
                    pred_finetuned_np)
        
        print()

print(f"Completed {len(reenact_sources)} cross-person reenactment tests")
print(f"Results saved to: {REENACT_COMPARISON_DIR}")

print("\n" + "="*60)
print("REENACTMENT METRICS SUMMARY (Cross-Person)")
print("="*60)

if FACENET_AVAILABLE and reenact_metrics['base_similarity']:
    print("\nIdentity Preservation (FaceNet Cosine Similarity):")
    print(f"  Base Model:      {np.mean(reenact_metrics['base_similarity']):.4f} ± {np.std(reenact_metrics['base_similarity']):.4f}")
    print(f"  Finetuned Model: {np.mean(reenact_metrics['finetuned_similarity']):.4f} ± {np.std(reenact_metrics['finetuned_similarity']):.4f}")
    print(f"  Avg Improvement: {np.mean(reenact_metrics['improvement']):.4f}")
    print(f"  Better results:  {sum(1 for x in reenact_metrics['improvement'] if x > 0)}/{len(reenact_metrics['improvement'])} tests")

if reenact_metrics['pixel_l1_base']:
    print("\nPixel Reconstruction (L1 Loss with Target):")
    print(f"  Base Model:      {np.mean(reenact_metrics['pixel_l1_base']):.4f} ± {np.std(reenact_metrics['pixel_l1_base']):.4f}")
    print(f"  Finetuned Model: {np.mean(reenact_metrics['pixel_l1_finetuned']):.4f} ± {np.std(reenact_metrics['pixel_l1_finetuned']):.4f}")
    
    pixel_improvements = [reenact_metrics['pixel_l1_base'][i] - reenact_metrics['pixel_l1_finetuned'][i] 
                         for i in range(len(reenact_metrics['pixel_l1_base']))]
    print(f"  Avg Improvement: {np.mean(pixel_improvements):.4f} (lower is better)")
    print(f"  Better results:  {sum(1 for x in pixel_improvements if x > 0)}/{len(pixel_improvements)} tests")

print("="*60)

In [None]:

%matplotlib inline

if reenact_sources:
    n_tests = len(reenact_sources)
    

    fig, axes = plt.subplots(n_tests, 4, figsize=(16, 4*n_tests))
    
    if n_tests == 1:
        axes = axes.reshape(1, -1)
    
    for i in range(n_tests):

        axes[i, 0].imshow(reenact_sources[i])
        axes[i, 0].set_title(f'Source\n{reenact_info[i]}', fontsize=12, fontweight='bold')
        axes[i, 0].axis('off')
        
        axes[i, 1].imshow(reenact_targets[i])
        axes[i, 1].set_title(f'Target (Driving)\n{reenact_info[i]}', fontsize=12, fontweight='bold')
        axes[i, 1].axis('off')
        

        axes[i, 2].imshow(reenact_base_outputs[i])
        if FACENET_AVAILABLE and i < len(reenact_metrics['base_similarity']):
            sim = reenact_metrics['base_similarity'][i]
            l1 = reenact_metrics['pixel_l1_base'][i]
            title = f'Base Reenactment\nID Sim: {sim:.3f} | L1: {l1:.4f}'
        else:
            title = f'Base Reenactment\n{reenact_info[i]}'
        axes[i, 2].set_title(title, fontsize=11, color='blue')
        axes[i, 2].axis('off')
        
        axes[i, 3].imshow(reenact_finetuned_outputs[i])
        if FACENET_AVAILABLE and i < len(reenact_metrics['finetuned_similarity']):
            sim = reenact_metrics['finetuned_similarity'][i]
            l1 = reenact_metrics['pixel_l1_finetuned'][i]
            imp = reenact_metrics['improvement'][i]
            imp_str = f"+{imp:.3f}" if imp > 0 else f"{imp:.3f}"
            title = f'Finetuned Reenactment\nID Sim: {sim:.3f} (Δ{imp_str}) | L1: {l1:.4f}'
            title_color = 'darkgreen' if imp > 0 else 'darkorange'
        else:
            title = f'Finetuned Reenactment\n{reenact_info[i]}'
            title_color = 'green'
        axes[i, 3].set_title(title, fontsize=11, color=title_color)
        axes[i, 3].axis('off')
    
    plt.suptitle('Direct Reenactment Comparison: Base vs Finetuned (No Pipeline)', 
                 fontsize=16, fontweight='bold', y=0.995)
    plt.tight_layout()
    
    viz_path = REENACT_COMPARISON_DIR / 'reenactment_comparison_grid.png'
    plt.savefig(str(viz_path), dpi=150, bbox_inches='tight')
    print(f"\nSaved visualization: {viz_path}")
    
    plt.show()
    
    if FACENET_AVAILABLE and reenact_metrics['base_similarity']:
        fig, axes = plt.subplots(1, 4, figsize=(20, 5))
        
        x = np.arange(n_tests)
        width = 0.35
        
        axes[0].bar(x - width/2, reenact_metrics['base_similarity'], width, 
                   label='Base', color='steelblue', alpha=0.8)
        axes[0].bar(x + width/2, reenact_metrics['finetuned_similarity'], width, 
                   label='Finetuned', color='seagreen', alpha=0.8)
        axes[0].set_xlabel('Test Index', fontsize=12)
        axes[0].set_ylabel('Identity Similarity', fontsize=12)
        axes[0].set_title('Identity Preservation', fontsize=14, fontweight='bold')
        axes[0].legend()
        axes[0].grid(True, alpha=0.3, axis='y')
        axes[0].set_ylim([0, 1])
        
        axes[1].bar(x - width/2, reenact_metrics['pixel_l1_base'], width, 
                   label='Base', color='coral', alpha=0.8)
        axes[1].bar(x + width/2, reenact_metrics['pixel_l1_finetuned'], width, 
                   label='Finetuned', color='lightgreen', alpha=0.8)
        axes[1].set_xlabel('Test Index', fontsize=12)
        axes[1].set_ylabel('L1 Loss', fontsize=12)
        axes[1].set_title('Pixel Reconstruction (Lower = Better)', fontsize=14, fontweight='bold')
        axes[1].legend()
        axes[1].grid(True, alpha=0.3, axis='y')
        
        colors = ['green' if x > 0 else 'red' for x in reenact_metrics['improvement']]
        axes[2].bar(x, reenact_metrics['improvement'], color=colors, alpha=0.7)
        axes[2].axhline(y=0, color='black', linestyle='--', linewidth=1)
        axes[2].set_xlabel('Test Index', fontsize=12)
        axes[2].set_ylabel('Identity Similarity delta', fontsize=12)
        axes[2].set_title('Identity Improvement', fontsize=14, fontweight='bold')
        axes[2].grid(True, alpha=0.3, axis='y')
        
        box_data = [
            reenact_metrics['base_similarity'],
            reenact_metrics['finetuned_similarity']
        ]
        bp = axes[3].boxplot(box_data, labels=['Base', 'Finetuned'], 
                            patch_artist=True,
                            boxprops=dict(facecolor='lightblue', alpha=0.7),
                            medianprops=dict(color='red', linewidth=2))
        axes[3].set_ylabel('Identity Similarity', fontsize=12)
        axes[3].set_title('Distribution', fontsize=14, fontweight='bold')
        axes[3].grid(True, alpha=0.3, axis='y')
        axes[3].set_ylim([0, 1])
        
        plt.tight_layout()
        
        metrics_viz_path = REENACT_COMPARISON_DIR / 'reenactment_metrics.png'
        plt.savefig(str(metrics_viz_path), dpi=150, bbox_inches='tight')
        print(f"Saved metrics visualization: {metrics_viz_path}")
        
        plt.show()