# The Best CLoRA for SDXL: Usage

In [ ]:
# !pip install diffusers transformers accelerate safetensors pytorch-metric-learning
# Ensure you have a 'utils.py' with AttentionStore from the original CLoRA repo in the same directory.

## 1. Model Initialization

In [ ]:
import torch
from pipeline_clora_xl import CloraXLPipeline
from diffusers import AutoencoderKL

model_id = "stabilityai/stable-diffusion-xl-base-1.0"

# The custom pipeline inherits from the official SDXL pipeline,
# so we can load it directly with .from_pretrained
pipeline = CloraXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
).to("cuda")

# Recommended: Use the fixed VAE for better quality
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=torch.float16
).to("cuda")
pipeline.vae = vae

## 2. Load SDXL-Compatible LoRAs
**IMPORTANT**: You MUST use LoRAs trained specifically for SDXL. LoRAs for SD 1.5 will not work.

In [ ]:
# Replace with actual paths to your downloaded or trained SDXL LoRAs
lora_dog_path = "./models/dog_xl.safetensors"
lora_cat_path = "./models/cat_xl.safetensors"

pipeline.load_lora_weights(lora_dog_path, adapter_name="dog")
pipeline.load_lora_weights(lora_cat_path, adapter_name="cat")

## 3. Define Prompts and Hyperparameters

In [ ]:
# -- CLoRA Setup --
# LoRA for the background (empty string), then one for each concept
fg_loras = ["", "dog", "cat"]

# Prompt for the full scene, then one for each concept isolated with a trigger word (e.g., 'sks')
fg_prompts = [
    "A photo of a cat and a dog in a garden, cinematic lighting",
    "A photo of a sks dog in a garden, cinematic lighting",
    "A photo of a sks cat in a garden, cinematic lighting",
]

fg_negative = ["blurry, low quality, cartoon, anime"] * 3

# -- Token Indices for Loss and Masking --
# IMPORTANT: Re-calculate these for your specific prompts!
for prompt in fg_prompts:
    ids = pipeline.tokenizer(prompt).input_ids
    tokens = pipeline.tokenizer.convert_ids_to_tokens(ids)
    print({j: tok for j, tok in enumerate(tokens)})

# Example Indices (MUST BE UPDATED BASED ON ABOVE OUTPUT)
# Group tokens for the same concept together.
important_token_indices = [
    [[2, 3], [5], [2, 3]], # Concept 1: Cat (indices for 'a cat', 'cat', 'sks cat')
    [[5], [2, 3], [5]]  # Concept 2: Dog (indices for 'a dog', 'sks dog', 'dog')
]

# Which tokens' attention maps to use for creating the spatial masks
mask_indices = [
    [],       # Background - uses what's left over
    [2, 3],   # Dog mask from prompt 2 ('sks dog')
    [2, 3]    # Cat mask from prompt 3 ('sks cat')
]

## 4. Run the Pipeline

In [ ]:
generator = torch.Generator(device="cuda").manual_seed(42)

result = pipeline(
    prompt_list=fg_prompts,
    lora_list=fg_loras,
    negative_prompt_list=fg_negative,
    important_token_indices=important_token_indices,
    mask_indices=mask_indices,
    height=1024,
    width=1024,
    num_inference_steps=40,
    guidance_scale=8.0,
    latent_update=True,
    max_iter_to_alter=25,
    step_size=0.02, # Corresponds to latent update step size
    mask_threshold_alpha=0.4,
    generator=generator,
)

image = result.images[0]
image