In [None]:
!pip install bitsandbytes

In [11]:
%%writefile train_bbox.py
import os
import torch
import numpy as np
import random
import glob
from PIL import Image, ImageDraw
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from torchvision import transforms
import bitsandbytes as bnb
import wandb
from kaggle_secrets import UserSecretsClient

# Accelerate & Diffusers
from accelerate import Accelerator
from diffusers import (
    StableDiffusionControlNetPipeline, 
    ControlNetModel, 
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from transformers import AutoTokenizer, CLIPTextModel

# --- CONFIGURATION ---
class Config:
    COCO_ROOT = "/kaggle/input/coco-2017-dataset/coco2017"
    TRAIN_IMG_DIR = os.path.join(COCO_ROOT, "train2017")
    TRAIN_ANN_FILE = os.path.join(COCO_ROOT, "annotations/instances_train2017.json")
    
    MODEL_ID = "runwayml/stable-diffusion-v1-5"
    OUTPUT_DIR = "/kaggle/working/controlnet-coco-bbox"
    
    # Auto-resume logic ("latest" or path or None)
    RESUME_FROM_CHECKPOINT = "latest"
    
    # Hyperparameters
    RESOLUTION = 512
    BATCH_SIZE = 8          
    GRAD_ACCUM_STEPS = 1
    LEARNING_RATE = 1e-4    # High LR for ControlNet training
    NUM_EPOCHS = 10         
    
    # Logging
    LOG_INTERVAL = 100       
    LOG_BATCH_SIZE = 8
    SAVE_INTERVAL = 1000    
    MAX_SAMPLES = 20000 
    
    # System
    NUM_WORKERS = 1         
    
    # Robustness
    PROMPT_DROPOUT_PROB = 0.4 

# --- DATASET CLASS ---
class COCOBBoxDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, tokenizer, size=512, max_samples=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.img_ids = self.coco.getImgIds()
        self.tokenizer = tokenizer
        self.size = size
        
        # Filter images with annotations
        self.img_ids = [img_id for img_id in self.img_ids if len(self.coco.getAnnIds(imgIds=img_id)) > 0]
        
        if max_samples:
            self.img_ids = self.img_ids[:max_samples]

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]), 
        ])
        
        self.cond_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(size),
            transforms.ToTensor(), 
        ])

    def draw_bbox_map(self, img_shape, anns):
        canvas = Image.new("RGB", img_shape, (0, 0, 0))
        draw = ImageDraw.Draw(canvas)
        for ann in anns:
            bbox = ann['bbox']
            x, y, w, h = bbox
            # Draw white outlines on black background
            draw.rectangle([x, y, x+w, y+h], outline=(255, 255, 255), width=2)
        return canvas

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        control_image = self.draw_bbox_map(image.size, anns)
        
        cat_ids = [ann['category_id'] for ann in anns]
        cats = self.coco.loadCats(cat_ids)
        cat_names = list(set([cat['name'] for cat in cats]))
        
        # --- PROMPT DROPOUT LOGIC ---
        if random.random() < Config.PROMPT_DROPOUT_PROB:
            text_prompt = ""
        else:
            text_prompt = f"A photorealistic image containing {', '.join(cat_names)}" if cat_names else "A photorealistic image"
        
        return {
            "pixel_values": self.image_transforms(image),
            "conditioning_pixel_values": self.cond_transforms(control_image),
            "input_ids": self.tokenizer(
                text_prompt, max_length=self.tokenizer.model_max_length, 
                padding="max_length", truncation=True, return_tensors="pt"
            ).input_ids[0],
            "raw_prompt": text_prompt
        }

# --- VALIDATION HELPER ---
def log_validation(accelerator, controlnet, unet, vae, text_encoder, tokenizer, val_batch, step):
    if not accelerator.is_main_process: return

    try:
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            Config.MODEL_ID,
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
            unet=unet, controlnet=controlnet,
            safety_checker=None, torch_dtype=torch.float16
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)
        
        log_images = []
        num_samples = min(len(val_batch["raw_prompt"]), Config.LOG_BATCH_SIZE)
        
        for i in range(num_samples):
            # Ensure prompt exists for visualization
            prompt = val_batch["raw_prompt"][i]
            if prompt == "": prompt = "A photorealistic image of the scene"

            # 1. Ground Truth
            gt_tensor = val_batch["pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            gt_image = (gt_tensor / 2 + 0.5).clamp(0, 1)
            gt_image = transforms.ToPILImage()(gt_image)

            # 2. Control Image (BBox Map)
            cond_tensor = val_batch["conditioning_pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            cond_image = transforms.ToPILImage()(cond_tensor)

            # 3. Generation WITH ControlNet (Scale=1.0)
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                pred_image = pipeline(
                    prompt, 
                    image=cond_image, 
                    num_inference_steps=20, 
                    generator=generator,
                    controlnet_conditioning_scale=1.0 
                ).images[0]
            
            # 4. Generation WITHOUT ControlNet (Scale=0.0) - Ablation
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                base_image = pipeline(
                    prompt, 
                    image=cond_image, 
                    num_inference_steps=20, 
                    generator=generator,
                    controlnet_conditioning_scale=0.0 
                ).images[0]
            
            log_images.append(wandb.Image(cond_image, caption=f"#{i} BBox Map"))
            log_images.append(wandb.Image(gt_image, caption=f"#{i} Truth"))
            log_images.append(wandb.Image(pred_image, caption=f"#{i} With Control"))
            log_images.append(wandb.Image(base_image, caption=f"#{i} No Control"))
        
        wandb.log({"validation": log_images})
        del pipeline
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"Skipping validation log due to error: {e}")

# --- MAIN FUNCTION ---
def main():
    accelerator = Accelerator(
        gradient_accumulation_steps=Config.GRAD_ACCUM_STEPS,
        mixed_precision="fp16",
        log_with="wandb",
    )
    
    if accelerator.is_main_process:
        try:
            user_secrets = UserSecretsClient()
            wandb.login(key=user_secrets.get_secret("wandb"))
            accelerator.init_trackers("controlnet-coco-bbox", config=Config.__dict__)
        except Exception as e:
            print(f"WandB init warning: {e}")

    # Load Models 
    if accelerator.is_main_process: print("Loading models...")
    
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, subfolder="tokenizer", use_fast=False)
    noise_scheduler = DDPMScheduler.from_pretrained(Config.MODEL_ID, subfolder="scheduler")
    
    text_encoder = CLIPTextModel.from_pretrained(Config.MODEL_ID, subfolder="text_encoder", torch_dtype=torch.float16)
    vae = AutoencoderKL.from_pretrained(Config.MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(Config.MODEL_ID, subfolder="unet", torch_dtype=torch.float16)
    
    controlnet = ControlNetModel.from_unet(unet)
    
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    
    controlnet.train()
    controlnet.enable_gradient_checkpointing()
    unet.enable_gradient_checkpointing()

    # Dataset
    if accelerator.is_main_process: print("Loading dataset...")
    dataset = COCOBBoxDataset(Config.TRAIN_IMG_DIR, Config.TRAIN_ANN_FILE, tokenizer, max_samples=Config.MAX_SAMPLES)
    
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )

    val_batch = next(iter(train_dataloader))
    optimizer = bnb.optim.AdamW8bit(controlnet.parameters(), lr=Config.LEARNING_RATE)

    # Prepare
    controlnet, optimizer, train_dataloader = accelerator.prepare(
        controlnet, optimizer, train_dataloader
    )
    
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    unet.to(accelerator.device)

    # --- RESUME LOGIC ---
    global_step = 0
    first_epoch = 0
    
    if Config.RESUME_FROM_CHECKPOINT:
        if Config.RESUME_FROM_CHECKPOINT == "latest":
            # Check for existing checkpoints
            if os.path.exists(Config.OUTPUT_DIR):
                dirs = os.listdir(Config.OUTPUT_DIR)
                dirs = [d for d in dirs if d.startswith("checkpoint")]
                dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
                path = dirs[-1] if len(dirs) > 0 else None
            else:
                path = None
            
            if path:
                accelerator.print(f"Resuming from latest checkpoint: {path}")
                accelerator.load_state(os.path.join(Config.OUTPUT_DIR, path))
                global_step = int(path.split("-")[1])
                first_epoch = global_step // len(train_dataloader)
            else:
                accelerator.print("No checkpoint found. Starting from scratch.")
        else:
            accelerator.print(f"Resuming from checkpoint: {Config.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(Config.RESUME_FROM_CHECKPOINT)
            global_step = int(Config.RESUME_FROM_CHECKPOINT.split("-")[-1])
            first_epoch = global_step // len(train_dataloader)

    # --- TRAINING LOOP ---
    if accelerator.is_main_process: print(f"Starting training from Step {global_step}, Epoch {first_epoch}...")
    
    for epoch in range(first_epoch, Config.NUM_EPOCHS):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch}")
        
        for batch in train_dataloader:
            with accelerator.accumulate(controlnet):
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                control_cond = batch["conditioning_pixel_values"].to(dtype=torch.float16)
                
                down_res, mid_res = controlnet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=control_cond,
                    return_dict=False,
                )
                
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=[r.to(dtype=torch.float16) for r in down_res],
                    mid_block_additional_residual=mid_res.to(dtype=torch.float16),
                ).sample
                
                loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                global_step += 1
                progress_bar.update(1)
                
                if accelerator.is_main_process:
                    wandb.log({"train_loss": loss.item(), "global_step": global_step})
                    
                    if global_step % Config.LOG_INTERVAL == 0:
                        log_validation(accelerator, accelerator.unwrap_model(controlnet), unet, vae, text_encoder, tokenizer, val_batch, global_step)
                        
                    if global_step % Config.SAVE_INTERVAL == 0:
                        save_path = os.path.join(Config.OUTPUT_DIR, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
    
    if accelerator.is_main_process:
        accelerator.unwrap_model(controlnet).save_pretrained(os.path.join(Config.OUTPUT_DIR, "final_controlnet_bbox"))
        print("Training Finished.")
        accelerator.end_training()

if __name__ == "__main__":
    main()

Overwriting train_accelerate.py


In [None]:
!accelerate launch --multi_gpu --num_processes=2 --mixed_precision=fp16 train_accelerate.py

ipex flag is deprecated, will be removed in Accelerate v1.10. From 2.7.0, PyTorch has all needed optimizations for Intel CPU and XPU.
The following values were not passed to `accelerate launch` and had defaults used instead:
	`--num_machines` was set to a value of `1`
	`--dynamo_backend` was set to a value of `'no'`
2025-12-13 14:20:14.834274: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2025-12-13 14:20:14.834279: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765635614.858817    1323 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765635614.858944    1324 cuda_dnn.cc:8310] Unable to register cuDNN factor

In [None]:
# ==========================================
# EVALUATION & VISUALIZATION CELL
# ==========================================
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
from pycocotools.coco import COCO
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, UniPCMultistepScheduler

# 1. Configuration
VAL_IMG_DIR = "/kaggle/input/coco-2017-dataset/coco2017/val2017"
VAL_ANN_FILE = "/kaggle/input/coco-2017-dataset/coco2017/annotations/instances_val2017.json"
MODEL_PATH = "/kaggle/working/controlnet-coco-bbox/final_controlnet_accelerate" # Path to your saved model
BASE_MODEL = "runwayml/stable-diffusion-v1-5"
NUM_SAMPLES = 20

# 2. Helper to Draw Bounding Boxes
def draw_bbox_map(img_shape, anns):
    canvas = Image.new("RGB", img_shape, (0, 0, 0))
    draw = ImageDraw.Draw(canvas)
    for ann in anns:
        bbox = ann['bbox'] # [x, y, w, h]
        x, y, w, h = bbox
        draw.rectangle([x, y, x+w, y+h], outline=(255, 255, 255), width=2)
    return canvas

# 3. Load COCO Validation Data
print("Loading COCO Annotations...")
coco = COCO(VAL_ANN_FILE)
img_ids = coco.getImgIds()
# Filter for images that actually have annotations
img_ids = [img_id for img_id in img_ids if len(coco.getAnnIds(imgIds=img_id)) > 0]

# Select random samples
np.random.seed(42)
selected_indices = np.random.choice(img_ids, NUM_SAMPLES, replace=False)

# 4. Load Model Pipeline
print(f"Loading Model from {MODEL_PATH}...")
controlnet = ControlNetModel.from_pretrained(MODEL_PATH, torch_dtype=torch.float16)
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    safety_checker=None
).to("cuda")

# Use a fast scheduler for inference
pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
# Optimizations
pipe.enable_model_cpu_offload() 

# 5. Run Inference & Collect Results
results = []
print("Generating Images...")

for i, img_id in enumerate(selected_indices):
    # Load Info
    img_info = coco.loadImgs(int(img_id))[0]
    img_path = os.path.join(VAL_IMG_DIR, img_info['file_name'])
    
    # Load Original Image (Ground Truth)
    gt_image = Image.open(img_path).convert("RGB").resize((512, 512))
    
    # Create Layout (Condition)
    ann_ids = coco.getAnnIds(imgIds=img_id)
    anns = coco.loadAnns(ann_ids)
    control_image = draw_bbox_map((512, 512), anns)
    
    # Create Prompt
    cat_ids = [ann['category_id'] for ann in anns]
    cats = coco.loadCats(cat_ids)
    cat_names = list(set([cat['name'] for cat in cats]))
    prompt = f"A photorealistic image containing {', '.join(cat_names)}"
    
    # Generate
    generator = torch.Generator(device="cuda").manual_seed(42)
    with torch.inference_mode():
        pred_image = pipe(
            prompt, 
            image=control_image, 
            num_inference_steps=25, 
            generator=generator
        ).images[0]
    
    results.append((control_image, gt_image, pred_image, prompt))
    print(f"Processed {i+1}/{NUM_SAMPLES}")

# 6. Plotting
print("Plotting results...")
# Create a figure with NUM_SAMPLES rows and 3 columns
fig, axes = plt.subplots(NUM_SAMPLES, 3, figsize=(15, 5 * NUM_SAMPLES))

for idx, (cond, gt, pred, prompt) in enumerate(results):
    # Column 1: Input Layout
    axes[idx, 0].imshow(cond)
    axes[idx, 0].set_title(f"Input Layout\n(BBox Map)", fontsize=10)
    axes[idx, 0].axis("off")
    
    # Column 2: Ground Truth
    axes[idx, 1].imshow(gt)
    axes[idx, 1].set_title(f"Ground Truth", fontsize=10)
    axes[idx, 1].axis("off")
    
    # Column 3: Prediction
    axes[idx, 2].imshow(pred)
    axes[idx, 2].set_title(f"Prediction\nPrompt: {prompt[:50]}...", fontsize=10)
    axes[idx, 2].axis("off")

plt.tight_layout()
plt.show()