In [1]:
import numpy as np
from sdhelper import SD
from PIL import Image
import torch
from tqdm.autonotebook import tqdm, trange
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch
import torch.nn.functional as F
import random


In [2]:
data = load_dataset("jonasloos/imagenet_subset", split="train")

In [None]:
sd = SD()


In [4]:
images = [d['image'] for d in data]

In [None]:

block = 'mid_block'

representations = sd.img2repr(images, extract_positions=[block], step=50)

In [None]:
w, h = images[0].size
token_size = w // representations[0][block].shape[-1]
representations_cropped = sd.img2repr([img.crop((token_size, token_size, w-token_size, h-token_size)) for img in images], extract_positions=[block], step=50)

In [None]:
print(f'base: {representations[0][block].shape}')
print(f'cropped: {representations_cropped[0][block].shape}')

In [None]:
similarity_maps = torch.stack([x.cosine_similarity(x) for x in tqdm(representations)])
similarity_maps_cropped = torch.stack([x.cosine_similarity(x) for x in tqdm(representations_cropped)])



In [None]:
n = similarity_maps.shape[1]
plt.subplot(1, 2, 1)
plt.imshow(similarity_maps[:,1,1,:,:].mean(dim=0))
plt.plot([.5,.5,n-1.5,n-1.5,.5], [.5,n-1.5,n-1.5,.5,.5], 'k-')
plt.axis('off')
plt.title('Full image')
plt.subplot(1, 2, 2)
plt.imshow(F.pad(similarity_maps_cropped[:,0,0,:,:].mean(dim=0), (1,1,1,1), value=torch.nan))
plt.axis('off')
plt.title('Cropped image')
plt.show()

In [None]:
corner_slices = [(a,b) for a in [0,-1] for b in [0,-1]]
border_slices = [(slice(1,-1), 0), (slice(1,-1), -1), (0, slice(1,-1)), (-1, slice(1,-1))]
other_slices  = [(slice(1,-1), slice(1,-1))]
slices = [
    ('corner', corner_slices),
    ('border', border_slices),
    ('other', other_slices),
]

similarity_maps_repr_cropped = F.pad(similarity_maps, [-1]*8, value=torch.nan)

for sim_maps in [similarity_maps_cropped, similarity_maps_repr_cropped]:
    n = similarity_maps.shape[1]
    m = len(slices)

    result = torch.zeros((m, m))
    for i, (name1, slices1) in enumerate(slices):
        for j, (name2, slices2) in enumerate(slices):
            count = torch.stack([torch.ones((n,n,n,n))[s11,s12,s21,s22].sum() for s11, s12 in slices1 for s21, s22 in slices2]).sum() * len(sim_maps)
            self_similarities = torch.stack([torch.ones((n,n))[s11,s12].sum() for s11, s12 in slices1]).sum() * len(sim_maps) if i == j else 0
            result[i,j] = (torch.stack([sim_maps[:,s11,s12,s21,s22].sum() for s11, s12 in slices1 for s21, s22 in slices2]).sum() - self_similarities) / (count - self_similarities)

    plt.imshow(result, vmin=0)
    names = [x[0] for x in slices]
    plt.xticks(ticks=range(len(slices)), labels=names, rotation=0)
    plt.tick_params(top=True, bottom=False, labeltop=True, labelbottom=False)

    plt.yticks(ticks=range(len(slices)), labels=names)
    for i in range(m):
        for j in range(m):
            plt.text(j, i, f'{result[i,j]:.4f}', ha='center', va='center', color='white' if result[i,j] < result.max()/2 else 'black')

    plt.colorbar()
    plt.show()



In [None]:
corner_slices = [(a,b) for a in [0,-1] for b in [0,-1]]
border_slices = [(slice(1,-1), 0), (slice(1,-1), -1), (0, slice(1,-1)), (-1, slice(1,-1))]
other_slices  = [(slice(1,-1), slice(1,-1))]
slices = [
    ('corner', corner_slices),
    ('border', border_slices),
    ('other', other_slices),
]

def calc_group_similarities(block, images):
    representations = sd.img2repr(images, extract_positions=[block], step=50)
    w, h = images[0].size
    token_size = w // representations[0][block].shape[-1]
    representations_cropped = sd.img2repr([img.crop((token_size, token_size, w-token_size, h-token_size)) for img in images], extract_positions=[block], step=50)
    similarity_maps = torch.stack([x.cosine_similarity(x) for x in representations])
    similarity_maps_cropped = torch.stack([x.cosine_similarity(x) for x in representations_cropped])
    similarity_maps_repr_cropped = F.pad(similarity_maps, [-1]*8, value=torch.nan)

    results = []
    for sim_maps in [similarity_maps_cropped, similarity_maps_repr_cropped]:
        n = similarity_maps.shape[1]
        m = len(slices)

        result = torch.zeros((m, m))
        for i, (_, slices1) in enumerate(slices):
            for j, (_, slices2) in enumerate(slices):
                count = torch.stack([torch.ones((n,n,n,n))[s11,s12,s21,s22].sum() for s11, s12 in slices1 for s21, s22 in slices2]).sum() * len(sim_maps)
                self_similarities = torch.stack([torch.ones((n,n))[s11,s12].sum() for s11, s12 in slices1]).sum() * len(sim_maps) if i == j else 0
                result[i,j] = (torch.stack([sim_maps[:,s11,s12,s21,s22].sum() for s11, s12 in slices1 for s21, s22 in slices2]).sum() - self_similarities) / (count - self_similarities)
        results.append(result)
    return results


results = {}
for block in tqdm(sd.available_extract_positions):
    results[block] = calc_group_similarities(block, images)

In [None]:
results_filtered = {k: v for k, v in results.items() if k not in ['conv_out']}
xs = range(len(results_filtered))

for i, group_name in enumerate(['corner', 'border', 'other']):
    plt.plot(xs, [x[0][i,i]/x[1][i,i] for x in results_filtered.values()], label=group_name)

plt.plot(xs, [1]*len(xs), 'k--', color='gray')
plt.ylabel('Relative Similarity')
plt.yscale('log')
plt.yticks([0.5, 1, 2], ['0.5', '1', '2'])
plt.gca().yaxis.set_minor_formatter(plt.NullFormatter())
block_names = [x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in results_filtered.keys()]
plt.xticks(ticks=xs, labels=block_names, rotation=90)
plt.title('Similarity between groups of tokens')
plt.legend()
plt.show()

# This shows that for many blocks, the tokens are more similar to each other if they are in the corner, as otherwise.
