[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Guidosalimbeni/aicaricaturist/blob/main/lora_img2img/lora_img2img_v6.ipynb)



In [None]:
# import shutil
# shutil.rmtree("caricature-lora-model")  # Deletes the folder and all its contents

In [None]:
# Enhanced LoRA-based Image-to-Image Diffusion Training

!pip install torch torchvision
!pip install diffusers
!pip install accelerate
!pip install Pillow
!pip install tqdm

!pip install -q peft

from peft import LoraConfig, get_peft_model
from peft import PeftModel, LoraConfig, get_peft_model

import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
from torch.utils.data import Dataset, DataLoader, random_split
from torchvision import transforms
from PIL import Image
import numpy as np
from tqdm.auto import tqdm
import random
import logging
from torchvision import models  # This was missing from the imports
from transformers import get_cosine_schedule_with_warmup

import accelerate
from diffusers import UNet2DConditionModel, DDPMScheduler, AutoencoderKL
#from diffusers.optimization import get_scheduler # This import is not used in the code, consider removing
from accelerate import Accelerator
from torch.optim.lr_scheduler import CosineAnnealingLR # Import the missing class
from matplotlib import pyplot as plt
import math

In [None]:
num_of_caricature = 59


# For reproducibility
torch.manual_seed(42)
random.seed(42)
np.random.seed(42)

# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

DATA_DIR = '/content/drive/MyDrive/caricature Project Diffusion/paired_caricature'

class AugmentedCaricatureDataset(Dataset):
    def __init__(self, data_dir, split='train'):
        super().__init__()
        self.data_dir = data_dir
        self.split = split
        
        # Simple transforms without normalization
        base_transforms = [
            transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.ToTensor(),
        ]
        
        if split == 'train':
            self.transform = transforms.Compose(
                [transforms.RandomHorizontalFlip(p=0.5)] + base_transforms
            )
        else:
            self.transform = transforms.Compose(base_transforms)
        
        # Load all pairs
        self.pairs = []
        for i in range(1, num_of_caricature):
            face_path = os.path.join(data_dir, f"{i:03d}_f.png")
            caric_path = os.path.join(data_dir, f"{i:03d}_c.png")
            if os.path.exists(face_path) and os.path.exists(caric_path):
                self.pairs.append((face_path, caric_path))
    
    def __len__(self):
        return len(self.pairs)
    
    def normalize_tensor(self, tensor):
        # MinMax scaling to [-1, 1] range
        min_val = tensor.min()
        max_val = tensor.max()
        normalized = 2 * (tensor - min_val) / (max_val - min_val) - 1
        return normalized
    
    def __getitem__(self, idx):
        face_path, caric_path = self.pairs[idx]
        
        # Load as grayscale
        face_img = Image.open(face_path).convert("L")
        caric_img = Image.open(caric_path).convert("L")
        
        # Apply transforms with same random flip
        seed = torch.randint(0, 2**32, (1,))[0].item()
        
        torch.manual_seed(seed)
        face_tensor = self.transform(face_img)
        
        torch.manual_seed(seed)
        caric_tensor = self.transform(caric_img)
        
        # Apply MinMax normalization
        face_tensor = self.normalize_tensor(face_tensor)
        caric_tensor = self.normalize_tensor(caric_tensor)
        
        # Repeat single channel to 3 channels
        face_tensor = face_tensor.repeat(3, 1, 1)
        caric_tensor = caric_tensor.repeat(3, 1, 1)
        
        return {
            "face": face_tensor,
            "caric": caric_tensor
        }
    

class EnhancedImageEncoder(nn.Module):
    def __init__(self, out_dim=768):
        super().__init__()
        resnet = models.resnet34(weights=models.ResNet34_Weights.IMAGENET1K_V1)
        
        # Remove average pooling and fc layers
        self.feature_extractor = nn.Sequential(*list(resnet.children())[:-2])
        
        # Simplified attention with fewer channels
        self.attention = nn.Sequential(
            nn.Conv2d(512, 128, 1),  # Reduced from 2048->512 to 512->128
            nn.ReLU(),
            nn.Conv2d(128, 1, 1),
            nn.Sigmoid()
        )
        
        # Simplified projection
        self.proj = nn.Sequential(
            nn.Linear(512, out_dim),  # Input dim reduced from 2048 to 512
            nn.LayerNorm(out_dim)
        )
        
    def forward(self, x):
        features = self.feature_extractor(x)
        att_weights = self.attention(features)
        features = features * att_weights
        features = F.adaptive_avg_pool2d(features, (1, 1))
        features = features.view(features.size(0), -1)
        out = self.proj(features)
        return out.unsqueeze(1)

def apply_enhanced_lora(unet, r=32, lora_alpha=4.0, lora_dropout=0.2):
    """Apply LoRA with increased rank to supported modules"""
    config = LoraConfig(
        r=r,
        lora_alpha=lora_alpha,
        target_modules=["to_k", "to_q", "to_v", "to_out.0", "proj_in", "proj_out"],
        lora_dropout=lora_dropout,
        bias="none",
    )
    unet = get_peft_model(unet, config)
    unet.print_trainable_parameters()
    return unet

class EMA:
    def __init__(self, beta=0.9999):
        super().__init__()
        self.beta = beta
    
    def update_model_average(self, ma_model, current_model):
        for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
            old_weight, up_weight = ma_params.data, current_params.data
            ma_params.data = self.update_average(old_weight, up_weight)
    
    def update_average(self, old, new):
        if old is None:
            return new
        return old * self.beta + (1 - self.beta) * new

class TrainingConfig:
    def __init__(self):
        # Increased epochs and adjusted early stopping
        self.num_epochs = 1000
        self.patience = 50
        
        # Batch and optimization
        self.train_batch_size = 2
        self.eval_batch_size = 2
        self.gradient_accumulation_steps = 4
        
        # Learning rate settings
        self.learning_rate = 1e-5
        self.min_learning_rate = 1e-7
        self.lr_warmup_steps = 500
        
        # Saving frequency
        self.save_image_epochs = 10
        self.save_model_epochs = 20
        
        # Mixed precision
        self.mixed_precision = "fp16"
        self.output_dir = "caricature-lora-model"
        
        # LoRA settings
        self.lora_r = 32
        self.lora_alpha = 4.0
        self.lora_dropout = 0.2
        
        # Optimizer settings
        self.adam_beta1 = 0.9
        self.adam_beta2 = 0.999
        self.adam_weight_decay = 1e-2
        self.adam_epsilon = 1e-08
        
        # Style loss weight
        self.style_weight = 0.2

def get_lr_scheduler(optimizer, config, num_training_steps):
    """Creates a learning rate scheduler with warmup and cosine decay"""
    def lr_lambda(current_step):
        if current_step < config.lr_warmup_steps:
            # Linear warmup
            return float(current_step) / float(max(1, config.lr_warmup_steps))
        else:
            # Cosine decay after warmup
            progress = float(current_step - config.lr_warmup_steps) / float(
                max(1, num_training_steps - config.lr_warmup_steps)
            )
            return max(
                config.min_learning_rate / config.learning_rate,
                0.5 * (1.0 + math.cos(math.pi * progress))
            )
    
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda)

def plot_losses(train_losses, val_losses, save_path, save_epochs):
    plt.figure(figsize=(12, 6))
    plt.plot(train_losses, label='Train Loss', alpha=0.7)
    val_epochs = range(0, len(train_losses), save_epochs)
    val_loss_plot = [val_losses[i] for i in range(len(val_losses))]
    plt.plot(val_epochs, val_loss_plot, label='Val Loss', alpha=0.7)
    
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.title('Training and Validation Losses')
    plt.grid(True, alpha=0.3)
    plt.legend()
    
    window_size = 10
    if len(train_losses) >= window_size:
        train_ma = np.convolve(train_losses, np.ones(window_size)/window_size, mode='valid')
        plt.plot(range(window_size-1, len(train_losses)), train_ma, 
                 label=f'Train {window_size}-epoch MA', linestyle='--', alpha=0.5)
    
    plt.savefig(save_path)
    plt.close()

class StyleLoss(nn.Module):
    def __init__(self):
        super(StyleLoss, self).__init__()
        
    def gram_matrix(self, x):
        b, c, h, w = x.size()
        features = x.view(b, c, h * w)
        gram = torch.bmm(features, features.transpose(1, 2))
        return gram.div(c * h * w)
    
    def forward(self, pred, target):
        pred_gram = self.gram_matrix(pred)
        target_gram = self.gram_matrix(target)
        return F.mse_loss(pred_gram, target_gram)

class CaricaturePipeline:
    def __init__(self, vae, unet, image_encoder, scheduler):
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        
        self.vae = vae.to(self.device)
        self.unet = unet.to(self.device)
        self.image_encoder = image_encoder.to(self.device)
        self.scheduler = scheduler
        
        self.vae.eval()
        self.unet.eval()
        self.image_encoder.eval()
    
    @torch.no_grad()
    def __call__(
        self,
        face_image,
        height=256,
        width=256,
        num_inference_steps=75,
        guidance_scale=9.5,
        generator=None
    ):
        # Get face embedding
        cond_embedding = self.image_encoder(face_image)
        
        # Set timesteps
        self.scheduler.set_timesteps(num_inference_steps)
        timesteps = self.scheduler.timesteps
        
        # Generate initial noise
        latents = torch.randn(
            (1, 4, height // 8, width // 8),
            generator=generator,
            device=self.device
        )
        
        # Classifier-free guidance setup
        uncond_embedding = torch.zeros_like(cond_embedding)
        
        # Denoising loop
        for t in tqdm(timesteps):
            latent_model_input = torch.cat([latents] * 2)
            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
            
            noise_pred = self.unet(
                latent_model_input,
                t,
                encoder_hidden_states=torch.cat([uncond_embedding, cond_embedding])
            ).sample
            
            noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
            noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)
            
            latents = self.scheduler.step(noise_pred, t, latents).prev_sample
        
        # Scale and decode latents
        latents = 1 / 0.18215 * latents
        image = self.vae.decode(latents).sample
        
        # Convert to grayscale by averaging channels
        image = image.mean(dim=1, keepdim=True).repeat(1, 3, 1, 1)
        
        # Rescale from [-1, 1] to [0, 1]
        image = (image + 1) / 2
        image = image.clamp(0, 1)
        
        # Convert to PIL Image
        image = image.cpu().permute(0, 2, 3, 1).numpy()
        image = (image * 255).round().astype("uint8")
        image = [Image.fromarray(img, mode='RGB').convert('L') for img in image]
        
        return {"images": image}


def generate_samples(vae, unet, image_encoder, noise_scheduler, val_dataset, epoch, config):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=noise_scheduler
    )
    
    # Fixed samples for consistent tracking
    fixed_indices = [0, len(val_dataset)//2, -1]
    random_indices = random.sample(range(len(val_dataset)), k=2)
    eval_indices = fixed_indices + random_indices
    
    os.makedirs(os.path.join(config.output_dir, "samples"), exist_ok=True)
    
    for idx in eval_indices:
        sample = val_dataset[idx]
        face_image = sample["face"].unsqueeze(0).to(device)
        
        with torch.no_grad():
            generated_image = pipeline(
                face_image,
                num_inference_steps=50,
                guidance_scale=9.5
            )['images'][0]
            
            comparison = Image.new('L', (768, 256))
            
            face_pil = transforms.ToPILImage()(sample["face"][0])
            gt_pil = transforms.ToPILImage()(sample["caric"][0])
            
            index_type = "fixed" if idx in fixed_indices else "random"
            
            comparison.paste(face_pil, (0, 0))
            comparison.paste(generated_image, (256, 0))
            comparison.paste(gt_pil, (512, 0))
            
            comparison.save(
                os.path.join(config.output_dir, "samples", f"sample-{epoch}-{index_type}-{idx}.png")
            )

def main():
    config = TrainingConfig()
    
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
    )
    
    dataset = AugmentedCaricatureDataset(DATA_DIR)
    train_size = int(0.9 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
    
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=2
    )
    val_dataloader = DataLoader(
        val_dataset,
        batch_size=config.eval_batch_size,
        shuffle=False,
        num_workers=2
    )
    
    model_path = "runwayml/stable-diffusion-v1-5"
    
    device = accelerator.device
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae").to(device)
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet").to(device)
    image_encoder = EnhancedImageEncoder(out_dim=768).to(device)
    
    # Apply LoRA with updated parameters
    unet = apply_enhanced_lora(
        unet,
        r=config.lora_r,
        lora_alpha=config.lora_alpha,
        lora_dropout=config.lora_dropout
    )
    
    # Freeze VAE
    vae.requires_grad_(False)
    vae.eval()
    
    # Initialize noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Initialize optimizer
    params_to_train = list(unet.parameters()) + list(image_encoder.parameters())
    optimizer = torch.optim.AdamW(
        params_to_train,
        lr=config.learning_rate,
        betas=(config.adam_beta1, config.adam_beta2),
        weight_decay=config.adam_weight_decay,
        eps=config.adam_epsilon
    )
    
    # Initialize the new learning rate scheduler
    num_training_steps = config.num_epochs * len(train_dataloader)
    lr_scheduler = get_lr_scheduler(optimizer, config, num_training_steps)

    style_criterion = StyleLoss().to(device)
    
    # Prepare models for training
    unet, image_encoder, optimizer, train_dataloader, val_dataloader = accelerator.prepare(
        unet, image_encoder, optimizer, train_dataloader, val_dataloader
    )
    
    # Initialize EMA
    ema = EMA(beta=0.9999)
    ema_unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")
    ema_unet = apply_enhanced_lora(ema_unet, r=config.lora_r)
    ema_unet = accelerator.prepare(ema_unet)
    
    # Training tracking
    train_losses = []
    val_losses = []
    best_val_loss = float('inf')
    patience_counter = 0
    global_step = 0
    
    for epoch in range(config.num_epochs):
        unet.train()
        image_encoder.train()
        train_loss = 0.0

        current_lr = optimizer.param_groups[0]['lr']
        print(f"Epoch {epoch}: Current learning rate: {current_lr:.6f}")
        
        # Training loop
        progress_bar = tqdm(total=len(train_dataloader), desc=f"Epoch {epoch}")
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(unet):
                face = batch["face"]
                caric = batch["caric"]
                
                # Get latent encoding
                latents = vae.encode(caric).latent_dist.sample()
                latents = latents * 0.18215
                
                # Add noise
                noise = torch.randn_like(latents)
                timesteps = torch.randint(
                    0,
                    noise_scheduler.config.num_train_timesteps,
                    (latents.shape[0],),
                    device=latents.device
                )
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                # Get conditioning
                cond_embedding = image_encoder(face)
                
                # Predict noise
                noise_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=cond_embedding
                ).sample
                
                # Calculate losses
                content_loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
                style_loss = style_criterion(noise_pred, noise)
                loss = content_loss + config.style_weight * style_loss
                
                # Backward pass
                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(params_to_train, 1.0)
                optimizer.step()
                optimizer.zero_grad()
                
                # Update EMA model
                if accelerator.sync_gradients:
                    ema.update_model_average(ema_unet, unet)
            
            train_loss += loss.detach().item()
            global_step += 1
            lr_scheduler.step()
            
            progress_bar.update(1)
            progress_bar.set_postfix({"loss": loss.detach().item()})
        
        progress_bar.close()
        
        # Calculate average train loss
        avg_train_loss = train_loss / len(train_dataloader)
        train_losses.append(avg_train_loss)
        
        # Validation
        if epoch % config.save_image_epochs == 0:
            unet.eval()
            image_encoder.eval()
            val_loss = 0.0
            
            for step, batch in enumerate(val_dataloader):
                with torch.no_grad():
                    face = batch["face"]
                    caric = batch["caric"]
                    
                    latents = vae.encode(caric).latent_dist.sample()
                    latents = latents * 0.18215
                    
                    noise = torch.randn_like(latents)
                    timesteps = torch.randint(
                        0,
                        noise_scheduler.config.num_train_timesteps,
                        (latents.shape[0],),
                        device=latents.device
                    )
                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                    
                    cond_embedding = image_encoder(face)
                    
                    noise_pred = unet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=cond_embedding
                    ).sample
                    
                    val_loss += F.mse_loss(
                        noise_pred.float(),
                        noise.float(),
                        reduction="mean"
                    ).item()
            
            val_loss /= len(val_dataloader)
            val_losses.append(val_loss)
            
            # Early stopping check
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                patience_counter = 0
                if accelerator.is_main_process:
                    save_path = os.path.join(config.output_dir, "best_model")
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_unet = accelerator.unwrap_model(unet)
                    unwrapped_unet.save_pretrained(save_path)
                    torch.save(image_encoder.state_dict(), os.path.join(save_path, "image_encoder.pt"))
            else:
                patience_counter += 1
            
            print(f"Epoch {epoch}: Train Loss = {avg_train_loss:.4f}, Val Loss = {val_loss:.4f}")
            
            # Plot losses
            plot_losses(
                train_losses, 
                val_losses, 
                os.path.join(config.output_dir, 'loss_plot.png'),
                config.save_image_epochs
            )
            
            # Save checkpoint and generate samples
            if epoch % config.save_model_epochs == 0:
                accelerator.wait_for_everyone()
                if accelerator.is_main_process:
                    save_path = os.path.join(config.output_dir, f"checkpoint-{epoch}")
                    os.makedirs(save_path, exist_ok=True)
                    unwrapped_unet = accelerator.unwrap_model(unet)
                    unwrapped_unet.save_pretrained(save_path)
                    torch.save(image_encoder.state_dict(), os.path.join(save_path, "image_encoder.pt"))
                    
                    generate_samples(
                        vae,
                        ema_unet if epoch > 0 else unet,
                        image_encoder,
                        noise_scheduler,
                        val_dataset,
                        epoch,
                        config
                    )
            
            # Early stopping
            if patience_counter >= config.patience:
                print(f"Early stopping triggered after {epoch} epochs")
                break



In [None]:

# Set up logging
logging.basicConfig(level=logging.INFO)

# Create output directory
os.makedirs(TrainingConfig().output_dir, exist_ok=True)

# Start training
main()

# Inference

In [None]:
def load_and_generate_caricature(
    face_image_path,
    checkpoint_path,
    model_path="runwayml/stable-diffusion-v1-5"
):
    # Load models
    vae = AutoencoderKL.from_pretrained(model_path, subfolder="vae")
    unet = UNet2DConditionModel.from_pretrained(model_path, subfolder="unet")  
    unet = PeftModel.from_pretrained(unet, checkpoint_path)
    image_encoder = EnhancedImageEncoder(out_dim=768)
    image_encoder.load_state_dict(torch.load(os.path.join(checkpoint_path, "image_encoder.pt")))
    scheduler = DDPMScheduler.from_pretrained(model_path, subfolder="scheduler")
    
    # Create pipeline
    pipeline = CaricaturePipeline(
        vae=vae,
        unet=unet,
        image_encoder=image_encoder,
        scheduler=scheduler
    )
    
    # Load and preprocess image - matching training preprocessing
    transform = transforms.Compose([
        transforms.Resize((256, 256), interpolation=transforms.InterpolationMode.BILINEAR),
        transforms.ToTensor(),
    ])
    
    face_image = Image.open(face_image_path).convert("L")
    face_tensor = transform(face_image)
    
    # Apply MinMax normalization
    min_val = face_tensor.min()
    max_val = face_tensor.max()
    face_tensor = 2 * (face_tensor - min_val) / (max_val - min_val) - 1
    
    # Repeat grayscale channel to 3 channels
    face_tensor = face_tensor.repeat(3, 1, 1)
    face_tensor = face_tensor.unsqueeze(0).to("cuda")
    
    # Generate caricature
    output = pipeline(
        face_tensor,
        num_inference_steps=100,
        guidance_scale=9.5  # Increased from 7.5 to match training
    )
    
    # The output is already in the correct format from the pipeline
    generated_image = output["images"][0]
    
    return generated_image

# Example usage:
checkpoint_path = "caricature-lora-model/checkpoint-90"  # Adjust epoch number as needed
test_image_path = '/content/drive/MyDrive/caricature Project Diffusion/test_06.png'
generated_caricature = load_and_generate_caricature(test_image_path, checkpoint_path)
generated_caricature.save("generated_caricature.png")
display(generated_caricature)