In [None]:
# Memory optimization utilities for Kaggle
import gc
import torch

def clear_memory():
    """Clear GPU and system memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    print("‚úì Memory cleared")

def print_memory_usage():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
    
print("‚úì Memory utilities loaded")

# Memory Optimization for Kaggle

**Important Notes for Kaggle Execution:**
- Kaggle has stricter memory limits than Colab
- Clear GPU cache between major operations
- Use smaller batch sizes if OOM occurs
- Monitor memory usage with `nvidia-smi`

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import os
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import json

class COCOKeypointDataset(Dataset):
    def __init__(self, root_dir='./coco_data', split='train', transform=None, image_size=512, max_samples=None, download=True):
        """
        Official COCO Keypoint Dataset using pycocotools with captions
        
        Args:
            root_dir (str): Root directory to store COCO data
            split (str): 'train' or 'val'
            transform: Optional transform to be applied on images
            image_size (int): Size to resize images to
            max_samples (int): Optional limit on number of samples to load
            download (bool): Whether to download annotations if not found
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.image_size = image_size
        
        # COCO 2017 URLs
        self.annotation_urls = {
            'train': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
            'val': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
        }
        
        # Setup paths
        self.ann_dir = os.path.join(root_dir, 'annotations')
        self.img_dir = os.path.join(root_dir, f'{split}2017')
        
        split_name = 'train' if split == 'train' else 'val'
        self.ann_file = os.path.join(self.ann_dir, f'person_keypoints_{split_name}2017.json')
        self.caption_file = os.path.join(self.ann_dir, f'captions_{split_name}2017.json')
        
        # Create directories
        os.makedirs(self.ann_dir, exist_ok=True)
        os.makedirs(self.img_dir, exist_ok=True)
        
        # Download annotations if needed
        if download and not os.path.exists(self.ann_file):
            print(f"Annotation file not found. Downloading COCO 2017 annotations...")
            self._download_annotations()
        
        # Check if annotation file exists
        if not os.path.exists(self.ann_file):
            raise FileNotFoundError(
                f"Annotation file not found: {self.ann_file}\n"
                f"Please download COCO 2017 annotations from:\n"
                f"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n"
                f"Extract to: {self.ann_dir}"
            )
        
        print(f"Loading COCO {split} annotations from official dataset...")
        self.coco = COCO(self.ann_file)
        
        # Load captions if available
        self.coco_caps = None
        if os.path.exists(self.caption_file):
            print(f"Loading COCO {split} captions...")
            self.coco_caps = COCO(self.caption_file)
        else:
            print(f"‚ö†Ô∏è  Caption file not found: {self.caption_file}")
            print("Captions will not be available. The annotations.zip should contain both keypoints and captions.")
        
        # Get all image IDs that have person annotations with keypoints
        print("Filtering images with keypoint annotations...")
        cat_ids = self.coco.getCatIds(catNms=['person'])
        all_img_ids = self.coco.getImgIds(catIds=cat_ids)
        
        self.img_ids = []
        limit = max_samples if max_samples else len(all_img_ids)
        
        for img_id in all_img_ids[:limit]:
            ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=cat_ids, iscrowd=False)
            anns = self.coco.loadAnns(ann_ids)
            
            # Check if any annotation has keypoints
            for ann in anns:
                if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                    self.img_ids.append(img_id)
                    break
        
        print(f"Total images with keypoints: {len(self.img_ids)}")
        print(f"Images will be loaded from: {self.img_dir}")
        
        if len(self.img_ids) == 0:
            print("\n‚ö†Ô∏è WARNING: No images with keypoints found!")
    
    def _download_annotations(self):
        """Download COCO annotations"""
        import zipfile
        
        url = self.annotation_urls[self.split]
        zip_path = os.path.join(self.root_dir, 'annotations.zip')
        
        print(f"Downloading from {url}...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(zip_path, 'wb') as f:
            with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    pbar.update(len(chunk))
        
        print("Extracting annotations...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)
        
        os.remove(zip_path)
        print("Annotations downloaded successfully!")
    
    def _download_image(self, img_info):
        """Download a single image from COCO"""
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        
        if not os.path.exists(img_path):
            try:
                response = requests.get(img_info['coco_url'])
                with open(img_path, 'wb') as f:
                    f.write(response.content)
            except Exception as e:
                print(f"Error downloading {img_info['file_name']}: {e}")
                return None
        
        return img_path
    
    def get_captions(self, img_id):
        """Get text captions for an image"""
        if self.coco_caps is None:
            return []
        
        ann_ids = self.coco_caps.getAnnIds(imgIds=img_id)
        anns = self.coco_caps.loadAnns(ann_ids)
        return [ann['caption'] for ann in anns]
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        
        # Load image info
        img_info = self.coco.loadImgs(img_id)[0]
        img_path = os.path.join(self.img_dir, img_info['file_name'])
        
        # Download image if it doesn't exist
        if not os.path.exists(img_path):
            print(f"Downloading image: {img_info['file_name']}")
            img_path = self._download_image(img_info)
            if img_path is None:
                # Return a dummy sample if download fails
                return self.__getitem__((idx + 1) % len(self))
        
        # Load image
        image = Image.open(img_path).convert('RGB')
        width, height = image.size
        
        # Get annotations
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.coco.getCatIds(catNms=['person']), iscrowd=False)
        anns = self.coco.loadAnns(ann_ids)
        
        # Get keypoints from first person with keypoints
        keypoints = None
        for ann in anns:
            if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                break
        
        if keypoints is None:
            keypoints = np.zeros((17, 3))
        
        # Get captions
        captions = self.get_captions(img_id)
        
        # Create keypoint heatmap
        keypoint_map = self.create_keypoint_map(keypoints, width, height)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])(image)
        
        keypoint_map = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor()
        ])(keypoint_map)
        
        return {
            'image': image,
            'keypoints': keypoint_map,
            'raw_keypoints': keypoints,
            'image_id': img_id,
            'captions': captions  # List of text descriptions
        }
    
    def create_keypoint_map(self, keypoints, width, height, sigma=2):
        """
        Create a keypoint heatmap from keypoint annotations
        """
        keypoint_map = np.zeros((height, width), dtype=np.float32)
        
        for i, (x, y, v) in enumerate(keypoints):
            if v > 0:  # visible keypoint
                x, y = int(x), int(y)
                if 0 <= x < width and 0 <= y < height:
                    # Create a gaussian around the keypoint
                    for dy in range(-sigma*3, sigma*3+1):
                        for dx in range(-sigma*3, sigma*3+1):
                            nx, ny = x + dx, y + dy
                            if 0 <= nx < width and 0 <= ny < height:
                                dist = np.sqrt(dx**2 + dy**2)
                                keypoint_map[ny, nx] = max(
                                    keypoint_map[ny, nx],
                                    np.exp(-(dist**2) / (2 * sigma**2))
                                )
        
        return (keypoint_map * 255).astype(np.uint8)


print("=" * 80)
print("COCO OFFICIAL DATASET LOADER WITH CAPTIONS")
print("=" * 80)
print("\nThis uses the official COCO 2017 dataset with pycocotools.")
print("Includes both keypoint annotations AND text captions!")
print("\nThe dataset will:")
print("1. Download annotations automatically (~252MB)")
print("2. Download images on-demand as you access them")
print("3. Cache everything in './coco_data' directory")
print("\nAlternatively, you can manually download:")
print("- Annotations: http://images.cocodataset.org/annotations/annotations_trainval2017.zip")
print("- Images: http://images.cocodataset.org/zips/train2017.zip (18GB)")
print("- Images: http://images.cocodataset.org/zips/val2017.zip (1GB)")
print("=" * 80)

# Create datasets - annotations will download automatically, images download on-demand
train_dataset = COCOKeypointDataset(
    root_dir='./coco_data',
    split='train',
    image_size=512,
    max_samples=100,  # Start with 100 for testing, remove for full dataset
    download=True
)

val_dataset = COCOKeypointDataset(
    root_dir='./coco_data',
    split='val',
    image_size=512,
    max_samples=50,
    download=True
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0  # Set to 0 to avoid issues with on-demand downloading
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0
)

print(f"\n‚úì Dataset loaded successfully!")
print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"\nImages will be downloaded automatically as needed.")

# Show a sample caption
if len(train_dataset) > 0:
    sample = train_dataset[0]
    if sample['captions']:
        print(f"\nüìù Sample caption: \"{sample['captions'][0]}\"")
    else:
        print("\n‚ö†Ô∏è  No captions available (caption file not loaded)")

In [None]:
import matplotlib.pyplot as plt

# Get one sample from the training dataset
sample = train_dataset[0]
img_id = sample['image_id']

# Get annotations for this image from COCO
ann_ids = train_dataset.coco.getAnnIds(imgIds=img_id)
anns = train_dataset.coco.loadAnns(ann_ids)
img_info = train_dataset.coco.loadImgs(img_id)[0]

# Get captions (text prompts)
captions = sample['captions']

# Build metadata text
metadata = f"Image ID: {img_id}\nFilename: {img_info['file_name']}\n"
metadata += f"Size: {img_info['width']}x{img_info['height']}\n"
metadata += f"Person annotations: {len([a for a in anns if a.get('category_id') == 1])}\n"

for i, ann in enumerate(anns):
    if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
        metadata += f"\nPerson {i+1}: {ann['num_keypoints']} keypoints"
        if 'area' in ann:
            metadata += f", area: {int(ann['area'])}"
        break

# Convert tensors back to displayable format
image = sample['image'].permute(1, 2, 0).cpu().numpy()
image = (image * 0.5 + 0.5)  # Denormalize from [-1, 1] to [0, 1]
image = np.clip(image, 0, 1)

keypoint_map = sample['keypoints'].squeeze().cpu().numpy()

# Create a figure with subplots
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Original image (larger)
ax1 = fig.add_subplot(gs[0:2, 0])
ax1.imshow(image)
ax1.set_title('Original Image', fontsize=14, fontweight='bold')
ax1.axis('off')

# Keypoint heatmap
ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(keypoint_map, cmap='hot')
ax2.set_title('Keypoint Heatmap', fontsize=14, fontweight='bold')
ax2.axis('off')

# Overlay
ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(image)
ax3.imshow(keypoint_map, cmap='hot', alpha=0.5)
ax3.set_title('Overlay', fontsize=14, fontweight='bold')
ax3.axis('off')

# Text Captions/Prompts - TOP SECTION
ax_captions = fig.add_subplot(gs[1, 1:3])
ax_captions.axis('off')
ax_captions.text(0.05, 0.95, 'üìù TEXT PROMPTS (COCO Captions):', 
                fontsize=13, fontweight='bold', va='top')

if captions:
    caption_text = "\n\n".join([f"{i+1}. {cap}" for i, cap in enumerate(captions)])
else:
    caption_text = "No captions available for this image."

ax_captions.text(0.05, 0.80, caption_text, 
                fontsize=10, va='top', wrap=True,
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))

# Metadata
ax_meta = fig.add_subplot(gs[2, 0])
ax_meta.axis('off')
ax_meta.text(0.05, 0.95, 'Image Metadata:', 
            fontsize=12, fontweight='bold', va='top')
ax_meta.text(0.05, 0.75, metadata, 
            fontsize=10, va='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Keypoint details
ax_kp = fig.add_subplot(gs[2, 1:3])
ax_kp.axis('off')

keypoint_names = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

keypoint_text = "Keypoint Details:\n" + "‚îÄ" * 40 + "\n"
for i, (x, y, v) in enumerate(sample['raw_keypoints']):
    if v > 0:  # visible keypoint
        visibility = "visible" if v == 2 else "occluded"
        keypoint_text += f"{keypoint_names[i]:15s}: ({int(x):3d}, {int(y):3d}) - {visibility}\n"

ax_kp.text(0.05, 0.95, keypoint_text, 
          fontsize=9, va='top', family='monospace')

plt.show()

print("=" * 80)
print(f"Image shape: {sample['image'].shape}")
print(f"Keypoint map shape: {sample['keypoints'].shape}")
print(f"Number of keypoints detected: {(sample['raw_keypoints'][:, 2] > 0).sum()}")
print(f"Number of captions: {len(captions)}")
print("=" * 80)

# ControlNet Training Setup

This section sets up and trains a ControlNet model for pose-guided image generation using:
- **Spatial Conditioning**: Keypoint heatmaps
- **Text Conditioning**: COCO captions
- **Base Model**: Stable Diffusion v1.5

The training uses 1000 samples from the COCO keypoint dataset.

In [None]:
# Install required packages (run once)
# !pip install diffusers transformers accelerate xformers safetensors tensorboard

import torch
import torch.nn.functional as F
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
import os
from datetime import datetime

print("‚úì Imports successful!")

In [None]:
# Training Configuration
class TrainingConfig:
    # Model settings
    pretrained_model_name = "runwayml/stable-diffusion-v1-5"
    controlnet_conditioning_channels = 1  # Grayscale keypoint heatmap
    
    # Training settings
    num_training_samples = 1000
    num_epochs = 10
    train_batch_size = 1  # Reduced for Kaggle memory constraints
    gradient_accumulation_steps = 8  # Increased to maintain effective batch size of 8
    learning_rate = 1e-5
    lr_warmup_steps = 500
    
    # Image settings
    resolution = 512
    
    # Checkpointing
    output_dir = "./controlnet_keypoint_output"
    save_steps = 500
    checkpointing_steps = 1000
    
    # Logging
    logging_dir = "./logs"
    report_to = "tensorboard"
    
    # Hardware
    mixed_precision = "fp16"  # Use "bf16" if available, "no" for CPU
    gradient_checkpointing = True
    
    # Validation
    validation_steps = 500
    num_validation_images = 4
    validation_prompt = "a person standing"

config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)

print(f"Output directory: {config.output_dir}")
print(f"Training {config.num_training_samples} samples for {config.num_epochs} epochs")
print(f"Batch size: {config.train_batch_size}, Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"Effective batch size: {config.train_batch_size * config.gradient_accumulation_steps}")

In [None]:
# Prepare dataset for training (limit to 1000 samples)
print("Recreating training dataset with 1000 samples...")

train_dataset_full = COCOKeypointDataset(
    root_dir='./coco_data',
    split='train',
    image_size=config.resolution,
    max_samples=config.num_training_samples,
    download=True
)

# Create dataloader with collate function
def collate_fn(examples):
    """Custom collate function to handle captions"""
    images = torch.stack([example['image'] for example in examples])
    keypoints = torch.stack([example['keypoints'] for example in examples])
    
    # Get first caption for each image (COCO has multiple captions per image)
    captions = []
    for example in examples:
        if example['captions'] and len(example['captions']) > 0:
            captions.append(example['captions'][0])
        else:
            captions.append("a person")  # Fallback caption
    
    return {
        'images': images,
        'keypoints': keypoints,
        'captions': captions
    }

train_dataloader = DataLoader(
    train_dataset_full,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"‚úì Training dataset ready: {len(train_dataset_full)} samples")
print(f"‚úì Total batches per epoch: {len(train_dataloader)}")

# Clear memory after dataset creation (important for Kaggle)
clear_memory()
print_memory_usage()

In [None]:
# Initialize models
print("Loading pretrained models...")

# Load tokenizer and text encoder
tokenizer = CLIPTokenizer.from_pretrained(
    config.pretrained_model_name, 
    subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="text_encoder"
)

# Load VAE
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name, 
    subfolder="vae"
)

# Load UNet
unet = UNet2DConditionModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="unet"
)

# Initialize ControlNet from UNet
print("Initializing ControlNet...")
controlnet = ControlNetModel.from_unet(
    unet,
    conditioning_channels=config.controlnet_conditioning_channels
)

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(
    config.pretrained_model_name,
    subfolder="scheduler"
)

# Freeze VAE and text encoder - we only train ControlNet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print("‚úì Models loaded successfully!")
print(f"  - ControlNet parameters: {sum(p.numel() for p in controlnet.parameters() if p.requires_grad):,}")
print(f"  - Text encoder (frozen): {sum(p.numel() for p in text_encoder.parameters()):,}")
print(f"  - UNet (frozen): {sum(p.numel() for p in unet.parameters()):,}")

# Clear memory after model loading (important for Kaggle)
clear_memory()
print_memory_usage()

In [None]:
# Setup optimizer and learning rate scheduler
optimizer = torch.optim.AdamW(
    controlnet.parameters(),
    lr=config.learning_rate,
    betas=(0.9, 0.999),
    weight_decay=1e-2,
    eps=1e-8,
)

lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
    num_training_steps=len(train_dataloader) * config.num_epochs * config.gradient_accumulation_steps,
)

print(f"‚úì Optimizer configured")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Warmup steps: {config.lr_warmup_steps}")
print(f"  - Total training steps: {len(train_dataloader) * config.num_epochs}")

In [None]:
# Initialize Accelerator for distributed training and mixed precision
accelerator = Accelerator(
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    mixed_precision=config.mixed_precision,
    log_with=config.report_to,
    project_dir=config.logging_dir,
)

# Enable gradient checkpointing to save memory
if config.gradient_checkpointing:
    controlnet.enable_gradient_checkpointing()

# Prepare models with accelerator
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, optimizer, train_dataloader, lr_scheduler
)

# Move models to device
unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)

# Set models to eval mode (only ControlNet is in training mode)
unet.eval()
vae.eval()
text_encoder.eval()

print(f"‚úì Accelerator initialized")
print(f"  - Device: {accelerator.device}")
print(f"  - Mixed precision: {config.mixed_precision}")
print(f"  - Distributed: {accelerator.num_processes} process(es)")
print(f"  - Gradient accumulation: {config.gradient_accumulation_steps} steps")

In [None]:
# Training function
def train_controlnet():
    """Main training loop for ControlNet"""
    
    global_step = 0
    
    # Create progress bar
    progress_bar = tqdm(
        range(len(train_dataloader) * config.num_epochs),
        desc="Training",
        disable=not accelerator.is_local_main_process,
    )
    
    print(f"\n{'='*80}")
    print(f"Starting ControlNet Training")
    print(f"{'='*80}")
    print(f"Total epochs: {config.num_epochs}")
    print(f"Samples per epoch: {len(train_dataset_full)}")
    print(f"Batches per epoch: {len(train_dataloader)}")
    print(f"Total training steps: {len(train_dataloader) * config.num_epochs}")
    print(f"{'='*80}\n")
    
    for epoch in range(config.num_epochs):
        controlnet.train()
        epoch_loss = 0.0
        
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(controlnet):
                # Get images and conditioning
                images = batch['images'].to(accelerator.device, dtype=torch.float32)
                keypoint_conditioning = batch['keypoints'].to(accelerator.device, dtype=torch.float32)
                captions = batch['captions']
                
                # Encode images to latent space with VAE
                with torch.no_grad():
                    latents = vae.encode(images).latent_dist.sample()
                    latents = latents * vae.config.scaling_factor
                
                # Sample noise
                noise = torch.randn_like(latents)
                bsz = latents.shape[0]
                
                # Sample random timesteps for each image
                timesteps = torch.randint(
                    0, noise_scheduler.config.num_train_timesteps, (bsz,),
                    device=latents.device
                )
                timesteps = timesteps.long()
                
                # Add noise to latents (forward diffusion process)
                noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
                
                # Encode text prompts
                with torch.no_grad():
                    text_inputs = tokenizer(
                        captions,
                        padding="max_length",
                        max_length=tokenizer.model_max_length,
                        truncation=True,
                        return_tensors="pt",
                    )
                    text_embeddings = text_encoder(text_inputs.input_ids.to(accelerator.device))[0]
                
                # Get ControlNet output
                down_block_res_samples, mid_block_res_sample = controlnet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=text_embeddings,
                    controlnet_cond=keypoint_conditioning,
                    return_dict=False,
                )
                
                # Predict noise with UNet + ControlNet conditioning
                model_pred = unet(
                    noisy_latents,
                    timesteps,
                    encoder_hidden_states=text_embeddings,
                    down_block_additional_residuals=down_block_res_samples,
                    mid_block_additional_residual=mid_block_res_sample,
                ).sample
                
                # Calculate loss (MSE between predicted and actual noise)
                loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
                
                # Backpropagation
                accelerator.backward(loss)
                
                # Gradient clipping
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()
                
                epoch_loss += loss.detach().item()
            
            # Update progress bar
            if accelerator.sync_gradients:
                progress_bar.update(1)
                global_step += 1
                
                # Logging
                if global_step % 10 == 0:
                    avg_loss = epoch_loss / (step + 1)
                    logs = {
                        "loss": loss.detach().item(),
                        "avg_loss": avg_loss,
                        "lr": lr_scheduler.get_last_lr()[0],
                        "epoch": epoch,
                    }
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=global_step)
                
                # Memory monitoring (Kaggle optimization)
                if global_step % 100 == 0:
                    print_memory_usage()
                
                # Save checkpoint
                if global_step % config.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                        accelerator.save_state(save_path)
                        print(f"\n‚úì Checkpoint saved: {save_path}")
        
        # End of epoch summary
        avg_epoch_loss = epoch_loss / len(train_dataloader)
        print(f"\n{'‚îÄ'*80}")
        print(f"Epoch {epoch + 1}/{config.num_epochs} completed")
        print(f"Average loss: {avg_epoch_loss:.4f}")
        print(f"{'‚îÄ'*80}\n")
    
    # Save final model
    if accelerator.is_main_process:
        controlnet_save_path = os.path.join(config.output_dir, "controlnet_final")
        unwrapped_controlnet = accelerator.unwrap_model(controlnet)
        unwrapped_controlnet.save_pretrained(controlnet_save_path)
        print(f"\n{'='*80}")
        print(f"‚úì Training completed!")
        print(f"‚úì Final ControlNet saved to: {controlnet_save_path}")
        print(f"{'='*80}")

    
    print("‚úì Training function ready")

    return controlnet
    return controlnetprint("‚úì Training function ready")



In [None]:
# Start training!
print("üöÄ Starting training...")
print(f"‚è∞ Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

trained_controlnet = train_controlnet()

print(f"\n‚è∞ End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## Test the Trained ControlNet

Generate images using the trained ControlNet with keypoint conditioning.

In [None]:
# Load the trained ControlNet and create pipeline
from diffusers import StableDiffusionControlNetPipeline
from PIL import Image

# Clear memory before inference (important for Kaggle)
clear_memory()
print_memory_usage()

print("Loading trained ControlNet pipeline...")

# Load the saved ControlNet
controlnet_path = os.path.join(config.output_dir, "controlnet_final")
controlnet_trained = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)

# Create inference pipeline
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    config.pretrained_model_name,
    controlnet=controlnet_trained,
    torch_dtype=torch.float16,
    safety_checker=None,
)
pipe = pipe.to(accelerator.device)

# Enable memory efficient attention (optional, if xformers is available)
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("‚úì XFormers memory efficient attention enabled")
except Exception as e:

    print(f"‚ö†Ô∏è  XFormers not available, using default attention: {e}")print("‚úì Pipeline ready for inference!")

    print("   (This is fine, just uses a bit more memory)")

In [None]:
# Generate images using trained ControlNet
import matplotlib.pyplot as plt

# Get a test sample from validation dataset
test_sample = val_dataset[0]
test_keypoint = test_sample['keypoints']
test_caption = test_sample['captions'][0] if test_sample['captions'] else "a person standing"

print(f"Test prompt: \"{test_caption}\"")
print(f"Keypoint shape: {test_keypoint.shape}")

# Prepare keypoint conditioning as tensor (keep single channel!)
# Add batch dimension and move to device
test_keypoint_input = test_keypoint.unsqueeze(0).to(accelerator.device, dtype=torch.float16)
print(f"Keypoint input shape: {test_keypoint_input.shape}")

# Generate image
print("\nGenerating image with ControlNet...")
generator = torch.Generator(device=accelerator.device).manual_seed(42)

# Use the tensor directly instead of PIL image to maintain single channel
output = pipe(
    prompt=test_caption,
    image=test_keypoint_input,  # Pass tensor directly
    num_inference_steps=20,
    generator=generator,
    guidance_scale=7.5,
).images[0]

# Visualize results
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

# Original image
orig_img = test_sample['image'].permute(1, 2, 0).cpu().numpy()
orig_img = (orig_img * 0.5 + 0.5)
orig_img = np.clip(orig_img, 0, 1)
axes[0].imshow(orig_img)
axes[0].set_title('Original Image', fontsize=14, fontweight='bold')
axes[0].axis('off')

# Keypoint conditioning
axes[1].imshow(test_keypoint.squeeze().cpu().numpy(), cmap='hot')
axes[1].set_title('Keypoint Conditioning', fontsize=14, fontweight='bold')
axes[1].axis('off')

# Generated image
axes[2].imshow(output)
axes[2].set_title('Generated Image', fontsize=14, fontweight='bold')
axes[2].axis('off')

# Side by side comparison
axes[3].imshow(orig_img)
axes[3].imshow(test_keypoint.squeeze().cpu().numpy(), cmap='hot', alpha=0.3)
axes[3].set_title('Original + Keypoints', fontsize=14, fontweight='bold')
axes[3].axis('off')

plt.suptitle(f'Prompt: "{test_caption}"', fontsize=16, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

print("\n‚úì Image generation complete!")