In [None]:
!pip install bitsandbytes

In [None]:
%%writefile train_seg.py
import os
import torch
import numpy as np
import random
from PIL import Image
from tqdm.auto import tqdm
from pycocotools.coco import COCO
from torchvision import transforms
import bitsandbytes as bnb
import wandb
from kaggle_secrets import UserSecretsClient

# Imports
from accelerate import Accelerator
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    DDPMScheduler,
    AutoencoderKL,
    UNet2DConditionModel
)
from transformers import AutoTokenizer, CLIPTextModel

# Config
class Config:
    COCO_ROOT = "/kaggle/input/coco-2017-dataset/coco2017"
    TRAIN_IMG_DIR = os.path.join(COCO_ROOT, "train2017")
    TRAIN_ANN_FILE = os.path.join(COCO_ROOT, "annotations/instances_train2017.json")
    MODEL_ID = "runwayml/stable-diffusion-v1-5"
    OUTPUT_DIR = "/kaggle/working/controlnet-coco-seg"
    RESUME_FROM_CHECKPOINT = "latest"
    RESOLUTION = 512
    BATCH_SIZE = 8
    GRAD_ACCUM_STEPS = 1
    LEARNING_RATE = 3e-5
    NUM_EPOCHS = 10
    LOG_INTERVAL = 100
    LOG_BATCH_SIZE = 8
    SAVE_INTERVAL = 1000
    MAX_SAMPLES = 20000
    NUM_WORKERS = 1
    PROMPT_DROPOUT_PROB = 0.4

# Dataset
class COCOSegmentationDataset(torch.utils.data.Dataset):
    def __init__(self, img_dir, ann_file, tokenizer, size=512, max_samples=None):
        self.img_dir = img_dir
        self.coco = COCO(ann_file)
        self.img_ids = self.coco.getImgIds()
        self.tokenizer = tokenizer
        self.size = size
        self.img_ids = [img_id for img_id in self.img_ids if len(self.coco.getAnnIds(imgIds=img_id)) > 0]
        if max_samples:
            self.img_ids = self.img_ids[:max_samples]
        self.image_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ])
        self.cond_transforms = transforms.Compose([
            transforms.Resize(size, interpolation=transforms.InterpolationMode.NEAREST),
            transforms.CenterCrop(size),
            transforms.ToTensor(),
        ])
        self.color_map = self._generate_color_map()

    def _generate_color_map(self):
        cats = self.coco.loadCats(self.coco.getCatIds())
        palette = {}
        for cat in cats:
            import hashlib
            hash_object = hashlib.md5(str(cat['id']).encode())
            hex_hash = hash_object.hexdigest()
            r = int(hex_hash[0:2], 16)
            g = int(hex_hash[2:4], 16)
            b = int(hex_hash[4:6], 16)
            palette[cat['id']] = (r, g, b)
        return palette

    def draw_segmentation_map(self, img_shape, anns):
        mask = np.zeros((img_shape[1], img_shape[0], 3), dtype=np.uint8)
        anns = sorted(anns, key=lambda x: x['area'], reverse=True)
        for ann in anns:
            cat_id = ann['category_id']
            color = self.color_map.get(cat_id, (255, 255, 255))
            binary_mask = self.coco.annToMask(ann)
            mask[binary_mask == 1] = color
        return Image.fromarray(mask)

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        image = Image.open(img_path).convert("RGB")
        ann_ids = self.coco.getAnnIds(imgIds=img_id)
        anns = self.coco.loadAnns(ann_ids)
        control_image = self.draw_segmentation_map(image.size, anns)
        cat_ids = [ann['category_id'] for ann in anns]
        cats = self.coco.loadCats(cat_ids)
        cat_names = list(set([cat['name'] for cat in cats]))
        if random.random() < Config.PROMPT_DROPOUT_PROB:
            text_prompt = ""
        else:
            text_prompt = f"A photorealistic image containing {', '.join(cat_names)}" if cat_names else "A photorealistic image"
        return {
            "pixel_values": self.image_transforms(image),
            "conditioning_pixel_values": self.cond_transforms(control_image),
            "input_ids": self.tokenizer(
                text_prompt, max_length=self.tokenizer.model_max_length,
                padding="max_length", truncation=True, return_tensors="pt"
            ).input_ids[0],
            "raw_prompt": text_prompt
        }

# Validation helper
def log_validation(accelerator, controlnet, unet, vae, text_encoder, tokenizer, val_batch, step):
    if not accelerator.is_main_process: return
    try:
        pipeline = StableDiffusionControlNetPipeline.from_pretrained(
            Config.MODEL_ID,
            vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
            unet=unet, controlnet=controlnet,
            safety_checker=None, torch_dtype=torch.float16
        ).to(accelerator.device)
        pipeline.set_progress_bar_config(disable=True)
        log_images = []
        num_samples = min(len(val_batch["raw_prompt"]), Config.LOG_BATCH_SIZE)
        for i in range(num_samples):
            prompt = val_batch["raw_prompt"][i]
            if prompt == "": prompt = "A photorealistic image of the scene"
            gt_tensor = val_batch["pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            gt_image = (gt_tensor / 2 + 0.5).clamp(0, 1)
            gt_image = transforms.ToPILImage()(gt_image)
            cond_tensor = val_batch["conditioning_pixel_values"][i].to(accelerator.device, dtype=torch.float16)
            cond_image = transforms.ToPILImage()(cond_tensor)
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                pred_image = pipeline(
                    prompt,
                    image=cond_image,
                    num_inference_steps=20,
                    generator=generator,
                    controlnet_conditioning_scale=1.0
                ).images[0]
            generator = torch.Generator(device=accelerator.device).manual_seed(42 + i)
            with torch.autocast("cuda"):
                base_image = pipeline(
                    prompt,
                    image=cond_image,
                    num_inference_steps=20,
                    generator=generator,
                    controlnet_conditioning_scale=0.0
                ).images[0]
            log_images.append(wandb.Image(cond_image, caption=f"#{i} Seg Map"))
            log_images.append(wandb.Image(gt_image, caption=f"#{i} Truth"))
            log_images.append(wandb.Image(pred_image, caption=f"#{i} With Control"))
            log_images.append(wandb.Image(base_image, caption=f"#{i} No Control"))
        wandb.log({"validation": log_images})
        del pipeline
        torch.cuda.empty_cache()
    except Exception as e:
        print(f"Skipping validation log due to error: {e}")

# Main
def main():
    accelerator = Accelerator(
        gradient_accumulation_steps=Config.GRAD_ACCUM_STEPS,
        mixed_precision="fp16",
        log_with="wandb",
    )
    if accelerator.is_main_process:
        try:
            user_secrets = UserSecretsClient()
            wandb.login(key=user_secrets.get_secret("wandb"))
            accelerator.init_trackers("controlnet-coco-seg", config=Config.__dict__)
        except Exception as e:
            print(f"WandB init warning: {e}")
    if accelerator.is_main_process: print("Loading models...")
    tokenizer = AutoTokenizer.from_pretrained(Config.MODEL_ID, subfolder="tokenizer", use_fast=False)
    noise_scheduler = DDPMScheduler.from_pretrained(Config.MODEL_ID, subfolder="scheduler")
    text_encoder = CLIPTextModel.from_pretrained(Config.MODEL_ID, subfolder="text_encoder", torch_dtype=torch.float16)
    vae = AutoencoderKL.from_pretrained(Config.MODEL_ID, subfolder="vae", torch_dtype=torch.float16)
    unet = UNet2DConditionModel.from_pretrained(Config.MODEL_ID, subfolder="unet", torch_dtype=torch.float16)
    controlnet = ControlNetModel.from_unet(unet)
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    controlnet.train()
    controlnet.enable_gradient_checkpointing()
    unet.enable_gradient_checkpointing()
    if accelerator.is_main_process: print("Loading dataset...")
    dataset = COCOSegmentationDataset(Config.TRAIN_IMG_DIR, Config.TRAIN_ANN_FILE, tokenizer, max_samples=Config.MAX_SAMPLES)
    train_dataloader = torch.utils.data.DataLoader(
        dataset,
        batch_size=Config.BATCH_SIZE,
        shuffle=True,
        num_workers=Config.NUM_WORKERS,
        pin_memory=True,
        drop_last=True
)
    val_batch = next(iter(train_dataloader))
    optimizer = bnb.optim.AdamW8bit(controlnet.parameters(), lr=Config.LEARNING_RATE)
    controlnet, optimizer, train_dataloader = accelerator.prepare(
        controlnet, optimizer, train_dataloader
)
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    unet.to(accelerator.device)
    global_step = 0
    first_epoch = 0
    if Config.RESUME_FROM_CHECKPOINT:
        if Config.RESUME_FROM_CHECKPOINT == "latest":
            dirs = os.listdir(Config.OUTPUT_DIR)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None
            if path:
                accelerator.print(f"Resuming from latest checkpoint: {path}")
                accelerator.load_state(os.path.join(Config.OUTPUT_DIR, path))
                global_step = int(path.split("-")[1])
                first_epoch = global_step // len(train_dataloader)
            else:
                accelerator.print("No checkpoint found. Starting from scratch.")
        else:
            accelerator.print(f"Resuming from checkpoint: {Config.RESUME_FROM_CHECKPOINT}")
            accelerator.load_state(Config.RESUME_FROM_CHECKPOINT)
            global_step = int(Config.RESUME_FROM_CHECKPOINT.split("-")[-1])
            first_epoch = global_step // len(train_dataloader)
    if accelerator.is_main_process: print(f"Starting training from Step {global_step}, Epoch {first_epoch}...")
    for epoch in range(first_epoch, Config.NUM_EPOCHS):
        progress_bar = tqdm(total=len(train_dataloader), disable=not accelerator.is_local_main_process, desc=f"Epoch {epoch}")
        for batch in train_dataloader:
            with accelerator.accumulate(controlnet):
                latents = vae.encode(batch["pixel_values"].to(dtype=torch.float16)).latent_dist.sample()
                latents = latents * vae.config.scaling_factor
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device).long()
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                encoder_hidden_states = text_encoder(batch["input_ids"])[0]
                control_cond = batch["conditioning_pixel_values"].to(dtype=torch.float16)
                down_res, mid_res = controlnet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    controlnet_cond=control_cond,
                    return_dict=False,
                )
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=encoder_hidden_states,
                    down_block_additional_residuals=[r.to(dtype=torch.float16) for r in down_res],
                    mid_block_additional_residual=mid_res.to(dtype=torch.float16),
                ).sample
                loss = torch.nn.functional.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                accelerator.backward(loss)
                optimizer.step()
                optimizer.zero_grad()
            if accelerator.sync_gradients:
                global_step += 1
                progress_bar.update(1)
                if accelerator.is_main_process:
                    wandb.log({"train_loss": loss.item(), "global_step": global_step})
                    if global_step % Config.LOG_INTERVAL == 0:
                        log_validation(accelerator, accelerator.unwrap_model(controlnet), unet, vae, text_encoder, tokenizer, val_batch, global_step)
                    if global_step % Config.SAVE_INTERVAL == 0:
                        save_path = os.path.join(Config.OUTPUT_DIR, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
    if accelerator.is_main_process:
        accelerator.unwrap_model(controlnet).save_pretrained(os.path.join(Config.OUTPUT_DIR, "final_controlnet_seg"))
        print("Training Finished.")
        accelerator.end_training()

if __name__ == "__main__":
    main()

In [None]:
!accelerate launch --multi_gpu --num_processes=2 --mixed-precision=fp16 train_seg.py

# Upload model to huggingface hub

In [None]:
from huggingface_hub import HfApi
from kaggle_secrets import UserSecretsClient

token = UserSecretsClient().get_secret("HF_TOKEN")

api = HfApi(token=token)

api.upload_folder(
    folder_path="/kaggle/working/controlnet_hf",
    repo_id="ritishshrirao/Controlnet_SD1.5_coco_segmentation",
    repo_type="model",
    commit_message="Upload ControlNet model"
)


In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
import glob
import cv2
from PIL import Image
from pycocotools.coco import COCO
from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel,
    UniPCMultistepScheduler,
    UNet2DConditionModel
)

VAL_IMG_DIR = "/kaggle/input/coco-2017-dataset/coco2017/val2017"
VAL_ANN_FILE = "/kaggle/input/coco-2017-dataset/coco2017/annotations/instances_val2017.json"
OUTPUT_DIR = "/kaggle/working/controlnet-coco-seg"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

NUM_SAMPLES = 16
BATCH_SIZE = 16
RESOLUTION = 512
CONTROLNET_COND_SCALE = 5.0

RNG = np.random.default_rng()

def get_latest_checkpoint(output_dir):
    if not os.path.exists(output_dir):
        raise FileNotFoundError(f"Output directory {output_dir} does not exist.")
    final_path = os.path.join(output_dir, "final_controlnet_seg")
    if os.path.exists(os.path.join(final_path, "config.json")):
        return final_path, "diffusers"
    checkpoints = glob.glob(os.path.join(output_dir, "checkpoint-*"))
    if not checkpoints:
        raise FileNotFoundError(f"No checkpoints found in {output_dir}")
    latest_checkpoint = max(checkpoints, key=lambda x: int(x.split("-")[-1]))
    return latest_checkpoint, "accelerator"

def load_controlnet(checkpoint_path, type_hint):
    if type_hint == "diffusers":
        print(f"Loading standard Diffusers model from {checkpoint_path}")
        return ControlNetModel.from_pretrained(checkpoint_path, torch_dtype=torch.float16)
    print(f"Loading Accelerator state from {checkpoint_path}")
    unet = UNet2DConditionModel.from_pretrained(BASE_MODEL, subfolder="unet")
    controlnet = ControlNetModel.from_unet(unet)
    possible_weights = os.path.join(checkpoint_path, "pytorch_model.bin")
    if os.path.exists(possible_weights):
        state_dict = torch.load(possible_weights, map_location="cpu")
    else:
        bins = glob.glob(os.path.join(checkpoint_path, "*.bin"))
        if len(bins) > 0:
            state_dict = torch.load(bins[0], map_location="cpu")
        else:
            raise FileNotFoundError(f"Could not find model weights in {checkpoint_path}")
    new_state_dict = {}
    for k, v in state_dict.items():
        new_state_dict[k[7:]] = v if k.startswith("module.") else v
    missing, unexpected = controlnet.load_state_dict(new_state_dict, strict=False)
    print(f"Loaded weights. Missing keys: {len(missing)}, Unexpected keys: {len(unexpected)}")
    return controlnet.to(dtype=torch.float16)

ckpt_path, ckpt_type = get_latest_checkpoint(OUTPUT_DIR)
controlnet = load_controlnet(ckpt_path, ckpt_type)

print("Setting up Pipeline...")
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    safety_checker=None,
).to("cuda")

pipe.scheduler = UniPCMultistepScheduler.from_config(pipe.scheduler.config)
pipe.set_progress_bar_config(disable=True)
pipe.enable_model_cpu_offload()

def generate_color_map(coco):
    cats = coco.loadCats(coco.getCatIds())
    palette = {}
    import hashlib
    for cat in cats:
        h = hashlib.md5(str(cat["id"]).encode()).hexdigest()
        palette[cat["id"]] = (int(h[0:2], 16), int(h[2:4], 16), int(h[4:6], 16))
    return palette

def draw_segmentation_map(img_wh, anns, coco, color_map):
    target_w, target_h = img_wh
    mask_canvas = np.zeros((target_h, target_w, 3), dtype=np.uint8)
    anns = sorted(anns, key=lambda x: x.get("area", 0.0), reverse=True)
    for ann in anns:
        cat_id = ann["category_id"]
        color = color_map.get(cat_id, (255, 255, 255))
        binary_mask_orig = coco.annToMask(ann)
        binary_mask_resized = cv2.resize(
            binary_mask_orig,
            (target_w, target_h),
            interpolation=cv2.INTER_NEAREST
        )
        mask_canvas[binary_mask_resized == 1] = color
    return Image.fromarray(mask_canvas)

def build_prompt(coco, anns):
    cat_ids = [ann["category_id"] for ann in anns]
    cats = coco.loadCats(cat_ids)
    cat_names = sorted(list(set([cat["name"] for cat in cats])))
    if not cat_names:
        return "A photorealistic image"
    return f"A photorealistic image containing {', '.join(cat_names)}"

print("Loading COCO...")
coco = COCO(VAL_ANN_FILE)
color_map = generate_color_map(coco)

img_ids = coco.getImgIds()
img_ids = [img_id for img_id in img_ids if len(coco.getAnnIds(imgIds=img_id)) > 0]
selected_ids = RNG.choice(img_ids, size=NUM_SAMPLES, replace=False)

results = []
print("Generating in batches...")
seeds = RNG.integers(low=0, high=2**31 - 1, size=NUM_SAMPLES, dtype=np.int64)

for start in range(0, NUM_SAMPLES, BATCH_SIZE):
    end = min(start + BATCH_SIZE, NUM_SAMPLES)
    batch_ids = selected_ids[start:end]
    batch_prompts = []
    batch_control_images = []
    batch_gt_images = []
    batch_generators = []
    for j, img_id in enumerate(batch_ids):
        img_info = coco.loadImgs(int(img_id))[0]
        img_path = os.path.join(VAL_IMG_DIR, img_info["file_name"])
        gt_image = Image.open(img_path).convert("RGB").resize((RESOLUTION, RESOLUTION))
        ann_ids = coco.getAnnIds(imgIds=int(img_id))
        anns = coco.loadAnns(ann_ids)
        control_image = draw_segmentation_map((RESOLUTION, RESOLUTION), anns, coco, color_map)
        prompt = build_prompt(coco, anns)
        batch_gt_images.append(gt_image)
        batch_control_images.append(control_image)
        batch_prompts.append(prompt)
        g = torch.Generator(device="cuda").manual_seed(int(seeds[start + j]))
        batch_generators.append(g)
    with torch.inference_mode():
        out = pipe(
            prompt=batch_prompts,
            image=batch_control_images,
            num_inference_steps=20,
            generator=batch_generators,
            controlnet_conditioning_scale=CONTROLNET_COND_SCALE,
        )
    batch_preds = out.images
    for k in range(len(batch_ids)):
        results.append((batch_control_images[k], batch_gt_images[k], batch_preds[k], batch_prompts[k]))
    print(f"Processed {end}/{NUM_SAMPLES}")

print("Plotting...")
fig, axes = plt.subplots(NUM_SAMPLES, 3, figsize=(15, 5 * NUM_SAMPLES))
if NUM_SAMPLES == 1:
    axes = np.expand_dims(axes, 0)
for idx, (cond, gt, pred, prompt) in enumerate(results):
    axes[idx, 0].imshow(cond)
    axes[idx, 0].set_title("Seg Mask", fontsize=10)
    axes[idx, 0].axis("off")
    axes[idx, 1].imshow(gt)
    axes[idx, 1].set_title("Ground Truth", fontsize=10)
    axes[idx, 1].axis("off")
    axes[idx, 2].imshow(pred)
    axes[idx, 2].set_title(f"Pred: {prompt[:60]}...", fontsize=10)
    axes[idx, 2].axis("off")
plt.tight_layout()
plt.show()