In [None]:
# Memory optimization utilities for Kaggle
import gc
import torch

def clear_memory():
    """Clear GPU and system memory"""
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()
    print("‚úì Memory cleared")

def print_memory_usage():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
    
print("‚úì Memory utilities loaded")

# ControlNet Training for Pose-Based Spatial Conditioning

## üìã **Pre-Flight Checklist - MUST DO BEFORE RUNNING!**

### ‚úÖ **1. Update Caption File Path**
In Cell 8 (Training Configuration), update:
```python
train_captions_file = './train_captions.json'  # ‚Üê CHANGE THIS!
```

Options:
- **Local**: `'./train_captions.json'` or `'path/to/your/train_captions.json'`
- **Kaggle**: `'/kaggle/input/captions/train_captions.json'`

### ‚úÖ **2. Verify Caption File Format**
Your JSON file should look like:
```json
{
  "000000391895.jpg": "A person wearing a red shirt...",
  "000000522418.jpg": "A person standing outdoors...",
  ...
}
```

### ‚úÖ **3. Check GPU/Hardware**
- Requires CUDA GPU (tested on T4, P100, V100)
- ~16GB GPU memory minimum
- ~50GB disk space for COCO data + checkpoints

### ‚úÖ **4. Install Dependencies** (if not already installed)
```bash
pip install diffusers transformers accelerate xformers safetensors tensorboard pycocotools
```

---

## üöÄ **Once Ready:**
Run cells sequentially from top to bottom. Training will:
- Download COCO annotations automatically
- Download images on-the-fly as needed
- Save checkpoints to `./controlnet_pose_output/`
- Generate validation samples every epoch
- Auto-cleanup old checkpoints to save disk space

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import os
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import json
import cv2
import zipfile

class COCOPoseDataset(Dataset):
    def __init__(self, root_dir='./coco_data', split='train', transform=None, image_size=512, 
                 custom_captions_file=None, max_samples=None, download=True):
        """
        Custom Dataset with COCO Pose Skeletons + Your Custom Captions
        
        Args:
            root_dir (str): Root directory to store COCO data
            split (str): 'train' or 'val'
            transform: Optional transform to be applied on images
            image_size (int): Size to resize images to
            custom_captions_file (str): Path to JSON file with custom captions {image_name: caption}
            max_samples (int): Optional limit on number of samples to load
            download (bool): Whether to download COCO annotations if not found
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.image_size = image_size
        self.custom_captions_file = custom_captions_file
        
        # COCO 2017 URLs
        self.annotation_urls = {
            'train': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
            'val': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
        }
        
        # Setup paths
        self.ann_dir = os.path.join(root_dir, 'annotations')
        self.img_dir = os.path.join(root_dir, f'{split}2017')
        
        split_name = 'train' if split == 'train' else 'val'
        self.ann_file = os.path.join(self.ann_dir, f'person_keypoints_{split_name}2017.json')
        
        # Create directories
        os.makedirs(self.ann_dir, exist_ok=True)
        os.makedirs(self.img_dir, exist_ok=True)
        
        # Load custom captions - REQUIRED
        self.custom_captions = None
        self.custom_caption_map = {}
        self.img_ids = []
        
        if not custom_captions_file:
            raise ValueError(f"‚ùå custom_captions_file is REQUIRED! Please provide a JSON file with captions.")
        
        # Check if caption file exists (try absolute and relative paths)
        caption_path = custom_captions_file
        if not os.path.exists(caption_path):
            # Try absolute path if relative doesn't work
            caption_path = os.path.abspath(custom_captions_file)
        
        if not os.path.exists(caption_path):
            raise FileNotFoundError(
                f"‚ùå Caption file not found: {custom_captions_file}\n"
                f"Tried paths:\n"
                f"  - {custom_captions_file}\n"
                f"  - {os.path.abspath(custom_captions_file)}\n"
                f"Current working directory: {os.getcwd()}\n"
                f"Please ensure the caption file exists with format: {{'image_filename.jpg': 'caption text', ...}}"
            )
        
        print(f"Loading custom captions from {caption_path}...")
        with open(caption_path, 'r') as f:
            self.custom_captions = json.load(f)
        print(f"‚úì Loaded {len(self.custom_captions)} custom captions")
        
        # Download COCO annotations if needed
        if download and not os.path.exists(self.ann_file):
            print(f"Annotation file not found. Downloading COCO 2017 annotations...")
            self._download_annotations()
        
        # Check if annotation file exists
        if not os.path.exists(self.ann_file):
            raise FileNotFoundError(
                f"Annotation file not found: {self.ann_file}\n"
                f"Please download COCO 2017 annotations from:\n"
                f"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n"
                f"Extract to: {self.ann_dir}"
            )
        
        print(f"Loading COCO {split} pose annotations...")
        self.coco = COCO(self.ann_file)
        
        # Map images from caption file
        print("Setting up dataset with YOUR images, poses from COCO, and YOUR captions...")
        items_to_process = list(self.custom_captions.items())
        if max_samples:
            items_to_process = items_to_process[:max_samples]
        
        for img_filename, caption in items_to_process:
            # Extract image ID from filename (e.g., '000000391895.jpg' -> 391895)
            img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
            self.img_ids.append(img_id)
            self.custom_caption_map[img_id] = caption
        
        limit_msg = f" (limited to first {max_samples})" if max_samples else ""
        print(f"‚úì Dataset ready with {len(self.img_ids)} images")
        print(f"  - Images from: {self.img_dir}")
        print(f"  - Captions from: {custom_captions_file}{limit_msg}")
        print(f"  - Poses from: COCO person_keypoints annotations")
        print(f"‚úÖ Using ACTUAL POSE SKELETONS (stick figures) as conditioning!\n")
        
        if len(self.img_ids) == 0:
            print("\n‚ùå ERROR: No images found in caption file!")
    
    def _download_annotations(self):
        """Download COCO annotations"""
        url = self.annotation_urls[self.split]
        zip_path = os.path.join(self.root_dir, 'annotations.zip')
        
        print(f"Downloading from {url}...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(zip_path, 'wb') as f:
            with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    pbar.update(len(chunk))
        
        print("Extracting annotations...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)
        
        os.remove(zip_path)
        print("‚úì Annotations downloaded successfully!")
    
    def _download_image(self, img_id, img_filename):
        """Download a single image from COCO dataset on-the-fly"""
        img_path = os.path.join(self.img_dir, img_filename)
        
        # If image already exists, skip download
        if os.path.exists(img_path):
            return img_path
        
        # Get image info from COCO API
        img_info = self.coco.loadImgs(img_id)[0]
        img_url = img_info['coco_url']
        
        # Download image
        try:
            response = requests.get(img_url, timeout=10)
            response.raise_for_status()
            
            # Save image
            with open(img_path, 'wb') as f:
                f.write(response.content)
            
            return img_path
        except Exception as e:
            raise RuntimeError(f"Failed to download image {img_filename} from {img_url}: {str(e)}")
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_filename = list(self.custom_captions.keys())[idx]
        img_path = os.path.join(self.img_dir, img_filename)
        
        # Download image if it doesn't exist (ON-THE-FLY DOWNLOAD)
        if not os.path.exists(img_path):
            print(f"üì• Downloading image: {img_filename}...")
            img_path = self._download_image(img_id, img_filename)
        
        # Try to load image, skip if corrupted
        max_retries = 3
        retry_count = 0
        image = None
        
        while retry_count < max_retries and image is None:
            try:
                image = Image.open(img_path).convert('RGB')
                width, height = image.size
                break
            except Exception as e:
                retry_count += 1
                print(f"‚ö†Ô∏è  Failed to load image {img_filename} (attempt {retry_count}/{max_retries}): {str(e)[:50]}")
                
                # Delete corrupted file and try re-downloading
                if os.path.exists(img_path):
                    os.remove(img_path)
                    print(f"   Deleted corrupted file: {img_filename}")
                
                if retry_count < max_retries:
                    try:
                        print(f"   Re-downloading image...")
                        img_path = self._download_image(img_id, img_filename)
                    except Exception as download_err:
                        print(f"   Download failed: {str(download_err)[:50]}")
        
        # If still can't load, return a placeholder sample
        if image is None:
            print(f"‚ùå Giving up on {img_filename}, returning next valid image instead...")
            # Try next image in dataset
            next_idx = (idx + 1) % len(self.img_ids)
            if next_idx != idx:  # Avoid infinite loop
                return self.__getitem__(next_idx)
            else:
                # Return black image as fallback
                image = Image.new('RGB', (self.image_size, self.image_size), color='black')
                width, height = self.image_size, self.image_size
        
        # Get caption
        caption = self.custom_caption_map[img_id]
        
        # ‚úÖ EXTRACT ACTUAL POSE KEYPOINTS FROM COCO
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.coco.getCatIds(catNms=['person']), iscrowd=False)
        anns = self.coco.loadAnns(ann_ids)
        
        # Get pose keypoints from first person with visible keypoints
        keypoints = None
        for ann in anns:
            if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                break
        
        if keypoints is None:
            keypoints = np.zeros((17, 3))
        
        # ‚úÖ CREATE ACTUAL POSE SKELETON (stick figure from keypoints)
        pose_skeleton = self.create_pose_skeleton(keypoints, width, height)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])(image)
        
        pose_map = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.image_size, self.image_size)),
        ])(pose_skeleton)
        
        return {
            'image': image,
            'pose': pose_map,
            'raw_keypoints': keypoints,
            'image_id': img_id,
            'captions': [caption]  # Your custom caption
        }
    
    def create_pose_skeleton(self, keypoints, width, height):
        """
        Create ACTUAL human pose skeleton from COCO keypoint annotations
        Draws a stick figure with bones connecting the 17 joints
        
        COCO Keypoints (17 total):
        0: nose, 1: L eye, 2: R eye, 3: L ear, 4: R ear,
        5: L shoulder, 6: R shoulder, 7: L elbow, 8: R elbow, 9: L wrist, 10: R wrist,
        11: L hip, 12: R hip, 13: L knee, 14: R knee, 15: L ankle, 16: R ankle
        """
        # Create black canvas
        pose_img = np.zeros((height, width), dtype=np.uint8)
        
        # COCO keypoint skeleton connections (bones linking joints)
        skeleton = [
            (0, 1), (0, 2),           # nose to eyes
            (1, 3), (2, 4),           # eyes to ears
            (0, 5), (0, 6),           # nose to shoulders
            (5, 7), (7, 9),           # left arm (shoulder‚Üíelbow‚Üíwrist)
            (6, 8), (8, 10),          # right arm (shoulder‚Üíelbow‚Üíwrist)
            (5, 11), (6, 12),         # shoulders to hips
            (11, 12),                 # hip to hip
            (11, 13), (13, 15),       # left leg (hip‚Üíknee‚Üíankle)
            (12, 14), (14, 16)        # right leg (hip‚Üíknee‚Üíankle)
        ]
        
        # Line thickness and circle radius scale with image size
        line_thickness = max(2, int(min(width, height) / 100))
        circle_radius = max(3, int(min(width, height) / 80))
        
        # Draw bones (connections) FIRST
        for start_idx, end_idx in skeleton:
            if start_idx < len(keypoints) and end_idx < len(keypoints):
                x1, y1, v1 = keypoints[start_idx]
                x2, y2, v2 = keypoints[end_idx]
                
                # Draw line ONLY if both keypoints are visible (v > 0)
                if v1 > 0 and v2 > 0:
                    cv2.line(pose_img, (int(x1), int(y1)), (int(x2), int(y2)), 
                            255, line_thickness, cv2.LINE_AA)
        
        # Draw keypoint circles ON TOP of bones
        for i, (x, y, v) in enumerate(keypoints):
            if v > 0:  # Only draw visible keypoints
                cv2.circle(pose_img, (int(x), int(y)), circle_radius, 255, -1)
        
        return pose_img

print("="*80)
print("CUSTOM DATASET LOADER WITH ACTUAL POSE SKELETONS")
print("="*80)
print("Pipeline:")
print("1. ‚úÖ Downloads COCO images on-the-fly (from COCO CDN)")
print("2. ‚úÖ Loads YOUR custom captions")
print("3. ‚úÖ Extracts ACTUAL pose keypoints from COCO annotations")
print("4. ‚úÖ Draws stick figures (nose‚Üíeyes‚Üíears, shoulders‚Üíelbows‚Üíwrists, etc.)")
print("5. ‚úÖ Passes: image + caption + pose skeleton to training")
print("="*80)

# Create datasets with COCO POSES + YOUR CUSTOM CAPTIONS
train_dataset = COCOPoseDataset(
    root_dir='./coco_data',
    split='train',
    image_size=512,
    custom_captions_file='/kaggle/input/captions/train_captions.json',
    download=True
)

# Validation dataset
val_dataset = COCOPoseDataset(
    root_dir='./coco_data',
    split='val',
    image_size=512,
    custom_captions_file='/kaggle/input/captions/val_captions.json',
    max_samples=None,
    download=True
)

# Create dataloaders
train_loader = DataLoader(
    train_dataset,
    batch_size=4,
    shuffle=True,
    num_workers=0
)

val_loader = DataLoader(
    val_dataset,
    batch_size=4,
    shuffle=False,
    num_workers=0
)

print(f"\n‚úì Dataset loaded successfully!")
print(f"Training samples: {len(train_dataset)} (images + your captions + COCO poses)")
print(f"Validation samples: {len(val_dataset)} (images + your captions + COCO poses)")

# Show a sample
if len(train_dataset) > 0:
    sample = train_dataset[0]
    if sample['captions']:
        print(f"\nüìù Sample caption from training:")
        caption = sample['captions'][0]
        print(f"   \"{caption[:200]}...\"" if len(caption) > 200 else f"   \"{caption}\"")
    print(f"ü¶¥ Sample has {(sample['raw_keypoints'][:, 2] > 0).sum()} visible pose keypoints")

In [None]:
import matplotlib.pyplot as plt

# Get one sample from the training dataset
sample = train_dataset[0]
img_id = sample['image_id']

# Get annotations for this image from COCO
ann_ids = train_dataset.coco.getAnnIds(imgIds=img_id)
anns = train_dataset.coco.loadAnns(ann_ids)
img_info = train_dataset.coco.loadImgs(img_id)[0]

# Get captions (text prompts)
captions = sample['captions']

# Build metadata text
metadata = f"Image ID: {img_id}\nFilename: {img_info['file_name']}\n"
metadata += f"Size: {img_info['width']}x{img_info['height']}\n"
metadata += f"Person annotations: {len([a for a in anns if a.get('category_id') == 1])}\n"

for i, ann in enumerate(anns):
    if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
        metadata += f"\nPerson {i+1}: {ann['num_keypoints']} keypoints"
        if 'area' in ann:
            metadata += f", area: {int(ann['area'])}"
        break

# Convert tensors back to displayable format
image = sample['image'].permute(1, 2, 0).cpu().numpy()
image = (image * 0.5 + 0.5)  # Denormalize from [-1, 1] to [0, 1]
image = np.clip(image, 0, 1)

# Convert PIL Image to numpy array
pose_map = np.array(sample['pose'])

# Create a figure with subplots
fig = plt.figure(figsize=(16, 10))
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)

# Original image (larger)
ax1 = fig.add_subplot(gs[0:2, 0])
ax1.imshow(image)
ax1.set_title('Original Image', fontsize=14, fontweight='bold')
ax1.axis('off')

# Pose skeleton
ax2 = fig.add_subplot(gs[0, 1])
ax2.imshow(pose_map, cmap='gray')
ax2.set_title('Pose Skeleton', fontsize=14, fontweight='bold')
ax2.axis('off')

# Overlay
ax3 = fig.add_subplot(gs[0, 2])
ax3.imshow(image)
ax3.imshow(pose_map, cmap='hot', alpha=0.5)
ax3.set_title('Overlay', fontsize=14, fontweight='bold')
ax3.axis('off')

# Text Captions/Prompts - TOP SECTION
ax_captions = fig.add_subplot(gs[1, 1:3])
ax_captions.axis('off')
ax_captions.text(0.05, 0.95, 'üìù TEXT PROMPTS (Custom Captions):', 
                fontsize=13, fontweight='bold', va='top')

if captions:
    caption_text = "\n\n".join([f"{i+1}. {cap}" for i, cap in enumerate(captions)])
else:
    caption_text = "No captions available for this image."

ax_captions.text(0.05, 0.80, caption_text, 
                fontsize=10, va='top', wrap=True,
                bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))

# Metadata
ax_meta = fig.add_subplot(gs[2, 0])
ax_meta.axis('off')
ax_meta.text(0.05, 0.95, 'Image Metadata:', 
            fontsize=12, fontweight='bold', va='top')
ax_meta.text(0.05, 0.75, metadata, 
            fontsize=10, va='top',
            bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

# Pose keypoint details
ax_kp = fig.add_subplot(gs[2, 1:3])
ax_kp.axis('off')

keypoint_names = [
    'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
    'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
    'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
    'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
]

keypoint_text = "Keypoint Details:\n" + "‚îÄ" * 40 + "\n"
for i, (x, y, v) in enumerate(sample['raw_keypoints']):
    if v > 0:  # visible keypoint
        visibility = "visible" if v == 2 else "occluded"
        keypoint_text += f"{keypoint_names[i]:15s}: ({int(x):3d}, {int(y):3d}) - {visibility}\n"

ax_kp.text(0.05, 0.95, keypoint_text, 
          fontsize=9, va='top', family='monospace')

plt.show()

print("=" * 80)
print(f"Image shape: {sample['image'].shape}")
print(f"Pose skeleton shape: {pose_map.shape}")
print(f"Number of visible keypoints: {(sample['raw_keypoints'][:, 2] > 0).sum()}")
print(f"Number of captions: {len(captions)}")
print("=" * 80)


# ControlNet Training Setup

This section sets up and trains a ControlNet model for pose-guided image generation using:
- **Spatial Conditioning**: Pose skeleton (stick figure from COCO keypoints)
- **Text Conditioning**: Your Custom Captions
- **Base Model**: Stable Diffusion v1.5

The training uses your custom-captioned images with COCO pose annotations.


In [None]:
# Install required packages (run once)
# !pip install diffusers transformers accelerate xformers safetensors tensorboard
# ‚úÖ NEW: Required for 8-bit optimizer
# !pip install bitsandbytes

import torch
import torch.nn.functional as F
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler, UNet2DConditionModel, AutoencoderKL
from diffusers.optimization import get_scheduler
from transformers import CLIPTextModel, CLIPTokenizer
from accelerate import Accelerator
from tqdm.auto import tqdm
import os
from datetime import datetime

print("‚úì Imports successful!")
print("\n" + "="*70)
print("REQUIRED PACKAGES FOR 8-BIT OPTIMIZER TRAINING")
print("="*70)
print("‚úÖ Core packages: diffusers, transformers, accelerate, xformers")
print("‚úÖ Bitsandbytes: For 8-bit optimizer (memory efficient)")
print("‚úÖ Tensorboard: For training visualization")
print("\nIf you get 'bitsandbytes' import error, run:")
print("  pip install bitsandbytes")
print("="*70)

## Training Configuration

### ‚úÖ **NEW: Memory-Efficient 8-bit Optimizer with Accelerate + FP16**

This training setup now uses **state-of-the-art memory optimization**:

#### 1. **8-bit Optimizer (bitsandbytes)**
- Uses `AdamW8bit` from bitsandbytes library
- **Memory reduction**: ~75% less memory for optimizer states
- **Previous**: Standard AdamW (stores full precision optimizer states)
- **Now**: 8-bit quantized optimizer states
- **Benefits**:
  - **Larger batch sizes** (16 instead of 1)
  - **Better gradient updates** from larger batches
  - **Faster convergence** with more diverse samples per step

#### 2. **Accelerate + FP16 Mixed Precision**
- `accelerate` library handles distributed training and memory optimization
- FP16 (float16) computation reduces memory by 50%
- Maintains numerical stability with loss scaling
- **Combined with 8-bit optimizer**: Near optimal memory efficiency

#### 3. **Effective Batch Size = 16** ‚úÖ
- **Before**: batch_size=1, gradient_accumulation=8 ‚Üí effective batch=8
- **After**: batch_size=16, gradient_accumulation=1 ‚Üí effective batch=16
- **Why better**:
  - 8x larger batch = better gradient estimates
  - No gradient accumulation overhead
  - Faster training iterations
  - More stable training dynamics

#### 4. **Gradient Checkpointing**
- Trades computation for memory during forward/backward pass
- Reduces activation memory by ~40-50%
- Minimal speed impact with modern GPUs

### üìä **Memory Impact Summary**
```
Configuration              | GPU Memory | Batch Size | Training Speed
Standard FP32 (Baseline)   | 100%       | 1x         | 100%
FP16 only                  | 50%        | 2x         | 110%
FP16 + Gradient Ckpt       | 25-30%     | 4x         | 105%
‚úÖ 8-bit + FP16 + Ckpt    | 15-20%     | 16x        | 95%
```

### üéØ **Why This Matters**
- **Effective batch size of 16** provides much better gradient estimates
- **8-bit optimizer** doesn't hurt convergence (bitsandbytes handles this carefully)
- **Combined approach** achieves 5-6x memory savings vs standard FP32
- **Result**: Better training quality in same memory budget

### üìù **Configuration Options**
In the Training Configuration cell, you can control:
- `use_8bit_optimizer`: Toggle 8-bit optimizer (True/False)
- `train_batch_size`: Main batch size (default: 16)
- `gradient_accumulation_steps`: Accumulation steps (default: 1, increase if OOM)
- `num_training_samples`: Set to `None` to use ALL data, or limit to a number

### ‚ö†Ô∏è **Important Notes**
- **bitsandbytes required**: Install with `pip install bitsandbytes`
- **CUDA recommended**: 8-bit operations optimized for NVIDIA GPUs
- **All training data used**: `num_training_samples = None` uses complete dataset
- **Expected training time**: ~30-50% faster with larger batch size

In [None]:
# Weights & Biases login (hardcode your key below)
import os, wandb

# üîë Put your key in the string below
WANDB_KEY = "PASTE_YOUR_WANDB_API_KEY_HERE"

if WANDB_KEY and WANDB_KEY != "PASTE_YOUR_WANDB_API_KEY_HERE":
    os.environ["WANDB_API_KEY"] = WANDB_KEY
    wandb.login(key=WANDB_KEY, relogin=True)
    print("‚úì W&B logged in with provided key")
else:
    print("‚ö†Ô∏è Set WANDB_KEY to your actual API key to enable W&B logging.")

# Default project if not provided via env/secret
os.environ.setdefault("WANDB_PROJECT", "controlnet-pose")


In [None]:
# Training Configuration
class TrainingConfig:
    # Model settings
    pretrained_model_name = "runwayml/stable-diffusion-v1-5"
    controlnet_conditioning_channels = 1  # Grayscale pose skeleton
    
    # Training settings
    num_training_samples = None  # Set to None to use ALL images, or set to a number (e.g., 1000) to limit
    num_epochs = 10
    train_batch_size = 16
    gradient_accumulation_steps = 1
    learning_rate = 1e-4
    lr_warmup_steps = 500
    lr_scheduler_type = "constant_with_warmup"
    caption_dropout_prob = 0.5
    
    # Data paths - UPDATE THESE TO YOUR PATHS!
    train_captions_file = './train_captions.json'  # Path to training captions JSON
    # For Kaggle, use: '/kaggle/input/train_captions.json'
    
    # Image settings
    resolution = 512
    
    # Checkpointing & Validation
    output_dir = "./controlnet_pose_output"
    validate_every_n_epochs = 1  # Generate validation samples every N epochs
    
    # Logging
    logging_dir = "./logs"
    report_to = "wandb"
    wandb_project = os.environ.get("WANDB_PROJECT", "controlnet-pose")
    wandb_run_name = None  # Optionally set a custom run name
    
    # Hardware & Optimization
    mixed_precision = "fp16"  # Use "bf16" if available, "no" for CPU
    gradient_checkpointing = True
    use_8bit_optimizer = False  # ‚úÖ DISABLED: Use standard AdamW instead
    
    # Validation
    validation_steps = 500
    num_validation_images = 4
    validation_prompt = "a person standing"

config = TrainingConfig()
os.makedirs(config.output_dir, exist_ok=True)
os.makedirs(config.logging_dir, exist_ok=True)

print(f"Output directory: {config.output_dir}")
print(f"Training {config.num_training_samples or 'ALL (COMPLETE DATASET)'} samples for {config.num_epochs} epochs")
print(f"‚úÖ Using ALL available training data points")
print(f"\nBatch Configuration:")
print(f"  - Batch size: {config.train_batch_size}")
print(f"  - Gradient accumulation: {config.gradient_accumulation_steps}")
print(f"  - Effective batch size: {config.train_batch_size * config.gradient_accumulation_steps}")
print(f"\nOptimizer & Memory:")
print(f"  - Learning rate: {config.learning_rate}")
print(f"  - Mixed precision: {config.mixed_precision}")
print(f"  - 8-bit optimizer: {'‚úÖ ENABLED' if config.use_8bit_optimizer else '‚ùå DISABLED'}")
print(f"  - Gradient checkpointing: {'‚úÖ ENABLED' if config.gradient_checkpointing else '‚ùå DISABLED'}")
print(f"\nTraining Strategy:")
print(f"  - Caption dropout: {config.caption_dropout_prob*100:.0f}% (enables unconditional generation)")
print(f"  - Validation: Every {config.validate_every_n_epochs} epoch(s)")
print(f"\nLogging:")
print(f"  - Reporting to: {config.report_to}")
print(f"  - W&B Project: {config.wandb_project}")
print(f"  - Run name: {config.wandb_run_name or 'auto'}")

In [None]:
import random
from PIL import Image
import torchvision.transforms as transforms

# Define collate function with caption dropout
def collate_fn(batch):
    """
    Custom collate function that:
    1. Handles variable-sized inputs
    2. Applies caption dropout for classifier-free guidance
    3. Normalizes pose maps to [-1, 1]
    """
    images = []
    poses = []
    captions = []
    
    for sample in batch:
        images.append(sample['image'])
        
        # Convert pose PIL Image to tensor and normalize to [-1, 1]
        pose_tensor = transforms.ToTensor()(sample['pose'])
        poses.append(pose_tensor * 2 - 1)  # Normalize from [0, 1] to [-1, 1]
        
        # Caption dropout: 50% chance to use empty caption
        caption = sample['captions'][0] if sample['captions'] else ""
        if random.random() < config.caption_dropout_prob:
            caption = ""
        captions.append(caption)
    
    return {
        "images": torch.stack(images),
        "poses": torch.stack(poses),
        "captions": captions
    }

# Create training dataloader with caption dropout
train_dataloader = DataLoader(
    train_dataset,
    batch_size=config.train_batch_size,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

print(f"\n" + "="*70)
print(f"DATASET CONFIGURATION")
print(f"="*70)
print(f"‚úÖ Total training samples: {len(train_dataset)} (USING ALL AVAILABLE DATA)")
print(f"‚úÖ Batches per epoch: {len(train_dataloader)}")
print(f"‚úÖ Total training iterations: {len(train_dataloader) * config.num_epochs}")
print(f"‚úÖ Caption dropout: {config.caption_dropout_prob*100:.0f}% of samples use empty captions")
print(f"‚úÖ Pose conditioning normalized to [-1, 1] range")
print(f"="*70)

In [None]:
# Initialize models
print("Loading pretrained models...")

# Load tokenizer and text encoder
tokenizer = CLIPTokenizer.from_pretrained(
    config.pretrained_model_name, 
    subfolder="tokenizer"
)

text_encoder = CLIPTextModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="text_encoder"
)

# Load VAE
vae = AutoencoderKL.from_pretrained(
    config.pretrained_model_name, 
    subfolder="vae"
)

# Load UNet
unet = UNet2DConditionModel.from_pretrained(
    config.pretrained_model_name, 
    subfolder="unet"
)

# Initialize ControlNet from UNet
print("Initializing ControlNet...")
controlnet = ControlNetModel.from_unet(
    unet,
    conditioning_channels=config.controlnet_conditioning_channels
)

# Load noise scheduler
noise_scheduler = DDPMScheduler.from_pretrained(
    config.pretrained_model_name,
    subfolder="scheduler"
)

# Freeze VAE and text encoder - we only train ControlNet
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
unet.requires_grad_(False)

print("‚úì Models loaded successfully!")
print(f"  - ControlNet parameters: {sum(p.numel() for p in controlnet.parameters() if p.requires_grad):,}")
print(f"  - Text encoder (frozen): {sum(p.numel() for p in text_encoder.parameters()):,}")
print(f"  - UNet (frozen): {sum(p.numel() for p in unet.parameters()):,}")

# Clear memory after model loading (important for Kaggle)
clear_memory()
print_memory_usage()

In [None]:
# Setup optimizer and learning rate scheduler
if config.use_8bit_optimizer:
    # ‚úÖ 8-bit optimizer from bitsandbytes (memory efficient!)
    import bitsandbytes as bnb
    print(f"\n‚úÖ Initializing 8-bit optimizer from bitsandbytes...")
    optimizer = bnb.optim.AdamW8bit(
        controlnet.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-8,
    )
    print(f"‚úÖ 8-bit optimizer loaded successfully!")
else:
    # Standard AdamW optimizer (no extra dependencies needed)
    print(f"\n‚úÖ Initializing standard AdamW optimizer...")
    optimizer = torch.optim.AdamW(
        controlnet.parameters(),
        lr=config.learning_rate,
        betas=(0.9, 0.999),
        weight_decay=1e-2,
        eps=1e-8,
    )
    print(f"‚úÖ Standard AdamW optimizer loaded successfully!")

lr_scheduler = get_scheduler(
    "cosine",
    optimizer=optimizer,
    num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
    num_training_steps=len(train_dataloader) * config.num_epochs * config.gradient_accumulation_steps,
)

print(f"\n" + "="*70)
print(f"OPTIMIZER CONFIGURATION")
print(f"="*70)
print(f"‚úÖ Optimizer type: {'8-bit AdamW (bitsandbytes)' if config.use_8bit_optimizer else 'Standard AdamW'}")
print(f"‚úÖ Learning rate: {config.learning_rate}")
print(f"‚úÖ Warmup steps: {config.lr_warmup_steps}")
print(f"‚úÖ Total training steps: {len(train_dataloader) * config.num_epochs}")
print(f"‚úÖ Learning rate scheduler: cosine with warmup")
print(f"="*70)

In [None]:
# Initialize Accelerator for distributed training and mixed precision
accelerator = Accelerator(
    gradient_accumulation_steps=config.gradient_accumulation_steps,
    mixed_precision=config.mixed_precision,
    log_with=config.report_to,
    project_dir=config.logging_dir,
)

if config.report_to:
    tracker_config = {
        "learning_rate": config.learning_rate,
        "num_epochs": config.num_epochs,
        "train_batch_size": config.train_batch_size,
        "grad_accum": config.gradient_accumulation_steps,
        "num_training_samples": config.num_training_samples,
        "resolution": config.resolution,
    }
    init_kwargs = {}
    if "wandb" in str(config.report_to):
        init_kwargs["wandb"] = {"name": config.wandb_run_name} if config.wandb_run_name else {}
    accelerator.init_trackers(
        project_name=config.wandb_project,
        config=tracker_config,
        init_kwargs=init_kwargs or None,
    )

# Enable gradient checkpointing to save memory
if config.gradient_checkpointing:
    controlnet.enable_gradient_checkpointing()

# Prepare models with accelerator
controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
    controlnet, optimizer, train_dataloader, lr_scheduler
)

# Move models to device
unet.to(accelerator.device)
vae.to(accelerator.device)
text_encoder.to(accelerator.device)

# Set models to eval mode (only ControlNet is in training mode)
unet.eval()
vae.eval()
text_encoder.eval()

# ‚úÖ MULTI-GPU SUPPORT: Use DataParallel for multiple GPUs
num_gpus = torch.cuda.device_count()
if num_gpus > 1:
    print(f"\n‚úÖ MULTI-GPU DETECTED: {num_gpus} GPUs available!")
    print(f"   GPU Details: {torch.cuda.get_device_name(0)}")
    controlnet = torch.nn.DataParallel(controlnet)
    unet = torch.nn.DataParallel(unet)
    vae = torch.nn.DataParallel(vae)
    text_encoder = torch.nn.DataParallel(text_encoder)
    print(f"   ‚úÖ Models wrapped with DataParallel for distributed training")
else:
    print(f"\n‚úÖ SINGLE GPU MODE: Using {accelerator.device}")

print(f"\n" + "="*70)
print(f"ACCELERATE CONFIGURATION")
print(f"="*70)
print(f"‚úÖ Device: {accelerator.device}")
print(f"‚úÖ GPUs Available: {num_gpus}")
print(f"‚úÖ Mixed precision: {config.mixed_precision}")
print(f"‚úÖ Gradient accumulation steps: {config.gradient_accumulation_steps}")
print(f"‚úÖ Gradient checkpointing: {'ENABLED' if config.gradient_checkpointing else 'DISABLED'}")
print(f"\nüöÄ TRAINING READINESS SUMMARY:")
print(f"  - Effective Batch Size: {config.train_batch_size * config.gradient_accumulation_steps * num_gpus}")
print(f"  - Batch per GPU: {config.train_batch_size}")
print(f"  - GPUs in use: {num_gpus}")
print(f"  - Total Training Data: {len(train_dataloader) * config.train_batch_size} samples/epoch")
print(f"  - Memory Optimization: fp16 + gradient checkpointing + multi-GPU")
print(f"  - Expected Memory per GPU: ~4-6 GB")
print(f"="*70)

In [None]:
import math

import os

import torch

import torch.nn.functional as F

from tqdm.auto import tqdm


def train_controlnet():
    """
    Training function with multi-GPU support via DataParallel
    Prints epoch-level training loss summary.
    """
    
    os.makedirs(config.output_dir, exist_ok=True)
    
    # 1. Setup Device & Precision
    device = accelerator.device
    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16
        
    # Cast frozen models to weight_dtype for memory efficiency
    vae_model = vae.module if isinstance(vae, torch.nn.DataParallel) else vae
    text_encoder_model = text_encoder.module if isinstance(text_encoder, torch.nn.DataParallel) else text_encoder
    unet_model = unet.module if isinstance(unet, torch.nn.DataParallel) else unet
    
    vae_model.to(device, dtype=weight_dtype)
    text_encoder_model.to(device, dtype=weight_dtype)
    unet_model.to(device, dtype=weight_dtype)

    # 2. Calculate Steps
    num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
    max_train_steps = config.num_epochs * num_update_steps_per_epoch

    # 3. Setup Progress Bar
    progress_bar = tqdm(
        range(max_train_steps), 
        desc="Steps", 
        disable=not accelerator.is_local_main_process
    )
    
    accelerator.print(f"\n{'='*50}")
    accelerator.print(f"üöÄ TRAINING STARTING (Multi-GPU Enabled)")
    accelerator.print(f"{'='*50}")
    accelerator.print(f"Total Epochs: {config.num_epochs}")
    accelerator.print(f"GPUs: {torch.cuda.device_count()}")
    accelerator.print(f"Precision: {weight_dtype}")
    accelerator.print("Validation: disabled during training")
    
    global_step = 0
    
    for epoch in range(config.num_epochs):
        controlnet.train()
        train_loss = 0.0
        epoch_loss_sum = 0.0
        epoch_loss_count = 0
        
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(controlnet):
                pixel_values = batch["images"].to(device, dtype=torch.float32)
                controlnet_image = batch["poses"].to(device, dtype=torch.float32)
                captions = batch["captions"]
                
                with accelerator.autocast():
                    # VAE encode
                    with torch.no_grad():
                        latents = vae_model.encode(pixel_values.to(dtype=weight_dtype)).latent_dist.sample()
                        latents = latents * vae_model.config.scaling_factor

                    # Add noise
                    noise = torch.randn_like(latents)
                    bsz = latents.shape[0]
                    timesteps = torch.randint(
                        0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device
                    ).long()

                    noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)

                    # Text embeddings
                    with torch.no_grad():
                        inputs = tokenizer(
                            captions, 
                            max_length=tokenizer.model_max_length, 
                            padding="max_length", 
                            truncation=True, 
                            return_tensors="pt"
                        )
                        encoder_hidden_states = text_encoder_model(inputs.input_ids.to(device))[0]

                    # ControlNet forward
                    down_block_res_samples, mid_block_res_sample = controlnet(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=encoder_hidden_states,
                        controlnet_cond=controlnet_image,
                        return_dict=False,
                    )

                    # UNet forward
                    model_pred = unet_model(
                        noisy_latents,
                        timesteps,
                        encoder_hidden_states=encoder_hidden_states,
                        down_block_additional_residuals=down_block_res_samples,
                        mid_block_additional_residual=mid_block_res_sample,
                    ).sample

                    # Loss
                    loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")

                # Backprop
                avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
                train_loss += avg_loss.item() / config.gradient_accumulation_steps
                
                accelerator.backward(loss)
                
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                    optimizer.step()
                    lr_scheduler.step()
                    optimizer.zero_grad()
                    
                    progress_bar.update(1)
                    global_step += 1
                    
                    # Accumulate epoch loss summary
                    epoch_loss_sum += avg_loss.item()
                    epoch_loss_count += 1
                    
                    logs = {"loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}
                    progress_bar.set_postfix(**logs)
                    accelerator.log(logs, step=global_step)
                    train_loss = 0.0

                    if hasattr(config, 'checkpointing_steps') and global_step % config.checkpointing_steps == 0:
                         save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                         accelerator.save_state(save_path)

        # End of Epoch
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            # Print epoch-level training loss summary
            if epoch_loss_count > 0:
                epoch_loss_avg = epoch_loss_sum / epoch_loss_count
                accelerator.print(f"üìâ Epoch {epoch+1}/{config.num_epochs} - Training Loss: {epoch_loss_avg:.6f}")
                # Log epoch loss for dashboards (e.g., TensorBoard/W&B)
                accelerator.log({"train_loss_epoch": epoch_loss_avg}, step=global_step)

            # Get the unwrapped model for saving
            controlnet_unwrapped = controlnet.module if isinstance(controlnet, torch.nn.DataParallel) else controlnet
            save_path = os.path.join(config.output_dir, f"epoch-{epoch+1}")
            controlnet_unwrapped.save_pretrained(save_path)
            accelerator.print(f"‚úÖ Epoch {epoch+1} Saved: {save_path}")
            
            # Auto-cleanup: Delete previous epoch checkpoint
            if epoch > 0:
                prev_epoch_path = os.path.join(config.output_dir, f"epoch-{epoch}")
                if os.path.exists(prev_epoch_path):
                    import shutil
                    shutil.rmtree(prev_epoch_path)
                    accelerator.print(f"üßπ Cleaned up previous checkpoint: epoch-{epoch}")
            
            if torch.cuda.is_available():
                torch.cuda.empty_cache()

    # Final Save
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        controlnet_unwrapped = controlnet.module if isinstance(controlnet, torch.nn.DataParallel) else controlnet
        final_path = os.path.join(config.output_dir, "controlnet_final")
        controlnet_unwrapped.save_pretrained(final_path)
        accelerator.print(f"üéâ Training Complete! Final model saved to {final_path}")

    accelerator.end_training()
    
    return os.path.join(config.output_dir, "controlnet_final")


In [None]:
# Start training!
print("üöÄ Starting training...")
print(f"‚è∞ Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")

trained_controlnet = train_controlnet()

print(f"\n‚è∞ End time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

## Test the Trained ControlNet

Generate images using the trained ControlNet with pose skeleton conditioning.


In [None]:
# Load the trained ControlNet and create pipeline

from diffusers import StableDiffusionControlNetPipeline

from PIL import Image



# Clear memory before inference (important for Kaggle)

clear_memory()

print_memory_usage()



print("Loading trained ControlNet pipeline...")



# Load the saved ControlNet

controlnet_path = os.path.join(config.output_dir, "controlnet_final")

controlnet_trained = ControlNetModel.from_pretrained(controlnet_path, torch_dtype=torch.float16)



# Create inference pipeline

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    config.pretrained_model_name,

    controlnet=controlnet_trained,

    torch_dtype=torch.float16,

    safety_checker=None,

)

pipe = pipe.to(accelerator.device)



# Enable memory efficient attention (optional, if xformers is available)

try:

    pipe.enable_xformers_memory_efficient_attention()

    print("‚úì XFormers memory efficient attention enabled")

except Exception as e:

    print(f"‚ö†Ô∏è  XFormers not available, using default attention: {e}")

    print("   (This is fine, just uses a bit more memory)")



print("‚úì Pipeline ready for inference!")



# --- Quick validation preview on a single held-out sample ---

# Use validation dataset, not training, to avoid optimistic bias

val_sample = val_dataset[0]

val_pose = val_sample['pose']

val_caption = val_sample['captions'][0] if val_sample['captions'] else "a person"



# Convert pose to tensor for pipeline input

val_pose_tensor = transforms.ToTensor()(val_pose)

val_pose_input = val_pose_tensor.unsqueeze(0).to(accelerator.device, dtype=torch.float16)



# Generate one preview image using the trained controlnet + validation pose

try:

    preview_image = pipe(
        prompt=val_caption,

        image=val_pose_input,

        num_inference_steps=20,

        guidance_scale=7.5,

    ).images[0]

    print("‚úì Generated a validation preview image")

except Exception as e:

    preview_image = None

    print(f"‚ö†Ô∏è  Preview generation failed: {str(e)[:80]}")



# Visualize original image, pose, and generated preview

import matplotlib.pyplot as plt

fig, axes = plt.subplots(1, 3, figsize=(16, 5))

axes[0].imshow(val_sample['image'].permute(1, 2, 0).cpu().numpy() * 0.5 + 0.5)

axes[0].set_title('Validation: Original Image')

axes[0].axis('off')



axes[1].imshow(val_pose_tensor.squeeze().cpu().numpy(), cmap='gray')

axes[1].set_title('Validation: Pose Skeleton')

axes[1].axis('off')



axes[2].imshow(preview_image if preview_image is not None else np.zeros((256, 256, 3)))

axes[2].set_title('Validation: Generated Image' if preview_image is not None else 'Generation Failed')

axes[2].axis('off')

plt.show()

In [None]:
# Test Trained ControlNet on 25 Validation Images
import matplotlib.pyplot as plt
from matplotlib.gridspec import GridSpec
import cv2

print("="*80)
print("TESTING TRAINED CONTROLNET ON VALIDATION IMAGES")
print("="*80)

# Number of images to test
num_test_images = min(25, len(val_dataset))
print(f"\nGenerating results for {num_test_images} validation images...")
print(f"Each will show: [Original Image] [Pose Skeleton] [Generated Image]\n")

# Create figure with subplots for all results
fig = plt.figure(figsize=(20, 5 * num_test_images))
gs = GridSpec(num_test_images, 3, figure=fig, hspace=0.4, wspace=0.2)

# Set seed for reproducibility
torch.manual_seed(42)
generator = torch.Generator(device=accelerator.device).manual_seed(42)

generated_images = []
test_results = []

# Process each validation image
for idx in range(num_test_images):
    print(f"Processing image {idx+1}/{num_test_images}...")
    
    # Get validation sample
    test_sample = val_dataset[idx]
    test_image = test_sample['image']
    test_pose = test_sample['pose']
    test_caption = test_sample['captions'][0] if test_sample['captions'] else "a person standing"
    
    # Store for results
    test_results.append({
        'caption': test_caption,
        'image_id': test_sample['image_id'],
        'keypoints': test_sample['raw_keypoints']
    })
    
    # Convert pose to tensor
    test_pose_tensor = transforms.ToTensor()(test_pose)
    
    # Prepare pose input (single channel, normalized to [-1, 1])
    test_pose_input = test_pose_tensor.unsqueeze(0).to(accelerator.device, dtype=torch.float16)
    
    # Generate image
    try:
        output = pipe(
            prompt=test_caption,
            image=test_pose_input,
            num_inference_steps=20,
            generator=generator,
            guidance_scale=7.5,
        ).images[0]
        generated_images.append(output)
    except Exception as e:
        print(f"  ‚ö†Ô∏è  Error generating image: {str(e)[:50]}")
        generated_images.append(None)
        continue
    
    # --- Plot Results for this image ---
    
    # Column 1: Original Image
    ax1 = fig.add_subplot(gs[idx, 0])
    orig_img = test_image.permute(1, 2, 0).cpu().numpy()
    orig_img = (orig_img * 0.5 + 0.5)
    orig_img = np.clip(orig_img, 0, 1)
    ax1.imshow(orig_img)
    ax1.set_title(f'Image {idx+1}: Original', fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Column 2: Pose Skeleton
    ax2 = fig.add_subplot(gs[idx, 1])
    ax2.imshow(test_pose_tensor.squeeze().cpu().numpy(), cmap='gray')
    num_keypoints = (test_sample['raw_keypoints'][:, 2] > 0).sum()
    ax2.set_title(f'Pose Skeleton\n({int(num_keypoints)} keypoints visible)', 
                  fontsize=12, fontweight='bold')
    ax2.axis('off')
    
    # Column 3: Generated Image
    ax3 = fig.add_subplot(gs[idx, 2])
    if generated_images[idx] is not None:
        ax3.imshow(generated_images[idx])
        ax3.set_title(f'Generated Image', fontsize=12, fontweight='bold')
    else:
        ax3.text(0.5, 0.5, 'Generation Failed', ha='center', va='center', 
                fontsize=14, color='red', fontweight='bold')
        ax3.set_title('Generated Image', fontsize=12, fontweight='bold')
    ax3.axis('off')
    
    # Add caption as suptitle for this row
    caption_text = test_caption if len(test_caption) <= 60 else test_caption[:57] + "..."
    fig.text(0.5, 0.98 - (idx * (1/num_test_images)) - 0.015, 
            f'Caption: "{caption_text}"', 
            ha='center', fontsize=10, style='italic', alpha=0.7)

plt.suptitle(f'ControlNet Validation Results - {num_test_images} Images\n' + 
             'Original Image | Pose Skeleton (from COCO) | Generated Image', 
             fontsize=16, fontweight='bold', y=0.995)

plt.show()

print("\n" + "="*80)
print("RESULTS SUMMARY")
print("="*80)

successful_generations = sum(1 for img in generated_images if img is not None)
print(f"‚úì Successfully generated: {successful_generations}/{num_test_images} images")
print(f"‚úì Success rate: {100*successful_generations/num_test_images:.1f}%")

print(f"\nValidation Captions & Keypoints:")
for i, result in enumerate(test_results[:5]):  # Show first 5 as examples
    print(f"  Image {i+1}: '{result['caption'][:50]}...' ({int(result['keypoints'][:, 2].sum())} keypoints)")

print(f"\n‚úì Validation testing complete!")
print(f"‚úì All {num_test_images} validation images tested with their COCO poses and custom captions")

Saving  the Model Weights