In [None]:
from tqdm.auto import tqdm
import os
from typing import List
import random
import json

import torch
from PIL import Image
import pandas as pd

from pipeline import Pipeline
from utils import AttentionStore, register_attention_control

%load_ext autoreload

In [None]:
import torch
torch.cuda.is_available()

In [None]:
sd_version = "stabilityai/stable-diffusion-xl-base-1.0"
num_inference_steps = 50
run_standard_sd = False

In [None]:
def load_model(sd_version):
    device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    stable = Pipeline.from_pretrained(sd_version, torch_dtype=torch.float16).to(device)
    return stable

def run_on_prompt(prompt: List[str],
                  model: Pipeline,
                  controller: AttentionStore,
                  generator: torch.Generator,
                  seed: int,
                  mita:int,
                  multi: int,
                  loss_orders,
                  is_sd=False) -> Image.Image:
    if controller is not None:
        register_attention_control(model, controller)
    outputs= model(prompt=prompt,
                   attention_store=controller,
                   generator=generator,
                   seed=seed,
                   max_iter_to_alter=mita,
                   multiplier=multi,
                   num_inference_steps=num_inference_steps,
                   run_standard_sd=is_sd,#run_standard_sd
                   loss_orders=loss_orders)
    image = outputs.images[0]
    return image

def save_image(image, output_path, prompt, seed, unique_id, index):
    os.makedirs(f"{output_path}/{prompt}", exist_ok=True)
    image.save(f"{output_path}/{prompt}/{unique_id}.png")

In [None]:
sd_model = load_model(sd_version)

In [None]:
def find_best_prefix(strings, word):
    max_common_chars = -1
    best_index = -1
    
    for i, s in reversed(list(enumerate(strings))):
        # Check if the string is a prefix of the word
        if word.startswith(s):
            # Count common characters after the prefix
            common_chars = sum(1 for a, b in zip(s, word[:len(s)]) if a == b)
            
            # Update if this prefix has more common characters
            if common_chars > max_common_chars:
                max_common_chars = common_chars
                best_index = i

    return best_index+1

def get_indices(sentence,obj1,obj2):
    tokens = [sd_model.tokenizer.decode(t) for t in sd_model.tokenizer(sentence)["input_ids"]]
    tokens = tokens[1:-1]
    print(tokens)
    idx1 = find_best_prefix(tokens,obj1)
    idx2 = find_best_prefix(tokens,obj2)
    return idx1, idx2

In [None]:
prompt = "a dog to the left of a cat"
tokens = [sd_model.tokenizer.decode(t) for t in sd_model.tokenizer(prompt)["input_ids"]]
tokens = tokens[1:-1]
list(enumerate(tokens, start=1))

In [None]:
seed =  random.randrange(1000000)
loss_ord = [{'indices_to_alter':[2,8],'direction':'left'}]

is_sd = False
g = torch.Generator('cuda').manual_seed(seed)
controller = AttentionStore()
mita = 10
multi = 1000
image = run_on_prompt(prompt=prompt,
                        model=sd_model,
                        controller=controller,
                        generator=g,
                        seed=seed,
                        mita=mita,
                        multi=multi,
                        loss_orders=loss_ord,is_sd=is_sd)
save_image(image=image, output_path='<path>', prompt=prompt, seed=seed, unique_id=f"{seed}-{mita}-{multi}", index=seed)