# Image Generation with IP-Adapter (SD 1.5)

This notebook demonstrates how to use the DEGIS package to:
1. Load trained color head models
2. Set up IP-Adapter with ControlNet for image generation
3. Generate images using color and layout control

Based on the ablation notebook but using the new package structure.


## 1. Setup and Install Dependencies


In [None]:
# Install IP-Adapter and dependencies
%pip uninstall -y ip-adapter diffusers
%pip install --no-cache-dir git+https://github.com/Ahmed-Sherif-ASA/IP-Adapter@main
%pip install diffusers


## 2. Imports and Setup


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
from PIL import Image
from torchvision import transforms
from IPython.display import display
import os
import glob

# Import the DEGIS package
import degis
from degis.data.dataset import UnifiedImageDataset
from degis.config import CSV_PATH, HF_XL_EMBEDDINGS_TARGET_PATH, COLOR_HIST_PATH_HCL_514, EDGE_MAPS_PATH

# Import IP-Adapter
import ip_adapter
from ip_adapter import IPAdapter
from diffusers import ControlNetModel, StableDiffusionControlNetPipeline

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")


## 3. Load Data and Models


In [None]:
# Load datasets
df = pd.read_csv(CSV_PATH)
colour_dataset = UnifiedImageDataset(
    df.rename(columns={"local_path": "file_path"}),
    mode="file_df",
    size=(224, 224),
    subset_ratio=1.0
)

# Load precomputed data
embeddings = np.load(HF_XL_EMBEDDINGS_TARGET_PATH, mmap_mode="r").astype(np.float32, copy=False)
histograms = np.load(COLOR_HIST_PATH_HCL_514, mmap_mode="r").astype(np.float32, copy=False)
edge_maps = np.load(EDGE_MAPS_PATH, mmap_mode="r")

print(f"Loaded embeddings: {embeddings.shape}")
print(f"Loaded histograms: {histograms.shape}")
print(f"Loaded edge maps: {edge_maps.shape}")

# Find the latest trained model
run_dirs = glob.glob("runs/*")
if run_dirs:
    latest_run = max(run_dirs, key=os.path.getctime)
    checkpoint_path = os.path.join(latest_run, "best_color_head_tmp.pth")
    print(f"Using checkpoint: {checkpoint_path}")
else:
    # Fallback to a default path
    checkpoint_path = "best_color_head.pth"
    print(f"Using default checkpoint: {checkpoint_path}")

# Load trained color head
color_head = degis.load_trained_color_head(
    checkpoint_path=checkpoint_path,
    clip_dim=embeddings.shape[1],
    hist_dim=histograms.shape[1],
    device=device
)
print("✓ Color head loaded successfully")


## 4. Setup IP-Adapter Pipeline


In [None]:
# Setup cache directory
HF_CACHE = "/data/hf-cache" if os.path.exists("/data") else "./hf-cache"
os.makedirs(HF_CACHE, exist_ok=True)

os.environ["HF_HOME"] = HF_CACHE
os.environ["HUGGINGFACE_HUB_CACHE"] = os.path.join(HF_CACHE, "hub")
os.environ["TRANSFORMERS_CACHE"] = os.path.join(HF_CACHE, "transformers")
os.environ["DIFFUSERS_CACHE"] = os.path.join(HF_CACHE, "diffusers")
os.environ["TORCH_HOME"] = os.path.join(HF_CACHE, "torch")

print(f"Using cache directory: {HF_CACHE}")

# Create IP-Adapter generator
generator = degis.IPAdapterGenerator(device=device)

# Setup the pipeline
generator.setup_pipeline(
    model_id="runwayml/stable-diffusion-v1-5",
    controlnet_id="lllyasviel/control_v11p_sd15_canny",
    ip_ckpt="/data/thesis/models/ip-adapter_sd15.bin",  # Update path as needed
    image_encoder_path="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
    cache_dir=HF_CACHE,
    torch_dtype=torch.float16,
)

print("✓ IP-Adapter pipeline setup complete")


## 5. Image Generation Functions


In [None]:
def generate_from_dataset_id(
    colour_index: int,
    layout_index: int,
    prompt: str = "a cat playing with a ball",
    guidance_scale: float = 7.5,
    steps: int = 30,
    controlnet_conditioning_scale: float = 1.0,
    num_samples: int = 1,
    scale: float = 0.8,
):
    """Generate images using color and layout control."""
    
    # Get original image for display
    img_t, _ = colour_dataset[colour_index]
    pil_img = transforms.ToPILImage()(img_t)
    
    # Get CLIP embedding and compute color embedding
    z_clip = torch.as_tensor(embeddings[colour_index], dtype=torch.float32, device=device).unsqueeze(0)
    color_embedding = degis.get_color_embedding(color_head, z_clip)
    
    # Create control image from edge data
    control_image = degis.create_edge_control_image(edge_maps[layout_index], size=512)
    
    # Generate images
    images = generator.generate(
        color_embedding=color_embedding,
        control_image=control_image,
        prompt=prompt,
        negative_prompt=(
            "monochrome, lowres, bad anatomy, worst quality, low quality, blurry, "
            "sketch, cartoon, drawing, anime:1.4, comic, illustration, posterized, "
            "mosaic, stained glass, abstract, surreal, psychedelic, trippy, texture artifact, "
            "embroidery, knitted, painting, oversaturated, unrealistic, bad shading"
        ),
        num_samples=num_samples,
        guidance_scale=guidance_scale,
        num_inference_steps=steps,
        controlnet_conditioning_scale=controlnet_conditioning_scale,
        scale=scale,
    )
    
    # Display results
    comparison = degis.display_comparison_grid(
        original=pil_img,
        control=control_image,
        generated=images,
        cols=3
    )
    display(comparison)
    
    return images

print("✓ Generation function defined")


## 6. Generate Images


In [None]:
# Generate images with different prompts and parameters
print("Generating images with IP-Adapter (SD 1.5)...")

# Example 1: Cat with ball
images1 = generate_from_dataset_id(
    colour_index=1000,
    layout_index=33,
    prompt="a cat playing with a ball",
    guidance_scale=7.5,
    steps=30,
    controlnet_conditioning_scale=1.0,
    num_samples=1,
    scale=0.8,
)

# Example 2: Dog on hoodie
images2 = generate_from_dataset_id(
    colour_index=1008,
    layout_index=33,
    prompt="a dog on the hoodie",
    guidance_scale=7.5,
    steps=30,
    controlnet_conditioning_scale=1.0,
    num_samples=1,
    scale=0.8,
)

# Example 3: Different style
images3 = generate_from_dataset_id(
    colour_index=1003,
    layout_index=33,
    prompt="A cat on the hoodie, digital art style",
    guidance_scale=13.0,
    steps=50,
    controlnet_conditioning_scale=0.8,
    num_samples=1,
    scale=0.6,
)

print("✓ Image generation complete!")
