# **Training**

In [None]:
!pip install -q -U bitsandbytes diffusers accelerate transformers pycocotools
!pip install -U peft

In [None]:
%%writefile train_script.py
import os
import random
import torch
from pathlib import Path
from PIL import Image, ImageDraw
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Dict, List, Tuple, Optional

from pycocotools.coco import COCO
from torchvision import transforms
from transformers import AutoTokenizer, CLIPTextModel
from diffusers import ControlNetModel, AutoencoderKL, UNet2DConditionModel, DDPMScheduler
from accelerate import Accelerator
import bitsandbytes as bnb 

@dataclass
class TrainingConfig:
    coco_root: str = "/kaggle/input/coco-2017-dataset/coco2017"
    train_img_dir: str = "train2017"
    train_ann_file: str = "annotations/instances_train2017.json"
    output_dir: str = "/kaggle/working/controlnet-layout-model"
    model_id: str = "runwayml/stable-diffusion-v1-5"
    resolution: int = 512
    batch_size: int = 2          
    grad_accumulation: int = 4   
    num_epochs: int = 10
    learning_rate: float = 1e-5 
    max_samples: Optional[int] = 5

def render_layout_mask(img_size, annotations):
    canvas = Image.new("RGB", img_size, (0, 0, 0))
    draw = ImageDraw.Draw(canvas)
    for ann in annotations:
        x, y, w, h = ann['bbox']
        draw.rectangle([x, y, x + w, y + h], outline=(255, 255, 255), width=2)
    return canvas

class LayoutConditionedDataset(torch.utils.data.Dataset):
    def __init__(self, config, tokenizer):
        self.config = config
        self.tokenizer = tokenizer
        self.root_dir = Path(config.coco_root)
        self.img_dir = self.root_dir / config.train_img_dir
        self.ann_file = self.root_dir / config.train_ann_file
        self.coco = COCO(self.ann_file)
        self.ids = list(self.coco.imgs.keys())
        self.ids = [i for i in self.ids if len(self.coco.getAnnIds(imgIds=i)) > 0]
        if config.max_samples: self.ids = self.ids[:config.max_samples]

        self.img_transforms = transforms.Compose([
            transforms.Resize(config.resolution), transforms.CenterCrop(config.resolution),
            transforms.ToTensor(), transforms.Normalize([0.5], [0.5]),
        ])
        self.mask_transforms = transforms.Compose([
            transforms.Resize(config.resolution), transforms.CenterCrop(config.resolution),
            transforms.ToTensor(),
        ])

    def __len__(self): return len(self.ids)

    def __getitem__(self, idx):
        img_id = self.ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        original_img = Image.open(self.img_dir / img_info['file_name']).convert("RGB")
        anns = self.coco.loadAnns(self.coco.getAnnIds(imgIds=img_id))
        control_mask = render_layout_mask(original_img.size, anns)
        
        prompt = "A photorealistic image" 
        tokens = self.tokenizer(prompt, max_length=77, padding="max_length", truncation=True, return_tensors="pt").input_ids[0]
        
        return {
            "pixel_values": self.img_transforms(original_img),
            "conditioning_pixel_values": self.mask_transforms(control_mask),
            "input_ids": tokens
        }

def main():
    config = TrainingConfig()
    accelerator = Accelerator(mixed_precision="fp16", gradient_accumulation_steps=config.grad_accumulation)
    
    if accelerator.is_main_process: print("üöÄ Starting Training Script...")

    tokenizer = AutoTokenizer.from_pretrained(config.model_id, subfolder="tokenizer")
    noise_scheduler = DDPMScheduler.from_pretrained(config.model_id, subfolder="scheduler")
    vae = AutoencoderKL.from_pretrained(config.model_id, subfolder="vae")
    text_encoder = CLIPTextModel.from_pretrained(config.model_id, subfolder="text_encoder")
    unet = UNet2DConditionModel.from_pretrained(config.model_id, subfolder="unet")

    vae.enable_slicing()
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)

    controlnet = ControlNetModel.from_unet(unet)
    controlnet.train()
    controlnet.enable_gradient_checkpointing()

    optimizer = bnb.optim.AdamW8bit(controlnet.parameters(), lr=config.learning_rate)
    dataset = LayoutConditionedDataset(config, tokenizer)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=4, pin_memory=True)

    controlnet, optimizer, dataloader = accelerator.prepare(controlnet, optimizer, dataloader)
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    unet.to(accelerator.device)

    for epoch in range(config.num_epochs):
        if accelerator.is_main_process: pbar = tqdm(total=len(dataloader), desc=f"Epoch {epoch+1}")
        
        for batch in dataloader:
            with accelerator.accumulate(controlnet):
                latents = vae.encode(batch["pixel_values"].to(torch.float32)).latent_dist.sample() * 0.18215
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                enc_hidden_states = text_encoder(batch["input_ids"])[0]
                
                down, mid = controlnet(noisy_latents, timesteps, encoder_hidden_states=enc_hidden_states, controlnet_cond=batch["conditioning_pixel_values"].to(torch.float32), return_dict=False)
                pred = unet(noisy_latents, timesteps, encoder_hidden_states=enc_hidden_states, down_block_additional_residuals=down, mid_block_additional_residual=mid).sample
                
                loss = torch.nn.functional.mse_loss(pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
                
            if accelerator.is_main_process: 
                pbar.update(1)
                pbar.set_postfix(loss=loss.item())

        if accelerator.is_main_process:
            # --- THE FIX IS HERE ---
            # We unwrap the model from the Multi-GPU container before saving
            save_path = f"{config.output_dir}/epoch_{epoch}"
            accelerator.unwrap_model(controlnet).save_pretrained(save_path)
            print(f"Saved epoch {epoch} to {save_path}")

if __name__ == "__main__":
    main()

In [None]:
!accelerate launch --multi_gpu --num_processes 2 --mixed_precision fp16 train_script.py

# **METRICS**

In [None]:
import torch
import os
import random
import numpy as np
import pandas as pd
from PIL import Image, ImageDraw
from tqdm.auto import tqdm
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
from torchmetrics.image import StructuralSimilarityIndexMeasure
from pycocotools.coco import COCO
from torchvision import transforms

# --- CONFIGURATION ---
# The list of epochs you want to compare
epochs_to_test = ["epoch_9"]

# We define WHERE to look for each model.
# The code will check these paths in order for each epoch name.
search_paths = [
    "/kaggle/working/controlnet-layout-model" # Check the newly trained one
]

num_samples = 2  # 50 samples for a good score

# --- SETUP METRICS ---
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"‚öôÔ∏è Setting up evaluation on {device}...")

# LPIPS (Perceptual Loss) - Lower is Better
lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type='alex').to(device)
# SSIM (Structural Similarity) - Higher is Better
ssim_metric = StructuralSimilarityIndexMeasure(data_range=1.0).to(device)

# --- DATASET HELPER ---
coco_root = "/kaggle/input/coco-2017-dataset/coco2017"
img_dir = os.path.join(coco_root, "train2017")
ann_file = os.path.join(coco_root, "annotations/instances_train2017.json")
coco = COCO(ann_file)

def render_layout_mask(img_size, annotations):
    canvas = Image.new("RGB", img_size, (0, 0, 0))
    draw = ImageDraw.Draw(canvas)
    for ann in annotations:
        x, y, w, h = ann['bbox']
        draw.rectangle([x, y, x + w, y + h], outline=(255, 255, 255), width=2)
    return canvas

def get_eval_batch(num_samples):
    # Get random IDs that have annotations
    ids = list(coco.imgs.keys())
    ids = [i for i in ids if len(coco.getAnnIds(imgIds=i)) > 0]
    random.seed(42) # Fixed seed ensures we test on the SAME images for every model
    test_ids = random.sample(ids, num_samples)
    
    batch_data = []
    for img_id in test_ids:
        # Load Real Image (Ground Truth)
        img_info = coco.loadImgs(img_id)[0]
        img_path = os.path.join(img_dir, img_info['file_name'])
        original_img = Image.open(img_path).convert("RGB").resize((512, 512))
        
        # Load Annotations & Make Mask
        anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
        scale_x = 512 / img_info['width']
        scale_y = 512 / img_info['height']
        scaled_anns = [{'bbox': [a['bbox'][0]*scale_x, a['bbox'][1]*scale_y, a['bbox'][2]*scale_x, a['bbox'][3]*scale_y]} for a in anns]
        mask = render_layout_mask((512, 512), scaled_anns)
        
        # Make Prompt
        cats = coco.loadCats([a['category_id'] for a in anns])
        names = list(set([c['name'] for c in cats]))
        prompt = f"A photorealistic image comprising {', '.join(names)}"
        
        batch_data.append({"real": original_img, "mask": mask, "prompt": prompt})
    return batch_data

def find_model_path(epoch_name, search_paths):
    """
    Intelligently searches for the folder 'epoch_X' inside the search paths.
    """
    for base_path in search_paths:
        # Check 1: Is it directly inside? (e.g. /kaggle/input/my-dataset/epoch_0)
        candidate = os.path.join(base_path, epoch_name)
        if os.path.exists(candidate):
            return candidate
            
        # Check 2: Sometimes uploading zips creates a double folder structure
        # (e.g. /kaggle/input/my-dataset/epoch_0/epoch_0)
        candidate_nested = os.path.join(base_path, epoch_name, epoch_name)
        if os.path.exists(candidate_nested):
            return candidate_nested

    return None

# --- EVALUATION LOOP ---
results = []
test_data = get_eval_batch(num_samples)
print(f"üì¶ Prepared {len(test_data)} samples for evaluation.")

to_tensor = transforms.ToTensor()

for epoch_name in epochs_to_test:
    # 1. FIND THE MODEL
    model_path = find_model_path(epoch_name, search_paths)
    
    if model_path is None:
        print(f"‚ö†Ô∏è SKIPPING {epoch_name}: Could not find folder in any search path.")
        continue

    print(f"\nüöÄ Evaluating {epoch_name} from: {model_path}")
    
    try:
        # 2. LOAD MODEL
        controlnet = ControlNetModel.from_pretrained(model_path, torch_dtype=torch.float16)
        pipe = StableDiffusionControlNetPipeline.from_pretrained(
            "runwayml/stable-diffusion-v1-5", controlnet=controlnet, torch_dtype=torch.float16, safety_checker=None
        ).to("cuda")
        pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
        pipe.set_progress_bar_config(disable=True) # Silence individual progress bars
        pipe.enable_model_cpu_offload() # Save VRAM
        
        lpips_scores = []
        ssim_scores = []
        
        # 3. RUN INFERENCE
        for item in tqdm(test_data, desc=f"Testing {epoch_name}"):
            # Generate (Fixed seed 42 for consistency across models)
            gen = torch.manual_seed(42)
            
            gen_img = pipe(
                item["prompt"], 
                image=item["mask"], 
                num_inference_steps=20,
                guidance_scale=7.5,
                generator=gen
            ).images[0]
            
            # Convert to Tensor for Metrics (normalize to 0-1)
            real_t = to_tensor(item["real"]).unsqueeze(0).to(device)
            gen_t = to_tensor(gen_img).unsqueeze(0).to(device)
            
            # Compute Metrics
            with torch.no_grad():
                # LPIPS expects input in [0, 1] or [-1, 1], we provide [0, 1]
                l_score = lpips_metric(gen_t, real_t)
                s_score = ssim_metric(gen_t, real_t)
            
            lpips_scores.append(l_score.item())
            ssim_scores.append(s_score.item())
            
        avg_lpips = np.mean(lpips_scores)
        avg_ssim = np.mean(ssim_scores)
        
        print(f"   üìä Results: LPIPS: {avg_lpips:.4f} | SSIM: {avg_ssim:.4f}")
        results.append({
            "Epoch": epoch_name, 
            "LPIPS (Lower is Better)": avg_lpips, 
            "SSIM (Higher is Better)": avg_ssim
        })
        
        # Cleanup memory for next model
        del pipe, controlnet
        torch.cuda.empty_cache()
        
    except Exception as e:
        print(f"   ‚ùå Error evaluating {epoch_name}: {e}")

# --- FINAL TABLE ---
print("\n" + "="*60)
print("FINAL QUANTITATIVE RESULTS")
print("="*60)
df = pd.DataFrame(results)
print(df.to_string(index=False))

# **VISUALISATION OF IMAGES**

In [None]:
import torch
import os
import random
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler
from pycocotools.coco import COCO

# --- 1. CONFIGURATION ---
NUM_IMAGES = 4  # Number of comparison rows
output_dir = "/kaggle/working/controlnet-layout-model"
epoch_0_path = os.path.join(output_dir, "epoch_0")
epoch_9_path = os.path.join(output_dir, "epoch_9")

print(f"üìâ Epoch 0 Path: {epoch_0_path}")
print(f"üìà Epoch 9 Path: {epoch_9_path}")

# --- 2. DATASET & FILTERING ---
coco_root = "/kaggle/input/coco-2017-dataset/coco2017"
img_dir = os.path.join(coco_root, "train2017")
ann_file = os.path.join(coco_root, "annotations/instances_train2017.json")
coco = COCO(ann_file)

def find_simple_images(coco, limit=5):
    target_cats = ['dog', 'cat', 'horse', 'bird', 'airplane', 'bus']
    cat_ids = coco.getCatIds(catNms=target_cats)
    found_ids = []
    
    random.shuffle(cat_ids)
    for cat_id in cat_ids:
        img_ids = coco.getImgIds(catIds=[cat_id])
        random.shuffle(img_ids)
        for img_id in img_ids:
            if img_id in found_ids: continue
            anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
            # Strict filter: 1-2 objects, large area
            if 1 <= len(anns) <= 2:
                if all(ann['area'] > 40000 for ann in anns):
                    found_ids.append(img_id)
                    break 
        if len(found_ids) >= limit: break
    return found_ids

def render_layout_mask(img_size, annotations):
    canvas = Image.new("RGB", img_size, (0, 0, 0))
    draw = ImageDraw.Draw(canvas)
    for ann in annotations:
        x, y, w, h = ann['bbox']
        draw.rectangle([x, y, x + w, y + h], outline=(255, 255, 255), width=2)
    return canvas

# --- 3. PREPARE DATA (Images & Prompts) ---
curated_ids = find_simple_images(coco, limit=NUM_IMAGES)
data_batch = []

print(f"üì¶ Preparing {len(curated_ids)} test cases...")

for img_id in curated_ids:
    # Load Real
    img_info = coco.loadImgs(img_id)[0]
    img_path = os.path.join(img_dir, img_info['file_name'])
    original_img = Image.open(img_path).convert("RGB").resize((512, 512))
    
    # Load Anns
    anns = coco.loadAnns(coco.getAnnIds(imgIds=img_id))
    scale_x, scale_y = 512 / img_info['width'], 512 / img_info['height']
    scaled_anns = [{'bbox': [a['bbox'][0]*scale_x, a['bbox'][1]*scale_y, a['bbox'][2]*scale_x, a['bbox'][3]*scale_y]} for a in anns]
    
    # Create Mask
    control_mask = render_layout_mask((512, 512), scaled_anns)
    
    # Create Prompt
    cats = coco.loadCats([a['category_id'] for a in anns])
    names = list(set([c['name'] for c in cats]))
    prompt = f"A photorealistic image comprising {', '.join(names)}"
    
    data_batch.append({
        "id": img_id,
        "real": original_img,
        "mask": control_mask,
        "prompt": prompt,
        "epoch_0_img": None,
        "epoch_9_img": None
    })

# --- 4. GENERATION LOOP (Model Swapping) ---

# A. Generate with EPOCH 0
print("\nü§ñ Loading EPOCH 0 Model...")
cnet_0 = ControlNetModel.from_pretrained(epoch_0_path, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=cnet_0, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

print("   Generating Epoch 0 outputs...")
for i, item in enumerate(data_batch):
    gen = torch.manual_seed(100 + i)
    item["epoch_0_img"] = pipe(item["prompt"], image=item["mask"], num_inference_steps=20, generator=gen).images[0]

# Cleanup Memory
del cnet_0, pipe
torch.cuda.empty_cache()

# B. Generate with EPOCH 9
print("\nüöÄ Loading EPOCH 9 Model...")
cnet_9 = ControlNetModel.from_pretrained(epoch_9_path, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5", controlnet=cnet_9, torch_dtype=torch.float16, safety_checker=None
).to("cuda")
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()

print("   Generating Epoch 9 outputs...")
for i, item in enumerate(data_batch):
    gen = torch.manual_seed(100 + i) # SAME SEED ensures fair comparison
    item["epoch_9_img"] = pipe(item["prompt"], image=item["mask"], num_inference_steps=20, generator=gen).images[0]

# --- 5. FINAL VISUALIZATION ---
print("\nüé® Displaying Results...")
for item in data_batch:
    fig, axs = plt.subplots(1, 4, figsize=(20, 5)) 
    
    # Column 1: Layout
    axs[0].imshow(item["mask"])
    axs[0].set_title("Input Layout (Control)")
    axs[0].axis("off")
    
    # Column 2: Ground Truth
    axs[1].imshow(item["real"])
    axs[1].set_title("Ground Truth (Real)")
    axs[1].axis("off")
    
    # Column 3: Epoch 0
    axs[2].imshow(item["epoch_9_img"])
    axs[2].set_title("Epoch 0 (Untrained)")
    axs[2].axis("off")
    
    # Column 4: Epoch 9
    axs[3].imshow(item["epoch_9_img"])
    axs[3].set_title("Epoch 9 (Final)")
    axs[3].axis("off")
    
    plt.show()