In [None]:
import torch
import json
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from tqdm import tqdm
import os
import torchvision.transforms as transforms

from diffusers import (
    StableDiffusionControlNetPipeline,
    ControlNetModel
)

# CONFIG
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

CONTROLNET_PATH = "/kaggle/input/controlnet-humna-pose/pytorch/default/1"
BASE_MODEL = "runwayml/stable-diffusion-v1-5"

VAL_CAPTIONS_FILE = "/kaggle/input/captions/val_captions.json"
COCO_ROOT = "/kaggle/working/coco_data"

NUM_VALIDATION_IMAGES = 75
IMAGE_SIZE = 512

print("Validation config loaded")

OUTPUT_DIR = "/kaggle/working/validation_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

print(f"Saving generated images to: {OUTPUT_DIR}")

# LOAD CONTROLNET PIPELINE
print("Loading ControlNet pipeline...")

controlnet = ControlNetModel.from_pretrained(
    CONTROLNET_PATH,
    torch_dtype=torch.float16
)

pipe = StableDiffusionControlNetPipeline.from_pretrained(
    BASE_MODEL,
    controlnet=controlnet,
    torch_dtype=torch.float16,
    safety_checker=None,
)

pipe = pipe.to(DEVICE)
# pipe.enable_xformers_memory_efficient_attention()

print("Pipeline ready")

from pycocotools.coco import COCO
import cv2
import os

import torch
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
import torchvision.transforms as transforms
import os
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import json
import cv2
import zipfile

class COCOPoseDataset(Dataset):
    def __init__(self, root_dir='/kaggle/working/coco_data', split='train', transform=None, image_size=512, 
                 custom_captions_file=None, max_samples=None, download=True):
        """
        Custom Dataset with COCO Pose Skeletons + Custom Captions
        """
        self.root_dir = root_dir
        self.split = split
        self.transform = transform
        self.image_size = image_size
        self.custom_captions_file = custom_captions_file
        
        # COCO 2017 URLs
        self.annotation_urls = {
            'train': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip',
            'val': 'http://images.cocodataset.org/annotations/annotations_trainval2017.zip'
        }
        
        # Setup paths
        self.ann_dir = os.path.join(root_dir, 'annotations')
        self.img_dir = os.path.join(root_dir, f'{split}2017')
        
        split_name = 'train' if split == 'train' else 'val'
        self.ann_file = os.path.join(self.ann_dir, f'person_keypoints_{split_name}2017.json')
        
        # Create directories
        os.makedirs(self.ann_dir, exist_ok=True)
        os.makedirs(self.img_dir, exist_ok=True)
        
        # Load custom captions - REQUIRED
        self.custom_captions = None
        self.custom_caption_map = {}
        self.img_ids = []
        
        if not custom_captions_file:
            raise ValueError(f"custom_captions_file is REQUIRED! Please provide a JSON file with captions.")
        
        # Check if caption file exists (try absolute and relative paths)
        caption_path = custom_captions_file
        if not os.path.exists(caption_path):
            # Try absolute path if relative doesn't work
            caption_path = os.path.abspath(custom_captions_file)
        
        if not os.path.exists(caption_path):
            raise FileNotFoundError(
                f"Caption file not found: {custom_captions_file}\n"
                f"Tried paths:\n"
                f"  - {custom_captions_file}\n"
                f"  - {os.path.abspath(custom_captions_file)}\n"
                f"Current working directory: {os.getcwd()}\n"
                f"Please ensure the caption file exists with format: {{'image_filename.jpg': 'caption text', ...}}"
            )
        
        print(f"Loading custom captions from {caption_path}...")
        with open(caption_path, 'r') as f:
            self.custom_captions = json.load(f)
        print(f"Loaded {len(self.custom_captions)} custom captions")
        
        # Download COCO annotations if needed
        if download and not os.path.exists(self.ann_file):
            print(f"Annotation file not found. Downloading COCO 2017 annotations...")
            self._download_annotations()
        
        # Check if annotation file exists
        if not os.path.exists(self.ann_file):
            raise FileNotFoundError(
                f"Annotation file not found: {self.ann_file}\n"
                f"Please download COCO 2017 annotations from:\n"
                f"http://images.cocodataset.org/annotations/annotations_trainval2017.zip\n"
                f"Extract to: {self.ann_dir}"
            )
        
        print(f"Loading COCO {split} pose annotations...")
        self.coco = COCO(self.ann_file)
        
        # Map images from caption file
        print("Setting up dataset with images, poses from COCO, and custom captions...")
        items_to_process = list(self.custom_captions.items())
        if max_samples:
            items_to_process = items_to_process[:max_samples]
        
        for img_filename, caption in items_to_process:
            # Extract image ID from filename
            img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
            self.img_ids.append(img_id)
            self.custom_caption_map[img_id] = caption
        
        limit_msg = f" (limited to first {max_samples})" if max_samples else ""
        print(f"Dataset ready with {len(self.img_ids)} images")
        print(f"  - Images from: {self.img_dir}")
        print(f"  - Captions from: {custom_captions_file}{limit_msg}")
        print(f"  - Poses from: COCO person_keypoints annotations")
        print(f"Using ACTUAL POSE SKELETONS (stick figures) as conditioning!\n")
        
        if len(self.img_ids) == 0:
            print("\nERROR: No images found in caption file!")
    
    def _download_annotations(self):
        """Download COCO annotations"""
        url = self.annotation_urls[self.split]
        zip_path = os.path.join(self.root_dir, 'annotations.zip')
        
        print(f"Downloading from {url}...")
        response = requests.get(url, stream=True)
        total_size = int(response.headers.get('content-length', 0))
        
        with open(zip_path, 'wb') as f:
            with tqdm(total=total_size, unit='B', unit_scale=True) as pbar:
                for chunk in response.iter_content(chunk_size=8192):
                    f.write(chunk)
                    pbar.update(len(chunk))
        
        print("Extracting annotations...")
        with zipfile.ZipFile(zip_path, 'r') as zip_ref:
            zip_ref.extractall(self.root_dir)

        if os.path.exists(zip_path):
            os.remove(zip_path)
        
        print("Annotations downloaded successfully!")
    
    def _download_image(self, img_id, img_filename):
        """Download a single image from COCO dataset on-the-fly"""
        img_path = os.path.join(self.img_dir, img_filename)
        
        # If image already exists, skip download
        if os.path.exists(img_path):
            return img_path
        
        # Get image info from COCO API
        img_info = self.coco.loadImgs(img_id)[0]
        img_url = img_info['coco_url']
        
        # Download image
        try:
            response = requests.get(img_url, timeout=10)
            response.raise_for_status()
            
            # Save image
            with open(img_path, 'wb') as f:
                f.write(response.content)
            
            return img_path
        except Exception as e:
            raise RuntimeError(f"Failed to download image {img_filename} from {img_url}: {str(e)}")
    
    def __len__(self):
        return len(self.img_ids)
    
    def __getitem__(self, idx):
        img_id = self.img_ids[idx]
        img_filename = list(self.custom_captions.keys())[idx]
        img_path = os.path.join(self.img_dir, img_filename)
        
        # Download image if it doesn't exist (ON-THE-FLY DOWNLOAD)
        if not os.path.exists(img_path):
            # print(f"ðŸ“¥ Downloading image: {img_filename}...")
            img_path = self._download_image(img_id, img_filename)
        
        # Try to load image, skip if corrupted
        max_retries = 3
        retry_count = 0
        image = None
        
        while retry_count < max_retries and image is None:
            try:
                image = Image.open(img_path).convert('RGB')
                width, height = image.size
                break
            except Exception as e:
                retry_count += 1
                print(f"Failed to load image {img_filename} (attempt {retry_count}/{max_retries}): {str(e)[:50]}")
                
                # Delete corrupted file and try re-downloading
                if os.path.exists(img_path):
                    os.remove(img_path)
                    print(f"   Deleted corrupted file: {img_filename}")
                
                if retry_count < max_retries:
                    try:
                        print(f"   Re-downloading image...")
                        img_path = self._download_image(img_id, img_filename)
                    except Exception as download_err:
                        print(f"   Download failed: {str(download_err)[:50]}")
        
        if image is None:
            print(f"Giving up on {img_filename}, returning next valid image instead...")
            # Try next image in dataset
            next_idx = (idx + 1) % len(self.img_ids)
            if next_idx != idx:  # Avoid infinite loop
                return self.__getitem__(next_idx)
            else:
                # Return black image as fallback
                image = Image.new('RGB', (self.image_size, self.image_size), color='black')
                width, height = self.image_size, self.image_size
        
        # Get caption
        caption = self.custom_caption_map[img_id]
        
        ann_ids = self.coco.getAnnIds(imgIds=img_id, catIds=self.coco.getCatIds(catNms=['person']), iscrowd=False)
        anns = self.coco.loadAnns(ann_ids)
        
        # Get pose keypoints from first person with visible keypoints
        keypoints = None
        for ann in anns:
            if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                break
        
        if keypoints is None:
            keypoints = np.zeros((17, 3))
        
        pose_skeleton = self.create_pose_skeleton(keypoints, width, height)
        
        # Apply transforms
        if self.transform:
            image = self.transform(image)
        else:
            image = transforms.Compose([
                transforms.Resize((self.image_size, self.image_size)),
                transforms.ToTensor(),
                transforms.Normalize([0.5], [0.5])
            ])(image)
        
        pose_map = transforms.Compose([
            transforms.ToPILImage(),
            transforms.Resize((self.image_size, self.image_size)),
        ])(pose_skeleton)
        
        return {
            'image': image,
            'pose': pose_map,
            'raw_keypoints': keypoints,
            'image_id': img_id,
            'captions': [caption]  
        }
    
    def create_pose_skeleton(self, keypoints, width, height):
        """
        Create ACTUAL human pose skeleton from COCO keypoint annotations
        Draws a stick figure with bones connecting the 17 joints
        
        COCO Keypoints (17 total):
        0: nose, 1: L eye, 2: R eye, 3: L ear, 4: R ear,
        5: L shoulder, 6: R shoulder, 7: L elbow, 8: R elbow, 9: L wrist, 10: R wrist,
        11: L hip, 12: R hip, 13: L knee, 14: R knee, 15: L ankle, 16: R ankle
        """
        # Create black canvas
        pose_img = np.zeros((height, width), dtype=np.uint8)
        
        # COCO keypoint skeleton connections (bones linking joints)
        skeleton = [
            (0, 1), (0, 2),           # nose to eyes
            (1, 3), (2, 4),           # eyes to ears
            (0, 5), (0, 6),           # nose to shoulders
            (5, 7), (7, 9),           # left arm (shoulderâ†’elbowâ†’wrist)
            (6, 8), (8, 10),          # right arm (shoulderâ†’elbowâ†’wrist)
            (5, 11), (6, 12),         # shoulders to hips
            (11, 12),                 # hip to hip
            (11, 13), (13, 15),       # left leg (hipâ†’kneeâ†’ankle)
            (12, 14), (14, 16)        # right leg (hipâ†’kneeâ†’ankle)
        ]
        
        # Line thickness and circle radius scale with image size
        line_thickness = max(2, int(min(width, height) / 100))
        circle_radius = max(3, int(min(width, height) / 80))
        
        # Draw bones (connections) FIRST
        for start_idx, end_idx in skeleton:
            if start_idx < len(keypoints) and end_idx < len(keypoints):
                x1, y1, v1 = keypoints[start_idx]
                x2, y2, v2 = keypoints[end_idx]
                
                # Draw line ONLY if both keypoints are visible (v > 0)
                if v1 > 0 and v2 > 0:
                    cv2.line(pose_img, (int(x1), int(y1)), (int(x2), int(y2)), 
                            255, line_thickness, cv2.LINE_AA)
        
        # Draw keypoint circles ON TOP of bones
        for i, (x, y, v) in enumerate(keypoints):
            if v > 0:  # Only draw visible keypoints
                cv2.circle(pose_img, (int(x), int(y)), circle_radius, 255, -1)
        
        return pose_img
    
# Validation dataset
val_dataset = COCOPoseDataset(
    root_dir='/kaggle/working/coco_data',
    split='val',
    image_size=512,
    custom_captions_file='/kaggle/input/captions/val_captions.json',
    max_samples=None,
    download=True
)

print(f"Validation samples available: {len(val_dataset)}")

num_samples = len(val_dataset)

pipe.set_progress_bar_config(disable=True)

for i in tqdm(range(num_samples), desc="Validating"):
    sample = val_dataset[i]

    # Original image (for display)
    original = sample["image"]
    caption = sample["captions"][0]
    img_id = sample["image_id"]

    # Convert original image for display
    orig_np = original.permute(1, 2, 0).cpu().numpy()
    orig_np = (orig_np * 0.5 + 0.5)
    orig_np = np.clip(orig_np, 0, 1)

    pose_pil = sample["pose"]  # grayscale PIL
    pose_tensor = transforms.ToTensor()(pose_pil)          # [1, H, W]
    pose_tensor = pose_tensor.unsqueeze(0).to(DEVICE)     # [1, 1, H, W]

    # Generate
    with torch.autocast("cuda"):
        result = pipe(
            prompt=caption,
            image=pose_tensor,   
            num_inference_steps=30,
            guidance_scale=7.5
        )

    generated = result.images[0]

    # SAVE GENERATED IMAGE
    save_name = f"{img_id:012d}_gen.png"
    save_path = os.path.join(OUTPUT_DIR, save_name)
    generated.save(save_path)

    # DISPLAY ALL 3
    fig, axs = plt.subplots(1, 3, figsize=(15, 5))

    axs[0].imshow(orig_np)
    axs[0].set_title("Original Image")
    axs[0].axis("off")

    axs[1].imshow(pose_pil, cmap="gray")
    axs[1].set_title("Pose Skeleton")
    axs[1].axis("off")

    axs[2].imshow(generated)
    axs[2].set_title("Generated Image")
    axs[2].axis("off")

    plt.suptitle(caption, fontsize=10)
    plt.tight_layout()
    plt.show()
