In [1]:
# !pip install bitsandbytes

Collecting bitsandbytes
  Downloading bitsandbytes-0.49.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch<3,>=2.3->bitsandbytes)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-c

In [17]:
%%writefile train_seg.py
import os
import torch
import numpy as np
import random 
from PIL import Image
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from torchvision import transforms
import bitsandbytes as bnb
import wandb
from kaggle_secrets import UserSecretsClient

# Accelerate & Diffusers
from accelerate import Accelerator
from diffusers import (
    StableDiffusionControlNetPipeline, 
    ControlNetModel, 
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from transformers import AutoTokenizer, CLIPTextModel

# --- CONFIGURATION ---
class Config:
    COCO_ROOT = "/kaggle/input/coco-2017-dataset/coco2017"
    TRAIN_IMG_DIR = os.path.join(COCO_ROOT, "train2017")
    TRAIN_ANN_FILE = os.path.join(COCO_ROOT, "annotations/instances_train2017.json")
    
    MODEL_ID = "runwayml/stable-diffusion-v1-5"
    OUTPUT_DIR = "/kaggle/working/controlnet-coco-seg"
    
    RESUME_FROM_CHECKPOINT = "latest"
    
    # Hyperparameters
    RESOLUTION = 512
    BATCH_SIZE = 8          
    GRAD_ACCUM_STEPS = 1
    LEARNING_RATE = 1e-4    # High LR to wake up zero-convolutions
    NUM_EPOCHS = 10  
    
    LOG_INTERVAL = 100       
    LOG_BATCH_SIZE = 8
    SAVE_INTERVAL = 1000    
    MAX_SAMPLES = 20000 
    
    NUM_WORKERS = 1   
    
    PROMPT_DROPOUT_PROB = 0.4 

# --- DATASET CLASS ---
class COCOSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, tokenizer, size=512, max_samples=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.img_ids = self.coco.getImgIds()
        self.tokenizer = tokenizer
        self.size = size
        
        self.img_ids = [img_id for img_id in self.img_ids if len(self.coco.getAnnIds(imgIds=img_id)) > 0]
        
        if max_samples:
            self.img_ids = self.img_ids[:max_samples]

        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]), 
        ])
        
        self.cond_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(size),
            transforms.ToTensor(), 
        ])
        
        self.color_map = self._generate_color_map()

    def _generate_color_map(self):
        cats = self.coco.loadCats(self.coco.getCatIds())
        palette = {}
        for cat in cats:
            import hashlib
            hash_object = hashlib.md5(str(cat['id']).encode())
            hex_hash = hash_object.hexdigest()
            r = int(hex_hash[0:2], 16)
            g = int(hex_hash[2:4], 16)
            b = int(hex_hash[4:6], 16)
            palette[cat['id']] = (r, g, b)
        return palette

    def draw_segmentation_map(self, img_shape, anns):
        mask = np.zeros((img_shape[1], img_shape[0], 3), dtype=np.uint8)
        anns = sorted(anns, key=lambda x: x['area'], reverse=True)
        for ann in anns:
            cat_id = ann['category_id']
            color = self.color_map.get(cat_id, (255, 255, 255))
            binary_mask = self.coco.annToMask(ann)
            mask[binary_mask == 1] = color
        return Image.fromarray(mask)

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        control_image = self.draw_segmentation_map(image.size, anns)
        
        cat_ids = [ann['category_id'] for ann in anns]
        cats = self.coco.loadCats(cat_ids)
        cat_names = list(set([cat['name'] for cat in cats]))
        
        # --- HIGH DROPOUT LOGIC ---
        if random.random() < Config.PROMPT_DROPOUT_PROB:
            text_prompt = ""
        else:
            text_prompt = f"A photorealistic image containing {', '.join(cat_names)}" if cat_names else "A photorealistic image"
        
        return {
            "pixel_values": self.image_transforms(image),
            "conditioning_pixel_values": self.cond_transforms(control_image),
            "input_ids": self.tokenizer(
                text_prompt, max_length=self.tokenizer.model_max_length, 
                padding="max_length", truncation=True, return_tensors="pt"
            ).input_ids[0],
            "raw_prompt": text_prompt 
        }

# --- VALIDATION HELPER ---
def log_validation(accelerator, controlnet, unet, vae, text_encoder, tokenizer, val_batch, step):
    if not accelerator.is_main_process: return

    try:
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            Config.MODEL_ID,
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
            unet=unet, controlnet=controlnet,
            safety_checker=None, torch_dtype=torch.float16
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)
        
        log_images = []
        num_samples = min(len(val_batch["raw_prompt"]), Config.LOG_BATCH_SIZE)
        
        for i in range(num_samples):
            # Force prompt to exist for validation visualization
            prompt = val_batch["raw_prompt"][i]
            if prompt == "": prompt = "A photorealistic image of the scene"

            gt_tensor = val_batch["pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            gt_image = (gt_tensor / 2 + 0.5).clamp(0, 1)
            gt_image = transforms.ToPILImage()(gt_image)

            cond_tensor = val_batch["conditioning_pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            cond_image = transforms.ToPILImage()(cond_tensor)

            # 1. With Control
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                pred_image = pipeline(
                    prompt, 
                    image=cond_image, 
                    num_inference_steps=20, 
                    generator=generator,
                    controlnet_conditioning_scale=1.0 
                ).images[0]

            # 2. Without Control
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                base_image = pipeline(
                    prompt, 
                    image=cond_image, 
                    num_inference_steps=20, 
                    generator=generator,
                    controlnet_conditioning_scale=0.0
                ).images[0]
            
            log_images.append(wandb.Image(cond_image, caption=f"#{i} Seg Map"))
            log_images.append(wandb.Image(gt_image, caption=f"#{i} Truth"))
            log_images.append(wandb.Image(pred_image, caption=f"#{i} With Control"))
            log_images.append(wandb.Image(base_image, caption=f"#{i} No Control"))
        
        wandb.log({"validation": log_images})
        del pipeline
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"Skipping validation log due to error: {e}")

# --- MAIN FUNCTION ---
def main():
    accelerator = Accelerator(
        gradient_accumulation_steps=Config.GRAD_ACCUM_STEPS,
        mixed_precision="fp16",
        log_with="wandb",
    )
    
    if accelerator.is_main_process:
        try:
            user_secrets = UserSecretsClient()
            wandb.login(key=user_secrets.get_secret("wandb"))
            accelerator.init_trackers("controlnet-coco-seg", config=Config.__dict__)
        except Exception as e:
            print(f"WandB init warning: {e}")

    if accelerator.is_main_process: print("Loading models...")
    
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, subfolder="tokenizer", use_fast=False)
    noise_scheduler = DDPMScheduler.from_pretrained(Config.MODEL_ID, subfolder="scheduler")
    
    text_encoder = CLIPTextModel.from_pretrained(Config.MODEL_ID, subfolder="text_encoder", torch_dtype=torch.float16)
    vae = AutoencoderKL.from_pretrained(Config.MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(Config.MODEL_ID, subfolder="unet", torch_dtype=torch.float16)
    
    controlnet = ControlNetModel.from_unet(unet)
    
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    
    controlnet.train()
    controlnet.enable_gradient_checkpointing()
    unet.enable_gradient_checkpointing()

    if accelerator.is_main_process: print("Loading dataset...")
    dataset = COCOSegmentationDataset(Config.TRAIN_IMG_DIR, Config.TRAIN_ANN_FILE, tokenizer, max_samples=Config.MAX_SAMPLES)
    
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
    )

    val_batch = next(iter(train_dataloader))
    optimizer = bnb.optim.AdamW8bit(controlnet.parameters(), lr=Config.LEARNING_RATE)

    controlnet, optimizer, train_dataloader = accelerator.prepare(
        controlnet, optimizer, train_dataloader
    )
    
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    unet.to(accelerator.device)

    # --- RESUME LOGIC ---
    global_step = 0
    first_epoch = 0
    
    if Config.RESUME_FROM_CHECKPOINT:
        if Config.RESUME_FROM_CHECKPOINT == "latest":
            dirs = os.listdir(Config.OUTPUT_DIR)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
            
            if path:
                accelerator.print(f"Resuming from latest checkpoint: {path}")
                accelerator.load_state(os.path.join(Config.OUTPUT_DIR, path))
                global_step = int(path.split("-")[1])
                first_epoch = global_step // len(train_dataloader)
            else:
                accelerator.print("No checkpoint found. Starting from scratch.")
        else:
            accelerator.print(f"Resuming from checkpoint: {Config.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(Config.RESUME_FROM_CHECKPOINT)
            global_step = int(Config.RESUME_FROM_CHECKPOINT.split("-")[-1])
            first_epoch = global_step // len(train_dataloader)

    if accelerator.is_main_process: print(f"Starting training from Step {global_step}, Epoch {first_epoch}...")
    
    for epoch in range(first_epoch, Config.NUM_EPOCHS):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch}")
        
        for batch in train_dataloader:
            with accelerator.accumulate(controlnet):
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                control_cond = batch["conditioning_pixel_values"].to(dtype=torch.float16)
                
                down_res, mid_res = controlnet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=control_cond,
                    return_dict=False,
                )
                
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=[r.to(dtype=torch.float16) for r in down_res],
                    mid_block_additional_residual=mid_res.to(dtype=torch.float16),
                ).sample
                
                loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()

            if accelerator.sync_gradients:
                global_step += 1
                progress_bar.update(1)
                
                if accelerator.is_main_process:
                    wandb.log({"train_loss": loss.item(), "global_step": global_step})
                    
                    if global_step % Config.LOG_INTERVAL == 0:
                        log_validation(accelerator, accelerator.unwrap_model(controlnet), unet, vae, text_encoder, tokenizer, val_batch, global_step)
                        
                    if global_step % Config.SAVE_INTERVAL == 0:
                        save_path = os.path.join(Config.OUTPUT_DIR, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
    
    if accelerator.is_main_process:
        accelerator.unwrap_model(controlnet).save_pretrained(os.path.join(Config.OUTPUT_DIR, "final_controlnet_seg"))
        print("Training Finished.")
        accelerator.end_training()

if __name__ == "__main__":
    main()

Overwriting train_seg.py


In [None]:
!accelerate launch --multi_gpu --num_processes=2 --mixed-precision=fp16 train_seg.py

In [16]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pycocotools.coco import COCO
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    UniPCMultistepScheduler,
)

# -----------------------------
# 1) Config
# -----------------------------
VAL_IMG_DIR = "/kaggle/input/coco-2017-dataset/coco2017/val2017"
VAL_ANN_FILE = "/kaggle/input/coco-2017-dataset/coco2017/annotations/instances_val2017.json"
MODEL_PATH = "/kaggle/working/controlnet-coco-seg/final_controlnet_seg"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

NUM_SAMPLES = 5
RESOLUTION = 512
SEED_BASE = 42

# If outputs don't follow the mask, try 1.5-2.0
CONTROLNET_COND_SCALE = 1.0

# -----------------------------
# 2) Helpers (match training)
# -----------------------------
def generate_color_map(coco):
    cats = coco.loadCats(coco.getCatIds())
    palette = {}
    import hashlib
    for cat in cats:
        h = hashlib.md5(str(cat["id"]).encode()).hexdigest()
        palette[cat["id"]] = (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
    return palette

def draw_segmentation_map(img_wh, anns, coco, color_map):
    # img_wh = (W, H)
    mask = np.zeros((img_wh[1], img_wh[0], 3), dtype=np.uint8)
    anns = sorted(anns, key=lambda x: x.get("area", 0.0), reverse=True)
    for ann in anns:
        cat_id = ann["category_id"]
        color = color_map.get(cat_id, (255, 255, 255))
        binary_mask = coco.annToMask(ann)
        mask[binary_mask == 1] = color
    return Image.fromarray(mask)

def build_prompt(coco, anns):
    cat_ids = [ann["category_id"] for ann in anns]
    cats = coco.loadCats(cat_ids)
    cat_names = sorted(list(set([cat["name"] for cat in cats])))
    if not cat_names:
        return "A photorealistic image"
    return f"A photorealistic image containing {', '.join(cat_names)}"

@torch.no_grad()
def conditioning_effect_test(pipe, prompt, control_image, steps=10, seed=123):
    """
    Quick sanity check: compare output with real control vs black control.
    Returns MAD in pixel space (0..255).
    """
    gen = torch.Generator(device="cuda").manual_seed(seed)
    real = pipe(
        prompt,
        image=control_image,
        num_inference_steps=steps,
        generator=gen,
        controlnet_conditioning_scale=CONTROLNET_COND_SCALE,
    ).images[0]

    zero_img = Image.new("RGB", control_image.size, (0, 0, 0))
    gen = torch.Generator(device="cuda").manual_seed(seed)
    zero = pipe(
        prompt,
        image=zero_img,
        num_inference_steps=steps,
        generator=gen,
        controlnet_conditioning_scale=CONTROLNET_COND_SCALE,
    ).images[0]

    a = np.asarray(real).astype(np.float32)
    b = np.asarray(zero).astype(np.float32)
    mad = float(np.mean(np.abs(a - b)))
    return mad, real, zero

# -----------------------------
# 3) Load COCO + model
# -----------------------------
print("Loading COCO...")
coco = COCO(VAL_ANN_FILE)
color_map = generate_color_map(coco)

img_ids = coco.getImgIds()
img_ids = [img_id for img_id in img_ids if len(coco.getAnnIds(imgIds=img_id)) > 0]

np.random.seed(42)
selected_ids = np.random.choice(img_ids, NUM_SAMPLES, replace=False)

print(f"Loading ControlNet from {MODEL_PATH}...")
controlnet = ControlNetModel.from_pretrained(MODEL_PATH, torch_dtype=torch.float16)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    safety_checker=None,
).to("cuda")

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)

# NOTE: cpu offload is fine for Kaggle memory limits
pipe.enable_model_cpu_offload()

# -----------------------------
# 4) Inference
# -----------------------------
results = []
print("Generating...")

for i, img_id in enumerate(selected_ids):
    img_info = coco.loadImgs(int(img_id))[0]
    img_path = os.path.join(VAL_IMG_DIR, img_info["file_name"])

    gt_image = Image.open(img_path).convert("RGB").resize((RESOLUTION, RESOLUTION))

    ann_ids = coco.getAnnIds(imgIds=int(img_id))
    anns = coco.loadAnns(ann_ids)

    control_image = draw_segmentation_map((RESOLUTION, RESOLUTION), anns, coco, color_map)
    prompt = build_prompt(coco, anns)

    generator = torch.Generator(device="cuda").manual_seed(SEED_BASE + i)
    with torch.inference_mode():
        pred_image = pipe(
            prompt,
            image=control_image,
            num_inference_steps=20,
            generator=generator,
            controlnet_conditioning_scale=CONTROLNET_COND_SCALE,
        ).images[0]

    # Run the conditioning-vs-zero check only once (first sample) to avoid slowing everything
    extra = None
    if i == 0:
        try:
            mad, pred_real_10, pred_zero_10 = conditioning_effect_test(
                pipe, prompt, control_image, steps=10, seed=999
            )
            print(f"[cond-effect] MAD(real cond vs ZERO cond) = {mad:.2f} (0..255 scale)")
            extra = (pred_zero_10, mad)
        except Exception as e:
            print(f"[cond-effect] skipped due to error: {e}")

    results.append((control_image, gt_image, pred_image, prompt, extra))
    print(f"Processed {i+1}/{NUM_SAMPLES}")

# -----------------------------
# 5) Plot
# -----------------------------
print("Plotting...")
has_extra = results[0][4] is not None
ncols = 4 if has_extra else 3

fig, axes = plt.subplots(NUM_SAMPLES, ncols, figsize=(5 * ncols, 5 * NUM_SAMPLES))

# Handle NUM_SAMPLES == 1 case
if NUM_SAMPLES == 1:
    axes = np.expand_dims(axes, 0)

for idx, (cond, gt, pred, prompt, extra) in enumerate(results):
    axes[idx, 0].imshow(cond)
    axes[idx, 0].set_title("Seg Mask", fontsize=10)
    axes[idx, 0].axis("off")

    axes[idx, 1].imshow(gt)
    axes[idx, 1].set_title("Ground Truth", fontsize=10)
    axes[idx, 1].axis("off")

    axes[idx, 2].imshow(pred)
    axes[idx, 2].set_title(f"Pred: {prompt[:60]}...", fontsize=10)
    axes[idx, 2].axis("off")

    if ncols == 4:
        if extra is not None:
            pred_zero_10, mad = extra
            axes[idx, 3].imshow(pred_zero_10)
            axes[idx, 3].set_title(f"Pred (ZERO cond)\nMAD={mad:.2f}", fontsize=10)
            axes[idx, 3].axis("off")
        else:
            axes[idx, 3].axis("off")

plt.tight_layout()
plt.show()


2025-12-13 22:33:31.400254: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1765665211.422434      47 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1765665211.429306      47 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

AttributeError: 'MessageFactory' object has no attribute 'GetPrototype'

Loading COCO...
loading annotations into memory...
Done (t=0.65s)
creating index...
index created!
Loading ControlNet from /kaggle/working/controlnet-coco-seg/final_controlnet_seg...


OSError: We couldn't connect to 'https://huggingface.co' to load this model, couldn't find it in the cached files and it looks like /kaggle/working/controlnet-coco-seg/final_controlnet_seg is not the path to a directory containing a config.json file.
Checkout your internet connection or see how to run the library in offline mode at 'https://huggingface.co/docs/diffusers/installation#offline-mode'.