# Conditional Face Synthesis with Embedding-Conditioned Generative Models

This notebook demonstrates how I train a generative model in PyTorch to synthesize human faces conditioned on embeddings from a face encoder (FaceNet).  

---

## Project Overview

- **Goal:**  
  Train a generative model that takes face embeddings as input and produces realistic 128×128 face images. The embeddings are obtained from a pre-trained or fine-tuned FaceNet encoder.
- **Zero-Shot Generalization:**  
  The model is evaluated on its ability to generate high-quality faces from unseen embeddings.
- **Architecture Choices:**  
  - **Encoder:** FaceNet (pre-trained on VGGFace2, optionally fine-tuned)
  - **Generator:** UNet-style GAN with skip connections and self-attention
  - **Discriminator:** Projection-based conditional discriminator (inspired by BigGAN)
- **Metrics:**  
  I use FID, SSIM, PSNR, and LPIPS to quantitatively assess image quality and identity preservation.
- **Experiment Tracking:**  
  All training runs, metrics, and sample generations are logged to Weights & Biases (W&B) for transparency and reproducibility.
- **Model Sharing:**  
  Trained model checkpoints are uploaded to the Hugging Face Hub for public access and future inference.



# Setup: Install Required Dependencies

This cell installs all the necessary libraries for face reconstruction, GANs, and evaluation.  
Run this cell only once per Colab or Jupyter session.

In [None]:
# Install all required libraries
!pip install facenet-pytorch torch torchvision wandb huggingface-hub torchmetrics torchmetrics[image] lpips torch-fidelity timm --quiet

[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/53.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.8/53.8 kB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[?25h

# Imports

All necessary libraries and modules are imported here.  
This cell should be run before any other code cells to ensure all dependencies are available in the environment.

In [None]:
# Core imports and utilities
import os
from datetime import datetime
import zipfile

# PyTorch and TorchVision
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import transforms, datasets, utils
from torchvision.utils import make_grid

# Pretrained Face Encoder
from facenet_pytorch import InceptionResnetV1

# Experiment tracking & cloud
import wandb
from huggingface_hub import login, Repository, hf_hub_download
from google.colab import userdata

# Metrics and evaluation
from torchmetrics.image.fid import FrechetInceptionDistance
from lpips import LPIPS

# Learning rate scheduling
from torch.optim.lr_scheduler import ExponentialLR

# Base Configuration

All key hyperparameters and paths are defined in a single configuration dictionary for easy management and reproducibility.

In [None]:
BASE_CONFIG = {
    # Data settings
    "data_path": "./data/images",
    "image_size": 128,
    "channels": 3,
    "embedding_dim": 512,

    # Training settings
    "batch_size": 64,
    "num_epochs": 300,

    # Optimizer settings
    "lr_generator": 2e-4,
    "lr_discriminator": 1e-4,
    "lr_finetune_facenet": 1e-5,
    "weight_decay": 0,

    # Loss weights
    "lambda_emb": 1.0,
    "lambda_lpips": 0.8,
    "r1_gamma": 10.0,

    # Fine-tuning
    "finetune_facenet": True,

    # Device
    "device": "cuda" if torch.cuda.is_available() else "cpu",

    # Logging & Hugging Face Hub
    "wandb_project": "conditional-gan-facenet-exp-2",
    "wandb_run_name": "unet_conditional_gan_facenet-exp-2",
    "hf_repo_id": "Mayank022/conditional-gan-facenet-exp-2",
}

# Download and Extract Cropped Face Dataset [from Rust/YOLO]

This cell downloads the Cropped Face Dataset (128×128) from the Hugging Face Hub and extracts it to the specified data directory.

In [None]:
# Download the dataset zip file from Hugging Face Hub
zip_path = hf_hub_download(
    repo_id="Mayank022/Cropped_Face_Dataset_128x128",
    filename="output.zip",
    repo_type="dataset",
)

# Extract the dataset
extract_folder = "data/images"
os.makedirs(extract_folder, exist_ok=True)
with zipfile.ZipFile(zip_path, "r") as zf:
    zf.extractall(extract_folder)

print(f"Dataset extracted to {extract_folder}")

output.zip:   0%|          | 0.00/35.6M [00:00<?, ?B/s]

Dataset extracted to data/images


# Data Preparation and DataLoader

This cell sets up the image transformations, loads the dataset using `torchvision.datasets.ImageFolder`, and creates a PyTorch DataLoader for efficient batch processing.

In [None]:
# Define image transformations
transform = transforms.Compose([
    transforms.Resize((BASE_CONFIG["image_size"], BASE_CONFIG["image_size"])),
    transforms.ToTensor(),
    transforms.Normalize([0.5] * 3, [0.5] * 3),
])

# Load dataset from the extracted folder
dataset = datasets.ImageFolder(BASE_CONFIG["data_path"], transform=transform)

# Create DataLoader for batch processing
dataloader = DataLoader(
    dataset,
    batch_size=BASE_CONFIG["batch_size"],
    shuffle=True,
    num_workers=4,
    pin_memory=True,
)

print(f"Loaded {len(dataset)} images from {BASE_CONFIG['data_path']}")

# Dataset Size

Check the total number of images loaded in the dataset.

In [None]:
num_images = len(dataset)
print(f"Total images in dataset: {num_images}")

Total images in dataset: 12007


# Weight Initialization Helper

Utility function to initialize model weights for convolutional and linear layers, following best practices for GANs.

In [None]:
def init_weights(m):
    """
    Initializes weights for Conv2d, ConvTranspose2d, and Linear layers
    using a normal distribution (mean=0.0, std=0.02).
    Biases are initialized to zero.
    """
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
        if m.bias is not None:
            nn.init.zeros_(m.bias)

# Model Architecture: UNet-Style Generator with Self-Attention

This section defines the core building blocks for the generator:
- `ResBlockUp`: Residual upsampling block
- `SelfAttention`: Self-attention mechanism for feature maps
- `UNetStyleFusionGenerator`: Main generator architecture with skip connections and embedding fusion

In [None]:
class ResBlockUp(nn.Module):
    """
    Residual block with upsampling for the generator.
    Upsamples the input and adds a residual connection.
    """
    def __init__(self, in_ch, out_ch):
        super().__init__()
        self.up = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
            nn.BatchNorm2d(out_ch),
            nn.ReLU(inplace=True)
        )
        self.res = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=1, bias=False),
            nn.Upsample(scale_factor=2, mode='nearest')
        )

    def forward(self, x):
        return self.up(x) + self.res(x)

class SelfAttention(nn.Module):
    """
    Self-attention module for spatial feature refinement.
    """
    def __init__(self, in_dim):
        super().__init__()
        self.query = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.key   = nn.Conv2d(in_dim, in_dim // 8, kernel_size=1)
        self.value = nn.Conv2d(in_dim, in_dim, kernel_size=1)
        self.gamma = nn.Parameter(torch.zeros(1))

    def forward(self, x):
        B, C, H, W = x.size()
        q = self.query(x).view(B, -1, H * W).permute(0, 2, 1)
        k = self.key(x).view(B, -1, H * W)
        attn = torch.bmm(q, k).softmax(dim=-1)
        v = self.value(x).view(B, -1, H * W)
        out = torch.bmm(v, attn.permute(0, 2, 1)).view(B, C, H, W)
        return self.gamma * out + x

class UNetStyleFusionGenerator(nn.Module):
    """
    UNet-style generator with face embedding fusion and self-attention.
    - Embedding and noise are projected to a 4x4 feature map.
    - Embeddings are injected at multiple scales via skip connections.
    - Self-attention is applied at the 16x16 feature map.
    """
    def __init__(self, embedding_dim=512, noise_dim=128, base_ch=128, img_ch=3):
        super().__init__()
        self.noise_dim = noise_dim
        self.base_ch = base_ch

        # Project embedding + noise → 4×4 feature map
        self.fc = nn.Linear(embedding_dim + noise_dim, base_ch * 8 * 4 * 4)

        # Project embeddings for skip injection
        self.emb_proj = nn.ModuleList([
            nn.Linear(embedding_dim, base_ch * 8 * 8 * 8),       # for 8×8, 1024 channels
            nn.Linear(embedding_dim, base_ch * 4 * 16 * 16)      # for 16×16, 512 channels
        ])

        # Decoder blocks
        self.up1 = ResBlockUp(base_ch * 8, base_ch * 8)     # 4×4 → 8×8
        self.up2 = ResBlockUp(base_ch * 8, base_ch * 4)     # 8×8 → 16×16
        self.attn = SelfAttention(base_ch * 4)              # 16×16
        self.up3 = ResBlockUp(base_ch * 4, base_ch * 2)     # 16×16 → 32×32
        self.up4 = ResBlockUp(base_ch * 2, base_ch)         # 32×32 → 64×64
        self.up5 = nn.Sequential(                           # 64×64 → 128×128
            nn.Upsample(scale_factor=2, mode='nearest'),
            nn.Conv2d(base_ch, img_ch, kernel_size=3, padding=1),
            nn.Tanh()
        )

        self.apply(self._init_weights)

    def forward(self, emb, batch_size=None):
        """
        Forward pass for the generator.
        Args:
            emb: Face embedding tensor of shape [B, embedding_dim]
            batch_size: Optional batch size override
        Returns:
            Generated image tensor of shape [B, img_ch, 128, 128]
        """
        B = emb.size(0)
        if batch_size is None:
            batch_size = B

        # Concatenate embedding and noise
        z = torch.randn(batch_size, self.noise_dim, device=emb.device)
        x = torch.cat([emb, z], dim=1)
        x = self.fc(x).view(B, self.base_ch * 8, 4, 4)

        # Embedding projections
        emb8 = self.emb_proj[0](emb).view(B, self.base_ch * 8, 8, 8)
        emb16 = self.emb_proj[1](emb).view(B, self.base_ch * 4, 16, 16)

        # Decode with skip injections
        d1 = self.up1(x)           # → [B, 1024, 8, 8]
        d1 = d1 + emb8             # inject @8
        d2 = self.up2(d1)          # → [B, 512, 16, 16]
        d2 = self.attn(d2 + emb16) # inject @16 + attention
        d3 = self.up3(d2)          # → [B, 256, 32, 32]
        d4 = self.up4(d3)          # → [B, 128, 64, 64]
        out = self.up5(d4)         # → [B, 3, 128, 128]

        return out

    @staticmethod
    def _init_weights(m):
        """
        Initializes weights for Conv2d and Linear layers using Kaiming normal.
        """
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.kaiming_normal_(m.weight, a=0.2)
            if hasattr(m, 'bias') and m.bias is not None:
                nn.init.zeros_(m.bias)

# Instantiate the generator
device = BASE_CONFIG["device"]
G = UNetStyleFusionGenerator(
    embedding_dim=BASE_CONFIG["embedding_dim"],
    noise_dim=BASE_CONFIG["model_params"]["noise_dim"] if "model_params" in BASE_CONFIG else 128,
    base_ch=BASE_CONFIG["model_params"]["base_ch"] if "model_params" in BASE_CONFIG else 128,
    img_ch=BASE_CONFIG["channels"]
).to(device)

# Model Architecture: Conditional Discriminator

Defines a projection-based conditional discriminator, inspired by BigGAN, that uses both the image and the embedding vector for real/fake prediction.

In [None]:
class ConditionalDiscriminator(nn.Module):
    """
    Conditional discriminator with projection for embedding guidance.
    Takes both an image and a conditioning embedding (e.g., FaceNet) as input.
    """
    def __init__(self, img_channels=3, embedding_dim=512, base_ch=64):
        super().__init__()

        # Convolutional feature extractor
        self.conv1 = nn.Conv2d(img_channels, base_ch, kernel_size=4, stride=2, padding=1)    # 128→64
        self.conv2 = nn.Conv2d(base_ch, base_ch*2, kernel_size=4, stride=2, padding=1)       # 64→32
        self.bn2   = nn.BatchNorm2d(base_ch*2)

        self.conv3 = nn.Conv2d(base_ch*2, base_ch*4, kernel_size=4, stride=2, padding=1)     # 32→16
        self.bn3   = nn.BatchNorm2d(base_ch*4)

        self.conv4 = nn.Conv2d(base_ch*4, base_ch*8, kernel_size=4, stride=2, padding=1)     # 16→8
        self.bn4   = nn.BatchNorm2d(base_ch*8)

        self.conv5 = nn.Conv2d(base_ch*8, base_ch*16, kernel_size=4, stride=2, padding=1)    # 8→4
        self.bn5   = nn.BatchNorm2d(base_ch*16)

        # Final 4×4 feature → flattened for dot product
        self.final_linear = nn.Linear(base_ch*16*4*4, 1)

        # Projection for conditioning (dot product with image feature)
        self.embed_proj = nn.Linear(embedding_dim, base_ch*16*4*4)

        # Initialize weights
        self.apply(self._init_weights)

    def forward(self, x, emb):
        """
        Forward pass for the conditional discriminator.

        Args:
            x:   Input image tensor of shape [B, 3, 128, 128]
            emb: Conditioning embedding tensor of shape [B, embedding_dim]

        Returns:
            Real/fake logits, conditioned on the embedding [B, 1]
        """
        B = x.size(0)

        out = F.leaky_relu(self.conv1(x), 0.2)
        out = F.leaky_relu(self.bn2(self.conv2(out)), 0.2)
        out = F.leaky_relu(self.bn3(self.conv3(out)), 0.2)
        out = F.leaky_relu(self.bn4(self.conv4(out)), 0.2)
        out = F.leaky_relu(self.bn5(self.conv5(out)), 0.2)

        out_flat = out.view(B, -1)                        # [B, C]
        logits_real_fake = self.final_linear(out_flat)     # [B, 1]

        # Conditional projection
        proj = torch.sum(out_flat * self.embed_proj(emb), dim=1, keepdim=True)  # [B, 1]

        return logits_real_fake + proj

    @staticmethod
    def _init_weights(m):
        """
        Initializes weights for Conv2d and Linear layers using normal distribution.
        """
        if isinstance(m, (nn.Conv2d, nn.Linear)):
            nn.init.normal_(m.weight, mean=0.0, std=0.02)
            if m.bias is not None:
                nn.init.zeros_(m.bias)

# Instantiate the discriminator
device = BASE_CONFIG["device"]
D = ConditionalDiscriminator(
    img_channels=BASE_CONFIG["channels"],
    embedding_dim=BASE_CONFIG["embedding_dim"],
    base_ch=64
).to(device)

# Authentication and Experiment Tracking Setup

This cell securely loads the Hugging Face and Weights & Biases (W&B) tokens from Colab secrets, initializes experiment tracking, and logs into the Hugging Face Hub.

In [None]:
# Google Colab-specific secure access
from google.colab import userdata

# 1. Hugging Face Token (from Colab secrets)
hf_token = userdata.get("HF_TOKEN")
if hf_token is None:
    raise ValueError("Missing HF_TOKEN in Colab userdata! Please add it via Colab secrets.")

# 2. W&B Token (prioritize Colab secret, fallback to env variable)
wandb_token = userdata.get("WANDB_API_KEY") or os.getenv("WANDB_API_KEY")
if wandb_token is None:
    raise ValueError("Missing WANDB_API_KEY in secrets or environment variables!")

# 3. Initialize W&B
wandb.login(key=wandb_token)
wandb.init(
    project=BASE_CONFIG["wandb_project"],
    name=BASE_CONFIG["wandb_run_name"],
    config=BASE_CONFIG,
)

# 4. Track model weights/gradients (optional but useful)
wandb.watch(models=[G, D], log="all", log_freq=100)

# 5. Log into Hugging Face Hub
login(token=hf_token)



0,1
epoch,▁
losses/adv,▁
losses/discriminator,▁
losses/embedding,▁
losses/generator,▁
losses/lpips,▁
losses/r1_penalty,▁
metrics/FID,▁
metrics/PSNR,▁
metrics/SSIM,▁

0,1
epoch,1.0
losses/adv,1.1998
losses/discriminator,19.83531
losses/embedding,0.00364
losses/generator,1.79062
losses/lpips,0.827
losses/r1_penalty,0.08343
metrics/FID,327.74582
metrics/PSNR,6.28497
metrics/SSIM,0.11712


# FaceNet Encoder Setup (for Conditional Face Generation)

In this section, I load and configure the FaceNet encoder, which provides the face embeddings used to condition the generative model.

**Assignment Context:**  
- The generator takes these embeddings as input and produces 128×128 face images.
- The encoder can be fine-tuned jointly with the generator, or kept fixed.
- This setup supports both options, controlled by the configuration.

**Key Points:**
- The encoder is pre-trained on face recognition (VGGFace2).
- Fine-tuning is optional and controlled by `BASE_CONFIG["finetune_facenet"]`.
- If fine-tuning, an optimizer is created for the encoder.
- Otherwise, the encoder is frozen for inference-only use.
- I have finetuned it

In [None]:
# Load the FaceNet encoder (pretrained on VGGFace2 for strong face embeddings)
facenet = InceptionResnetV1(pretrained='vggface2').eval().to(device)

if BASE_CONFIG["finetune_facenet"]:
    # Optionally fine-tune the encoder jointly with the generator
    facenet.train()
    optim_f = torch.optim.Adam(
        facenet.parameters(),
        lr=BASE_CONFIG["lr_finetune_facenet"],
        weight_decay=BASE_CONFIG["weight_decay"],
    )
    print("FaceNet will be fine-tuned during training.")
else:
    # Freeze encoder weights for inference-only use
    for p in facenet.parameters():
        p.requires_grad = False
    print("FaceNet is frozen (no fine-tuning).")

# Optimizers, Losses, and Evaluation Metrics

In this section, I set up the training components for the conditional face generator:

- **Two-Time-Scale Update Rule (TTUR):**  
  I use different learning rates for the generator and discriminator to stabilize GAN training.
- **Loss Functions:**  
  I combine adversarial (BCE), pixel-wise (MSE), and perceptual (LPIPS) losses for high-quality, realistic outputs.
- **Regularization:**  
  I apply R1 gradient penalty to the discriminator for improved stability.
- **Evaluation Metric:**  
  I use Frechet Inception Distance (FID) to quantitatively assess the quality of generated images.
- **Learning Rate Schedulers:**  
  I use exponential decay for both optimizers to encourage convergence.

In [None]:
# 1. Two-Time-Scale Update Rule (TTUR)
#    Use different learning rates for G and D to stabilize training.
optim_G = torch.optim.Adam(
    G.parameters(),
    lr=BASE_CONFIG["lr_generator"] * 0.5,    # slower for generator
    betas=(0.5, 0.999),                      # classic GAN betas
    weight_decay=BASE_CONFIG["weight_decay"],
)
optim_D = torch.optim.Adam(
    D.parameters(),
    lr=BASE_CONFIG["lr_discriminator"],      # keep discriminator LR higher
    betas=(0.5, 0.999),
    weight_decay=BASE_CONFIG["weight_decay"],
)

# 2. Loss Functions
bce = nn.BCEWithLogitsLoss()  # Adversarial loss
mse = nn.MSELoss()            # Pixel-wise loss
perceptual = LPIPS(net="vgg").to(device)  # Perceptual loss for realism

# 3. Regularization on D (R1 gradient penalty)
def r1_penalty(real_preds, real_imgs):
    grad_real = torch.autograd.grad(
        outputs=real_preds.sum(), inputs=real_imgs, create_graph=True
    )[0]
    return grad_real.pow(2).view(real_imgs.size(0), -1).sum(1).mean()

# 4. FID metric for evaluation
fid = FrechetInceptionDistance(feature=2048, normalize=True).to(device)

# 5. (Optional) Learning Rate Schedulers
sched_G = ExponentialLR(optim_G, gamma=0.99)
sched_D = ExponentialLR(optim_D, gamma=0.99)

Setting up [LPIPS] perceptual loss: trunk [vgg], v[0.1], spatial [off]
Loading model from: /usr/local/lib/python3.11/dist-packages/lpips/weights/v0.1/vgg.pth


# Full Training Loop: Embedding-to-Image with W&B Logging and Metrics

This section implements the complete training loop for the conditional face generator, following best practices for GANs:

- **Conditioning:** The generator is conditioned on embeddings from a (possibly fine-tuned) FaceNet encoder.
- **Losses:** Combines adversarial, embedding, and perceptual (LPIPS) losses for high-quality, identity-preserving synthesis.
- **Metrics:** Tracks FID, SSIM, and PSNR to quantitatively evaluate image quality and similarity.
- **Logging:** Logs all key losses, metrics, and sample images to Weights & Biases (W&B) for transparent experiment tracking.
- **Checkpoints:** Saves model checkpoints at regular intervals for reproducibility and future inference.
- **Production-Readiness:** The loop is robust, modular, and ready for scaling or adaptation to new datasets or architectures.

**Note:**  
This loop is designed to demonstrate strong zero-shot generalization on unseen face embeddings, as required by the assignment.

In [None]:
def train_conditional_gan(
    G, D, facenet, dataloader, optim_G, optim_D, bce, mse, perceptual, fid,
    ssim_metric, psnr_metric, sched_G, sched_D, BASE_CONFIG, device, optim_f=None
):
    """
    Trains a conditional GAN for face synthesis, conditioned on embeddings from a FaceNet encoder.

    - Supports both fixed and fine-tuned FaceNet encoders.
    - Logs all key losses, metrics, and sample images to Weights & Biases (W&B).
    - Saves model checkpoints at regular intervals.
    - Tracks FID, SSIM, and PSNR for quantitative evaluation.

    Args:
        G: Generator model
        D: Discriminator model
        facenet: FaceNet encoder (pretrained or fine-tuned)
        dataloader: PyTorch DataLoader for training data
        optim_G: Optimizer for generator
        optim_D: Optimizer for discriminator
        bce: BCEWithLogitsLoss for adversarial loss
        mse: MSELoss for embedding loss
        perceptual: LPIPS perceptual loss
        fid: Frechet Inception Distance metric
        ssim_metric: SSIM metric
        psnr_metric: PSNR metric
        sched_G: LR scheduler for generator
        sched_D: LR scheduler for discriminator
        BASE_CONFIG: Configuration dictionary
        device: Device string ("cuda" or "cpu")
        optim_f: (Optional) Optimizer for FaceNet if fine-tuning

    Returns:
        None
    """


    # Create a timestamped checkpoint folder
    run_id = datetime.now().strftime("%Y%m%d_%H%M%S")
    ckpt_dir = f"hf_model/checkpoints/{run_id}"
    os.makedirs(ckpt_dir, exist_ok=True)

    for epoch in range(1, BASE_CONFIG["num_epochs"] + 1):
        G.train(); D.train()
        sum_g = 0.0
        sum_d = 0.0

        for batch_idx, (real_imgs, _) in enumerate(dataloader):
            real = real_imgs.to(device)

            # --- 1) Compute Embeddings ---
            emb = facenet((real + 1) * 0.5)
            if not BASE_CONFIG["finetune_facenet"]:
                emb = emb.detach()

            # --- 2) Generate Fake Images ---
            fake = G(emb)
            real = real.detach().requires_grad_()

            # --- 3) Discriminator Update ---
            real_logits = D(real, emb)
            fake_logits = D(fake.detach(), emb)
            loss_d = bce(real_logits, torch.ones_like(real_logits)) + \
                     bce(fake_logits, torch.zeros_like(fake_logits))
            # Add R1 penalty
            r1 = r1_penalty(real_logits, real)
            loss_d += BASE_CONFIG["r1_gamma"] * r1

            optim_D.zero_grad()
            loss_d.backward(retain_graph=True)
            optim_D.step()

            # --- 4) Generator + Embedding Update ---
            fake_logits2 = D(fake, emb)
            adv_loss = bce(fake_logits2, torch.ones_like(fake_logits2))
            fake_emb = facenet((fake + 1) * 0.5)
            emb_loss = mse(fake_emb, emb)
            perceptual_loss = perceptual(fake, real).mean()

            loss_g = adv_loss \
                     + BASE_CONFIG["lambda_emb"] * emb_loss \
                     + BASE_CONFIG["lambda_lpips"] * perceptual_loss

            optim_G.zero_grad()
            if BASE_CONFIG["finetune_facenet"]:
                optim_f.zero_grad()

            loss_g.backward()
            optim_G.step()
            if BASE_CONFIG["finetune_facenet"]:
                optim_f.step()

            # --- 5) Update image metrics ---
            real_norm = (real.clamp(-1, 1) + 1) * 0.5
            fake_norm = (fake.clamp(-1, 1) + 1) * 0.5
            real_u8 = (real_norm * 255).to(torch.uint8)
            fake_u8 = (fake_norm * 255).to(torch.uint8)
            fid.update(fake_u8, real=False)
            fid.update(real_u8, real=True)
            ssim_metric.update(fake_norm, real_norm)
            psnr_metric.update(fake_norm, real_norm)

            sum_g += loss_g.item()
            sum_d += loss_d.item()

            # --- 6) Log samples to W&B once per epoch ---
            if batch_idx == 0:
                combined = torch.cat([real_norm[:8], fake_norm[:8]], dim=0)
                grid = make_grid(combined, nrow=8)
                wandb.log({
                    f"Samples/Epoch_{epoch}": wandb.Image(grid, caption="Top: Real • Bottom: Fake")
                }, step=epoch)

        # --- 7) Epoch-end Metrics ---
        avg_g = sum_g / len(dataloader)
        avg_d = sum_d / len(dataloader)
        fid_score = fid.compute().item(); fid.reset()
        ssim_score = ssim_metric.compute().item(); ssim_metric.reset()
        psnr_score = psnr_metric.compute().item(); psnr_metric.reset()

        # --- 8) Print Summary ---
        print(
            f"[{epoch:03d}/{BASE_CONFIG['num_epochs']:03d}]  "
            f"G: {avg_g:.4f}  D: {avg_d:.4f}  "
            f"Adv: {adv_loss.item():.4f}  Emb: {emb_loss.item():.4f}  "
            f"LPIPS: {perceptual_loss.item():.4f}  "
            f"SSIM: {ssim_score:.4f}  PSNR: {psnr_score:.2f}  "
            f"FID: {fid_score:.2f}"
        )

        # --- 9) Log to W&B (structured for dashboards) ---
        wandb.log({
            "losses/generator": avg_g,
            "losses/discriminator": avg_d,
            "losses/adv": adv_loss.item(),
            "losses/embedding": emb_loss.item(),
            "losses/lpips": perceptual_loss.item(),
            "losses/r1_penalty": r1.item(),
            "metrics/SSIM": ssim_score,
            "metrics/PSNR": psnr_score,
            "metrics/FID": fid_score,
            "epoch": epoch,
        }, step=epoch)

        # --- 10) Step LR schedulers if enabled ---
        if sched_G is not None: sched_G.step()
        if sched_D is not None: sched_D.step()

        # --- 11) Save Checkpoints ---
        if epoch % 100 == 0 or epoch in [10, 50, BASE_CONFIG["num_epochs"]]:
            ckpt_dir_epoch = os.path.join("hf_model/checkpoints", f"epoch_{epoch:03d}")
            os.makedirs(ckpt_dir_epoch, exist_ok=True)

            torch.save(facenet.state_dict(), os.path.join(ckpt_dir_epoch, f"facenet_epoch{epoch:03d}.pt"))
            torch.save(G.state_dict(), os.path.join(ckpt_dir_epoch, f"generator_epoch{epoch:03d}.pt"))
            torch.save(D.state_dict(), os.path.join(ckpt_dir_epoch, f"discriminator_epoch{epoch:03d}.pt"))
            print(f"Saved checkpoints at epoch {epoch:03d} → {ckpt_dir_epoch}")

# To run the training loop, call:
# train_conditional_gan(G, D, facenet, dataloader, optim_G, optim_D, bce, mse, perceptual, fid,
#                      ssim_metric, psnr_metric, sched_G, sched_D, BASE_CONFIG, device, optim_f)

[001/300]  G: 4.3617  D: 0.8610  Adv: 4.1792  Emb: 0.0039  LPIPS: 0.8801  SSIM: 0.1598  PSNR: 5.91  FID: 323.07
[002/300]  G: 5.8846  D: 0.8379  Adv: 4.3380  Emb: 0.0034  LPIPS: 0.7829  SSIM: 0.0882  PSNR: 7.80  FID: 326.79
[003/300]  G: 4.2586  D: 1.4484  Adv: 3.9420  Emb: 0.0026  LPIPS: 0.7532  SSIM: 0.1643  PSNR: 9.74  FID: 326.13
[004/300]  G: 3.3303  D: 1.6065  Adv: 1.3502  Emb: 0.0023  LPIPS: 0.7682  SSIM: 0.2257  PSNR: 10.40  FID: 329.69
[005/300]  G: 2.8922  D: 1.4070  Adv: 0.9279  Emb: 0.0019  LPIPS: 0.7092  SSIM: 0.2562  PSNR: 10.66  FID: 327.06
[006/300]  G: 2.6238  D: 1.4932  Adv: 2.5685  Emb: 0.0018  LPIPS: 0.6960  SSIM: 0.2774  PSNR: 10.97  FID: 318.56
[007/300]  G: 2.5306  D: 1.4249  Adv: 0.4663  Emb: 0.0017  LPIPS: 0.6936  SSIM: 0.2908  PSNR: 11.14  FID: 306.91
[008/300]  G: 2.4499  D: 1.3356  Adv: 1.5787  Emb: 0.0017  LPIPS: 0.6870  SSIM: 0.2964  PSNR: 11.32  FID: 282.05
[009/300]  G: 2.3663  D: 1.3049  Adv: 2.0499  Emb: 0.0015  LPIPS: 0.6305  SSIM: 0.3052  PSNR: 11.27

# Upload Model Checkpoints to Hugging Face Hub

This section uploads selected model checkpoint folders to a Hugging Face Hub model repository for sharing, reproducibility, and future inference.

- The repository is created if it does not already exist.
- Only the specified checkpoint folders are uploaded.
- This enables easy sharing and public access to trained model weights.

In [None]:
def upload_checkpoints_to_hf(
    repo_id, hf_token, checkpoint_folders, base_path="hf_model/checkpoints", private=False
):
    """
    Uploads selected model checkpoint folders to a Hugging Face Hub model repository.

    Args:
        repo_id (str): Full Hugging Face repo ID (e.g., "username/repo_name").
        hf_token (str): Hugging Face access token.
        checkpoint_folders (list): List of folder names to upload.
        base_path (str): Local base path containing checkpoint folders.
        private (bool): Whether to create the repo as private.

    Returns:
        None
    """
    from huggingface_hub import create_repo, HfApi

    # Create model repo on Hugging Face (if it doesn't exist)
    create_repo(
        repo_id=repo_id,
        token=hf_token,
        repo_type="model",
        private=private,
        exist_ok=True  # avoids error if repo already exists
    )

    api = HfApi()

    # Upload each checkpoint folder
    for folder in checkpoint_folders:
        local_folder = os.path.join(base_path, folder)
        remote_folder = f"checkpoints/{folder}"
        print(f"Uploading {local_folder} → {remote_folder}")

        api.upload_folder(
            folder_path=local_folder,
            path_in_repo=remote_folder,
            repo_id=repo_id,
            token=hf_token,
            repo_type="model",
            commit_message=f"Upload {folder}"
        )

    print("All selected checkpoints uploaded.")

# Example usage:
repo_id = BASE_CONFIG["hf_repo_id"]
hf_token = userdata.get("HF_TOKEN")
checkpoint_folders = ["epoch_100", "epoch_200", "epoch_300"]
upload_checkpoints_to_hf(repo_id, hf_token, checkpoint_folders)

Uploading hf_model/checkpoints/epoch_100 → checkpoints/epoch_100


Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

discriminator_epoch100.pt:   0%|          | 0.00/78.3M [00:00<?, ?B/s]

generator_epoch100.pt:   0%|          | 0.00/516M [00:00<?, ?B/s]

facenet_epoch100.pt:   0%|          | 0.00/112M [00:00<?, ?B/s]

Uploading hf_model/checkpoints/epoch_200 → checkpoints/epoch_200


facenet_epoch200.pt:   0%|          | 0.00/112M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

discriminator_epoch200.pt:   0%|          | 0.00/78.3M [00:00<?, ?B/s]

generator_epoch200.pt:   0%|          | 0.00/516M [00:00<?, ?B/s]

Uploading hf_model/checkpoints/epoch_300 → checkpoints/epoch_300


facenet_epoch300.pt:   0%|          | 0.00/112M [00:00<?, ?B/s]

Upload 3 LFS files:   0%|          | 0/3 [00:00<?, ?it/s]

discriminator_epoch300.pt:   0%|          | 0.00/78.3M [00:00<?, ?B/s]

generator_epoch300.pt:   0%|          | 0.00/516M [00:00<?, ?B/s]

All selected checkpoints uploaded.


# Model Inference

Model inference and sample generation are demonstrated in a separate notebook for clarity and reproducibility.

**To run inference with the trained models, please see the `inference` notebook located in the `notebooks/` folder of this repository.**  
That notebook provides step-by-step instructions for loading checkpoints, generating new face images from embeddings, and visualizing results.

---

