# Do Semantic Correspondence / Dense Matching between images and their rotated versions

In [None]:
from sdhelper import SD
import numpy as np
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import PIL.ImageOps
import torch


In [None]:
sd = SD()

In [None]:
n=512
seed=42
step=50
img_paths = [
    'cat_next_to_house.png',
    'cat_going_right.png',
]


imgs = [PIL.Image.open(img_path).resize((n, n)) for img_path in img_paths]
poss = [
    ['down_blocks[0]'],
    ['down_blocks[1]'],
    ['down_blocks[2]'],
    ['down_blocks[3]'],
    ['mid_block'],
    ['up_blocks[0]'],
    ['up_blocks[1]'],
    ['up_blocks[2]'],
    ['up_blocks[3]'],
]
reprs = [[sd.img2repr(img, pos, step=step, seed=seed) for pos in poss] for img in imgs]
reprs_flipped = [[sd.img2repr(img.rotate(180), pos, step=step, seed=seed) for pos in poss] for img in imgs]
# setup colorwheel
colorwheel = np.zeros((n, n, 3), dtype=np.uint8)
offset = n/2 + .5
for i in range(n):
    for j in range(n):
        angle = np.arctan2(i-offset, j-offset)
        dist = 1 - np.sqrt((i-offset)**2 + (j-offset)**2) / offset / np.sqrt(2)
        colorwheel[i, j, :] = np.array([.5+.5*np.sin(angle), .5+.5*np.sin(angle+np.pi/2), dist]) * 255
# setup figure
fig, axs = plt.subplots(len(poss)+2, len(imgs)*3, figsize=((len(imgs)*3)*3+1, (len(poss)+2)*3))
# plot original images
for i, img in enumerate(imgs):
    axs[0, 3*i].imshow(img)
    axs[0, 3*i].axis('off')
    axs[0, 3*i].set_title('(Transferred) Image')
    axs[0, 3*i+1].imshow(colorwheel)
    axs[0, 3*i+1].axis('off')
    axs[0, 3*i+1].set_title('(Transferred) Colorwheel')
    axs[0, 3*i+2].axis('off')
    axs[0, 3*i+2].set_title('Error between transferred and target')
# plot flipped images
for i, img in enumerate(imgs):
    axs[1, 3*i].imshow(img.rotate(180))
    axs[1, 3*i].axis('off')
    axs[1, 3*i+1].imshow(colorwheel[::-1, ::-1])
    axs[1, 3*i+1].axis('off')
    axs[1, 3*i+2].axis('off')
# plot transferred images
for i, img in enumerate(imgs):
    for j, pos in enumerate(poss):
        img = np.array(img)
        similarities = reprs[i][j].cosine_similarity(reprs_flipped[i][j])
        transferred_img = np.zeros_like(img)
        transferred_colorwheel = np.zeros_like(img)
        m = similarities.shape[0]
        s = n // m
        for k in range(m):
            for l in range(m):
                argmax = similarities[:,:,k, l].flatten().argmax()
                k_, l_ = argmax // m, argmax % m
                transferred_img[k*s:(k+1)*s, l*s:(l+1)*s] = img[k_*s:(k_+1)*s, l_*s:(l_+1)*s][::-1, ::-1]  # use ::-1 to flip, and ::1 to not flip
                transferred_colorwheel[k*s:(k+1)*s, l*s:(l+1)*s] = colorwheel[k_*s:(k_+1)*s, l_*s:(l_+1)*s][::-1, ::-1]
        dist_compared_to_flipped = np.zeros((m, m))

        indices = similarities.view(-1, m, m).argmax(dim=0)
        k_, l_ = indices // m, indices % m
        k_, l_ = m - 1 - k_, m - 1 - l_  # rotate 180 degrees
        k, l = torch.meshgrid(torch.arange(m), torch.arange(m), indexing='ij')
        errors = ((k - k_)**2 + (l - l_)**2)**.5
        all_distances = torch.cdist(*[torch.stack([k.flatten(), l.flatten()], dim=1).float()]*2)
        percentiles = (all_distances < errors.flatten().unsqueeze(1)).float().mean(dim=1).reshape(errors.shape)

        axs[j+2, 3*i].imshow(transferred_img)
        axs[j+2, 3*i].axis('off')
        axs[j+2, 3*i+1].imshow(transferred_colorwheel)
        axs[j+2, 3*i+1].axis('off')
        axs[j+2, 3*i+2].imshow(np.where(percentiles == 0, np.nan, percentiles), cmap='YlOrRd', interpolation='nearest')
        axs[j+2, 3*i+2].axis('off')

for i, pos in enumerate(['Source Image', 'Target Image\n(rotated 180°)'] + [' & '.join(pos) for pos in poss]):
    axs[i, 0].text(-0.1, 0.5, pos, va='center', ha='right', transform=axs[i, 0].transAxes)
plt.tight_layout()
