# Lab 4.1.2: Image Generation - SOLUTIONS

This notebook contains complete solutions for the Style Transfer Pipeline challenge.

---

In [None]:
# Setup
import gc
import time
import json
from pathlib import Path
from typing import List, Dict, Any, Optional
from datetime import datetime

import torch
import numpy as np
from PIL import Image, PngImagePlugin
import matplotlib.pyplot as plt
import cv2

from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline

print("Loading ControlNet pipeline...")

controlnet = ControlNetModel.from_pretrained(
    "diffusers/controlnet-canny-sdxl-1.0",
    torch_dtype=torch.bfloat16,
)

pipe = StableDiffusionXLControlNetPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    controlnet=controlnet,
    torch_dtype=torch.bfloat16,
)
pipe = pipe.to("cuda")
pipe.enable_vae_slicing()

print("‚úÖ Ready!")

---

## Challenge Solution: Style Transfer Pipeline

In [None]:
def get_canny_edges(
    image: Image.Image,
    low_threshold: int = 100,
    high_threshold: int = 200
) -> Image.Image:
    """Extract Canny edges from an image."""
    img_array = np.array(image)
    gray = cv2.cvtColor(img_array, cv2.COLOR_RGB2GRAY)
    edges = cv2.Canny(gray, low_threshold, high_threshold)
    edges_rgb = np.stack([edges] * 3, axis=-1)
    return Image.fromarray(edges_rgb)


def save_with_metadata(
    image: Image.Image,
    filepath: Path,
    prompt: str,
    seed: int,
    style_name: str,
) -> None:
    """Save image with generation metadata."""
    metadata = PngImagePlugin.PngInfo()
    metadata.add_text("prompt", prompt)
    metadata.add_text("seed", str(seed))
    metadata.add_text("style", style_name)
    metadata.add_text("timestamp", datetime.now().isoformat())
    metadata.add_text("model", "SDXL ControlNet")
    
    image.save(filepath, pnginfo=metadata)


def create_comparison_grid(
    images: List[Image.Image],
    labels: List[str],
    cols: int = 3,
) -> Image.Image:
    """Create a labeled comparison grid."""
    from PIL import ImageDraw, ImageFont
    
    n = len(images)
    rows = (n + cols - 1) // cols
    
    # Get max dimensions
    max_w = max(img.width for img in images)
    max_h = max(img.height for img in images)
    
    padding = 10
    label_height = 30
    
    grid_width = cols * max_w + (cols + 1) * padding
    grid_height = rows * (max_h + label_height) + (rows + 1) * padding
    
    grid = Image.new('RGB', (grid_width, grid_height), (255, 255, 255))
    draw = ImageDraw.Draw(grid)
    
    try:
        font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", 18)
    except:
        font = ImageFont.load_default()
    
    for idx, (img, label) in enumerate(zip(images, labels)):
        row = idx // cols
        col = idx % cols
        
        x = padding + col * (max_w + padding)
        y = padding + row * (max_h + label_height + padding)
        
        # Paste image
        grid.paste(img.resize((max_w, max_h)), (x, y + label_height))
        
        # Add label
        text_x = x + max_w // 2
        text_y = y + 5
        draw.text((text_x, text_y), label[:30], fill=(0, 0, 0), font=font, anchor="mt")
    
    return grid


def style_transfer_pipeline(
    reference_image: Image.Image,
    styles: List[Dict[str, str]],
    output_dir: str = "style_transfer_outputs",
    controlnet_scale: float = 0.5,
    num_inference_steps: int = 30,
    base_seed: int = 42,
) -> Dict[str, Any]:
    """
    Apply multiple styles to a reference image using ControlNet.
    
    Args:
        reference_image: Input image to transform
        styles: List of dicts with 'name' and 'prompt' keys
        output_dir: Directory to save results
        controlnet_scale: ControlNet conditioning scale
        num_inference_steps: Number of denoising steps
        base_seed: Base random seed (incremented for each style)
        
    Returns:
        Dictionary with results and metadata
    """
    output_path = Path(output_dir)
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Extract edge map
    print("üìê Extracting edge map...")
    edge_map = get_canny_edges(reference_image)
    
    # Save edge map
    edge_map.save(output_path / "edge_map.png")
    
    # Generate styled images
    results = {
        "reference": reference_image,
        "edge_map": edge_map,
        "styles": [],
        "images": [edge_map],  # Include edge map in comparison
        "labels": ["Edge Map"],
    }
    
    negative_prompt = "blurry, low quality, distorted, ugly, watermark, text"
    
    for i, style in enumerate(styles):
        style_name = style['name']
        prompt = style['prompt']
        seed = base_seed + i
        
        print(f"\nüé® Generating: {style_name}")
        print(f"   Prompt: {prompt[:50]}...")
        
        start_time = time.time()
        
        generator = torch.Generator(device="cuda").manual_seed(seed)
        
        image = pipe(
            prompt=prompt,
            negative_prompt=negative_prompt,
            image=edge_map,
            controlnet_conditioning_scale=controlnet_scale,
            num_inference_steps=num_inference_steps,
            generator=generator,
        ).images[0]
        
        elapsed = time.time() - start_time
        print(f"   Generated in {elapsed:.1f}s")
        
        # Save with metadata
        filename = f"{i+1:02d}_{style_name.lower().replace(' ', '_')}.png"
        save_with_metadata(image, output_path / filename, prompt, seed, style_name)
        
        # Store results
        results['styles'].append({
            'name': style_name,
            'prompt': prompt,
            'seed': seed,
            'generation_time': elapsed,
            'filename': filename,
        })
        results['images'].append(image)
        results['labels'].append(style_name)
    
    # Create comparison grid
    print("\nüìä Creating comparison grid...")
    grid = create_comparison_grid(results['images'], results['labels'], cols=3)
    grid.save(output_path / "comparison_grid.png")
    
    # Save metadata JSON
    metadata = {
        'timestamp': datetime.now().isoformat(),
        'controlnet_scale': controlnet_scale,
        'num_inference_steps': num_inference_steps,
        'styles': results['styles'],
    }
    
    with open(output_path / "metadata.json", 'w') as f:
        json.dump(metadata, f, indent=2)
    
    results['grid'] = grid
    results['output_dir'] = str(output_path)
    
    print(f"\n‚úÖ Complete! Results saved to: {output_path}")
    
    return results

In [None]:
# Test the pipeline

# Create a simple reference image using SDXL
from diffusers import StableDiffusionXLPipeline

# First generate a reference image
print("Creating reference image...")
base_pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16,
)
base_pipe = base_pipe.to("cuda")

generator = torch.Generator(device="cuda").manual_seed(42)
reference = base_pipe(
    prompt="A simple house with a tree in front, clear sky, daytime",
    negative_prompt="complex, busy, cluttered",
    num_inference_steps=25,
    generator=generator,
).images[0]

# Free base pipeline
del base_pipe
torch.cuda.empty_cache()

# Define styles
styles = [
    {
        "name": "Van Gogh",
        "prompt": "A house with a tree, in the style of Van Gogh's Starry Night, oil painting, swirling brushstrokes, vibrant colors"
    },
    {
        "name": "Studio Ghibli",
        "prompt": "A house with a tree, Studio Ghibli anime style, Hayao Miyazaki, beautiful detailed scene, soft colors"
    },
    {
        "name": "Cyberpunk",
        "prompt": "A futuristic house with a holographic tree, cyberpunk style, neon lights, rain, night scene"
    },
    {
        "name": "Watercolor",
        "prompt": "A house with a tree, watercolor painting, soft washes of color, artistic, delicate brushwork"
    },
    {
        "name": "Low Poly",
        "prompt": "A house with a tree, low poly 3D art style, geometric shapes, vibrant colors, minimalist"
    },
]

# Run the pipeline
results = style_transfer_pipeline(
    reference,
    styles,
    output_dir="style_transfer_outputs",
    controlnet_scale=0.5,
    num_inference_steps=30,
)

In [None]:
# Display the comparison grid
plt.figure(figsize=(20, 12))
plt.imshow(results['grid'])
plt.axis('off')
plt.title("Style Transfer Results", fontsize=16)
plt.tight_layout()
plt.show()

# Print summary
print("\nüìä Generation Summary:")
print("=" * 60)
for style in results['styles']:
    print(f"  {style['name']:15} - {style['generation_time']:.1f}s - seed: {style['seed']}")

---

## Cleanup

In [None]:
del pipe, controlnet
torch.cuda.empty_cache()
gc.collect()
print("‚úÖ Cleanup complete!")