In [1]:
from sdhelper import SD
from pathlib import Path
from PIL import Image
import torch
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm, trange
import numpy as np


In [None]:
sd = SD()

In [None]:
prompts = [
    "a photo of a cat",
    "a photo of a dog",
    "a photo of a bird",
    "a paining of a house",
    "a drawing of a frog",
    "a beautiful landscape painting",
    "a cyberpunk cityscape",
]

imgs = []
reprs_generated = []
reprs_extracted = []

for prompt in tqdm(prompts):
    result = sd(prompt, extract_positions=sd.available_extract_positions)
    imgs.append(result.result_image)
    reprs_generated.append(result.representations)
    reprs_extracted.append(sd.img2repr(result.result_image, extract_positions=sd.available_extract_positions, step=100))

In [4]:
reprs_extracted2 = [sd.img2repr(img, extract_positions=sd.available_extract_positions, step=100) for img in imgs]

In [None]:
block = 'down_blocks[1]'

fig, axs = plt.subplots(len(prompts), 5, figsize=(15, 3*len(prompts)))
for i, (img, repr_generated, repr_extracted, repr_extracted2) in enumerate(zip(imgs, reprs_generated, reprs_extracted, reprs_extracted2)):

    # image
    axs[i, 0].imshow(img)
    axs[i, 0].axis('off')
    axs[i, 0].set_title(f"{prompts[i]}")

    # heatmap of diffs between generated and extracted reprs
    diffs = (repr_generated[block][:,0].cpu() - repr_extracted[block][0])
    diff = diffs[-5].norm(dim=0).cpu().numpy()
    axs[i, 1].imshow(diff)
    axs[i, 1].axis('off')
    axs[i, 1].set_title(f"Diffs gen.-extr.")
    im = axs[i, 1].imshow(diff)
    fig.colorbar(im, ax=axs[i, 1], orientation='vertical', fraction=0.046, pad=0.04)

    # bar chart of mean diffs
    axs[i, 2].bar(np.arange(len(diffs)), diffs.norm(dim=1).flatten(1).mean(dim=1).cpu().numpy())
    axs[i, 2].set_title(f"Diffs mean")
    axs[i, 2].set_xlabel("step")
    min_index = np.argmin(diffs.norm(dim=1).flatten(1).mean(dim=1).cpu().numpy())
    axs[i, 2].bar(min_index, diffs.norm(dim=1).flatten(1).mean(dim=1).cpu().numpy()[min_index], color='red')

    # heatmap of diffs between extracted reprs
    diffs2 = (repr_extracted[block][0] - repr_extracted2[block][0])
    diff2 = diffs2.norm(dim=0).cpu().numpy()
    axs[i, 3].imshow(diff2)
    axs[i, 3].axis('off')
    axs[i, 3].set_title(f"Diff extr.-extr.")
    im = axs[i, 3].imshow(diff2)
    fig.colorbar(im, ax=axs[i, 3], orientation='vertical', fraction=0.046, pad=0.04)

    # heatmap of diffs between extracted reprs of different images
    diffs3 = (repr_extracted[block][0] - reprs_extracted[(i+1)%len(imgs)][block][0])
    diff3 = diffs3.norm(dim=0).cpu().numpy()
    axs[i, 4].imshow(diff3)
    axs[i, 4].axis('off')
    axs[i, 4].set_title(f"Diff extr.-extr. (next)")
    im = axs[i, 4].imshow(diff3)
    fig.colorbar(im, ax=axs[i, 4], orientation='vertical', fraction=0.046, pad=0.04)

plt.tight_layout()
plt.show()

In [None]:
# TODO: statistics and different layers
# WIP

fig, axs = plt.subplots(len(sd.available_extract_positions), 4, figsize=(12, 3*len(sd.available_extract_positions)))
for i, block in enumerate(sd.available_extract_positions):
    reprs_generated_block = torch.stack([r[block][:,0,:,:,:].cpu() for r in reprs_generated])
    reprs_extracted_block = torch.stack([r[block][:,:,:,:] for r in reprs_extracted])
    diffs = (reprs_generated_block - reprs_extracted_block).norm(dim=2).mean(dim=0)
    
    axs[i, 0].imshow(diffs[-5])
    axs[i, 0].axis('off')
    axs[i, 0].set_title(f"Mean diff")

    axs[i, 1].bar(np.arange(len(diffs)), diffs.flatten(1).mean(dim=1))
    axs[i, 1].set_title(f"Mean diff")
    axs[i, 1].set_xlabel("step")

plt.tight_layout()
plt.show()

In [None]:
result = sd('a cat', extract_positions=sd.available_extract_positions)
reprs = result.representations

fig, axs = plt.subplots(50, len(sd.available_extract_positions), figsize=(len(sd.available_extract_positions)*2, 50*2))
for i, block in enumerate(sd.available_extract_positions):
    for j in range(50):
        axs[j, i].imshow(reprs[block][j][0].norm(dim=0).cpu().numpy())
        axs[j, i].axis('off')
    axs[0, i].set_title(block)

plt.tight_layout()
plt.show()