In [None]:
# Memory optimization utilities for Kaggle
import gc
import torch

def clear_memory():
    """Clear GPU and system memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    print(" Memory cleared")

def print_memory_usage():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
    
print(" Memory utilities loaded")

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import os
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import json
import cv2
import zipfile


class COCOPoseDataset(Dataset):
    def __init__(
        self,
        root_dir='./coco_data',
        split='train',
        transform=None,
        image_size=512,
        custom_captions_file=None,
        max_samples=None,
        download=True
    ):
        """
        Dataset that pairs COCO images with ground-truth human pose keypoints
        and externally generated, image-aligned captions.
        """

        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.image_size = image_size

        self.annotation_urls = {
            'train': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
            'val': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
        }

        self.ann_dir = os.path.join(root_dir, 'annotations')
        self.img_dir = os.path.join(root_dir, f'{split}2017')

        split_name = 'train' if split == 'train' else 'val'
        self.ann_file = os.path.join(
            self.ann_dir, f'person_keypoints_{split_name}2017.json'
        )

        os.makedirs(self.ann_dir, exist_ok=True)
        os.makedirs(self.img_dir, exist_ok=True)

        if not custom_captions_file:
            raise ValueError("Custom captions JSON file is required.")

        caption_path = (
            custom_captions_file
            if os.path.exists(custom_captions_file)
            else os.path.abspath(custom_captions_file)
        )

        if not os.path.exists(caption_path):
            raise FileNotFoundError(f"Caption file not found: {custom_captions_file}")

        with open(caption_path, 'r') as f:
            self.custom_captions = json.load(f)

        if download and not os.path.exists(self.ann_file):
            self._download_annotations()

        if not os.path.exists(self.ann_file):
            raise FileNotFoundError(f"COCO annotation file missing: {self.ann_file}")

        self.coco = COCO(self.ann_file)

        self.img_ids = []
        self.custom_caption_map = {}

        items = list(self.custom_captions.items())
        if max_samples:
            items = items[:max_samples]

        for img_filename, caption in items:
            img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
            self.img_ids.append(img_id)
            self.custom_caption_map[img_id] = caption

    def _download_annotations(self):
        url = self.annotation_urls[self.split]
        zip_path = os.path.join(self.root_dir, 'annotations.zip')

        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))

        with open(zip_path, 'wb') as f:
            with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    pbar.update(len(chunk))

        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)

        os.remove(zip_path)

    def _download_image(self, img_id, img_filename):
        img_path = os.path.join(self.img_dir, img_filename)
        if os.path.exists(img_path):
            return img_path

        img_info = self.coco.loadImgs(img_id)[0]
        img_url = img_info['coco_url']

        response = requests.get(img_url, timeout=10)
        response.raise_for_status()

        with open(img_path, 'wb') as f:
            f.write(response.content)

        return img_path

    def __len__(self):
        return len(self.img_ids)

    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_filename = list(self.custom_captions.keys())[idx]
        img_path = os.path.join(self.img_dir, img_filename)

        if not os.path.exists(img_path):
            img_path = self._download_image(img_id, img_filename)

        image = None
        for _ in range(3):
            try:
                image = Image.open(img_path).convert('RGB')
                break
            except Exception:
                if os.path.exists(img_path):
                    os.remove(img_path)
                img_path = self._download_image(img_id, img_filename)

        if image is None:
            return self.__getitem__((idx + 1) % len(self.img_ids))

        width, height = image.size
        caption = self.custom_caption_map[img_id]

        ann_ids = self.coco.getAnnIds(
            imgIds=img_id,
            catIds=self.coco.getCatIds(catNms=['person']),
            iscrowd=False
        )
        anns = self.coco.loadAnns(ann_ids)

        keypoints = None
        for ann in anns:
            if ann.get('num_keypoints', 0) > 0:
                keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                break

        if keypoints is None:
            keypoints = np.zeros((17, 3))

        pose_skeleton = self.create_pose_skeleton(keypoints, width, height)

        image = self.transform(image) if self.transform else transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5])
        ])(image)

        pose_map = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.image_size, self.image_size))
        ])(pose_skeleton)

        return {
            'image': image,
            'pose': pose_map,
            'raw_keypoints': keypoints,
            'image_id': img_id,
            'captions': [caption]
        }

    def create_pose_skeleton(self, keypoints, width, height):
        """
        Converts COCO keypoints into a binary stick-figure pose map.
        Only visible joints contribute to the skeleton.
        """

        pose_img = np.zeros((height, width), dtype=np.uint8)

        skeleton = [
            (0, 1), (0, 2), (1, 3), (2, 4),
            (0, 5), (0, 6),
            (5, 7), (7, 9),
            (6, 8), (8, 10),
            (5, 11), (6, 12),
            (11, 12),
            (11, 13), (13, 15),
            (12, 14), (14, 16)
        ]

        line_thickness = max(2, int(min(width, height) / 100))
        circle_radius = max(3, int(min(width, height) / 80))

        for s, e in skeleton:
            x1, y1, v1 = keypoints[s]
            x2, y2, v2 = keypoints[e]
            if v1 > 0 and v2 > 0:
                cv2.line(
                    pose_img,
                    (int(x1), int(y1)),
                    (int(x2), int(y2)),
                    255,
                    line_thickness,
                    cv2.LINE_AA
                )

        for x, y, v in keypoints:
            if v > 0:
                cv2.circle(
                    pose_img,
                    (int(x), int(y)),
                    circle_radius,
                    255,
                    -1
                )

        return pose_img


train_dataset = COCOPoseDataset(
    root_dir='./coco_data',
    split='train',
    image_size=512,
    custom_captions_file='/kaggle/input/train_captions.json',
    download=True
)

val_dataset = COCOPoseDataset(
    root_dir='./coco_data',
    split='val',
    image_size=512,
    custom_captions_file='/kaggle/input/val_captions.json',
    download=True
)

train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0
)

print(f"Training samples   : {len(train_dataset)}")
print(f"Validation samples : {len(val_dataset)}")

if len(train_dataset) > 0:
    sample = train_dataset[0]
    visible_kpts = (sample['raw_keypoints'][:, 2] > 0).sum()
    print(f"Visible keypoints in sample: {visible_kpts}")


In [None]:
import matplotlib.pyplot as plt


sample = train_dataset[0]
img_id = sample['image_id']

# Get annotations for this image from COCO
ann_ids = train_dataset.coco.getAnnIds(imgIds=img_id)
anns = train_dataset.coco.loadAnns(ann_ids)
img_info = train_dataset.coco.loadImgs(img_id)[0]

# Get captions (text prompts)
captions = sample['captions']

# Build metadata text
metadata = f"Image ID: {img_id}\nFilename: {img_info['file_name']}\n"
metadata += f"Size: {img_info['width']}x{img_info['height']}\n"
metadata += f"Person annotations: {len([a for a in anns if a.get('category_id') == 1])}\n"

for i, ann in enumerate(anns):
    if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
        metadata += f"\nPerson {i+1}: {ann['num_keypoints']} keypoints"
        if 'area' in ann:
            metadata += f", area: {int(ann['area'])}"
        break

# Convert tensors back to displayable format
image = sample['image'].permute(1, 2, 0).cpu().numpy()
image = (image * 0.5 + 0.5)  # Denormalize from [-1, 1] to [0, 1]
image = np.clip(image, 0, 1)

# Convert PIL Image to numpy array
pose_map = np.array(sample['pose'])


fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Original image (larger)
ax1 = fig.add_subplot(gs[0:2, 0])
ax1.imshow(image)
ax1.set_title('Original Image', fontsize=14, fontweight='bold')
ax1.axis('off')

# Pose skeleton
ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(pose_map, cmap='gray')
ax2.set_title('Pose Skeleton', fontsize=14, fontweight='bold')
ax2.axis('off')

# Overlay
ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(image)
ax3.imshow(pose_map, cmap='hot', alpha=0.5)
ax3.set_title('Overlay', fontsize=14, fontweight='bold')
ax3.axis('off')

# Text Captions/Prompts - TOP SECTION
ax_captions = fig.add_subplot(gs[1, 1:3])
ax_captions.axis('off')
ax_captions.text(0.05, 0.95, ' TEXT PROMPTS (Custom Captions):', 
                fontsize=13, fontweight='bold', va='top')

if captions:
    caption_text = "\n\n".join([f"{i+1}. {cap}" for i, cap in enumerate(captions)])
else:
    caption_text = "No captions available for this image."

ax_captions.text(0.05, 0.80, caption_text, 
                fontsize=10, va='top', wrap=True,
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))

# Metadata
ax_meta = fig.add_subplot(gs[2, 0])
ax_meta.axis('off')
ax_meta.text(0.05, 0.95, 'Image Metadata:', 
            fontsize=12, fontweight='bold', va='top')
ax_meta.text(0.05, 0.75, metadata, 
            fontsize=10, va='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Pose keypoint details
ax_kp = fig.add_subplot(gs[2, 1:3])
ax_kp.axis('off')

keypoint_names = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

keypoint_text = "Keypoint Details:\n" + "─" * 40 + "\n"
for i, (x, y, v) in enumerate(sample['raw_keypoints']):
    if v > 0:  
        visibility = "visible" if v == 2 else "occluded"
        keypoint_text += f"{keypoint_names[i]:15s}: ({int(x):3d}, {int(y):3d}) - {visibility}\n"

ax_kp.text(0.05, 0.95, keypoint_text, 
          fontsize=9, va='top', family='monospace')

plt.show()


print(f"Image shape: {sample['image'].shape}")
print(f"Pose skeleton shape: {pose_map.shape}")
print(f"Number of visible keypoints: {(sample['raw_keypoints'][:, 2] > 0).sum()}")
print(f"Number of captions: {len(captions)}")



# ControlNet Training Setup

This section sets up and trains a ControlNet model for pose-guided image generation using:
- **Spatial Conditioning**: Pose skeleton (stick figure from COCO keypoints)
- **Text Conditioning**: Your Custom Captions
- **Base Model**: Stable Diffusion v1.5

The training uses your custom-captioned images with COCO pose annotations.


In [None]:
import torch
import torch.nn.functional as F
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
import os
from datetime import datetime



In [None]:

class TrainingConfig:
    pretrained_model_name = "runwayml/stable-diffusion-v1-5"
    controlnet_conditioning_channels = 1  # Grayscale pose skeleton
    
    
    num_training_samples = None
    num_epochs = 20
    train_batch_size = 1
    gradient_accumulation_steps = 8
    learning_rate = 1e-4
    lr_warmup_steps = 375
    lr_scheduler_type = "cosine"
    caption_dropout_prob = 0.5
    
  
    train_captions_file = '/kaggle/input/train_captions.json'  
  
    
  
    resolution = 512
    
    
    output_dir = "/kaggle/working/controlnet_pose_output"
    validate_every_n_epochs = 1 
    
  
    logging_dir = "./logs"
    report_to = "tensorboard"
    
    
    mixed_precision = "fp16"  
    gradient_checkpointing = True
    use_8bit_optimizer = False 
    
    
    validation_steps = 500
    num_validation_images = 4
    validation_prompt = "a person standing"

config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)

print(f"Output directory: {config.output_dir}")
print(f"Training {config.num_training_samples or 'ALL (COMPLETE DATASET)'} samples for {config.num_epochs} epochs")
print(f" Using ALL available training data points")
print(f"\nBatch Configuration:")
print(f"  - Batch size: {config.train_batch_size}")
print(f"  - Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"  - Effective batch size: {config.train_batch_size * config.gradient_accumulation_steps}")
print(f"\nOptimizer & Memory:")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Mixed precision: {config.mixed_precision}")
print(f"  - 8-bit optimizer: {' ENABLED' if config.use_8bit_optimizer else ' DISABLED'}")
print(f"  - Gradient checkpointing: {' ENABLED' if config.gradient_checkpointing else ' DISABLED'}")
print(f"\nTraining Strategy:")
print(f"  - Caption dropout: {config.caption_dropout_prob*100:.0f}% (enables unconditional generation)")
print(f"  - Validation: Every {config.validate_every_n_epochs} epoch(s)")

In [None]:
import random
from PIL import Image
import torchvision.transforms as transforms

# Define collate function with caption dropout
def collate_fn(batch):
    """
    Custom collate function
    """
    images = []
    poses = []
    captions = []
    
    for sample in batch:
        images.append(sample['image'])
        
        # Convert pose PIL Image to tensor and normalize to [-1, 1]
        pose_tensor = transforms.ToTensor()(sample['pose'])
        poses.append(pose_tensor * 2 - 1)  # Normalize from [0, 1] to [-1, 1]
        
        # Caption dropout: 50% chance to use empty caption
        caption = sample['captions'][0] if sample['captions'] else ""
        if random.random() < config.caption_dropout_prob:
            caption = ""
        captions.append(caption)
    
    return {
        "images": torch.stack(images),
        "poses": torch.stack(poses),
        "captions": captions
    }

train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"Total training samples: {len(train_dataset)} (USING ALL AVAILABLE DATA)")
print(f"Batches per epoch: {len(train_dataloader)}")
print(f"Total training iterations: {len(train_dataloader) * config.num_epochs}")
print(f"Caption dropout: {config.caption_dropout_prob*100:.0f}% of samples use empty captions")
print(f"Pose conditioning normalized to [-1, 1] range")


In [None]:


tokenizer = CLIPTokenizer.from_pretrained(
    config.pretrained_model_name, 
    subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="text_encoder"
)


vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name, 
    subfolder="vae"
)


unet = UNet2DConditionModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="unet"
)


print("Initializing ControlNet...")
controlnet = ControlNetModel.from_unet(
    unet,
    conditioning_channels=config.controlnet_conditioning_channels
)

noise_scheduler = DDPMScheduler.from_pretrained(
    config.pretrained_model_name,
    subfolder="scheduler"
)

vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print(" Models loaded successfully")
print(f"  - ControlNet parameters: {sum(p.numel() for p in controlnet.parameters() if p.requires_grad):,}")
print(f"  - Text encoder (frozen): {sum(p.numel() for p in text_encoder.parameters()):,}")
print(f"  - UNet (frozen): {sum(p.numel() for p in unet.parameters()):,}")

# Clear memory after model loading (important for Kaggle)
clear_memory()
print_memory_usage()

In [None]:

if config.use_8bit_optimizer:
    #  8-bit optimizer from bitsandbytes (memory efficient!)
    import bitsandbytes as bnb
    print(f"\n Initializing 8-bit optimizer from bitsandbytes...")
    optimizer = bnb.optim.AdamW8bit(
        controlnet.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-3,
        eps=1e-8,
    )
    print(f" 8-bit optimizer loaded successfully!")
else:
    
    print(f"\n Initializing standard AdamW optimizer...")
    optimizer = torch.optim.AdamW(
        controlnet.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-3,
        eps=1e-8,
    )
    print(f" Standard AdamW optimizer loaded successfully!")

lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
    num_training_steps=len(train_dataloader) * config.num_epochs * config.gradient_accumulation_steps,
)

print(f"Optimizer type: {'8-bit AdamW (bitsandbytes)' if config.use_8bit_optimizer else 'Standard AdamW'}")
print(f"Learning rate: {config.learning_rate}")
print(f"Warmup steps: {config.lr_warmup_steps}")
print(f"Total training steps: {len(train_dataloader) * config.num_epochs}")
print(f"Learning rate scheduler: cosine with warmup")

In [None]:

accelerator = Accelerator(
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    mixed_precision=config.mixed_precision,
    log_with=config.report_to,
    project_dir=config.logging_dir,
)


if config.gradient_checkpointing:
    controlnet.enable_gradient_checkpointing()


controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, optimizer, train_dataloader, lr_scheduler
)


unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)


unet.eval()
vae.eval()
text_encoder.eval()

#MULTI-GPU SUPPORT: Use DataParallel for multiple GPUs
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
    print(f"\n MULTI-GPU DETECTED: {num_gpus} GPUs available")
    print(f"   GPU Details: {torch.cuda.get_device_name(0)}")
    controlnet = torch.nn.DataParallel(controlnet)
    unet = torch.nn.DataParallel(unet)
    vae = torch.nn.DataParallel(vae)
    text_encoder = torch.nn.DataParallel(text_encoder)
    print(f"    Models wrapped with DataParallel for distributed training")
else:
    print(f"\n SINGLE GPU MODE: Using {accelerator.device}")



In [None]:
import math

import os

import torch

import torch.nn.functional as F

from tqdm.auto import tqdm


def train_controlnet():
    """
    Training function with multi-GPU support via DataParallel
    Prints epoch-level training loss summary.
    """
    
    os.makedirs(config.output_dir, exist_ok=True)
    
    
    device = accelerator.device
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        
    # Cast frozen models to weight_dtype for memory efficiency
    vae_model = vae.module if isinstance(vae, torch.nn.DataParallel) else vae
    text_encoder_model = text_encoder.module if isinstance(text_encoder, torch.nn.DataParallel) else text_encoder
    unet_model = unet.module if isinstance(unet, torch.nn.DataParallel) else unet
    
    vae_model.to(device, dtype=weight_dtype)
    text_encoder_model.to(device, dtype=weight_dtype)
    unet_model.to(device, dtype=weight_dtype)

    # 2. Calculate Steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    max_train_steps = config.num_epochs * num_update_steps_per_epoch

    # 3. Setup Progress Bar
    progress_bar = tqdm(
        range(max_train_steps), 
        desc="Steps", 
        disable=not accelerator.is_local_main_process
    )
    
    accelerator.print(f" TRAINING STARTING (Multi-GPU Enabled)")
    accelerator.print(f"Total Epochs: {config.num_epochs}")
    accelerator.print(f"GPUs: {torch.cuda.device_count()}")
    accelerator.print(f"Precision: {weight_dtype}")
    accelerator.print("Validation: disabled during training")
    
    global_step = 0
    
    for epoch in range(config.num_epochs):
        controlnet.train()
        train_loss = 0.0
        epoch_loss_sum = 0.0
        epoch_loss_count = 0
        
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(controlnet):
                pixel_values = batch["images"].to(device, dtype=torch.float32)
                controlnet_image = batch["poses"].to(device, dtype=torch.float32)
                captions = batch["captions"]
                
                with accelerator.autocast():
                   
                    with torch.no_grad():
                        latents = vae_model.encode(pixel_values.to(dtype=weight_dtype)).latent_dist.sample()
                        latents = latents * vae_model.config.scaling_factor

                    
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device
                    ).long()

                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                   
                    with torch.no_grad():
                        inputs = tokenizer(
                            captions, 
                            max_length=tokenizer.model_max_length, 
                            padding="max_length", 
                            truncation=True, 
                            return_tensors="pt"
                        )
                        encoder_hidden_states = text_encoder_model(inputs.input_ids.to(device))[0]

                    # ControlNet forward
                    down_block_res_samples, mid_block_res_sample = controlnet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=encoder_hidden_states,
                        controlnet_cond=controlnet_image,
                        return_dict=False,
                    )

                    # UNet forward
                    model_pred = unet_model(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=encoder_hidden_states,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                    ).sample

                    # Loss
                    loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

                # Backprop
                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
                train_loss += avg_loss.item() / config.gradient_accumulation_steps
                
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    progress_bar.update(1)
                    global_step += 1
                    
                    # Accumulate epoch loss summary
                    epoch_loss_sum += avg_loss.item()
                    epoch_loss_count += 1
                    
                    logs = {"loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=global_step)
                    train_loss = 0.0

                    if hasattr(config, 'checkpointing_steps') and global_step % config.checkpointing_steps == 0:
                         save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                         accelerator.save_state(save_path)

        
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            
            if epoch_loss_count > 0:
                epoch_loss_avg = epoch_loss_sum / epoch_loss_count
                accelerator.print(f" Epoch {epoch+1}/{config.num_epochs} - Training Loss: {epoch_loss_avg:.6f}")
                # Log epoch loss for dashboards (e.g., TensorBoard)
                accelerator.log({"train_loss_epoch": epoch_loss_avg}, step=global_step)

            # Get the unwrapped model for saving
            controlnet_unwrapped = controlnet.module if isinstance(controlnet, torch.nn.DataParallel) else controlnet
            save_path = os.path.join(config.output_dir, f"epoch-{epoch+1}")
            controlnet_unwrapped.save_pretrained(save_path)
            accelerator.print(f" Epoch {epoch+1} Saved: {save_path}")
            
            
            if epoch > 0:
                prev_epoch_path = os.path.join(config.output_dir, f"epoch-{epoch}")
                if os.path.exists(prev_epoch_path):
                    import shutil
                    shutil.rmtree(prev_epoch_path)
                    accelerator.print(f" Cleaned up previous checkpoint: epoch-{epoch}")
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Final Save
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        controlnet_unwrapped = controlnet.module if isinstance(controlnet, torch.nn.DataParallel) else controlnet
        final_path = os.path.join(config.output_dir, "controlnet_final")
        controlnet_unwrapped.save_pretrained(final_path)
        accelerator.print(f"Training Complete. Final model saved to {final_path}")
    
    return os.path.join(config.output_dir, "controlnet_final")


In [None]:
print(" Starting training...")
print(f" Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

trained_controlnet = train_controlnet()

print(f"\n End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## Test the Trained ControlNet

Generate images using the trained ControlNet with pose skeleton conditioning.


In [None]:

from diffusers import StableDiffusionControlNetPipeline
from PIL import Image


clear_memory()
print_memory_usage()

print("Loading trained ControlNet pipeline...")


controlnet_path = os.path.join(config.output_dir, "controlnet_final")
controlnet_trained = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)

# Create inference pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    config.pretrained_model_name,
    controlnet=controlnet_trained,
    torch_dtype=torch.float16,
    safety_checker=None,
)
pipe = pipe.to(accelerator.device)

# Enable memory efficient attention (optional, if xformers is available)
try:
    pipe.enable_xformers_memory_efficient_attention()
    print(" XFormers memory efficient attention enabled")
except Exception as e:
    print(f"XFormers not available, using default attention: {e}")
    print("   (This is fine, just uses a bit more memory)")

print(" Pipeline ready for inference!")

In [None]:

import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import cv2


num_test_images = min(75, len(val_dataset))
print(f"\nGenerating results for {num_test_images} validation images...")
print(f"Each will show: [Original Image] [Pose Skeleton] [Generated Image]\n")

# Create figure with subplots for all results
fig = plt.figure(figsize=(20, 5 * num_test_images))
gs = GridSpec(num_test_images, 3, figure=fig, hspace=0.4, wspace=0.2)


torch.manual_seed(42)
generator = torch.Generator(device=accelerator.device).manual_seed(42)

generated_images = []
test_results = []

for idx in range(num_test_images):
    print(f"Processing image {idx+1}/{num_test_images}...")
    

    test_sample = val_dataset[idx]
    test_image = test_sample['image']
    test_pose = test_sample['pose']
    test_caption = test_sample['captions'][0] if test_sample['captions'] else "a person standing"
    

    test_results.append({
        'caption': test_caption,
        'image_id': test_sample['image_id'],
        'keypoints': test_sample['raw_keypoints']
    })
    
    # Convert pose to tensor
    test_pose_tensor = transforms.ToTensor()(test_pose)
    
    # Prepare pose input (single channel, normalized to [-1, 1])
    test_pose_input = test_pose_tensor.unsqueeze(0).to(accelerator.device, dtype=torch.float16)
    
   
    try:
        output = pipe(
            prompt=test_caption,
            image=test_pose_input,
            num_inference_steps=20,
            generator=generator,
            guidance_scale=7.5,
        ).images[0]
        generated_images.append(output)
    except Exception as e:
        print(f"Error generating image: {str(e)[:50]}")
        generated_images.append(None)
        continue
    
    
    
    # Column 1: Original Image
    ax1 = fig.add_subplot(gs[idx, 0])
    orig_img = test_image.permute(1, 2, 0).cpu().numpy()
    orig_img = (orig_img * 0.5 + 0.5)
    orig_img = np.clip(orig_img, 0, 1)
    ax1.imshow(orig_img)
    ax1.set_title(f'Image {idx+1}: Original', fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Column 2: Pose Skeleton
    ax2 = fig.add_subplot(gs[idx, 1])
    ax2.imshow(test_pose_tensor.squeeze().cpu().numpy(), cmap='gray')
    num_keypoints = (test_sample['raw_keypoints'][:, 2] > 0).sum()
    ax2.set_title(f'Pose Skeleton\n({int(num_keypoints)} keypoints visible)', 
                  fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Column 3: Generated Image
    ax3 = fig.add_subplot(gs[idx, 2])
    if generated_images[idx] is not None:
        ax3.imshow(generated_images[idx])
        ax3.set_title(f'Generated Image', fontsize=12, fontweight='bold')
    else:
        ax3.text(0.5, 0.5, 'Generation Failed', ha='center', va='center', 
                fontsize=14, color='red', fontweight='bold')
        ax3.set_title('Generated Image', fontsize=12, fontweight='bold')
    ax3.axis('off')
    
    # Add caption as suptitle for this row
    caption_text = test_caption if len(test_caption) <= 60 else test_caption[:57] + "..."
    fig.text(0.5, 0.98 - (idx * (1/num_test_images)) - 0.015, 
            f'Caption: "{caption_text}"', 
            ha='center', fontsize=10, style='italic', alpha=0.7)

plt.suptitle(f'ControlNet Validation Results - {num_test_images} Images\n' + 
             'Original Image | Pose Skeleton (from COCO) | Generated Image', 
             fontsize=16, fontweight='bold', y=0.995)

plt.show()


print("RESULTS SUMMARY")

successful_generations = sum(1 for img in generated_images if img is not None)
print(f" Successfully generated: {successful_generations}/{num_test_images} images")
print(f" Success rate: {100*successful_generations/num_test_images:.1f}%")

print(f"\nValidation Captions & Keypoints:")
for i, result in enumerate(test_results[:5]):  
    print(f"  Image {i+1}: '{result['caption'][:50]}...' ({int(result['keypoints'][:, 2].sum())} keypoints)")

print(f"\n Validation testing complete!")
print(f" All {num_test_images} validation images tested with their COCO poses and custom captions")