In [None]:
from diffusers import StableDiffusionPipeline
from matplotlib import pyplot as plt
import numpy as np
import time
import torch
import random
import daam

def set_seed(seed):
    gen = torch.Generator(device='cuda:3')
    random.seed(seed)
    np.random.seed(seed)
    torch.cuda.manual_seed(seed)
    torch.manual_seed(seed)

    return gen.manual_seed(s)


model = StableDiffusionPipeline.from_pretrained('stabilityai/stable-diffusion-2-base')
model = model.to('cuda:3')

In [None]:
def make_im_subplots(*args):
    fig, ax = plt.subplots(*args)

    for ax_ in ax.flatten():
        ax_.set_xticks([])
        ax_.set_yticks([])

    return fig, ax

s = 0
gen = set_seed(s)

two_object = model('a car and a boy', num_inference_steps=20, generator=gen).images[0]
with daam.trace(model, save_heads=True) as trc:
    masa_objects = model('a blue car and a running boy', num_inference_steps=20, generator=gen).images[0]
    car_map = trc.compute_global_heat_map().compute_word_heat_map('car')
    boy_map = trc.compute_global_heat_map().compute_word_heat_map('boy')

plt.rcParams['figure.figsize'] = (8, 8)
fig, ax = make_im_subplots(2, 2)

ax[0, 0].imshow(two_object)
ax[0,1].imshow(masa_objects)
car_map.plot_overlay(masa_objects, ax=ax[1, 0])
boy_map.plot_overlay(masa_objects, ax=ax[1, 1])
plt.show()

In [None]:
def make_im_subplots(*args):
    fig, ax = plt.subplots(*args)

    for ax_ in ax.flatten():
        ax_.set_xticks([])
        ax_.set_yticks([])

    return fig, ax

s = int(time.time())
gen = set_seed(s)

with daam.trace(model, save_heads=True) as trc:
    blue_image = model('a blue car driving down the street', num_inference_steps=20, generator=gen).images[0]
    blue_map = trc.compute_global_heat_map().compute_word_heat_map('blue')

gen = set_seed(s)

with daam.trace(model, load_heads=True) as trc:
    green_image = model('a green car driving down the street', num_inference_steps=20, generator=gen).images[0]
    green_map = trc.compute_global_heat_map().compute_word_heat_map('green')

gen = set_seed(s)

with daam.trace(model, load_heads=True) as trc:
    red_image = model('a red car driving down the street', num_inference_steps=20, generator=gen).images[0]
    red_map = trc.compute_global_heat_map().compute_word_heat_map('red')

In [None]:
plt.rcParams['figure.figsize'] = (8, 8)
fig, ax = make_im_subplots(2, 2)

# Original images
ax[0, 1].imshow(blue_image)
ax[1, 0].imshow(green_image)
ax[1, 1].imshow(red_image)

# Heat map
green_map.plot_overlay(green_image, ax=ax[0, 0])

plt.show()