In [None]:
%%writefile train_human_pose.py

def main():

    # Memory optimization utilities for Kaggle
    import gc
    import torch
    
    def clear_memory():
        """Clear GPU and system memory"""
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
            torch.cuda.synchronize()
        print("Memory cleared")
    
    def print_memory_usage():
        """Print current GPU memory usage"""
        if torch.cuda.is_available():
            allocated = torch.cuda.memory_allocated() / 1024**3
            reserved = torch.cuda.memory_reserved() / 1024**3
            print(f"GPU Memory: {allocated:.2f}GB allocated, {reserved:.2f}GB reserved")
        
    print("Memory utilities loaded")
    
    import torch
    from torch.utils.data import Dataset, DataLoader
    from PIL import Image
    import numpy as np
    import torchvision.transforms as transforms
    import os
    from pycocotools.coco import COCO
    import requests
    from tqdm import tqdm
    import json
    import cv2
    import zipfile
    
    class COCOPoseDataset(Dataset):
        def __init__(self, root_dir='/kaggle/working/coco_data', split='train', transform=None, image_size=512, 
                     custom_captions_file=None, max_samples=None, download=True):
            """
            Custom Dataset with COCO Pose Skeletons + Custom Captions
            """
            self.root_dir = root_dir
            self.split = split
            self.transform = transform
            self.image_size = image_size
            self.custom_captions_file = custom_captions_file
            
            # COCO 2017 URLs
            self.annotation_urls = {
                'train': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
                'val': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
            }
            
            # Setup paths
            self.ann_dir = os.path.join(root_dir, 'annotations')
            self.img_dir = os.path.join(root_dir, f'{split}2017')
            
            split_name = 'train' if split == 'train' else 'val'
            self.ann_file = os.path.join(self.ann_dir, f'person_keypoints_{split_name}2017.json')
            
            # Create directories
            os.makedirs(self.ann_dir, exist_ok=True)
            os.makedirs(self.img_dir, exist_ok=True)
            
            # Load custom captions
            self.custom_captions = None
            self.custom_caption_map = {}
            self.img_ids = []
            
            if not custom_captions_file:
                raise ValueError(f"custom_captions_file is required.")
            
            # Check if caption file exists
            caption_path = custom_captions_file
            if not os.path.exists(caption_path):
                # Try absolute path if relative doesn't work
                caption_path = os.path.abspath(custom_captions_file)
            
            if not os.path.exists(caption_path):
                raise FileNotFoundError(
                    f"Caption file not found: {custom_captions_file}\n"
                )
            
            print(f"Loading custom captions from {caption_path}...")
            with open(caption_path, 'r') as f:
                self.custom_captions = json.load(f)
            print(f"Loaded {len(self.custom_captions)} custom captions")
            
            # Download COCO annotations if needed
            if download and not os.path.exists(self.ann_file):
                print(f"Annotation file not found. Downloading COCO 2017 annotations...")
                self._download_annotations()
            
            # Check if annotation file exists
            if not os.path.exists(self.ann_file):
                raise FileNotFoundError(
                    f"Annotation file not found: {self.ann_file}\n"
                    f"Download COCO 2017 annotations from:\n"
                    f"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n"
                    f"Extract to: {self.ann_dir}"
                )
            
            print(f"Loading COCO {split} pose annotations...")
            self.coco = COCO(self.ann_file)
            
            # Map images from caption file
            print("Setting up dataset with images, poses from COCO, and custom captions...")
            items_to_process = list(self.custom_captions.items())
            if max_samples:
                items_to_process = items_to_process[:max_samples]
            
            for img_filename, caption in items_to_process:
                # Extract image ID from filename (e.g., '000000391895.jpg' -> 391895)
                img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
                self.img_ids.append(img_id)
                self.custom_caption_map[img_id] = caption
            
            limit_msg = f" (limited to first {max_samples})" if max_samples else ""
            print(f"Dataset ready with {len(self.img_ids)} images")
            print(f"  - Images from: {self.img_dir}")
            print(f"  - Captions from: {custom_captions_file}{limit_msg}")
            print(f"  - Poses from: COCO person_keypoints annotations")
            print(f"Using actual pose skeletons as conditioning.\n")
            
            if len(self.img_ids) == 0:
                print("\nERROR: No images found in caption file.")
        
        def _download_annotations(self):
            """Download COCO annotations"""
            url = self.annotation_urls[self.split]
            zip_path = os.path.join(self.root_dir, 'annotations.zip')
            
            print(f"Downloading from {url}...")
            response = requests.get(url, stream=True)
            total_size = int(response.headers.get('content-length', 0))
            
            with open(zip_path, 'wb') as f:
                with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                    for chunk in response.iter_content(chunk_size=8192):
                        f.write(chunk)
                        pbar.update(len(chunk))
            
            print("Extracting annotations...")
            with zipfile.ZipFile(zip_path, 'r') as zip_ref:
                zip_ref.extractall(self.root_dir)

            if os.path.exists(zip_path):
                os.remove(zip_path)
            
            print("Annotations downloaded successfully!")
        
        def _download_image(self, img_id, img_filename):
            """Download a single image from COCO dataset on-the-fly"""
            img_path = os.path.join(self.img_dir, img_filename)
            
            # If image already exists, skip download
            if os.path.exists(img_path):
                return img_path
            
            # Get image info from COCO API
            img_info = self.coco.loadImgs(img_id)[0]
            img_url = img_info['coco_url']
            
            # Download image
            try:
                response = requests.get(img_url, timeout=10)
                response.raise_for_status()
                
                # Save image
                with open(img_path, 'wb') as f:
                    f.write(response.content)
                
                return img_path
            except Exception as e:
                raise RuntimeError(f"Failed to download image {img_filename} from {img_url}: {str(e)}")
        
        def __len__(self):
            return len(self.img_ids)
        
        def __getitem__(self, idx):
            img_id = self.img_ids[idx]
            img_filename = list(self.custom_captions.keys())[idx]
            img_path = os.path.join(self.img_dir, img_filename)
            
            # Download image if it doesn't exist
            if not os.path.exists(img_path):
                # print(f"Downloading image: {img_filename}...")
                img_path = self._download_image(img_id, img_filename)
            
            # Try to load image, skip if corrupted
            max_retries = 3
            retry_count = 0
            image = None
            
            while retry_count < max_retries and image is None:
                try:
                    image = Image.open(img_path).convert('RGB')
                    width, height = image.size
                    break
                except Exception as e:
                    retry_count += 1
                    print(f"Failed to load image {img_filename} (attempt {retry_count}/{max_retries}): {str(e)[:50]}")
                    
                    # Delete corrupted file and try re-downloading
                    if os.path.exists(img_path):
                        os.remove(img_path)
                        print(f"Deleted corrupted file: {img_filename}")
                    
                    if retry_count < max_retries:
                        try:
                            print(f"Re-downloading image...")
                            img_path = self._download_image(img_id, img_filename)
                        except Exception as download_err:
                            print(f"Download failed: {str(download_err)[:50]}")
            
            # If still can't load, return a placeholder sample
            if image is None:
                print(f"Giving up on {img_filename}, returning next valid image instead...")
                # Try next image in dataset
                next_idx = (idx + 1) % len(self.img_ids)
                if next_idx != idx:  # Avoid infinite loop
                    return self.__getitem__(next_idx)
                else:
                    # Return black image as fallback
                    image = Image.new('RGB', (self.image_size, self.image_size), color='black')
                    width, height = self.image_size, self.image_size
            
            # Get caption
            caption = self.custom_caption_map[img_id]
            
            # Extract pose keypoints from COCO
            ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.coco.getCatIds(catNms=['person']), iscrowd=False)
            anns = self.coco.loadAnns(ann_ids)
            
            # Get pose keypoints from first person with visible keypoints
            keypoints = None
            for ann in anns:
                if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                    keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                    break
            
            if keypoints is None:
                keypoints = np.zeros((17, 3))
            
            # Create actual pose skeleton
            pose_skeleton = self.create_pose_skeleton(keypoints, width, height)
            
            # Apply transforms
            if self.transform:
                image = self.transform(image)
            else:
                image = transforms.Compose([
                    transforms.Resize((self.image_size, self.image_size)),
                    transforms.ToTensor(),
                    transforms.Normalize([0.5], [0.5])
                ])(image)
            
            pose_map = transforms.Compose([
                transforms.ToPILImage(),
                transforms.Resize((self.image_size, self.image_size)),
            ])(pose_skeleton)
            
            return {
                'image': image,
                'pose': pose_map,
                'raw_keypoints': keypoints,
                'image_id': img_id,
                'captions': [caption]  # Your custom caption
            }
        
        def create_pose_skeleton(self, keypoints, width, height):
            """
            Create actual human pose skeleton from COCO keypoint annotations
            Draws a stick figure with bones connecting the 17 joints
            """
            # Create black canvas
            pose_img = np.zeros((height, width), dtype=np.uint8)
            
            # COCO keypoint skeleton connections (bones linking joints)
            skeleton = [
                (0, 1), (0, 2),           # nose to eyes
                (1, 3), (2, 4),           # eyes to ears
                (0, 5), (0, 6),           # nose to shoulders
                (5, 7), (7, 9),           # left arm (shoulder -> elbow -> wrist)
                (6, 8), (8, 10),          # right arm (shoulder -> elbow -> wrist)
                (5, 11), (6, 12),         # shoulders to hips
                (11, 12),                 # hip to hip
                (11, 13), (13, 15),       # left leg (hip -> knee -> ankle)
                (12, 14), (14, 16)        # right leg (hip -> knee -> ankle)
            ]
            
            # Line thickness and circle radius scale with image size
            line_thickness = max(2, int(min(width, height) / 100))
            circle_radius = max(3, int(min(width, height) / 80))
            
            # Draw bones (connections)
            for start_idx, end_idx in skeleton:
                if start_idx < len(keypoints) and end_idx < len(keypoints):
                    x1, y1, v1 = keypoints[start_idx]
                    x2, y2, v2 = keypoints[end_idx]
                    
                    # Draw line only if both keypoints are visible (v > 0)
                    if v1 > 0 and v2 > 0:
                        cv2.line(pose_img, (int(x1), int(y1)), (int(x2), int(y2)), 
                                255, line_thickness, cv2.LINE_AA)
            
            # Draw keypoint circles on top of bones
            for i, (x, y, v) in enumerate(keypoints):
                if v > 0:  # Only draw visible keypoints
                    cv2.circle(pose_img, (int(x), int(y)), circle_radius, 255, -1)
            
            return pose_img
    
    # Create datasets with COCO poses + custom captions
    train_dataset = COCOPoseDataset(
        root_dir='/kaggle/working/coco_data',
        split='train',
        image_size=512,
        custom_captions_file='/kaggle/input/captions/train_captions.json',
        download=True
    )
    
    # Validation dataset
    val_dataset = COCOPoseDataset(
        root_dir='/kaggle/working/coco_data',
        split='val',
        image_size=512,
        custom_captions_file='/kaggle/input/captions/val_captions.json',
        max_samples=None,
        download=True
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True,
        num_workers=0
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=4,
        shuffle=False,
        num_workers=0
    )
    
    print(f"\nDataset loaded successfully.")
    print(f"Training samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
    
    # Show a sample
    if len(train_dataset) > 0:
        sample = train_dataset[0]
        if sample['captions']:
            print(f"\nSample caption from training:")
            caption = sample['captions'][0]
            print(f"   \"{caption[:200]}...\"" if len(caption) > 200 else f"   \"{caption}\"")
        print(f"Sample has {(sample['raw_keypoints'][:, 2] > 0).sum()} visible pose keypoints")
    
    import matplotlib.pyplot as plt
    
    # Get one sample from the training dataset
    sample = train_dataset[0]
    img_id = sample['image_id']
    
    # Get annotations for this image from COCO
    ann_ids = train_dataset.coco.getAnnIds(imgIds=img_id)
    anns = train_dataset.coco.loadAnns(ann_ids)
    img_info = train_dataset.coco.loadImgs(img_id)[0]
    
    # Get captions (text prompts)
    captions = sample['captions']
    
    # Build metadata text
    metadata = f"Image ID: {img_id}\nFilename: {img_info['file_name']}\n"
    metadata += f"Size: {img_info['width']}x{img_info['height']}\n"
    metadata += f"Person annotations: {len([a for a in anns if a.get('category_id') == 1])}\n"
    
    for i, ann in enumerate(anns):
        if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
            metadata += f"\nPerson {i+1}: {ann['num_keypoints']} keypoints"
            if 'area' in ann:
                metadata += f", area: {int(ann['area'])}"
            break
    
    # Convert tensors back to displayable format
    image = sample['image'].permute(1, 2, 0).cpu().numpy()
    image = (image * 0.5 + 0.5)  # Denormalize from [-1, 1] to [0, 1]
    image = np.clip(image, 0, 1)
    
    # Convert PIL Image to numpy array
    pose_map = np.array(sample['pose'])
    
    # Create a figure with subplots
    fig = plt.figure(figsize=(16, 10))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    
    # Original image
    ax1 = fig.add_subplot(gs[0:2, 0])
    ax1.imshow(image)
    ax1.set_title('Original Image', fontsize=14, fontweight='bold')
    ax1.axis('off')
    
    # Pose skeleton
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.imshow(pose_map, cmap='gray')
    ax2.set_title('Pose Skeleton', fontsize=14, fontweight='bold')
    ax2.axis('off')
    
    # Overlay
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.imshow(image)
    ax3.imshow(pose_map, cmap='hot', alpha=0.5)
    ax3.set_title('Overlay', fontsize=14, fontweight='bold')
    ax3.axis('off')
    
    # Text Captions/Prompts
    ax_captions = fig.add_subplot(gs[1, 1:3])
    ax_captions.axis('off')
    ax_captions.text(0.05, 0.95, 'Text Prompts:', 
                    fontsize=13, fontweight='bold', va='top')
    
    if captions:
        caption_text = "\n\n".join([f"{i+1}. {cap}" for i, cap in enumerate(captions)])
    else:
        caption_text = "No captions available for this image."
    
    ax_captions.text(0.05, 0.80, caption_text, 
                    fontsize=10, va='top', wrap=True,
                    bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.7))
    
    # Metadata
    ax_meta = fig.add_subplot(gs[2, 0])
    ax_meta.axis('off')
    ax_meta.text(0.05, 0.95, 'Image Metadata:', 
                fontsize=12, fontweight='bold', va='top')
    ax_meta.text(0.05, 0.75, metadata, 
                fontsize=10, va='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
    
    # Pose keypoint details
    ax_kp = fig.add_subplot(gs[2, 1:3])
    ax_kp.axis('off')
    
    keypoint_names = [
        'nose', 'left_eye', 'right_eye', 'left_ear', 'right_ear',
        'left_shoulder', 'right_shoulder', 'left_elbow', 'right_elbow',
        'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
        'left_knee', 'right_knee', 'left_ankle', 'right_ankle'
    ]
    
    keypoint_text = "Keypoint Details:\n" + "â”€" * 40 + "\n"
    for i, (x, y, v) in enumerate(sample['raw_keypoints']):
        if v > 0:  # visible keypoint
            visibility = "visible" if v == 2 else "occluded"
            keypoint_text += f"{keypoint_names[i]:15s}: ({int(x):3d}, {int(y):3d}) - {visibility}\n"
    
    ax_kp.text(0.05, 0.95, keypoint_text, 
              fontsize=9, va='top', family='monospace')
    
    plt.show()
    
    print(f"Image shape: {sample['image'].shape}")
    print(f"Pose skeleton shape: {pose_map.shape}")
    print(f"Number of visible keypoints: {(sample['raw_keypoints'][:, 2] > 0).sum()}")
    print(f"Number of captions: {len(captions)}")
    
    import torch
    import torch.nn.functional as F
    from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler, UNet2DConditionModel, AutoencoderKL
    from diffusers.optimization import get_scheduler
    from transformers import CLIPTextModel, CLIPTokenizer
    from accelerate import Accelerator
    from tqdm.auto import tqdm
    import os
    from datetime import datetime
    
    print("Imports successful.")

    # Training Configuration
    class TrainingConfig:
        # Model settings
        pretrained_model_name = "runwayml/stable-diffusion-v1-5"
        controlnet_conditioning_channels = 1  # Grayscale pose skeleton
        
        # Training settings
        num_training_samples = None  # None means use all images
        num_epochs = 50
        train_batch_size = 1
        gradient_accumulation_steps = 8
        learning_rate = 1e-4
        lr_warmup_steps = 375
        lr_scheduler_type = "consine"
        caption_dropout_prob = 0.5
        
        # Data paths
        train_captions_file = '/kaggle/input/captions/train_captions.json'
        
        # Image settings
        resolution = 512
        
        # Checkpointing & Validation
        output_dir = "/kaggle/working/controlnet_pose_output"
        validate_every_n_epochs = 1  # Generate validation samples every N epochs
        
        # Logging
        logging_dir = "./logs"
        report_to = "tensorboard"
        
        # Hardware & Optimization
        mixed_precision = "fp16"  # Use "bf16" if available, "no" for CPU
        gradient_checkpointing = True
        use_8bit_optimizer = False
        
        # Validation
        validation_steps = 500
        num_validation_images = 4
        validation_prompt = "a person standing"
    
    config = TrainingConfig()
    os.makedirs(config.output_dir, exist_ok=True)
    os.makedirs(config.logging_dir, exist_ok=True)
    
    import random
    from PIL import Image
    import torchvision.transforms as transforms
    
    # Define collate function with caption dropout
    def collate_fn(batch):
        images = []
        poses = []
        captions = []
        
        for sample in batch:
            images.append(sample['image'])
            
            # Convert pose PIL Image to tensor and normalize to [-1, 1]
            pose_tensor = transforms.ToTensor()(sample['pose'])
            poses.append(pose_tensor * 2 - 1)  # Normalize from [0, 1] to [-1, 1]
            
            # Caption dropout: 50% chance to use empty caption
            caption = sample['captions'][0] if sample['captions'] else ""
            if random.random() < config.caption_dropout_prob:
                caption = ""
            captions.append(caption)
        
        return {
            "images": torch.stack(images),
            "poses": torch.stack(poses),
            "captions": captions
        }
    
    # Create training dataloader with caption dropout
    train_dataloader = DataLoader(
        train_dataset,
        batch_size=config.train_batch_size,
        shuffle=True,
        num_workers=0,
        collate_fn=collate_fn
    )
    
        
    # Initialize models
    print("Loading pretrained models...")
    
    # Load tokenizer and text encoder
    tokenizer = CLIPTokenizer.from_pretrained(
        config.pretrained_model_name, 
        subfolder="tokenizer"
    )
    
    text_encoder = CLIPTextModel.from_pretrained(
        config.pretrained_model_name, 
        subfolder="text_encoder"
    )
    
    # Load VAE
    vae = AutoencoderKL.from_pretrained(
        config.pretrained_model_name, 
        subfolder="vae"
    )
    
    # Load UNet
    unet = UNet2DConditionModel.from_pretrained(
        config.pretrained_model_name, 
        subfolder="unet"
    )
    
    # Initialize ControlNet from UNet
    print("Initializing ControlNet...")
    controlnet = ControlNetModel.from_unet(
        unet,
        conditioning_channels=config.controlnet_conditioning_channels
    )
    
    # Load noise scheduler
    noise_scheduler = DDPMScheduler.from_pretrained(
        config.pretrained_model_name,
        subfolder="scheduler"
    )
    
    # Freeze VAE and text encoder - we only train ControlNet
    vae.requires_grad_(False)
    text_encoder.requires_grad_(False)
    unet.requires_grad_(False)
    
    print("Models loaded successfully.")

    # Clear memory after model loading
    clear_memory()
    print_memory_usage()
    
    # Setup optimizer and learning rate scheduler
        
    if config.use_8bit_optimizer:
        import bitsandbytes as bnb
        
        optimizer = bnb.optim.AdamW8bit(
            controlnet.parameters(),
            lr=config.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=1e-3,
            eps=1e-8,
        )
    
    else:
        # Standard AdamW optimizer

        optimizer = torch.optim.AdamW(
            controlnet.parameters(),
            lr=config.learning_rate,
            betas=(0.9, 0.999),
            weight_decay=1e-3,
            eps=1e-8,
        )
    
    lr_scheduler = get_scheduler(
        "cosine",
        optimizer=optimizer,
        num_warmup_steps=config.lr_warmup_steps * config.gradient_accumulation_steps,
        num_training_steps=len(train_dataloader) * config.num_epochs * config.gradient_accumulation_steps,
    )
    
    
    # Initialize Accelerator for distributed training and mixed precision
        
    accelerator = Accelerator(
        gradient_accumulation_steps=config.gradient_accumulation_steps,
        mixed_precision=config.mixed_precision,
        log_with=config.report_to,
        project_dir=config.logging_dir,
    )
    
    # Enable gradient checkpointing to save memory
    if config.gradient_checkpointing:
        controlnet.enable_gradient_checkpointing()
    
    # Prepare models with accelerator
    controlnet, optimizer, train_dataloader, lr_scheduler = accelerator.prepare(
        controlnet, optimizer, train_dataloader, lr_scheduler
    )
    
    # Move models to device
    unet.to(accelerator.device)
    vae.to(accelerator.device)
    text_encoder.to(accelerator.device)
    
    # Set models to eval mode (only ControlNet is in training mode)
    unet.eval()
    vae.eval()
    text_encoder.eval()
    
    # For training on multi-gpu
    num_gpus = torch.cuda.device_count()

    import math
    import os
    import torch
    import torch.nn.functional as F
    from tqdm.auto import tqdm
    
    def train_controlnet():
        """
        Training function with multi-GPU support via DataParallel
        Prints epoch-level training loss summary.
        """
        
        os.makedirs(config.output_dir, exist_ok=True)
        
        device = accelerator.device
        weight_dtype = torch.float32
        if accelerator.mixed_precision == "fp16":
            weight_dtype = torch.float16
        elif accelerator.mixed_precision == "bf16":
            weight_dtype = torch.bfloat16
            
        # Cast frozen models to weight_dtype for memory efficiency
        vae_model = vae.module if isinstance(vae, torch.nn.DataParallel) else vae
        text_encoder_model = text_encoder.module if isinstance(text_encoder, torch.nn.DataParallel) else text_encoder
        unet_model = unet.module if isinstance(unet, torch.nn.DataParallel) else unet
        
        vae_model.to(device, dtype=weight_dtype)
        text_encoder_model.to(device, dtype=weight_dtype)
        unet_model.to(device, dtype=weight_dtype)
    
        num_update_steps_per_epoch = math.ceil(len(train_dataloader) / config.gradient_accumulation_steps)
        max_train_steps = config.num_epochs * num_update_steps_per_epoch
    
        progress_bar = tqdm(
            range(max_train_steps), 
            desc="Steps", 
            disable=not accelerator.is_local_main_process
        )
        
        global_step = 0
        
        for epoch in range(config.num_epochs):
            controlnet.train()
            train_loss = 0.0
            epoch_loss_sum = 0.0
            epoch_loss_count = 0
            
            for step, batch in enumerate(train_dataloader):
                with accelerator.accumulate(controlnet):
                    pixel_values = batch["images"].to(device, dtype=torch.float32)
                    controlnet_image = batch["poses"].to(device, dtype=torch.float32)
                    captions = batch["captions"]
                    
                    with accelerator.autocast():
                        # VAE encode
                        with torch.no_grad():
                            latents = vae_model.encode(pixel_values.to(dtype=weight_dtype)).latent_dist.sample()
                            latents = latents * vae_model.config.scaling_factor
    
                        # Add noise
                        noise = torch.randn_like(latents)
                        bsz = latents.shape[0]
                        timesteps = torch.randint(
                            0, noise_scheduler.config.num_train_timesteps, (bsz,), device=device
                        ).long()
    
                        noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
    
                        # Text embeddings
                        with torch.no_grad():
                            inputs = tokenizer(
                                captions, 
                                max_length=tokenizer.model_max_length, 
                                padding="max_length", 
                                truncation=True, 
                                return_tensors="pt"
                            )
                            encoder_hidden_states = text_encoder_model(inputs.input_ids.to(device))[0]
    
                        # ControlNet forward
                        down_block_res_samples, mid_block_res_sample = controlnet(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states=encoder_hidden_states,
                            controlnet_cond=controlnet_image,
                            return_dict=False,
                        )
    
                        # UNet forward
                        model_pred = unet_model(
                            noisy_latents,
                            timesteps,
                            encoder_hidden_states=encoder_hidden_states,
                            down_block_additional_residuals=down_block_res_samples,
                            mid_block_additional_residual=mid_block_res_sample,
                        ).sample
    
                        # Loss
                        loss = F.mse_loss(model_pred.float(), noise.float(), reduction="mean")
    
                    # Backprop
                    avg_loss = accelerator.gather(loss.repeat(config.train_batch_size)).mean()
                    train_loss += avg_loss.item() / config.gradient_accumulation_steps
                    
                    accelerator.backward(loss)
                    
                    if accelerator.sync_gradients:
                        accelerator.clip_grad_norm_(controlnet.parameters(), 1.0)
                        optimizer.step()
                        lr_scheduler.step()
                        optimizer.zero_grad()
                        
                        progress_bar.update(1)
                        global_step += 1
                        
                        # Accumulate epoch loss summary
                        epoch_loss_sum += avg_loss.item()
                        epoch_loss_count += 1
                        
                        logs = {"loss": train_loss, "lr": lr_scheduler.get_last_lr()[0]}
                        progress_bar.set_postfix(**logs)
                        accelerator.log(logs, step=global_step)
                        train_loss = 0.0
    
                        if hasattr(config, 'checkpointing_steps') and global_step % config.checkpointing_steps == 0:
                             save_path = os.path.join(config.output_dir, f"checkpoint-{global_step}")
                             accelerator.save_state(save_path)
    
            # End of Epoch
            accelerator.wait_for_everyone()
            if accelerator.is_main_process:
                # Print epoch-level training loss summary
                if epoch_loss_count > 0:
                    epoch_loss_avg = epoch_loss_sum / epoch_loss_count
                    accelerator.print(f"Epoch {epoch+1}/{config.num_epochs} - Training Loss: {epoch_loss_avg:.6f}")
                    accelerator.log({"train_loss_epoch": epoch_loss_avg}, step=global_step)
    
                controlnet_unwrapped = accelerator.unwrap_model(controlnet)
            
                save_path = os.path.join(config.output_dir, f"epoch-{epoch+1}")
                
                controlnet_unwrapped.save_pretrained(
                    save_path,
                    save_function=accelerator.save
                )

                accelerator.print(f"Epoch {epoch+1} Saved: {save_path}")
                
                # Auto-cleanup: Delete previous epoch checkpoint
                if epoch > 0:
                    prev_epoch_path = os.path.join(config.output_dir, f"epoch-{epoch}")
                    if os.path.exists(prev_epoch_path):
                        import shutil
                        shutil.rmtree(prev_epoch_path)
                        accelerator.print(f"Cleaned up previous checkpoint: epoch-{epoch}")
                
                if torch.cuda.is_available():
                    torch.cuda.empty_cache()
    
        # Final Save
        accelerator.wait_for_everyone()
        if accelerator.is_main_process:
            controlnet_unwrapped = controlnet.module if isinstance(controlnet, torch.nn.DataParallel) else controlnet
            final_path = os.path.join(config.output_dir, "controlnet_final")
            controlnet_unwrapped.save_pretrained(final_path)
            accelerator.print(f"Training Complete. Final model saved to {final_path}")
        
        return os.path.join(config.output_dir, "controlnet_final")
    
    # Start training
    print("Starting training...")
    print(f"Start time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
        
    trained_controlnet = train_controlnet()
        
    print(f"\nEnd time: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")

if __name__ == "__main__":
    main()

Writing train_human_pose.py


In [3]:
!accelerate launch --multi_gpu --num_processes=2 --mixed-precision=fp16 train_human_pose.py