# Task-2: Generating Faces

This part of the task requires training a face generation model from embeddings. To generate suitable embeddings, I currently plan to use `gaunernst/vit_tiny_patch8_112.arcface_ms1mv3` model, which is a pre-trained ViT, trained upon MS1MV3 dataset. 


In [None]:
# dataset generation
import os
import torch
import timm
import random
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image

random.seed(42)  # For reproducibility

class FaceDataset(Dataset):
    def __init__(self, image_dir):
        self.image_dir = image_dir
        self.image_files = [f for f in os.listdir(image_dir) if f.endswith('.jpg')]
        
        self.embedding_model = timm.create_model("hf_hub:gaunernst/vit_tiny_patch8_112.arcface_ms1mv3", pretrained=True).eval()
        self.data_config = timm.data.resolve_data_config(self.embedding_model.pretrained_cfg)
        self.transform = timm.data.create_transform( **self.data_config, is_training=False)
        if not self.image_files:
            raise ValueError(f"No images found in directory: {image_dir}")
        
        self.to_latent = transforms.Compose([
            transforms.Resize((128, 128)),
            transforms.ToTensor(),
            transforms.Normalize(0.5, 0.5)
            ])
        
    
    def get_embedding(self, image):
        with torch.inference_mode():
            image = self.transform(image).unsqueeze(0)
            embedding = self.embedding_model(image)
            embedding = F.normalize(embedding, dim=1)
        return embedding.squeeze(0) # From [1, 512] to [512]
            
    def __len__(self):
        return len(self.image_files)

    def __getitem__(self, idx):
        img_path = os.path.join(self.image_dir, self.image_files[idx])
        pil_image = Image.open(img_path).convert('RGB')
        
        flip_prob=0.2
        if random.random() < flip_prob:
            print(f"Flipping image: {self.image_files[idx]}")
            pil_image = pil_image.transpose(Image.FLIP_LEFT_RIGHT)
        
        crop_prob=0.15
        if random.random() < crop_prob:
            print(f"Cropping image: {self.image_files[idx]}")
            width, height = pil_image.size
            left = int(width * 0.05)
            top = int(height * 0.05)
            right = int(width * 0.95)
            bottom = int(height * 0.95)
            pil_image = pil_image.crop((left, top, right, bottom))
        
        embedding = self.get_embedding(pil_image)
        
        image = self.to_latent(pil_image)
        
        return {
            'image': image,
            'embedding': embedding
        }
    
    
train_dataset = FaceDataset('data/train')
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)

val_dataset = FaceDataset('data/val')
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(train_dataset.__len__(), "training images found.")
print(val_dataset.__len__(), "validation images found.")

3296 training images found.
582 validation images found.


## Creating and Training the model.
Since I am a bit inexperienced in training diffusion models (as I previously mainly worked in the domain of NLP), I will try to keep the training process as simple and default as possible. The following code is adapted from the [diffusers_training_example.ipynb](https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/training_example.ipynb#scrollTo=67640279-979b-490d-80fe-65673b94ae00), which is a Hugging Face example for training diffusion models.

That means, I will be using the classic UNet2Dmodel, to run the diffusion process. 

* The training process will be done for 50 epochs, with a batch size of 16. The model will be trained on the dataset of 3296 images, and tested on roughly 500 imanges.
* These images are the faces detected using the `face_detection` model, and have been noted to contain atleast 128x128 pixels of face area.
* The training will be done on a single GPU, and the model will be saved after every 5 epochs.

In [None]:
from diffusers import UNet2DModel, DDPMScheduler
from diffusers.optimization import get_cosine_schedule_with_warmup
# Since I am unsure of the exact block types and configurations I would need, for now I will use the default DownBlock2d, AttnDownBlock2d and UpBlock2d, AttnUpBlock2d configurations.
OUTPUT_DIR="model_outputs"
LR_WARMUP_STEPS = 500
NUM_EPOCHS = 50


model = UNet2DModel(
    sample_size=128,
    layers_per_block=2, 
)

noise_scheduler = DDPMScheduler(num_train_timesteps=1000)
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-4)
lr_scheduler = get_cosine_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=LR_WARMUP_STEPS,
    num_training_steps=(len(train_loader) * NUM_EPOCHS),
)

In [None]:
from diffusers import DDPMPipeline
from huggingface_hub import create_repo, upload_folder
import wandb

from tqdm.auto import tqdm
from pathlib import Path
import os
import math

os.environ["WANDB_API_KEY"]="WANDB_API_KEY"


def make_grid(images, rows, cols):
    w, h = images[0].size
    grid = Image.new('RGB', size=(cols*w, rows*h))
    for i, image in enumerate(images):
        grid.paste(image, box=(i%cols*w, i//cols*h))
    return grid

def evaluate(epoch, pipeline):
    # Sample some images from random noise (this is the backward diffusion process).
    # The default pipeline output type is `List[PIL.Image]`
    images = pipeline(
        batch_size = 32, 
        generator=torch.manual_seed(42),
    ).images

    # Make a grid out of the images
    image_grid = make_grid(images, rows=4, cols=4)

    # Save the images
    test_dir = os.path.join(OUTPUT_DIR, "samples")
    os.makedirs(test_dir, exist_ok=True)
    image_grid.save(f"{test_dir}/{epoch:04d}.png")
    
    # Log to wandb
    wandb.log({
        "generated_images": wandb.Image(image_grid),
        "epoch": epoch
    })


def train_loop(model, noise_scheduler, optimizer, train_dataloader, lr_scheduler):
    # Initialize wandb
    wandb.init(
        project="face-generation",
        name=f"diffusion_training_{NUM_EPOCHS}epochs",
        dir=OUTPUT_DIR,
        notes="",
    )
    
    # Setup device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)
    
    gradient_accumulation_steps = 2
    
    # Create output directory
    os.makedirs(OUTPUT_DIR, exist_ok=True)
    
    global_step = 0
    accumulated_loss = 0.0
    # Training loop
    for epoch in range(NUM_EPOCHS):
        model.train()
        progress_bar = tqdm(total=len(train_dataloader))
        progress_bar.set_description(f"Epoch {epoch}")

        epoch_loss = 0.0
        num_batches=0
        
        for step, batch in enumerate(train_dataloader):
            clean_images = batch['image'].to(device)  # Note: changed from 'images' to 'image' based on your dataset
            # Sample noise to add to the images
            noise = torch.randn(clean_images.shape, device=device)
            bs = clean_images.shape[0]

            # Sample a random timestep for each image
            timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bs,), device=device).long()

            # Add noise to the clean images according to the noise magnitude at each timestep
            # (this is the forward diffusion process)
            noisy_images = noise_scheduler.add_noise(clean_images, noise, timesteps)
            
            # Predict the noise residual
            noise_pred = model(noisy_images, timesteps, return_dict=False)[0]
            loss = F.mse_loss(noise_pred, noise)
            
            loss = loss / gradient_accumulation_steps  # Scale the loss for gradient accumulation
            accumulated_loss += loss.item()
            
            # Backward pass
            loss.backward()
            
            if (step + 1) % gradient_accumulation_steps == 0:
                # Gradient clipping
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                actual_loss = accumulated_loss * gradient_accumulation_steps
                wandb.log({
                    "train/loss": actual_loss,
                    "train/learning_rate": lr_scheduler.get_last_lr()[0],
                    "train/step": global_step
                }, step=global_step)
                
                logs = {"loss": actual_loss, "lr": lr_scheduler.get_last_lr()[0], "step": global_step}
                progress_bar.set_postfix(**logs)
                
                epoch_loss += actual_loss
                accumulated_loss = 0.0
                global_step += 1
                num_batches += 1
            
            progress_bar.update(1)

            global_step += 1

        avg_epoch_loss = epoch_loss / num_batches if num_batches > 0 else 0.0
        
        
        # After each epoch you optionally sample some demo images with evaluate() and save the model
        model.eval()
        with torch.no_grad():
            pipeline = DDPMPipeline(unet=model, scheduler=noise_scheduler)

            if (epoch + 1) % 2 == 0 or epoch == NUM_EPOCHS - 1:
                evaluate(epoch, pipeline)

            if (epoch + 1) % 2 == 0 or epoch == NUM_EPOCHS - 1:
                pipeline.save_pretrained(OUTPUT_DIR)
        
        # Log epoch metrics
        wandb.log({
            "epoch": epoch,
            "train/epoch_loss": avg_epoch_loss  # You might want to track average epoch loss
        })
    
    # Finish wandb run
    wandb.finish()

In [None]:
train_loop(model, noise_scheduler, optimizer, train_loader, lr_scheduler)

## Validation and Evaluation

In [23]:
val_dataset = FaceDataset('data/val')
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(val_dataset.__len__(), "validation images found.")

582 validation images found.
