## Setup

In [None]:
from sdhelper import SD
from PIL import Image
import numpy as np
import torch
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec

from tqdm.autonotebook import trange, tqdm
from collections import defaultdict

In [None]:
sd = SD('SD1.5')

## Calculate cosine similarity between all blocks

In [None]:

pos = list(sorted([  # only for SD1.5 and similar
    'mid_block',
    *[f'{d}_blocks[{i}]' for i in range(4) for d in ['up', 'down']],
    'mid_block.attentions[0]',
    'mid_block.resnets[0]',
    'mid_block.resnets[1]',
    *[f'{d}_blocks[{i}].{t}[{j}]'
      for d, tmp in [
          ('down', [(2,2), (2,2), (2,2), (0,2)]),
          ('up', [(0,3), (3,3), (3,3), (3,3)]),
      ]
      for i, (a_len, r_len) in enumerate(tmp)
      for t, j in zip(['attentions']*a_len + ['resnets']*r_len, [*range(a_len), *range(r_len)])
    ],
]))

n = 50
cossim = dict()
for _ in trange(n):
    empty_img = Image.new('RGB', (512, 512), (0, 0, 0))
    r = sd.img2repr(empty_img, pos, 50, output_device='cuda')
    for p in pos:
        tmp = r.at(p).cosine_similarity(r.at(p)) / n
        if p not in cossim:
            cossim[p] = torch.zeros_like(tmp)
        cossim[p] += tmp

## similarity over distance

In [None]:
def plot_similarity_over_distance(cossim, p, distance_metric=1):
    shape = cossim[p].shape

    # Process in chunks to reduce memory usage
    chunk_size = 1000  # Adjust this based on your available memory
    tmp = defaultdict(list)

    for start in range(0, shape[0] * shape[1], chunk_size):
        end = min(start + chunk_size, shape[0] * shape[1])
        i, j = np.unravel_index(np.arange(start, end), (shape[0], shape[1]))
        coords1 = torch.tensor(np.stack([i, j], axis=1), dtype=torch.float)
        for k in range(shape[2]):
            for l in range(shape[3]):
                coords2 = torch.tensor([[k, l]], dtype=torch.float)
                distances = torch.cdist(coords1, coords2, p=distance_metric).squeeze()
                cossim_values = cossim[p][i, j, k, l]
                for d, v in zip(distances.tolist(), cossim_values.tolist()):
                    tmp[d].append(v)

    tmp = sorted(tmp.items())
    plt.scatter(
        torch.cdist(
            torch.tensor(list(np.ndindex(shape[:2]))).float(),
            torch.tensor(list(np.ndindex(shape[2:]))).float(),
        ).flatten(),
        cossim[p].flatten().cpu(),
        s=0.1,
        alpha=0.5,
    )
    plt.plot(
        *zip(*((k, np.mean(v)) for k, v in tmp)),
        label=f'{i},{j}',
        color='orange',
    )
    plt.xlabel(f'L{distance_metric} Distance')
    plt.ylabel('Cosine Similarity')
    plt.title(f'{p}')

# pos_tmp = [f'down_blocks[{i}]' for i in range(0,4)] + ['mid_block'] + [f'up_blocks[{i}]' for i in range(3)]  # ignore up_blocks[3] cuz of high resolution
pos_tmp = [f'down_blocks[{i}]' for i in range(1,4)] + ['mid_block'] + [f'up_blocks[{i}]' for i in range(2)]
plt.figure(figsize=(8, 3*len(pos_tmp)))
plt.suptitle('Cosine Similarity over Distance')
for i, p in enumerate(tqdm(pos_tmp)):
    plt.subplot(len(pos_tmp), 2, 2*i+1)
    plot_similarity_over_distance(cossim, p, 1)
    plt.subplot(len(pos_tmp), 2, 2*i+2)
    plot_similarity_over_distance(cossim, p, 2)
plt.tight_layout()
plt.show()


## Plot cosine similarity distribution for all blocks

In [None]:
for p, x in cossim.items():
    n = x.shape[0]
    x = x.cpu().numpy()
    tmp = np.full((n**2+n+1,n**2+n+1), np.nan)
    for i in range(n):
        for j in range(n):
            tmp[i*(n+1)+1:(i+1)*(n+1),j*(n+1)+1:(j+1)*(n+1)] = x[i,j]

    fig, ax = plt.subplots(figsize=(10,10))

    plt.imshow(tmp, vmin=0, vmax=1)
    plt.title(p)
    # Remove the spines
    for spine in ax.spines.values():
        spine.set_visible(False)

    plt.xticks(np.arange(0,(n+1)*n,(n+1)) + n/2 + .5, range(n))
    plt.yticks(np.arange(0,(n+1)*n,(n+1)) + n/2 + .5, range(n))
    plt.xlabel('x-position of reference location')
    plt.ylabel('y-position of reference location')
    plt.show()

## Plot cosine similarity over different resolutions and aspect ratios

In [None]:
pos = [f'down_blocks[{i}]' for i in range(4)] + ['mid_block'] + [f'up_blocks[{i}]' for i in range(4)]
res1 = (512, 512)
res_others = [
    (768,768),
    (256, 512),
    (512, 256),
    (256, 256),
]

# Create two empty images
img1 = Image.new('RGB', res1, (0, 0, 0))
img_others = [Image.new('RGB', res, (0, 0, 0)) for res in res_others]

# Compute cosine similarity
cossims = [{p: [] for p in pos} for _ in range(len(res_others))]
for _ in trange(50):
    r1 = sd.img2repr(img1, pos, 100, output_device='cuda')
    r_others = [sd.img2repr(img, pos, 100, output_device='cuda') for img in img_others]
    for p in pos:
        for i, r2 in enumerate(r_others):
            cossims[i][p].append(r1.at(p).cosine_similarity(r2.at(p)).cpu())
cossims = [{p: torch.stack(v).mean(0) for p, v in x.items()} for x in cossims]

In [None]:
# Set DPI and spacing
dpi = 200
h_spacing = 0.5  # Horizontal spacing between plots in inches
v_spacing = 0.2  # Vertical spacing between plots in inches
top_spacing = 1.0  # Extra space at the top for suptitle and column labels in inches

# Calculate figure size
all_res = [res1] + res_others
max_width = max(res[0] for res in all_res) / dpi
max_height = max(res[1] for res in all_res) / dpi
fig_width = sum(res[0] for res in all_res) / dpi + h_spacing * (len(all_res) - 1)
fig_height = len(pos) * max_height + v_spacing * (len(pos) - 1) + top_spacing

# Create figure
fig = plt.figure(figsize=(fig_width, fig_height), dpi=dpi)

# Create GridSpec
gs = gridspec.GridSpec(len(pos) + 2, len(all_res), 
                       height_ratios=[top_spacing/2, top_spacing/2] + [max_height] * len(pos),
                       width_ratios=[res[0]/dpi for res in all_res],
                       hspace=v_spacing/max_height, wspace=h_spacing/max_width)

# Add suptitle
fig.text(0.5, 1 - top_spacing/(4*fig_height), 'Averaged color transfer based on cosine similarity for an empty image', 
         fontsize=16, ha='center', va='center')

# Add column labels
for j, res in enumerate(all_res):
    label = f'Reference {res1}' if j == 0 else f'Target {res}'
    ax = fig.add_subplot(gs[1, j])
    ax.text(0.5, 0.5, label, ha='center', va='center', fontsize=14)
    ax.axis('off')

for i, p in enumerate(pos):
    # Add row labels
    ax_row = fig.add_subplot(gs[i+2, :])
    ax_row.text(-0.01, 0.5, p, va='center', ha='right', fontsize=12, transform=ax_row.transAxes)
    ax_row.axis('off')

    for j, res in enumerate(all_res):
        ax = fig.add_subplot(gs[i+2, j])

        if j == 0:  # Source image
            source_shape = cossims[0][p].shape[:2]
            source_color = np.zeros((*source_shape, 3))
            source_color[:, :, 0] = np.linspace(0, 1, source_shape[0])[:, None]
            source_color[:, :, 1] = np.linspace(0, 1, source_shape[1])[None, :]
            ax.imshow(source_color, aspect='equal', interpolation='nearest')
        else:  # Target images
            rows, cols = np.unravel_index(cossims[j-1][p].flatten(end_dim=1).argmax(axis=0), source_shape)
            target_color = source_color[rows, cols]
            ax.imshow(target_color, aspect='equal', interpolation='nearest')

        ax.axis('off')

plt.subplots_adjust(top=1, bottom=0, left=0, right=1)
plt.show()