In [1]:
import warnings
warnings.filterwarnings("ignore")

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
import diffusers

from src.llm_objects import spot_objects
from src.diffusion_models import SDXLEditingPipeline

device = torch.device("cuda:1")
device1 = torch.device("cuda:2")

In [2]:
# Model and tokenizer names
model_name = "google/gemma-7b-it"

# Load the model and tokenizer
model = AutoModelForCausalLM.from_pretrained(model_name).eval().to(device)
tokenizer = AutoTokenizer.from_pretrained(model_name)

`config.hidden_act` is ignored, you should use `config.hidden_activation` instead.
Gemma's activation function will be set to `gelu_pytorch_tanh`. Please, use
`config.hidden_activation` if you want to override this behaviour.
See https://github.com/huggingface/transformers/pull/29402 for more details.
Loading checkpoint shards: 100%|██████████| 4/4 [00:02<00:00,  1.55it/s]


# Extract objects from the prompt using LLMs

In [3]:
model_params = {
    "max_new_tokens": 200,
}
prompts = [
    "a realistic cartoon-style image with a princess and four dwarfs",
    "a vivid photo with a woman on the right and a clown on the left walking in a dirty alley",
    "a monkey sitting above a green motorcycle on the left and another raccoon sitting above a blue motorcycle on the right",
    "a photo of a giant macaron and a croissant splashing in the Seine with the Eiffel Tower in the background",
    "a DSLR photo of a meatball and a donut falling from the clouds onto a neighborhood",
]
results = []
for prompt in prompts:
    result = spot_objects(tokenizer, model, prompt, device, **model_params)
    results.append(result)

In [4]:
idx = 2
results[idx]

{'objects': [('monkey', [None]),
  ('motorcycle', ['green']),
  ('raccoon', [None]),
  ('motorcycle', ['blue'])],
 'bg_prompt': 'Unknown',
 'neg_prompt': ''}

# Image generation

In [5]:
base = SDXLEditingPipeline.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0",  use_safetensors=True, torch_dtype=torch.float16, variant="fp16", use_onnx=False
)
base.to(device1)
use_ddpm = True
if use_ddpm:
  print('Using DDPM as scheduler.')
  base.scheduler = diffusers.DDPMScheduler.from_config(base.scheduler.config)

Loading pipeline components...: 100%|██████████| 7/7 [00:01<00:00,  5.50it/s]


Using DDPM as scheduler.
