# The Best CLoRA for SDXL: Usage

In [None]:
# Ensure you have the necessary libraries installed
!pip install diffusers transformers accelerate safetensors pytorch-metric-learning matplotlib

In [None]:
# You MUST have a 'utils.py' file in the same directory as this notebook.
# It should contain the 'AttentionStore' and 'register_attention_control' code
# from the original CLoRA repository.

# For demonstration, here's a minimal version of 'utils.py' you can create:
utils_code = """
import torch

class AttentionStore:
    def __init__(self, res, min_res=16):
        self.res = res
        self.min_res = min_res
        self.reset()

    def reset(self):
        self.attention_maps = {}

    def B_transform(self, a, b):
        return torch.einsum('b i j, b j d -> b i d', a, b)

    def __call__(self, attn, is_cross, place_in_unet):
        if is_cross:
            if attn.shape[1] == (self.res[0] * self.res[1]):
                self.attention_maps[place_in_unet] = attn
            
    def aggregate_attention(self, places_in_unet=('down', 'mid', 'up')):
        att_map = [self.attention_maps[place] for place in places_in_unet if place in self.attention_maps]
        att_map = torch.cat(att_map, dim=1)
        return att_map

def register_attention_control(model, controller):

    def ca_forward(self, attn, is_cross, place_in_unet):
        def forward(hidden_states, encoder_hidden_states=None, attention_mask=None, **cross_attention_kwargs):
            
            # The `cross_attention_kwargs` should be empty here since we are processing the input
            # there are no kwargs to forward passed along and we don't want to break the code
            # without passing along the kwargs. That is why we need to pop them from the code.
            
            if cross_attention_kwargs is not None and len(cross_attention_kwargs) > 0:
                cross_attention_kwargs = cross_attention_kwargs.copy()
                lora_scale = cross_attention_kwargs.pop('scale', 1.0)
            else:
                lora_scale = 1.0

            if self.group_norm is not None:
                hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

            query = self.to_q(hidden_states)
            query = self.head_to_batch_dim(query)

            if encoder_hidden_states is None:
                encoder_hidden_states = hidden_states
            
            key = self.to_k(encoder_hidden_states)
            value = self.to_v(encoder_hidden_states)
            
            key = self.head_to_batch_dim(key)
            value = self.head_to_batch_dim(value)

            attention_probs = self.get_attention_scores(query, key, attention_mask)
            
            controller(attention_probs, is_cross, place_in_unet)
            
            hidden_states = torch.bmm(attention_probs, value)
            hidden_states = self.batch_to_head_dim(hidden_states)

            # linear proj
            hidden_states = self.to_out[0](hidden_states)
            # dropout
            hidden_states = self.to_out[1](hidden_states)

            return hidden_states
        return forward

    def register_recr(net_, count, place_in_unet):
        if net_.__class__.__name__ == 'Attention':
            net_.forward = ca_forward(net_, True, place_in_unet)
            return count + 1
        elif hasattr(net_, 'children'):
            for net__ in net_.children():
                count = register_recr(net__, count, place_in_unet)
        return count

    cross_att_count = 0
    sub_nets = model.named_children()
    for net in sub_nets:
        if "down" in net[0]:
            cross_att_count += register_recr(net[1], 0, "down")
        elif "up" in net[0]:
            cross_att_count += register_recr(net[1], 0, "up")
        elif "mid" in net[0]:
            cross_att_count += register_recr(net[1], 0, "mid")

"""

with open("utils.py", "w") as f:
    f.write(utils_code)

## 1. Model Initialization

In [None]:
import torch
from pipeline_clora_xl import CloraXLPipeline
from diffusers import AutoencoderKL

model_id = "stabilityai/stable-diffusion-xl-base-1.0"

# The custom pipeline inherits from the official SDXL pipeline,
# so we can load it directly with .from_pretrained
pipeline = CloraXLPipeline.from_pretrained(
    model_id,
    torch_dtype=torch.float16,
    use_safetensors=True,
    variant="fp16"
).to("cuda")

# Recommended: Use the fixed VAE for better quality
vae = AutoencoderKL.from_pretrained(
    "madebyollin/sdxl-vae-fp16-fix", 
    torch_dtype=torch.float16
).to("cuda")
pipeline.vae = vae

## 2. Load SDXL-Compatible LoRAs
**IMPORTANT**: You MUST use LoRAs trained specifically for SDXL. LoRAs for SD 1.5 will not work.
You will need to create a `./models` directory and place your `.safetensors` files inside.

In [None]:
import os

# Create a directory for models
os.makedirs("models", exist_ok=True)

# NOTE: You must download SDXL LoRAs and place them here.
# For example, from Civitai or Hugging Face Hub.
# This code will fail if the files don't exist.
lora_dog_path = "./models/dog_xl.safetensors"
lora_cat_path = "./models/cat_xl.safetensors"

try:
    pipeline.load_lora_weights(lora_dog_path, adapter_name="dog")
    pipeline.load_lora_weights(lora_cat_path, adapter_name="cat")
except Exception as e:
    print(f"Could not load LoRAs. Please download them and place them in the 'models' directory.")
    print(f"Error: {e}")

## 3. Define Prompts and Hyperparameters

In [None]:
# -- CLoRA Setup --
# LoRA for the background (empty string), then one for each concept
fg_loras = ["", "dog", "cat"]

# Prompt for the full scene, then one for each concept isolated with a trigger word (e.g., 'sks')
fg_prompts = [
    "A photo of a cat and a dog in a garden, cinematic lighting",
    "A photo of a sks dog in a garden, cinematic lighting",
    "A photo of a sks cat in a garden, cinematic lighting",
]

fg_negative = ["blurry, low quality, cartoon, anime"] * 3

# -- Token Indices for Loss and Masking --
# IMPORTANT: Re-calculate these for your specific prompts!
print("--- Tokenization for Prompts ---")
for prompt in fg_prompts:
    ids = pipeline.tokenizer(prompt).input_ids
    tokens = pipeline.tokenizer.convert_ids_to_tokens(ids)
    print({j: tok for j, tok in enumerate(tokens)})

# Example Indices (MUST BE UPDATED BASED ON ABOVE OUTPUT)
# Group tokens for the same concept together.
important_token_indices = [
    [[4], [7], [4]], # Concept 1: Cat
    [[7], [4], [7]]  # Concept 2: Dog
]

# Which tokens' attention maps to use for creating the spatial masks
mask_indices = [
    [],       # Background - uses what's left over
    [4],   # Dog mask from prompt 2 ('sks dog')
    [4]    # Cat mask from prompt 3 ('sks cat')
]

## 4. Run the Pipeline

In [None]:
generator = torch.Generator(device="cuda").manual_seed(42)

result = pipeline(
    prompt_list=fg_prompts,
    lora_list=fg_loras,
    negative_prompt_list=fg_negative,
    important_token_indices=important_token_indices,
    mask_indices=mask_indices,
    height=1024,
    width=1024,
    num_inference_steps=40,
    guidance_scale=8.0,
    latent_update=True,
    max_iter_to_alter=25,
    step_size=0.02,
    mask_threshold_alpha=0.4,
    generator=generator,
)

image = result.images[0]
image.show() # Use .show() to display in a separate window or just 'image' to display inline