# ControlNet Inference on Training Data

Test your trained ControlNet model on training data to verify it learned properly.

**Files used:**
- `config.json` - ControlNet model configuration
- `diffusion_pytorch_model.safetensors` - Trained ControlNet weights
- `train_caption.json` - Training captions and image associations

In [None]:
# Import Required Libraries
import torch
import torch.nn.functional as F
import numpy as np
import json
import os
from pathlib import Path
from PIL import Image
import matplotlib.pyplot as plt
from tqdm import tqdm
import torchvision.transforms as transforms
import cv2

# Diffusers and transformers
from diffusers import StableDiffusionControlNetPipeline, ControlNetModel, DDPMScheduler
from transformers import CLIPTextModel, CLIPTokenizer
from safetensors.torch import load_file

print("✓ All libraries imported successfully")

In [None]:
# Mount Google Drive and Extract Weights-ControlNet Folder
from google.colab import drive
from pathlib import Path
import zipfile
import shutil

print("Mounting Google Drive...")
drive.mount('/content/drive')

# Path to the zip file
ZIP_PATH = Path('/content/drive/My Drive/Weights-ControlNet.zip')

# Extract location
EXTRACT_PATH = Path('/content/weights_extracted')
WEIGHTS_FOLDER = EXTRACT_PATH / "Weights-ControlNet"

# Check if zip file exists
if ZIP_PATH.exists():
    print(f"✓ Found zip file: {ZIP_PATH}")
    print(f"  File size: {ZIP_PATH.stat().st_size / (1024**3):.2f} GB")
    
    # Extract the zip file
    print("\nExtracting Weights-ControlNet.zip...")
    with zipfile.ZipFile(ZIP_PATH, 'r') as zip_ref:
        zip_ref.extractall(EXTRACT_PATH)
    print("✓ Extraction complete!")
else:
    print(f"❌ Zip file not found at: {ZIP_PATH}")
    print("   Make sure 'Weights-ControlNet.zip' exists in your Google Drive root")

# Look for files in the extracted folder
CONFIG_PATH = WEIGHTS_FOLDER / "config.json"
WEIGHTS_PATH = WEIGHTS_FOLDER / "diffusion_pytorch_model.safetensors"
CAPTIONS_PATH = WEIGHTS_FOLDER / "train_caption.json"

# Create COCO data directory in Colab runtime
BASE_DIR = Path('/content')
COCO_DATA_DIR = BASE_DIR / "coco_data"

# Device
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
DTYPE = torch.float16 if torch.cuda.is_available() else torch.float32

print(f"\n✓ Device: {DEVICE}")
print(f"✓ Data type: {DTYPE}")
print(f"✓ Extracted folder: {WEIGHTS_FOLDER}")
print(f"\nFile paths:")
print(f"  Config: {CONFIG_PATH} (exists: {CONFIG_PATH.exists()})")
print(f"  Weights: {WEIGHTS_PATH} (exists: {WEIGHTS_PATH.exists()})")
print(f"  Captions: {CAPTIONS_PATH} (exists: {CAPTIONS_PATH.exists()})")

# Verify all files exist
if not all([CONFIG_PATH.exists(), WEIGHTS_PATH.exists(), CAPTIONS_PATH.exists()]):
    print("\n⚠️ WARNING: Not all required files found!")
    print(f"Expected files in: {WEIGHTS_FOLDER}")
    
    # List what's in the extracted folder
    print(f"\nContents of {WEIGHTS_FOLDER}:")
    if WEIGHTS_FOLDER.exists():
        try:
            for item in WEIGHTS_FOLDER.iterdir():
                size_info = ""
                if item.is_file():
                    size_info = f" ({item.stat().st_size / (1024**2):.2f} MB)"
                print(f"  - {item.name}{size_info}")
        except Exception as e:
            print(f"  Error listing contents: {e}")
    else:
        print(f"  Folder not found!")
        print(f"\nContents of {EXTRACT_PATH}:")
        try:
            for item in EXTRACT_PATH.iterdir():
                print(f"  - {item.name}")
        except:
            pass
else:
    print("\n✓ All required files found! Ready to proceed.")

In [None]:
# Load Configuration File
print("Loading configuration...")
with open(CONFIG_PATH, 'r') as f:
    config = json.load(f)

print(f"✓ Config loaded")
print(f"  Conditioning channels: {config.get('conditioning_channels', 'N/A')}")
print(f"  Cross attention dim: {config.get('cross_attention_dim', 'N/A')}")
print(f"  Block out channels: {config.get('block_out_channels', 'N/A')}")
print(f"  Diffusers version: {config.get('_diffusers_version', 'N/A')}")

In [None]:
# Load Training Captions
print("Loading training captions...")
with open(CAPTIONS_PATH, 'r') as f:
    train_captions = json.load(f)

print(f"✓ Loaded {len(train_captions)} training captions")

# Show sample captions
print("\nSample captions:")
for i, (img_name, caption) in enumerate(list(train_captions.items())[:3]):
    print(f"\n{i+1}. Image: {img_name}")
    print(f"   Caption (first 150 chars): {caption[:150]}...")

In [None]:
# Load Trained ControlNet Weights
print("Loading trained ControlNet model...")

# Verify file size
import os
weights_size_mb = os.path.getsize(WEIGHTS_PATH) / (1024**2)
print(f"Weights file size: {weights_size_mb:.2f} MB")

if weights_size_mb < 100:
    print("\n⚠️ WARNING: File size is suspiciously small!")
    print("   The safetensors file may be corrupted or incomplete during upload.")
    print("   Try uploading again or using a different method.")

# Try to load weights from safetensors
try:
    state_dict = load_file(str(WEIGHTS_PATH))
    print(f"✓ Loaded {len(state_dict)} weight tensors")
    print(f"  Sample keys: {list(state_dict.keys())[:5]}")
except Exception as e:
    print(f"\n❌ Error loading safetensors: {e}")
    print("\nTroubleshooting:")
    print("1. Re-upload the diffusion_pytorch_model.safetensors file")
    print("2. Make sure the file uploaded completely (check file size)")
    print("3. Ensure the file is not corrupted")
    print("\nFor now, let's verify the file integrity...")
    
    # Check file header
    try:
        with open(WEIGHTS_PATH, 'rb') as f:
            header = f.read(100)
            print(f"File starts with: {header[:50]}")
    except:
        print("Cannot read file - it may be completely corrupted")
    
    raise

# Load ControlNet from config
controlnet = ControlNetModel.from_config(config)
print("✓ ControlNet model created from config")

# Load the trained weights into ControlNet
controlnet.load_state_dict(state_dict)
print("✓ Trained weights loaded into ControlNet")

# Move to device
controlnet = controlnet.to(DEVICE, dtype=DTYPE)
controlnet.eval()
print(f"✓ ControlNet moved to {DEVICE}")

In [None]:
# Initialize Diffusion Pipeline
print("Initializing Stable Diffusion + ControlNet pipeline...")

# Use the trained ControlNet with Stable Diffusion v1.5
pipe = StableDiffusionControlNetPipeline.from_pretrained(
    "runwayml/stable-diffusion-v1-5",
    controlnet=controlnet,
    torch_dtype=DTYPE,
    safety_checker=None,  # Disable safety checker for faster inference
)
pipe = pipe.to(DEVICE)

print("✓ Pipeline initialized")

# Enable memory-efficient attention if xformers is available
try:
    pipe.enable_xformers_memory_efficient_attention()
    print("✓ XFormers memory efficient attention enabled")
except:
    print("⚠️  XFormers not available (this is okay, just uses more memory)")

print("\n✓ Ready for inference!")

In [None]:
# Helper function to create pose skeleton from COCO keypoints
def create_pose_skeleton(keypoints, width, height):
    """Create a human pose skeleton from keypoint array"""
    pose_img = np.zeros((height, width), dtype=np.uint8)
    
    # COCO skeleton connections
    skeleton = [
        (0, 1), (0, 2), (1, 3), (2, 4), (0, 5), (0, 6),
        (5, 7), (7, 9), (6, 8), (8, 10), (5, 11), (6, 12),
        (11, 12), (11, 13), (13, 15), (12, 14), (14, 16)
    ]
    
    line_thickness = max(2, int(min(width, height) / 100))
    circle_radius = max(3, int(min(width, height) / 80))
    
    # Draw bones
    for start_idx, end_idx in skeleton:
        if start_idx < len(keypoints) and end_idx < len(keypoints):
            x1, y1, v1 = keypoints[start_idx]
            x2, y2, v2 = keypoints[end_idx]
            
            if v1 > 0 and v2 > 0:
                cv2.line(pose_img, (int(x1), int(y1)), (int(x2), int(y2)), 
                        255, line_thickness, cv2.LINE_AA)
    
    # Draw keypoints
    for x, y, v in keypoints:
        if v > 0:
            cv2.circle(pose_img, (int(x), int(y)), circle_radius, 255, -1)
    
    return pose_img

print("✓ Helper function loaded")

In [None]:
# Download and Load COCO dataset
from pycocotools.coco import COCO
import requests
from tqdm import tqdm
import zipfile

print("Setting up COCO dataset...")

# Create COCO data directory if it doesn't exist
COCO_DATA_DIR.mkdir(parents=True, exist_ok=True)
ann_dir = COCO_DATA_DIR / "annotations"
train_img_dir = COCO_DATA_DIR / "train2017"

ann_dir.mkdir(parents=True, exist_ok=True)
train_img_dir.mkdir(parents=True, exist_ok=True)

# Download COCO annotations if not present
ann_file = ann_dir / "person_keypoints_train2017.json"

if not ann_file.exists():
    print("Downloading COCO 2017 pose annotations (~252MB)...")
    ann_url = "http://images.cocodataset.org/annotations/annotations_trainval2017.zip"
    zip_path = COCO_DATA_DIR / "annotations.zip"
    
    response = requests.get(ann_url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(zip_path, 'wb') as f:
        with tqdm(total=total_size, unit='B', unit_scale=True, desc="Downloading annotations") as pbar:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
                pbar.update(len(chunk))
    
    print("Extracting annotations...")
    with zipfile.ZipFile(zip_path, 'r') as zip_ref:
        zip_ref.extractall(COCO_DATA_DIR)
    
    zip_path.unlink()
    print("✓ Annotations downloaded and extracted")
else:
    print(f"✓ Annotations already exist at {ann_file}")

# Load COCO annotations
print("Loading COCO annotations...")
coco = COCO(str(ann_file))
print(f"✓ COCO annotations loaded")

In [None]:
# Helper function to download COCO image
def download_coco_image(img_info, img_dir):
    """Download a single COCO image if not already present"""
    img_path = img_dir / img_info['file_name']
    
    if img_path.exists():
        return img_path
    
    try:
        response = requests.get(img_info['coco_url'], timeout=10)
        img_path.parent.mkdir(parents=True, exist_ok=True)
        with open(img_path, 'wb') as f:
            f.write(response.content)
        return img_path
    except Exception as e:
        print(f"  ⚠️  Failed to download {img_info['file_name']}: {e}")
        return None

# Generate images for training data samples
print("="*80)
print("GENERATING IMAGES FOR TRAINING DATA")
print("="*80)

# Select random training samples to test
num_samples = 5
sample_keys = list(train_captions.keys())[:num_samples]

print(f"\nGenerating images for {len(sample_keys)} training samples...")

for sample_idx, img_filename in enumerate(sample_keys):
    print(f"\n{'─'*80}")
    print(f"Sample {sample_idx + 1}/{len(sample_keys)}: {img_filename}")
    print(f"{'─'*80}")
    
    # Get caption
    caption = train_captions[img_filename]
    print(f"Caption: {caption[:100]}...")
    
    # Try to get pose from COCO
    try:
        # Extract image ID from filename
        img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
        
        # Get image info and pose
        img_info = coco.loadImgs(img_id)[0]
        img_width, img_height = img_info['width'], img_info['height']
        
        ann_ids = coco.getAnnIds(imgIds=img_id, catIds=coco.getCatIds(catNms=['person']), iscrowd=False)
        anns = coco.loadAnns(ann_ids)
        
        # Get pose keypoints
        keypoints = None
        for ann in anns:
            if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                break
        
        if keypoints is None:
            print("⚠️  No pose keypoints found, skipping...")
            continue
        
        # Create pose skeleton
        pose_map = create_pose_skeleton(keypoints, img_width, img_height)
        pose_map_pil = Image.fromarray(pose_map)
        pose_map_pil = pose_map_pil.resize((512, 512))
        
        # Convert pose to tensor
        pose_tensor = transforms.ToTensor()(pose_map_pil).unsqueeze(0).to(DEVICE, dtype=DTYPE)
        
        print(f"✓ Pose skeleton created: {pose_tensor.shape}")
        
        # Generate image with ControlNet
        print("Generating image...")
        with torch.no_grad():
            generator = torch.Generator(device=DEVICE).manual_seed(42)
            output = pipe(
                prompt=caption,
                image=pose_tensor,
                num_inference_steps=30,
                generator=generator,
                guidance_scale=7.5,
            ).images[0]
        
        print("✓ Image generated successfully!")
        
        # Download and load original image
        print("Downloading training image...")
        img_path = download_coco_image(img_info, train_img_dir)
        
        if img_path:
            original_img = Image.open(img_path).convert('RGB')
            
            # Display results
            fig, axes = plt.subplots(1, 4, figsize=(20, 5))
            
            # Original image
            axes[0].imshow(original_img)
            axes[0].set_title('Original Image', fontsize=12, fontweight='bold')
            axes[0].axis('off')
            
            # Pose skeleton
            axes[1].imshow(pose_map, cmap='gray')
            axes[1].set_title('Pose Skeleton', fontsize=12, fontweight='bold')
            axes[1].axis('off')
            
            # Generated image
            axes[2].imshow(output)
            axes[2].set_title('Generated Image\n(with trained ControlNet)', fontsize=12, fontweight='bold')
            axes[2].axis('off')
            
            # Overlay
            axes[3].imshow(original_img)
            axes[3].imshow(pose_map, cmap='hot', alpha=0.3)
            axes[3].set_title('Original + Pose', fontsize=12, fontweight='bold')
            axes[3].axis('off')
            
            plt.suptitle(f'Training Sample: {img_filename}\n\"{caption[:80]}...\"', 
                        fontsize=14, fontweight='bold', y=1.00)
            plt.tight_layout()
            plt.show()
        else:
            # Just show pose and generated image
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            
            axes[0].imshow(pose_map, cmap='gray')
            axes[0].set_title('Pose Skeleton', fontsize=12, fontweight='bold')
            axes[0].axis('off')
            
            axes[1].imshow(output)
            axes[1].set_title('Generated Image', fontsize=12, fontweight='bold')
            axes[1].axis('off')
            
            plt.suptitle(f'{img_filename}\n\"{caption[:80]}...\"', fontsize=14, fontweight='bold')
            plt.tight_layout()
            plt.show()
    
    except Exception as e:
        print(f"❌ Error processing sample: {e}")
        continue

print(f"\n{'='*80}")
print("✓ Inference complete!")

In [None]:
# Batch inference on multiple training samples
print("="*80)
print("BATCH INFERENCE ON TRAINING DATA")
print("="*80)

# Generate images for a batch of training samples
batch_size = 3
num_batches = 2

results = []

for batch_idx in range(num_batches):
    print(f"\nBatch {batch_idx + 1}/{num_batches}")
    print("─"*80)
    
    batch_keys = sample_keys[batch_idx*batch_size:(batch_idx+1)*batch_size]
    batch_results = []
    
    for img_filename in batch_keys:
        caption = train_captions[img_filename]
        
        try:
            # Get pose from COCO
            img_id = int(img_filename.split('.')[0].lstrip('0') or '0')
            img_info = coco.loadImgs(img_id)[0]
            img_width, img_height = img_info['width'], img_info['height']
            
            ann_ids = coco.getAnnIds(imgIds=img_id, catIds=coco.getCatIds(catNms=['person']), iscrowd=False)
            anns = coco.loadAnns(ann_ids)
            
            keypoints = None
            for ann in anns:
                if 'keypoints' in ann and ann.get('num_keypoints', 0) > 0:
                    keypoints = np.array(ann['keypoints']).reshape(-1, 3)
                    break
            
            if keypoints is None:
                continue
            
            # Create pose skeleton
            pose_map = create_pose_skeleton(keypoints, img_width, img_height)
            pose_map_pil = Image.fromarray(pose_map).resize((512, 512))
            pose_tensor = transforms.ToTensor()(pose_map_pil).unsqueeze(0).to(DEVICE, dtype=DTYPE)
            
            # Download original training image
            print(f"  Downloading {img_filename}...")
            img_path = download_coco_image(img_info, train_img_dir)
            original_img = None
            if img_path:
                original_img = Image.open(img_path).convert('RGB')
                original_img = original_img.resize((512, 512))
            
            # Generate image
            with torch.no_grad():
                generator = torch.Generator(device=DEVICE).manual_seed(42 + batch_idx)
                output = pipe(
                    prompt=caption,
                    image=pose_tensor,
                    num_inference_steps=30,
                    generator=generator,
                    guidance_scale=7.5,
                ).images[0]
            
            batch_results.append({
                'filename': img_filename,
                'caption': caption,
                'original': original_img,
                'pose': pose_map,
                'generated': output
            })
            
            results.append(batch_results[-1])
            
        except Exception as e:
            print(f"  ⚠️  Skipped {img_filename}: {str(e)[:50]}")

print(f"\n✓ Generated {len(results)} images successfully")

In [None]:
# Visualize batch results
if results:
    print("Displaying batch results...\n")
    
    for idx, result in enumerate(results):
        if result['original'] is not None:
            # 4-panel comparison: Original, Pose, Generated, Overlay
            fig, axes = plt.subplots(1, 4, figsize=(20, 5))
            
            # Original image
            axes[0].imshow(result['original'])
            axes[0].set_title('Original Training Image', fontsize=12, fontweight='bold')
            axes[0].axis('off')
            
            # Pose skeleton
            axes[1].imshow(result['pose'], cmap='gray')
            axes[1].set_title('Pose Skeleton Conditioning', fontsize=12, fontweight='bold')
            axes[1].axis('off')
            
            # Generated image
            axes[2].imshow(result['generated'])
            axes[2].set_title('ControlNet Generated Image', fontsize=12, fontweight='bold')
            axes[2].axis('off')
            
            # Overlay: Original + Pose
            axes[3].imshow(result['original'])
            axes[3].imshow(result['pose'], cmap='hot', alpha=0.3)
            axes[3].set_title('Original + Pose Overlay', fontsize=12, fontweight='bold')
            axes[3].axis('off')
            
            caption_short = result['caption'][:90]
            plt.suptitle(f"Result {idx+1}: {result['filename']}\n\"{caption_short}...\"", 
                        fontsize=13, fontweight='bold', y=0.98)
            plt.tight_layout()
            plt.show()
        else:
            # 3-panel comparison: Pose, Generated, and info
            fig, axes = plt.subplots(1, 2, figsize=(12, 5))
            
            # Pose skeleton
            axes[0].imshow(result['pose'], cmap='gray')
            axes[0].set_title('Pose Skeleton Conditioning', fontsize=12, fontweight='bold')
            axes[0].axis('off')
            
            # Generated image
            axes[1].imshow(result['generated'])
            axes[1].set_title('ControlNet Generated Image', fontsize=12, fontweight='bold')
            axes[1].axis('off')
            
            caption_short = result['caption'][:90]
            plt.suptitle(f"Result {idx+1}: {result['filename']}\n\"{caption_short}...\"", 
                        fontsize=13, fontweight='bold', y=0.98)
            plt.tight_layout()
            plt.show()
else:
    print("⚠️  No results to display")

In [None]:
# Performance metrics and analysis
print("="*80)
print("MODEL PERFORMANCE ANALYSIS")
print("="*80)

print(f"\nTotal training samples tested: {len(results)}")
print(f"Successful generations: {len(results)}")
print(f"Success rate: {100.0 * len(results) / num_samples:.1f}%")

print("\nModel Configuration:")
print(f"  - Base model: Stable Diffusion v1.5")
print(f"  - ControlNet conditioning channels: {config.get('conditioning_channels')}")
print(f"  - Cross-attention dimensions: {config.get('cross_attention_dim')}")
print(f"  - Training data captions: {len(train_captions)}")

print("\nInference Settings:")
print(f"  - Device: {DEVICE}")
print(f"  - Dtype: {DTYPE}")
print(f"  - Inference steps: 30")
print(f"  - Guidance scale: 7.5")

print("\n" + "="*80)
print("✓ Inference complete!")
print("="*80)