[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Guidosalimbeni/aicaricaturist/blob/main/lora_img2img/lora_img2img_v3.ipynb)



In [None]:
# import shutil
# shutil.rmtree("caricature-lora-model")  # Deletes the folder and all its contents

In [None]:
# Enhanced LoRA-based Image-to-Image Diffusion Training
# @title Setup and Imports

!pip install torch torchvision
!pip install diffusers
!pip install accelerate
!pip install Pillow
!pip install tqdm

!pip install -q peft

from peft import LoraConfig, get_peft_model
from peft import PeftModel, LoraConfig, get_peft_model

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import random
import logging
from torchvision import models  # This was missing from the imports
from transformers import get_cosine_schedule_with_warmup

import accelerate
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
#from diffusers.optimization import get_scheduler # This import is not used in the code, consider removing
from accelerate import Accelerator
from torch.optim.lr_scheduler import CosineAnnealingLR # Import the missing class
from matplotlib import pyplot as plt

In [None]:


# For reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')
DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'

class AugmentedCaricatureDataset(Dataset):
    def __init__(self, data_dir, split='train'):
        super().__init__()
        self.data_dir = data_dir
        self.split = split
        
        # Simpler transformations
        base_transforms = [
            transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
        ]
        
        if split == 'train':
            self.transform_face = transforms.Compose(
                [transforms.RandomHorizontalFlip(p=0.5)] + base_transforms
            )
            self.transform_caric = transforms.Compose(
                [transforms.RandomHorizontalFlip(p=0.5)] + base_transforms
            )
        else:
            self.transform_face = transforms.Compose(base_transforms)
            self.transform_caric = transforms.Compose(base_transforms)
        
        # Load all pairs
        self.pairs = []
        for i in range(1, 43):
            face_path = os.path.join(data_dir, f"{i:03d}_f.png")
            caric_path = os.path.join(data_dir, f"{i:03d}_c.png")
            if os.path.exists(face_path) and os.path.exists(caric_path):
                self.pairs.append((face_path, caric_path))
    
    def __len__(self):
        return len(self.pairs)  # Added this method back
    
    def __getitem__(self, idx):
        face_path, caric_path = self.pairs[idx]
        
        # Load as grayscale directly
        face_img = Image.open(face_path).convert("L")
        caric_img = Image.open(caric_path).convert("L")
        
        # Apply transforms with same random flip
        seed = torch.randint(0, 2**32, (1,))[0].item()
        
        torch.manual_seed(seed)
        face_tensor = self.transform_face(face_img)
        
        torch.manual_seed(seed)
        caric_tensor = self.transform_caric(caric_img)
        
        # Repeat grayscale channel to match model input
        face_tensor = face_tensor.repeat(3, 1, 1)
        caric_tensor = caric_tensor.repeat(3, 1, 1)
        
        return {
            "face": face_tensor,
            "caric": caric_tensor
        }

class EnhancedImageEncoder(nn.Module):
    def __init__(self, out_dim=768):
        super().__init__()
        resnet = torchvision.models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)  # Using ResNet34 instead of 50
        
        # Remove average pooling and fc layers
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])
        
        # Simplified attention
        self.attention = nn.Sequential(
            nn.Conv2d(512, 128, 1),  # Reduced channels
            nn.ReLU(),
            nn.Conv2d(128, 1, 1),
            nn.Sigmoid()
        )
        
        # Simplified projection
        self.proj = nn.Sequential(
            nn.Linear(512, out_dim),
            nn.LayerNorm(out_dim)
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        att_weights = self.attention(features)
        features = features * att_weights
        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = features.view(features.size(0), -1)
        out = self.proj(features)
        return out.unsqueeze(1)

def apply_enhanced_lora(unet, r=8, lora_alpha=1.0, lora_dropout=0.1):
    """Apply LoRA with increased rank"""
    config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=["to_k", "to_q", "to_v", "to_out.0"],
        lora_dropout=lora_dropout,
        bias="none",
    )
    unet = get_peft_model(unet, config)
    unet.print_trainable_parameters()
    return unet

class EMA:
    def __init__(self, beta=0.9999):
        super().__init__()
        self.beta = beta
    
    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)
    
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class TrainingConfig:
    def __init__(self):
        self.train_batch_size = 1
        self.eval_batch_size = 1
        self.num_epochs = 200  # reduced from 300
        self.gradient_accumulation_steps = 4
        self.learning_rate = 1e-6  # Reduced learning rate
        self.min_learning_rate = 1e-7
        self.lr_warmup_steps = 100
        self.save_image_epochs = 5
        self.save_model_epochs = 10
        self.mixed_precision = "fp16"
        self.output_dir = "caricature-lora-model"
        
        # LoRA specific
        self.lora_r = 8  # Increased rank
        self.lora_alpha = 1.0
        self.lora_dropout = 0.1
        
        # Optimizer
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08

def plot_losses(train_losses, val_losses, save_path, save_epochs):
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label='Train Loss', alpha=0.7)
    val_epochs = range(0, len(train_losses), save_epochs)
    val_loss_plot = [val_losses[i] for i in range(len(val_losses))]
    plt.plot(val_epochs, val_loss_plot, label='Val Loss', alpha=0.7)
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    # Add moving average only if enough data points are available
    window_size = 10
    if len(train_losses) >= window_size:  # Check if enough data points exist
        train_ma = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        plt.plot(range(window_size-1, len(train_losses)), train_ma, 
                 label=f'Train {window_size}-epoch MA', linestyle='--', alpha=0.5)
    
    plt.savefig(save_path)
    plt.close()

def main():
    config = TrainingConfig()
    
    # Initialize accelerator
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
    )
    
    # Create dataset and split into train/val
    dataset = AugmentedCaricatureDataset(DATA_DIR)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    # Create dataloaders
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=2
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=2
    )
    
    # Load models
    model_path = "runwayml/stable-diffusion-v1-5"
    
    device = accelerator.device
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device)
    image_encoder = EnhancedImageEncoder(out_dim=768).to(device)
    
    # Apply LoRA
    unet = apply_enhanced_lora(
        unet,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout
    )
    
    # Freeze VAE
    vae.requires_grad_(False)
    vae.eval()
    
    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Initialize optimizer and scheduler
    params_to_train = list(unet.parameters()) + list(image_encoder.parameters())
    optimizer = torch.optim.AdamW(
        params_to_train,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        weight_decay=config.adam_weight_decay,
        eps=config.adam_epsilon
    )
    
    # lr_scheduler = CosineAnnealingLR(
    #     optimizer, 
    #     T_max=config.num_epochs,
    #     eta_min=config.min_learning_rate
    # )

    
    # In main(), replace the current scheduler with:
    num_training_steps = config.num_epochs * len(train_dataloader)
    num_warmup_steps = num_training_steps * 0.1  # 10% warmup

    lr_scheduler = get_cosine_schedule_with_warmup(
        optimizer,
        num_warmup_steps=num_warmup_steps,
        num_training_steps=num_training_steps
    )
    
    # Prepare models for training
    unet, image_encoder, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        unet, image_encoder, optimizer, train_dataloader, val_dataloader
    )
    
    # Initialize EMA
    ema = EMA(beta=0.9999)
    ema_unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")
    ema_unet = apply_enhanced_lora(ema_unet, r=config.lora_r)
    ema_unet = accelerator.prepare(ema_unet)
    
    # Training tracking
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience = 20  # epochs
    patience_counter = 0
    
    # Training loop
    global_step = 0
    
    for epoch in range(config.num_epochs):
        unet.train()
        image_encoder.train()
        train_loss = 0.0
        
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                face = batch["face"]
                caric = batch["caric"]
                
                latents = vae.encode(caric).latent_dist.sample()
                latents = latents * 0.18215
                
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (latents.shape[0],),
                    device=latents.device
                )
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                cond_embedding = image_encoder(face)
                
                noise_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=cond_embedding
                ).sample
                
                loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
                
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(params_to_train, 0.5)
                optimizer.step()
                optimizer.zero_grad()
                
                if accelerator.sync_gradients:
                    ema.update_model_average(ema_unet, unet)
            
            train_loss += loss.detach().item()
            global_step += 1
        
        # Update learning rate
        lr_scheduler.step()
        
        # Calculate average train loss
        avg_train_loss = train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        
        # Validation
        if epoch % config.save_image_epochs == 0:
            unet.eval()
            image_encoder.eval()
            val_loss = 0.0
            
            for step, batch in enumerate(val_dataloader):
                with torch.no_grad():
                    face = batch["face"]
                    caric = batch["caric"]
                    
                    latents = vae.encode(caric).latent_dist.sample()
                    latents = latents * 0.18215
                    
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(
                        0,
                        noise_scheduler.config.num_train_timesteps,
                        (latents.shape[0],),
                        device=latents.device
                    )
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                    cond_embedding = image_encoder(face)
                    
                    noise_pred = unet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=cond_embedding
                    ).sample
                    
                    val_loss += F.mse_loss(
                        noise_pred.float(),
                        noise.float(),
                        reduction="mean"
                    ).item()
            
            val_loss /= len(val_dataloader)
            val_losses.append(val_loss)
            
            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                # Save best model
                if accelerator.is_main_process:
                    save_path = os.path.join(config.output_dir, "best_model")
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_unet = accelerator.unwrap_model(unet)
                    unwrapped_unet.save_pretrained(save_path)
                    torch.save(image_encoder.state_dict(), os.path.join(save_path, "image_encoder.pt"))
            else:
                patience_counter += 1
            
            print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}, LR = {lr_scheduler.get_last_lr()[0]:.6f}")
            
            # Plot losses
            plot_losses(
                train_losses, 
                val_losses, 
                os.path.join(config.output_dir, 'loss_plot.png'),
                config.save_image_epochs
            )
            
            # Save checkpoint and generate samples
            if epoch % config.save_model_epochs == 0:
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    # Save checkpoint
                    save_path = os.path.join(config.output_dir, f"checkpoint-{epoch}")
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_unet = accelerator.unwrap_model(unet)
                    unwrapped_unet.save_pretrained(save_path)
                    torch.save(image_encoder.state_dict(), os.path.join(save_path, "image_encoder.pt"))
                    
                    # Generate samples
                    generate_samples(
                        vae,
                        ema_unet if epoch > 0 else unet,
                        image_encoder,
                        noise_scheduler,
                        val_dataset,
                        epoch,
                        config
                    )
            
            # Early stopping
            if patience_counter >= patience:
                print(f"Early stopping triggered after {epoch} epochs")
                break

class CaricaturePipeline:
    def __init__(self, vae, unet, image_encoder, scheduler):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.vae = vae.to(self.device)
        self.unet = unet.to(self.device)
        self.image_encoder = image_encoder.to(self.device)
        self.scheduler = scheduler
        
        self.vae.eval()
        self.unet.eval()
        self.image_encoder.eval()
    
    @torch.no_grad()
    def __call__(
        self,
        face_image,
        height=256,
        width=256,
        num_inference_steps=50,
        guidance_scale=7.5,
        generator=None
    ):
        # Get face embedding
        cond_embedding = self.image_encoder(face_image)
        
        # Set timesteps
        self.scheduler.set_timesteps(num_inference_steps)
        timesteps = self.scheduler.timesteps
        
        # Generate initial noise
        latents = torch.randn(
            (1, 4, height // 8, width // 8),
            generator=generator,
            device=self.device
        )
        
        # Classifier-free guidance setup
        uncond_embedding = torch.zeros_like(cond_embedding)
        
        # Denoising loop
        for t in tqdm(timesteps):
            # Expand latents for classifier-free guidance
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            
            # Predict noise residual
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=torch.cat([uncond_embedding, cond_embedding])
            ).sample
            
            # Perform guidance
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            # Compute previous noisy sample
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        
        # Scale and decode latents
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents).sample
        
        # Convert to PIL
        image = (image / 2 + 0.5).clamp(0, 1)
        image = image.permute(0, 2, 3, 1).cpu().numpy()
        image = (image * 255).round().astype("uint8")
        image = [Image.fromarray(img) for img in image]
        
        return {"images": image}

def generate_samples(vae, unet, image_encoder, noise_scheduler, val_dataset, epoch, config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=noise_scheduler
    )
    
    eval_indices = [0, len(val_dataset)//2, -1]
    os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
    
    for idx in eval_indices:
        sample = val_dataset[idx]
        face_image = sample["face"].unsqueeze(0).to(device)
        
        with torch.no_grad():
            generated_image = pipeline(
                face_image,
                num_inference_steps=50,
                guidance_scale=7.5
            )['images'][0]
            
            comparison = Image.new('RGB', (768, 256))
            face_pil = transforms.ToPILImage()(sample["face"])
            gt_pil = transforms.ToPILImage()(sample["caric"])
            
            comparison.paste(face_pil, (0, 0))
            comparison.paste(generated_image, (256, 0))
            comparison.paste(gt_pil, (512, 0))
            
            comparison.save(
                os.path.join(config.output_dir, "samples", f"sample-{epoch}-{idx}.png")
            )



In [None]:

# Set up logging
logging.basicConfig(level=logging.INFO)

# Create output directory
os.makedirs(TrainingConfig().output_dir, exist_ok=True)

# Start training
main()

# Inference

In [None]:
# @title Generate Caricatures from Test Images
# !pip install -q peft

def load_and_generate_caricature(
    face_image_path,
    checkpoint_path,
    model_path="runwayml/stable-diffusion-v1-5"
):
    # Load models
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
    # Load the base UNet model first
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")  
    # Then load LoRA weights
    unet = PeftModel.from_pretrained(unet, checkpoint_path)  # Load LoRA weights
    image_encoder = EnhancedImageEncoder(out_dim=768)
    image_encoder.load_state_dict(torch.load(os.path.join(checkpoint_path, "image_encoder.pt")))
    scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Create pipeline
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=scheduler
    )
    
    # Load and preprocess image - updated to match training
    transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
    ])
    
    face_image = Image.open(face_image_path).convert("L")  # Convert to grayscale
    face_tensor = transform(face_image)
    # Repeat the channel 3 times
    face_tensor = face_tensor.repeat(3, 1, 1)
    face_tensor = face_tensor.unsqueeze(0).to("cuda")
    
    # Generate caricature
    output = pipeline(
        face_tensor,
        num_inference_steps=50,
        guidance_scale=7.5
    )
    
    return output["images"][0]

# Example usage:
checkpoint_path = "caricature-lora-model/checkpoint-90"  # Adjust epoch number as needed
test_image_path = '/content/drive/MyDrive/caricature Project Diffusion/test_06.png'
generated_caricature = load_and_generate_caricature(test_image_path, checkpoint_path)
generated_caricature.save("generated_caricature.png")
display(generated_caricature)