# CLoRA

## Visualization Util Functions

In [1]:
import matplotlib.pyplot as plt

import numpy as np
import math
from PIL import Image


def attn_map_to_image(attn_image):
    attn_image = (attn_image - attn_image.min()) / (attn_image.max() - attn_image.min())
    cmap = plt.get_cmap("jet")
    attn_image = cmap(attn_image)
    attn_image = np.delete(attn_image, 3, 2)
    attn_image = attn_image * 255
    attn_image = attn_image.astype(np.uint8)
    attn_image = np.array(Image.fromarray(attn_image).resize((256, 256)).convert("RGB"))

    return attn_image


def plot_attention_maps(pipeline, attention_maps, prompts, output_dir, step=None):
    for i, (prompt, attention_map) in enumerate(zip(prompts, attention_maps)):
        if step is not None:
            output_path = output_dir / f"step_{step:06d}_prompt_{i:06d}.png"
        else:
            output_path = output_dir / f"average_prompt_{i:06d}.png"

        output_path.parent.mkdir(exist_ok=True, parents=True)

        if step:
            title = f"Attention Maps `{prompt}` at Step {step}"
        else:
            title = f"Average Attention Maps `{prompt}`"

        ids = pipeline.tokenizer(prompt).input_ids
        indices = {
            j: tok
            for tok, j in zip(
                pipeline.tokenizer.convert_ids_to_tokens(ids), range(len(ids))
            )
        }

        n_rows = math.ceil(math.sqrt(len(indices) - 2))
        n_cols = math.ceil((len(indices) - 2) / n_rows)

        fig, ax = plt.subplots(n_rows, n_cols, figsize=(10, 10))
        if n_rows == 1 or n_cols == 1:
            ax.axis("off")
        else:
            [a.axis("off") for a in ax.ravel()]
        for idx, token in indices.items():
            if token in ["<|startoftext|>", "<|endoftext|>"]:
                continue
            token = token.replace("</w>", "")
            col = (idx - 1) % n_cols
            row = (idx - 1) // n_cols
            attn_image = attention_map[:, :, idx].cpu().numpy()
            attn_image = attn_map_to_image(attn_image)

            if n_rows == 1 and n_cols == 1:
                ax.imshow(attn_image, cmap="jet")
                ax.set_title(token)
            elif n_rows == 1 or n_cols == 1:
                ax[col].imshow(attn_image, cmap="jet")
                ax[col].set_title(token)
            else:
                ax[row, col].imshow(attn_image, cmap="jet")
                ax[row, col].set_title(token)

        plt.suptitle(title)
        plt.tight_layout()
        plt.savefig(output_path)
        plt.close()

## Model Initialization

In [2]:
import torch
from diffusers import AutoencoderKL, DDIMScheduler

import sys 
sys.path.append(".")
sys.path.append("..")
from pipeline_clora import CloraPipeline

pipeline = CloraPipeline.from_pretrained(
    "SG161222/Realistic_Vision_V5.1_noVAE", # "runwayml/stable-diffusion-v1-5"
    torch_dtype=torch.float16,
    use_safetensors=True,
).to("cuda")

vae = AutoencoderKL.from_pretrained(
    "stabilityai/sd-vae-ft-mse", torch_dtype=torch.float16
).to("cuda")
pipeline.vae = vae

schedule_config = dict(pipeline.scheduler.config)
pipeline.scheduler = DDIMScheduler.from_config(schedule_config)

  from .autonotebook import tqdm as notebook_tqdm
Loading pipeline components...: 100%|██████████| 5/5 [00:00<00:00,  7.85it/s]
  deprecate("LoraLoaderMixin", "1.0.0", deprecation_message)


## LoRA Paths

In [3]:
lora_paths = {
    "dog": "models/dog/pytorch_lora_weights.safetensors",
    "cat": "models/cat/pytorch_lora_weights.safetensors",
}



### Load LoRA Weights to SD

In [4]:
for lora, lora_path in lora_paths.items():
    pipeline.load_lora_weights(
        lora_path,
        adapter_name=lora,
    )



## Hyperparameters

In [5]:
H = 512 # Height
W = 512 # Width
seed = 53 # Seed
num_inference_steps = 100 # Inference steps
guidance_scale = 10.0 # Guidance scale
steps_to_save_attention = [] # Which steps to save attention maps
step_size = 20 # Step size for optimization
max_iter_to_alter = 50 # Maximum iterations to alter the the latent
iterative_steps = [0, 10, 20] # Which steps to apply iterative refinement
iterative_step_steps = 20 # Iterative refinement steps
latent_update = True # Update the latent
apply_mask_after = 0 # Apply mask after this step
attn_res = None # Attention resolution
mask_threshold_alpha = 0.4 # Mask threshold alpha
mask_erode = False # Mask erode
mask_dilate = False # Mask dilate
mask_opening = False # Mask opening
mask_closing = False # Mask closing
guidance_rescale = 0.0 # Guidance rescale
clip_skip = None # Clip skip
kwargs = {"scale": 1.0} # LoRA Scale
use_text_encoder_lora = True # Use LoRAs' text encoder

## Prompts

In [10]:
style_lora = ""
style_lora_weight = 0.8

fg_loras = ["", "dog", "cat"]
fg_prompts = [
    "A cat and a dog",
    "A sks dog and a cat",
    "A sks cat and a dog",
]
fg_negative = [
    "nsfw",
    "nsfw",
    "nsfw",
]

## Prompt Indices

In [11]:
for prompt in fg_prompts:
    ids = pipeline.tokenizer(prompt).input_ids
    indices = {
        j: tok
        for tok, j in zip(
            pipeline.tokenizer.convert_ids_to_tokens(ids), range(len(ids))
        )
    }
    print(indices)

{0: '<|startoftext|>', 1: 'a</w>', 2: 'cat</w>', 3: 'and</w>', 4: 'a</w>', 5: 'dog</w>', 6: '<|endoftext|>'}
{0: '<|startoftext|>', 1: 'a</w>', 2: 'sks</w>', 3: 'dog</w>', 4: 'and</w>', 5: 'a</w>', 6: 'cat</w>', 7: '<|endoftext|>'}
{0: '<|startoftext|>', 1: 'a</w>', 2: 'sks</w>', 3: 'cat</w>', 4: 'and</w>', 5: 'a</w>', 6: 'dog</w>', 7: '<|endoftext|>'}


In [12]:


# Indicate which tokens are important for each prompt
important_token_indices = [
    [
        [2],    # Cat in Prompt 1
        [5],    # Cat in Prompt 2
        [2, 3]  # Cat in Prompt 3
    ],
    [
        [6],    # Dog in Prompt 1
        [2,3],  # Dog in Prompt 2
        [6]     # Dog in Prompt 3
    ],
]

mask_indices = [
    [], # Backgroung mask indices from Prompt 1
    [2, 3], # Dog mask indices from Prompt 2
    [2, 3], # Cat mask indices from Prompt 3
]

In [13]:
image, attention_maps_to_save, masks_to_save = pipeline(
    prompt_list=fg_prompts,
    negative_prompt_list=fg_negative,
    lora_list=fg_loras,
    style_lora=style_lora,
    style_lora_weight=style_lora_weight,
    important_token_indices=important_token_indices,
    mask_indices=mask_indices,
    steps_to_save_attention=steps_to_save_attention,
    step_size=step_size,
    max_iter_to_alter=max_iter_to_alter,
    iterative_steps=iterative_steps,
    iterative_steps_steps=iterative_step_steps,
    latent_update=latent_update,
    apply_mask_after=apply_mask_after,
    mask_erode=mask_erode,
    mask_dilate=mask_dilate,
    mask_opening=mask_opening,
    mask_closing=mask_closing,
    mask_threshold_alpha=mask_threshold_alpha,
    height=H,
    width=W,
    num_inference_steps=num_inference_steps,
    guidance_scale=guidance_scale,
    generator=torch.Generator(device=pipeline.device).manual_seed(seed),
    cross_attention_kwargs=kwargs,
    guidance_rescale=guidance_rescale,
    clip_skip=clip_skip,
    use_text_encoder_lora=use_text_encoder_lora,
)

image[0].show()

KeyError: 'unet'