# Lab 4.1.2 Solutions: Image Generation

This notebook contains solutions to the exercises in the Image Generation notebook.

---

## Challenge Solution: Style Variation Generator

The challenge was to create a function that generates the same scene in multiple artistic styles.

### Approach
We use SDXL with:
1. **Consistent Base Prompt**: The core scene description stays the same
2. **Style Modifiers**: Each style is appended to the base prompt
3. **Seeded Generation**: Using sequential seeds (base_seed + i) ensures variety while maintaining reproducibility

### Key Design Decisions
- **Seed Strategy**: We increment the seed for each style to get different compositions while keeping results reproducible
- **Prompt Structure**: `"{base}, {style}, masterpiece, highly detailed"` works well for quality
- **Grid Display**: Visual comparison makes it easy to evaluate styles

### Trade-offs
- Using the same seed for all styles would show the exact same composition in different styles
- Random seeds would make comparisons harder
- Our incremented seed approach balances variety with reproducibility

In [None]:
import torch
import gc
from PIL import Image
from typing import List, Optional
import matplotlib.pyplot as plt

def clear_gpu_memory():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.synchronize()

In [None]:
from diffusers import StableDiffusionXLPipeline, DPMSolverMultistepScheduler

# Load SDXL
print("Loading SDXL...")
pipe = StableDiffusionXLPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",
    torch_dtype=torch.bfloat16,
    use_safetensors=True,
    variant="fp16"
)
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe = pipe.to("cuda")
print("Loaded!")

In [None]:
def generate_style_variations(
    base_prompt: str,
    styles: List[str],
    seed: int = 42,
    num_steps: int = 25,
    guidance_scale: float = 7.5,
    negative_prompt: str = "ugly, blurry, low quality, distorted"
) -> List[Image.Image]:
    """
    Generate the same scene in multiple artistic styles.
    
    Args:
        base_prompt: The core scene description
        styles: List of style modifiers (e.g., 'oil painting', 'anime', 'photorealistic')
        seed: Base random seed (ensures consistency across styles)
        num_steps: Number of denoising steps
        guidance_scale: CFG scale
        negative_prompt: What to avoid
        
    Returns:
        List of generated images, one per style
    """
    images = []
    
    for i, style in enumerate(styles):
        # Combine base prompt with style
        full_prompt = f"{base_prompt}, {style}, masterpiece, highly detailed"
        
        print(f"\nGenerating [{i+1}/{len(styles)}]: {style}")
        
        # Use same base seed for consistency, but increment for variety
        generator = torch.Generator(device="cuda").manual_seed(seed + i)
        
        with torch.inference_mode():
            result = pipe(
                prompt=full_prompt,
                negative_prompt=negative_prompt,
                num_inference_steps=num_steps,
                guidance_scale=guidance_scale,
                generator=generator
            )
        
        images.append(result.images[0])
        print(f"  Done!")
    
    return images


def display_style_grid(images: List[Image.Image], styles: List[str], title: str = ""):
    """
    Display style variations in a grid.
    
    Args:
        images: List of generated images
        styles: Style names for labels
        title: Optional grid title
    """
    n = len(images)
    cols = min(n, 4)
    rows = (n + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(4*cols, 4*rows))
    if title:
        fig.suptitle(title, fontsize=14, fontweight='bold')
    
    axes = axes.flatten() if hasattr(axes, 'flatten') else [axes]
    
    for idx, (ax, img, style) in enumerate(zip(axes, images, styles)):
        ax.imshow(img)
        ax.set_title(style, fontsize=10)
        ax.axis('off')
    
    # Hide empty subplots
    for idx in range(len(images), len(axes)):
        axes[idx].axis('off')
    
    plt.tight_layout()
    plt.show()

print("Functions ready!")

In [None]:
# Example usage: Dragon flying over a castle

base = "A dragon flying over a medieval castle at sunset, dramatic sky"

styles = [
    "oil painting style, classical art",
    "anime style, studio ghibli",
    "photorealistic, 8k photography",
    "pixel art, 16-bit retro game"
]

print(f"Base prompt: {base}")
print(f"Generating {len(styles)} style variations...")

variations = generate_style_variations(base, styles, seed=42)
display_style_grid(variations, styles, title=base)

In [None]:
# Another example: Portrait

base2 = "Portrait of a wise old wizard with a long white beard"

styles2 = [
    "Renaissance painting style, Da Vinci",
    "modern digital art, trending on artstation",
    "watercolor illustration, soft colors",
    "comic book style, bold lines"
]

variations2 = generate_style_variations(base2, styles2, seed=123)
display_style_grid(variations2, styles2, title=base2)

In [None]:
# Advanced: Compare same seed vs different seeds

def compare_seed_effect(prompt: str, style: str, seeds: List[int]) -> List[Image.Image]:
    """
    Show how different seeds affect the same prompt+style.
    """
    images = []
    full_prompt = f"{prompt}, {style}"
    
    for seed in seeds:
        print(f"Generating with seed {seed}...")
        generator = torch.Generator(device="cuda").manual_seed(seed)
        
        with torch.inference_mode():
            result = pipe(
                prompt=full_prompt,
                negative_prompt="ugly, blurry",
                num_inference_steps=25,
                generator=generator
            )
        images.append(result.images[0])
    
    return images

# Show seed variations
seed_variations = compare_seed_effect(
    "A cozy cabin in a snowy forest",
    "photorealistic, winter landscape",
    seeds=[1, 42, 123, 999]
)

display_style_grid(
    seed_variations, 
    [f"Seed {s}" for s in [1, 42, 123, 999]], 
    title="Same prompt, different seeds"
)

In [None]:
# Cleanup
del pipe
clear_gpu_memory()
print("Solutions notebook complete!")