[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Guidosalimbeni/aicaricaturist/blob/main/lora_img2img/lora_img2img_v2.ipynb)



In [None]:
# Enhanced LoRA-based Image-to-Image Diffusion Training
# @title Setup and Imports

!pip install torch torchvision
!pip install diffusers
!pip install accelerate
!pip install Pillow
!pip install tqdm

!pip install -q peft

from peft import LoraConfig, get_peft_model

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import random
import logging
from torchvision import models  # This was missing from the imports

import accelerate
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
#from diffusers.optimization import get_scheduler # This import is not used in the code, consider removing
from accelerate import Accelerator


In [None]:


# For reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'

# Enhanced Data Augmentation Pipeline
class AugmentedCaricatureDataset(Dataset):
    def __init__(self, data_dir, split='train'):
        super().__init__()
        self.data_dir = data_dir
        self.split = split
        
        # Simplified transformations for training
        if split == 'train':
            self.transform_face = transforms.Compose([
                transforms.Resize((256, 256)),  # Direct resize to target size
                transforms.RandomHorizontalFlip(p=0.5),  # Keep mirror augmentation
                transforms.Grayscale(num_output_channels=3),  # Convert to grayscale
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            self.transform_caric = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.RandomHorizontalFlip(p=0.5),  # Same flip as face
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
        else:
            # Validation transforms (even simpler)
            self.transform_face = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])
            self.transform_caric = transforms.Compose([
                transforms.Resize((256, 256)),
                transforms.Grayscale(num_output_channels=3),
                transforms.ToTensor(),
                transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
            ])

        # Rest of the code remains the same...
        
        # Load all pairs
        self.pairs = []
        for i in range(1, 43):
            face_path = os.path.join(data_dir, f"{i:03d}_f.png")
            caric_path = os.path.join(data_dir, f"{i:03d}_c.png")
            if os.path.exists(face_path) and os.path.exists(caric_path):
                self.pairs.append((face_path, caric_path))
    
    def __len__(self):
        return len(self.pairs)
    
    def __getitem__(self, idx):
        face_path, caric_path = self.pairs[idx]
        
        face_img = Image.open(face_path).convert("RGB")
        caric_img = Image.open(caric_path).convert("RGB")
        
        # Apply transforms with same random flip
        seed = torch.randint(0, 2**32, (1,))[0].item()
        
        torch.manual_seed(seed)
        face_tensor = self.transform_face(face_img)
        
        torch.manual_seed(seed)
        caric_tensor = self.transform_caric(caric_img)
        
        return {
            "face": face_tensor,
            "caric": caric_tensor
        }

# Enhanced Image Encoder with ResNet50
class EnhancedImageEncoder(nn.Module):
    def __init__(self, out_dim=768):
        super().__init__()
        resnet = torchvision.models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V2)
        
        # Remove average pooling and fc layers
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])
        
        # Add spatial attention
        self.attention = nn.Sequential(
            nn.Conv2d(2048, 512, 1),
            nn.ReLU(),
            nn.Conv2d(512, 1, 1),
            nn.Sigmoid()
        )
        
        # Projection layers
        self.proj = nn.Sequential(
            nn.Linear(2048, 1024),
            nn.LayerNorm(1024),
            nn.ReLU(),
            nn.Linear(1024, out_dim)
        )
        
    def forward(self, x):
        # Extract features
        features = self.feature_extractor(x)  # [B, 2048, H, W]
        
        # Apply attention
        att_weights = self.attention(features)
        features = features * att_weights
        
        # Global average pooling
        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = features.view(features.size(0), -1)
        
        # Project to output dimension
        out = self.proj(features)
        return out.unsqueeze(1)  # [B, 1, out_dim]

# Enhanced LoRA application
def apply_enhanced_lora(unet, r=4, lora_alpha=1.0, lora_dropout=0.1):
    """Apply LoRA to both cross-attention and self-attention"""
    config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        lora_dropout=lora_dropout,
        bias="none",
    )
    unet = get_peft_model(unet, config)
    unet.print_trainable_parameters()  # This will show how many parameters are being trained
    return unet

# Training Utils
class EMA:
    def __init__(self, beta=0.9999):
        super().__init__()
        self.beta = beta
    
    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)
    
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

# Training configuration
class TrainingConfig:
    def __init__(self):
        self.train_batch_size = 2
        self.eval_batch_size = 2
        self.num_epochs = 100
        self.gradient_accumulation_steps = 2
        self.learning_rate = 1e-5
        self.lr_warmup_steps = 500
        self.save_image_epochs = 5
        self.save_model_epochs = 10
        self.mixed_precision = "fp16"
        self.output_dir = "caricature-lora-model"
        
        # LoRA specific
        self.lora_r = 4
        self.lora_alpha = 1.0
        self.lora_dropout = 0.1
        
        # Optimizer
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08
        
        # LR scheduler
        self.lr_scheduler = "cosine"

def main():
    config = TrainingConfig()
    
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
    )
    
    # Create dataset and split into train/val
    dataset = AugmentedCaricatureDataset(DATA_DIR)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=2
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=2
    )
    
    # Load models
    model_path = "runwayml/stable-diffusion-v1-5"
    
    device = accelerator.device
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device)
    image_encoder = EnhancedImageEncoder(out_dim=768).to(device)

    # Apply LoRA
    unet = apply_enhanced_lora(
        unet,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout
    )

    # Freeze VAE
    vae.requires_grad_(False)
    vae.eval()
    
    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Initialize optimizer
    params_to_train = list(unet.parameters()) + list(image_encoder.parameters())
    optimizer = torch.optim.AdamW(
        params_to_train,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        weight_decay=config.adam_weight_decay,
        eps=config.adam_epsilon
    )
    
    # Prepare models for training
    unet, image_encoder, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        unet, image_encoder, optimizer, train_dataloader, val_dataloader
    )
    
    # Initialize EMA
    ema = EMA(beta=0.9999)
    ema_unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")
    ema_unet = apply_enhanced_lora(ema_unet, r=config.lora_r)
    ema_unet = accelerator.prepare(ema_unet)
    
    # Training loop
    global_step = 0
    
    for epoch in range(config.num_epochs):
        unet.train()
        image_encoder.train()
        train_loss = 0.0
        
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                # Get images
                face = batch["face"]
                caric = batch["caric"]
                
                # Encode caricature to latent space
                latents = vae.encode(caric).latent_dist.sample()
                latents = latents * 0.18215
                
                # Add noise
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (latents.shape[0],),
                    device=latents.device
                )
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                # Get conditioning
                cond_embedding = image_encoder(face)
                
                # Predict noise
                noise_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=cond_embedding
                ).sample
                
                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
                
                # Backward pass
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(params_to_train, 1.0)
                optimizer.step()
                optimizer.zero_grad()
                
                # Update EMA
                if accelerator.sync_gradients:
                    ema.update_model_average(ema_unet, unet)
            
            train_loss += loss.detach().item()
            global_step += 1
            
        # Validation
        if epoch % config.save_image_epochs == 0:
            unet.eval()
            image_encoder.eval()
            val_loss = 0.0
            
            for step, batch in enumerate(val_dataloader):
                with torch.no_grad():
                    face = batch["face"]
                    caric = batch["caric"]
                    
                    latents = vae.encode(caric).latent_dist.sample()
                    latents = latents * 0.18215
                    
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(
                        0,
                        noise_scheduler.config.num_train_timesteps,
                        (latents.shape[0],),
                        device=latents.device
                    )
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                    cond_embedding = image_encoder(face)
                    
                    noise_pred = unet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=cond_embedding
                    ).sample
                    
                    val_loss += F.mse_loss(
                        noise_pred.float(),
                        noise.float(),
                        reduction="mean"
                    ).item()
            
            val_loss /= len(val_dataloader)
            
            print(f"Epoch {epoch}: Train Loss = {train_loss/len(train_dataloader):.4f}, Val Loss = {val_loss:.4f}")
            
            # Save checkpoint
            if epoch % config.save_model_epochs == 0:
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    # Save LoRA weights
                    save_path = os.path.join(config.output_dir, f"checkpoint-{epoch}")
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_unet = accelerator.unwrap_model(unet)
                    unwrapped_unet.save_pretrained(save_path)
                    # Save image encoder
                    torch.save(image_encoder.state_dict(), os.path.join(save_path, "image_encoder.pt"))
                    
                # Generate sample images
                if accelerator.is_main_process:
                    generate_samples(
                        vae,
                        ema_unet if epoch > 0 else unet,
                        image_encoder,
                        noise_scheduler,
                        val_dataset,
                        epoch,
                        config
                    )

def generate_samples(vae, unet, image_encoder, noise_scheduler, val_dataset, epoch, config):
    """Generate and save sample images during training"""
    # Define device here
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  

    
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=noise_scheduler
    )
    
    # Select a few validation images
    eval_indices = [0, len(val_dataset)//2, -1]  # Beginning, middle, end
    os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
    
    for idx in eval_indices:
        sample = val_dataset[idx]
        face_image = sample["face"].unsqueeze(0).to(device)
        
        with torch.no_grad():
            generated_image = pipeline(
                face_image,
                num_inference_steps=50,
                guidance_scale=7.5
            )['images'][0]
            
            # Save the triplet (input face, generated caricature, ground truth)
            comparison = Image.new('RGB', (768, 256))
            face_pil = transforms.ToPILImage()(sample["face"])
            gt_pil = transforms.ToPILImage()(sample["caric"])
            
            comparison.paste(face_pil, (0, 0))
            comparison.paste(generated_image, (256, 0))
            comparison.paste(gt_pil, (512, 0))
            
            comparison.save(
                os.path.join(config.output_dir, "samples", f"sample-{epoch}-{idx}.png")
            )

class CaricaturePipeline:
    """Pipeline for generating caricatures from face images"""
    def __init__(self, vae, unet, image_encoder, scheduler):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.vae = vae.to(self.device)
        self.unet = unet.to(self.device)
        self.image_encoder = image_encoder.to(self.device)
        self.scheduler = scheduler
        
        self.vae.eval()
        self.unet.eval()
        self.image_encoder.eval()
        
        self.device = "cuda" if torch.cuda.is_available() else "cpu"
        self.vae.to(self.device)
        self.unet.to(self.device)
        self.image_encoder.to(self.device)
    
    @torch.no_grad()
    def __call__(
        self,
        face_image,
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=None
    ):
        # 0. Get face embedding
        cond_embedding = self.image_encoder(face_image)
        
        # 1. Set timesteps
        self.scheduler.set_timesteps(num_inference_steps)
        timesteps = self.scheduler.timesteps
        
        # 2. Generate initial noise
        latents = torch.randn(
            (1, 4, height // 8, width // 8),
            generator=generator,
            device=self.device
        )
        
        # 3. Classifier-free guidance setup
        uncond_embedding = torch.zeros_like(cond_embedding)
        
        # 4. Denoising loop
        for t in tqdm(timesteps):
            # Expand latents for classifier-free guidance
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            
            # Predict noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=torch.cat([uncond_embedding, cond_embedding])
            ).sample
            
            # Perform guidance
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            # Compute previous noisy sample
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        
        # 5. Scale and decode latents
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents).sample
        
        # 6. Convert to PIL
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.permute(0, 2, 3, 1).cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = [Image.fromarray(img) for img in image]
        
        return {"images": image}



In [None]:

# Set up logging
logging.basicConfig(level=logging.INFO)

# Create output directory
os.makedirs(TrainingConfig().output_dir, exist_ok=True)

# Start training
main()

# Inference

In [None]:
# @title Generate Caricatures from Test Images
# !pip install -q peft
from peft import PeftModel, LoraConfig, get_peft_model

def load_and_generate_caricature(
    face_image_path,
    checkpoint_path,
    model_path="runwayml/stable-diffusion-v1-5"
):
    # Load models
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
    # Load the base UNet model first
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")  
    # Then load LoRA weights
    unet = PeftModel.from_pretrained(unet, checkpoint_path)  # Load LoRA weights
    image_encoder = EnhancedImageEncoder(out_dim=768)
    image_encoder.load_state_dict(torch.load(os.path.join(checkpoint_path, "image_encoder.pt")))
    scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Create pipeline
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=scheduler
    )
    
    # Load and preprocess image
    transform = transforms.Compose([
        transforms.Resize((256, 256)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    ])
    
    face_image = Image.open(face_image_path).convert("RGB")
    face_tensor = transform(face_image).unsqueeze(0).to("cuda")
    
    # Generate caricature
    output = pipeline(
        face_tensor,
        num_inference_steps=50,
        guidance_scale=7.5
    )
    
    return output["images"][0]

# Example usage:
checkpoint_path = "caricature-lora-model/checkpoint-90"  # Adjust epoch number as needed
# test_image_path = "/path/to/test/face.png"  # Replace with your test image path
test_image_path = '/content/drive/MyDrive/caricature Project Diffusion/test_06.png'
generated_caricature = load_and_generate_caricature(test_image_path, checkpoint_path)
generated_caricature.save("generated_caricature.png")
display(generated_caricature)