# Lora img2img Caricature Generation

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/USERNAME/REPO/blob/main/aicaricaturist/pix2pix_caricature/caricature_training.ipynb)

This notebook implements a demo for a Lora img2img model for face-to-caricature translation using PyTorch.

In [None]:
!pip install accelerate transformers diffusers==0.14.0 huggingface_hub safetensors --quiet
!pip install opencv-python Pillow --quiet
!pip install git+https://github.com/cloneofsimo/lora.git --quiet


In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm import tqdm

import accelerate
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
from diffusers.optimization import get_scheduler

# (Optional) for reproducibility
torch.manual_seed(42)


In [None]:
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'


Custom dataset

In [None]:
class CaricatureDataset(Dataset):
    def __init__(self, data_dir, transform_face=None, transform_caric=None):
        super().__init__()
        self.data_dir = data_dir
        self.transform_face = transform_face
        self.transform_caric = transform_caric
        
        # We expect pairs from 1..42
        self.pairs = []
        for i in range(1, 43):
            face_path = os.path.join(data_dir, f"{i:03d}_f.png")
            caric_path = os.path.join(data_dir, f"{i:03d}_c.png")
            if os.path.exists(face_path) and os.path.exists(caric_path):
                self.pairs.append((face_path, caric_path))
        
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        face_path, caric_path = self.pairs[idx]
        
        face_img = Image.open(face_path).convert("RGB")
        caric_img = Image.open(caric_path).convert("RGB")
        
        if self.transform_face:
            face_img = self.transform_face(face_img)
        if self.transform_caric:
            caric_img = self.transform_caric(caric_img)
        
        return {
            "face": face_img,       # Condition input
            "caric": caric_img      # Target to reconstruct via diffusion
        }

# Define image transforms
transform_face = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
    # Optionally normalize
])

transform_caric = transforms.Compose([
    transforms.Resize((256,256)),
    transforms.ToTensor(),
])

# Create dataset
dataset = CaricatureDataset(DATA_DIR, transform_face, transform_caric)
print("Number of pairs:", len(dataset))

# Create dataloader
train_dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


## the Text Encoder
In Stable Diffusion, the conditioning typically comes from a CLIP text encoder. We want to replace that with an image-based encoder. A straightforward approach is to use a pretrained image model (e.g., CLIP’s vision transformer) to produce a latent embedding from the face image, which we feed as the “condition” to the UNet’s cross-attention.

Below is a toy example using a simple ResNet from torchvision.models. In practice, you might want to use CLIP’s vision encoder from openai/clip-vit-base-patch32 or similar.

In [None]:
import torchvision.models as models

class SimpleImageEncoder(nn.Module):
    def __init__(self, out_dim=768):
        super().__init__()
        # Use a pretrained ResNet to extract a feature vector
        # Then project to out_dim to match typical SD cross-attn dim
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        # Remove last linear layer
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-1])
        # A linear projection to out_dim
        in_features = resnet.fc.in_features
        self.proj = nn.Linear(in_features, out_dim)
        
    def forward(self, x):
        # x: [B, 3, 256, 256]
        feat = self.feature_extractor(x)     # [B, 2048, 1, 1]
        feat = feat.view(feat.size(0), -1)   # [B, 2048]
        out = self.proj(feat)               # [B, out_dim]
        return out  # treat as "text embedding" for cross-attn


## Inject LoRA into the UNet Cross-Attention Layers
We need to adapt the UNet so that it uses LoRA on the cross-attention layers. You can do this manually or use an existing LoRA library. Below is a small example using the lora library.

Important: The official diffusers library has its own LoRA approach in the Diffusers examples. You could adapt it similarly. The snippet below is for demonstration.

In [None]:
from lora.lora import mark_only_lora_as_trainable, lora

# We’ll wrap the attention modules with LoRA. 
# In diffusers, the cross-attention modules are typically in the UNet’s transformer blocks. 
# The exact names can differ depending on the version of diffusers.

def apply_lora_to_unet(unet, r=4, lora_alpha=1.0, lora_dropout=0.0):
    for name, module in unet.named_modules():
        # Cross-attention modules often have names like "attn2.to_q", "attn2.to_k", ...
        if "attn2.to_" in name and isinstance(module, nn.Linear):
            lora(module, r=r, lora_alpha=lora_alpha, lora_dropout=lora_dropout)
    mark_only_lora_as_trainable(unet)

    return unet


## Assemble the Full Model (VAE + UNet + ImageEncoder)
We’ll load a pretrained Stable Diffusion UNet and VAE (e.g. from runwayml/stable-diffusion-v1-5) but ignore the text encoder. Then we inject LoRA in the cross-attention layers. Finally, we keep only the LoRA parameters trainable.

In [None]:
from diffusers import AutoencoderKL, UNet2DConditionModel

pretrained_model_name_or_path = "runwayml/stable-diffusion-v1-5"

# 5.1 Load VAE (frozen, we won't train it)
vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder="vae")
vae.requires_grad_(False)
vae.eval()

# 5.2 Load UNet
unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder="unet")

# 5.3 Inject LoRA
unet = apply_lora_to_unet(unet, r=4)

# 5.4 Create the image encoder
image_encoder = SimpleImageEncoder(out_dim=768)  # stable diffusion v1.5 uses 768-dim text embeddings

# Put everything on GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
vae.to(device)
unet.to(device)
image_encoder.to(device)


## Diffusion Scheduler and Optimizer
We’ll use a DDPM-style scheduler and a standard AdamW optimizer for the LoRA parameters. The typical text-to-image fine-tuning examples from Hugging Face use the DDIMScheduler or DPMSolverMultistepScheduler, but you can start with a simple DDPMScheduler.

In [None]:
noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder="scheduler")

# Only LoRA (and image encoder) are trainable
params_to_train = list(unet.parameters()) + list(image_encoder.parameters())
optimizer = torch.optim.AdamW(params_to_train, lr=1e-5)


## training

In [None]:
def encode_images(vae, x):
    # x is [B, 3, 256, 256] in pixel space
    # We want the latent representation from the VAE
    with torch.no_grad():
        latent_dist = vae.encode(x).latent_dist
        latents = latent_dist.sample() * 0.18215
    return latents

def training_step(batch, unet, vae, image_encoder, noise_scheduler, optimizer):
    face = batch["face"].to(device)
    caric = batch["caric"].to(device)
    
    # 1) Get latents for the caricature
    with torch.no_grad():
        latents = encode_images(vae, caric)
    
    # 2) Sample noise
    noise = torch.randn_like(latents)
    
    # 3) Random timestep
    timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (latents.shape[0],), device=device).long()
    
    # 4) Add noise
    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
    # 5) Get face embedding from image_encoder
    cond_emb = image_encoder(face)  # [B, 768]
    # We need to expand dims to [B, sequence_length, dim], typically for cross-attn
    # We'll do a hacky approach: treat cond_emb as a single “token”.
    cond_emb = cond_emb.unsqueeze(1)  # [B, 1, 768]
    
    # 6) UNet forward
    model_pred = unet(
        sample=noisy_latents,
        timestep=timesteps,
        encoder_hidden_states=cond_emb
    ).sample  # shape: [B, 4, 64, 64] (if stable diffusion v1.5 has 4 latent channels)
    
    # 7) Compute loss (predict noise)
    loss = F.mse_loss(model_pred, noise)
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    return loss.item()


num_epochs = 5
num_training_steps = num_epochs * len(train_dataloader)

lr_scheduler = get_scheduler(
    "linear",
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=num_training_steps,
)

for epoch in range(num_epochs):
    pbar = tqdm(train_dataloader, desc=f"Epoch {epoch}", total=len(train_dataloader))
    for batch in pbar:
        loss = training_step(batch, unet, vae, image_encoder, noise_scheduler, optimizer)
        lr_scheduler.step()
        pbar.set_postfix({"loss": loss})


## Inference

In [None]:
@torch.no_grad()
def generate_caricature(face_img_path, unet, vae, image_encoder, noise_scheduler, num_inference_steps=50):
    # 1) Load and preprocess face
    face = Image.open(face_img_path).convert("RGB")
    face = transform_face(face).unsqueeze(0).to(device)
    
    # 2) Encode face to condition embedding
    cond_emb = image_encoder(face)  # [1, 768]
    cond_emb = cond_emb.unsqueeze(1)  # [1, 1, 768]
    
    # 3) Start from random noise in latent space
    latents = torch.randn((1, 4, 64, 64), device=device)
    
    noise_scheduler.set_timesteps(num_inference_steps)
    
    for t in noise_scheduler.timesteps:
        # expand for batch size = 1
        latent_model_input = latents
        # predict noise
        noise_pred = unet(
            latent_model_input, t, encoder_hidden_states=cond_emb
        ).sample
        
        # compute previous noisy sample x_t -> x_t-1
        latents = noise_scheduler.step(noise_pred, t, latents).prev_sample
    
    # 4) Decode latents with VAE
    scaled_latents = 1 / 0.18215 * latents
    image = vae.decode(scaled_latents).sample
    # Convert to PIL
    image = (image / 2 + 0.5).clamp(0, 1)
    image = image.cpu().permute(0, 2, 3, 1).numpy()[0]
    image = (image * 255).astype(np.uint8)
    pil_image = Image.fromarray(image)
    return pil_image

# Example usage:
test_face_path = "/content/drive/MyDrive/caricature Project Diffusion/paired_caricature/001_f.png"
caric_out = generate_caricature(
    test_face_path, unet, vae, image_encoder, noise_scheduler, num_inference_steps=50
)
caric_out.save("generated_caric.png")
caric_out
