In [1]:
# !pip install torchvision==0.12.0 numpy==1.19.2 albumentations==0.4.3 diffusers opencv-python==4.1.2.30 pudb==2019.2 invisible-watermark imageio==2.9.0 imageio-ffmpeg==0.4.2 pytorch-lightning==1.4.2 omegaconf==2.1.1
# !pip install test-tube>=0.7.5 streamlit>=0.73.1 einops==0.3.0 torch-fidelity==0.3.0 torchmetrics==0.6.0 kornia==0.6

# !pip install ftfy ipywidgets matplotlib pyrallis torch==1.12.0 diffusers==0.12.1 transformers==4.26.0 accelerate
%load_ext autoreload
%autoreload 2

DEVICE = 2
import os
os.environ["CUDA_VISIBLE_DEVICES"] = str(DEVICE)
print("Using GPU: {}".format(DEVICE))


Using GPU: 2


In [2]:

from typing import List, Dict
import torch

import sys
sys.path.append(".")
sys.path.append("..")

from pipeline_attend_and_excite import AttendAndExcitePipeline
from config import RunConfig
from run import run_on_prompt, get_indices_to_alter, read_associated_indices
from utils import vis_utils
from utils.ptp_utils import AttentionStore
from diffusers import DPMSolverMultistepScheduler


# Load Model Weights (may take a few minutes)

In [3]:

NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
    
stable = AttendAndExcitePipeline.from_pretrained("CompVis/stable-diffusion-v1-4").to(device)
# stable.scheduler = DPMSolverMultistepScheduler.from_pretrained(stable.scheduler.config, local_files_only=True)
stable = stable.to(torch.float16)
tokenizer = stable.tokenizer

safety_checker/pytorch_model.fp16.safetensors not found


Fetching 21 files:   0%|          | 0/21 [00:00<?, ?it/s]

The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


# Pipeline Wrapper

In [4]:
# configurable parameters (see RunConfig for all parameters)
# scale factor- intensity of shift by gradient
# thresholds- a dictionary for iterative refinement mapping the iteration number to the attention threshold
# max_iter_to_alter- maximal inference timestep to apply Attend-and-Excite
@torch.autocast("cuda")
def run_and_display(prompts: List[str],
                    controller: AttentionStore,
                    indices_to_alter: List[int],
                    groups: List[List[int]], # EDIT
                    generator: torch.Generator,
                    run_standard_sd: bool = False,
                    scale_factor: int = 20,
                    thresholds: Dict[int, float] = {0: 0.05, 10: 0.5, 20: 0.8},
                    max_iter_to_alter: int = 25,
                    display_output: bool = False,
                    loss_type: str = "cos",
                    ae_ratio: float = 0.7,
                    height: int = 1024,
                    width: int = 1024,
                    ):
    
    config = RunConfig(prompt=prompts[0],
                       run_standard_sd=run_standard_sd,
                       scale_factor=scale_factor,
                       thresholds=thresholds,
                       max_iter_to_alter=max_iter_to_alter,
                       loss_type=loss_type,
                       )
    
    image = run_on_prompt(model=stable,
                          prompt=prompts,
                          controller=controller,
                          token_indices=indices_to_alter,
                          groups=groups, # EDIT
                          seed=generator,
                          config=config,
                          ae_ratio=ae_ratio,
                          height=height,
                          width=width,
                          )
    if display_output:
        display(image)
    return image

# Show Cross-Attention Per Strengthened Token

## Define your seeds, prompt and the indices to strengthen

In [5]:
all_prompts, all_groups, all_indices_to_alter = read_associated_indices(path='../multi_obj_prompts_with_association.csv', group_split_char='|')
assert len(all_prompts) == len(all_groups) == len(all_indices_to_alter)
print("Some sample prompts and index groups:")
all_groups[5:15], all_prompts[5:15]

Some sample prompts and index groups:


([[[2, 3], [6, 7]],
  [[2, 3], [6, 7], [11, 12], [15, 16]],
  [[2, 3], [6, 7], [11, 14], [17]],
  [[2, 3], [6, 7], [10, 11], [15, 16], [19, 20]],
  [[2, 3], [6, 7], [10, 11], [15, 16, 17], [20, 21]],
  [[2, 3, 4], [7, 8, 9], [12, 13], [17, 18], [21]],
  [[2, 3, 4], [7, 8, 9], [12, 13], [17, 18, 19], [22, 24, 25]],
  [[2, 3, 4], [7, 8], [14, 15, 16], [20, 21, 22]],
  [[2, 3, 4], [6, 7, 8], [11, 14, 12], [17, 18, 19], [23, 24]],
  [[2, 3], [5, 6], [8, 9], [12, 13, 15], [18, 20, 21, 22, 23], [27, 28, 29]]],
 ['A purple car in a dark garage.',
  'A red apple, a yellow banana, and a green pear in a fruit basket.',
  'A blue pen, a yellow highlighter, and a white piece of paper on a desk.',
  'A brown chair, a black table, a green plant, and a white wall in a living room.',
  'A silver spoon, a white plate, a blue napkin, and a red wine glass on a dining table.',
  'A red stop sign, a yellow traffic light, a green tree, and a blue sky on a street.',
  "A brown teddy bear, a yellow rubber duc

In [6]:
i = 2
print(f"Selecting prompt {i} for visualization & comparison")

# Run config
prompt = all_prompts[i]
token_indices = all_indices_to_alter[i]
token_group = all_groups[i]
seeds = [21]


prompt, token_group

Selecting prompt 2 for visualization & comparison


('A yellow flower in a blue vase.', [[2, 3], [6, 7]])

## Stable Diffusion

In [8]:
for seed in seeds:
    g = torch.Generator('cuda').manual_seed(seed)
    prompts = [prompt]
    controller = AttentionStore()
    image = run_and_display(prompts=prompts,
                            controller=controller,
                            indices_to_alter=token_indices,
                            groups=None,
                            generator=g,
                            run_standard_sd=True,
                            display_output=True)
    vis_utils.show_cross_attention(attention_store=controller,
                                   prompt=prompt,
                                   tokenizer=tokenizer,
                                   res=16,
                                   from_where=("up", "down", "mid"),
                                   indices_to_alter=token_indices,
                                   orig_image=image)

  0%|          | 0/50 [00:00<?, ?it/s]

OutOfMemoryError: CUDA out of memory. Tried to allocate 4.00 GiB (GPU 0; 39.43 GiB total capacity; 15.51 GiB already allocated; 1.07 GiB free; 16.59 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

## Attend-and-Excite

In [None]:
for seed in seeds:
    g = torch.Generator('cuda').manual_seed(seed)
    prompts = [prompt]
    controller = AttentionStore()
    image = run_and_display(prompts=prompts,
                            controller=controller,
                            indices_to_alter=token_indices,
                            groups=None, # regular AE
                            generator=g,
                            run_standard_sd=False,
                            display_output=True)
    vis_utils.show_cross_attention(attention_store=controller,
                                   prompt=prompt,
                                   tokenizer=tokenizer,
                                   res=16,
                                   from_where=("up", "down", "mid"),
                                   indices_to_alter=token_indices,
                                   orig_image=image)

# Our contrastive/adversarial loss

### Cosine loss

In [None]:
from run import run_on_prompt
loss_type = "cos"
ae_ratio = 0

for seed in seeds:
    g = torch.Generator('cuda').manual_seed(seed)
    prompts = [prompt]
    controller = AttentionStore()
    image = run_and_display(prompts=prompts,
                            controller=controller,
                            indices_to_alter=token_indices,
                            groups=token_group, # new losses
                            generator=g,
                            run_standard_sd=False,
                            display_output=True,
                            loss_type=loss_type,
                            ae_ratio=ae_ratio)
    vis_utils.show_cross_attention(attention_store=controller,
                                   prompt=prompt,
                                   tokenizer=tokenizer,
                                   res=16,
                                   from_where=("up", "down", "mid"),
                                   indices_to_alter=token_indices,
                                   orig_image=image)
torch.cuda.empty_cache()

### Wasserstein loss

In [None]:
# NOTE: Not working because we have only one sample and covariance matrix is 0
# from run import run_on_prompt
# loss_type = "wasserstein"

# for seed in seeds:
#     g = torch.Generator('cuda').manual_seed(seed)
#     prompts = [prompt]
#     controller = AttentionStore()
#     image = run_and_display(prompts=prompts,
#                             controller=controller,
#                             indices_to_alter=token_indices,
#                             groups=token_group, # new losses
#                             generator=g,
#                             run_standard_sd=False,
#                             display_output=True,
#                             loss_type=loss_type)
#     vis_utils.show_cross_attention(attention_store=controller,
#                                    prompt=prompt,
#                                    tokenizer=tokenizer,
#                                    res=16,
#                                    from_where=("up", "down", "mid"),
#                                    indices_to_alter=token_indices,
#                                    orig_image=image)
# torch.cuda.empty_cache()

### Distance correlation

In [None]:
from run import run_on_prompt
loss_type = "dc"
ae_ratio = 0.5
for seed in seeds:
    g = torch.Generator('cuda').manual_seed(seed)
    prompts = [prompt]
    controller = AttentionStore()
    image = run_and_display(prompts=prompts,
                            controller=controller,
                            indices_to_alter=token_indices,
                            groups=token_group, # new losses
                            generator=g,
                            run_standard_sd=False,
                            display_output=True,
                            loss_type=loss_type,
                            ae_ratio=ae_ratio)
    vis_utils.show_cross_attention(attention_store=controller,
                                   prompt=prompt,
                                   tokenizer=tokenizer,
                                   res=16,
                                   from_where=("up", "down", "mid"),
                                   indices_to_alter=token_indices,
                                   orig_image=image)
torch.cuda.empty_cache()