In [None]:
%env CUDA_VISIBLE_DEVICES=5
%load_ext autoreload
%autoreload 2
import sys
sys.path.append("..")

In [None]:
import torch
from PIL import Image
from matplotlib import pyplot as plt

from utils.dataset import read_dataset_config, load_dataset, Preprocessor
from utils.model import read_model_config, load_model

In [None]:
config_path = 'dataset_config.yaml'
dataset_name = 'SPair-71k'
image_size = (768, 768)
model_config = 'eval_config.yaml'
model_name = 'diff_sd2-1_hook'

In [None]:
dataset_config = read_dataset_config(config_path)
config = dataset_config[dataset_name]

preprocess = Preprocessor(image_size=image_size, image_range=[-1, 1], rescale_data=False, flip_data=False, normalize_image=False)
dataset = load_dataset(dataset_name, config, preprocess)
if hasattr(dataset, 'category_to_id'):
    dataset.create_category_to_id()

In [None]:
model_config = read_model_config(model_config)[model_name]
model = load_model(model_name, model_config)
model.eval()
model.to("cuda")

In [None]:
def attn(Q, K, V=None):
    attn = torch.matmul(Q, K.transpose(-1, -2))
    attn = attn / torch.sqrt(torch.tensor(Q.shape[-1]).float())
    attn = torch.nn.functional.softmax(attn, dim=-1)
    #attn = torch.matmul(attn, V)
    return attn.transpose(-1, -2)

def plot_attention_maps(sample, prompt_prefix='a photo of a ', block=0):
    prompt = prompt_prefix + sample['source_category']

    with torch.no_grad():
        features = model.get_features(sample['source_image'].unsqueeze(0).to("cuda"), [prompt])

    Q = features[block*3]
    K = features[block*3+1]
    V = features[block*3+2]
    print(Q.shape, K.shape, V.shape)

    vocab = model.extractor.pipe.tokenizer.get_vocab()
    vocab = {v: k for k, v in vocab.items()}
    tokens = [vocab[t] for t in model.extractor.pipe.tokenizer.encode(prompt)[1:-1]] # remove <s> and </s>

    fig, ax = plt.subplots(1, len(tokens)+1, figsize=(20, 5))
    ax[0].imshow(Image.open(sample['source_image_path']))
    ax[0].axis('off')

    A = attn(Q, K)
    for i, t in enumerate(tokens):
        size = int(A.shape[-1] ** .5)
        attn_map = A[0, i+1].reshape(size, size)
        attn_map = (attn_map - attn_map.min()) / (attn_map.max() - attn_map.min())
        attn_map = torch.nn.functional.interpolate(attn_map.unsqueeze(0).unsqueeze(0), size=(768, 768), mode='bilinear').squeeze(0).squeeze(0)
        ax[i+1].imshow(attn_map.cpu().numpy(), cmap='hot')
        ax[i+1].set_title(t)
        ax[i+1].axis('off')

In [None]:
for i in range(9):
    plot_attention_maps(dataset[0], prompt_prefix='front side of a ', block=i)
    plt.show()