# Image Generation with IP-Adapter XL (SDXL)

This notebook demonstrates how to use the DEGIS package to:
1. Load trained color head models
2. Set up IP-Adapter XL with ControlNet for high-quality image generation
3. Generate images using color and layout control with SDXL

Based on the ablation notebook but using IP-Adapter XL for higher quality results.


## 1. Setup and Install Dependencies


In [None]:
# Install IP-Adapter and dependencies
%pip uninstall -y ip-adapter diffusers
%pip install --no-cache-dir git+https://github.com/Ahmed-Sherif-ASA/IP-Adapter@main
%pip install diffusers


## 2. Imports and Setup


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision import transforms
from IPython.display import display
import os
import glob

# Import the DEGIS package
import degis
from degis.data.dataset import UnifiedImageDataset
from degis.config import CSV_PATH, HF_XL_EMBEDDINGS_TARGET_PATH, COLOR_HIST_PATH_HCL_514, EDGE_MAPS_PATH

# Import IP-Adapter XL
import ip_adapter
from ip_adapter import IPAdapterXL
from diffusers import ControlNetModel, StableDiffusionXLControlNetPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 3. Load Data and Models


In [None]:
# Load datasets
df = pd.read_csv(CSV_PATH)
colour_dataset = UnifiedImageDataset(
    df.rename(columns={"local_path": "file_path"}),
    mode="file_df",
    size=(224, 224),
    subset_ratio=1.0
)

# Load precomputed data
embeddings = np.load(HF_XL_EMBEDDINGS_TARGET_PATH, mmap_mode="r").astype(np.float32, copy=False)
histograms = np.load(COLOR_HIST_PATH_HCL_514, mmap_mode="r").astype(np.float32, copy=False)
edge_maps = np.load(EDGE_MAPS_PATH, mmap_mode="r")

print(f"Loaded embeddings: {embeddings.shape}")
print(f"Loaded histograms: {histograms.shape}")
print(f"Loaded edge maps: {edge_maps.shape}")

# Find the latest trained model
run_dirs = glob.glob("runs/*")
if run_dirs:
    latest_run = max(run_dirs, key=os.path.getctime)
    checkpoint_path = os.path.join(latest_run, "best_color_head_tmp.pth")
    print(f"Using checkpoint: {checkpoint_path}")
else:
    # Fallback to a default path
    checkpoint_path = "best_color_head.pth"
    print(f"Using default checkpoint: {checkpoint_path}")

# Load trained color head
color_head = degis.load_trained_color_head(
    checkpoint_path=checkpoint_path,
    clip_dim=embeddings.shape[1],
    hist_dim=histograms.shape[1],
    device=device
)
print("✓ Color head loaded successfully")


## 4. Setup IP-Adapter XL Pipeline


In [None]:
# Setup cache directory
HF_CACHE = "/data/hf-cache" if os.path.exists("/data") else "./hf-cache"
os.makedirs(HF_CACHE, exist_ok=True)

os.environ["HF_HOME"] = HF_CACHE
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(HF_CACHE, "hub")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_CACHE, "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(HF_CACHE, "diffusers")
os.environ["TORCH_HOME"] = os.path.join(HF_CACHE, "torch")

print(f"Using cache directory: {HF_CACHE}")

# Create IP-Adapter XL generator
generator = degis.IPAdapterXLGenerator(device=device)

# Setup the pipeline
generator.setup_pipeline(
    model_id="stabilityai/stable-diffusion-xl-base-1.0",
    controlnet_id="diffusers/controlnet-canny-sdxl-1.0",
    ip_ckpt="/data/thesis/models/ip-adapter_sdxl.bin",  # Update path as needed
    image_encoder_path="laion/CLIP-ViT-bigG-14-laion2B-39B-b160k",
    cache_dir=HF_CACHE,
    torch_dtype=torch.float16,
)

print("✓ IP-Adapter XL pipeline setup complete")


## 5. Image Generation Functions


In [None]:
def generate_from_dataset_id_xl(
    colour_index: int,
    layout_index: int,
    prompt: str = "a cat playing with a ball",
    guidance_scale: float = 6.5,
    steps: int = 40,
    controlnet_conditioning_scale: float = 0.8,
    num_samples: int = 1,
    attn_ip_scale: float = 0.8,
    text_token_scale: float = 1.0,
    ip_token_scale: float = None,
    ip_uncond_scale: float = 0.0,
    zero_ip_in_uncond: bool = True,
):
    """Generate images using IP-Adapter XL with advanced controls."""
    
    # Get original image for display
    img_t, _ = colour_dataset[colour_index]
    pil_img = transforms.ToPILImage()(img_t)
    
    # Get CLIP embedding and compute color embedding
    z_clip = torch.as_tensor(embeddings[colour_index], dtype=torch.float32, device=device).unsqueeze(0)
    color_embedding = degis.get_color_embedding(color_head, z_clip)
    
    # Create control image from edge data
    control_image = degis.create_edge_control_image(edge_maps[layout_index], size=512)
    
    # Generate images with IP-Adapter XL
    images = generator.generate(
        color_embedding=color_embedding,
        control_image=control_image,
        prompt=prompt,
        negative_prompt=(
            "monochrome, lowres, bad anatomy, worst quality, low quality, blurry, "
            "sketch, cartoon, drawing, anime:1.4, comic, illustration, posterized, "
            "mosaic, stained glass, abstract, surreal, psychedelic, trippy, texture artifact, "
            "embroidery, knitted, painting, oversaturated, unrealistic, bad shading"
        ),
        num_samples=num_samples,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        # IP-Adapter XL specific parameters
        attn_ip_scale=attn_ip_scale,
        text_token_scale=text_token_scale,
        ip_token_scale=ip_token_scale,
        ip_uncond_scale=ip_uncond_scale,
        zero_ip_in_uncond=zero_ip_in_uncond,
    )
    
    # Display results
    comparison = degis.display_comparison_grid(
        original=pil_img,
        control=control_image,
        generated=images,
        cols=3
    )
    display(comparison)
    
    return images

print("✓ IP-Adapter XL generation function defined")


## 6. Generate High-Quality Images


In [None]:
# Generate high-quality images with IP-Adapter XL
print("Generating images with IP-Adapter XL (SDXL)...")

# Example 1: Cat with ball (high quality)
images1 = generate_from_dataset_id_xl(
    colour_index=1000,
    layout_index=33,
    prompt="a cat playing with a ball, high quality, detailed",
    guidance_scale=6.5,
    steps=40,
    controlnet_conditioning_scale=0.8,
    num_samples=1,
    attn_ip_scale=0.8,
    text_token_scale=1.0,
    ip_token_scale=0.5,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=True,
)

# Example 2: Dog on hoodie (artistic style)
images2 = generate_from_dataset_id_xl(
    colour_index=1008,
    layout_index=33,
    prompt="a dog on the hoodie, artistic style, professional photography",
    guidance_scale=7.5,
    steps=50,
    controlnet_conditioning_scale=0.9,
    num_samples=1,
    attn_ip_scale=0.6,
    text_token_scale=1.1,
    ip_token_scale=0.4,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=True,
)

# Example 3: Creative composition
images3 = generate_from_dataset_id_xl(
    colour_index=1003,
    layout_index=33,
    prompt="A cat on the hoodie, digital art, vibrant colors, masterpiece",
    guidance_scale=8.0,
    steps=60,
    controlnet_conditioning_scale=0.7,
    num_samples=1,
    attn_ip_scale=0.7,
    text_token_scale=1.2,
    ip_token_scale=0.6,
    ip_uncond_scale=0.0,
    zero_ip_in_uncond=True,
)

print("✓ High-quality image generation complete!")
