In [None]:
from sdhelper import SD
import numpy as np
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import PIL.ImageOps
import torch

In [None]:
sd = SD()

In [None]:
n=512
seed=42
step=50
img_paths = [
    ('cat_going_left.png', 'alligator in a street.jpg'),
    ('cat_going_left.png', 'picture of a cow standing on the field.jpg'),
    # ('split_cat_3.jpg', 'split_cat_4.jpg'),
    # ('split_cat_white_1.jpg', 'split_cat_white_2.jpg'),
]


imgs = [[PIL.Image.open(img_path).resize((n, n)) for img_path in entry] for entry in img_paths]
poss = [
    ['down_blocks[0]'],
    ['down_blocks[1]'],
    ['down_blocks[2]'],
    ['down_blocks[3]'],
    ['mid_block'],
    ['up_blocks[0]'],
    ['up_blocks[1]'],
    ['up_blocks[2]'],
    ['up_blocks[3]'],
]
reprs1 = [[sd.img2repr(img1, pos, step=step, seed=seed) for pos in poss] for img1, _ in imgs]
reprs2 = [[sd.img2repr(img2, pos, step=step, seed=seed) for pos in poss] for _, img2 in imgs]
# setup colorwheel
colorwheel = np.zeros((n, n, 3), dtype=np.uint8)
offset = n/2 + .5
for i in range(n):
    for j in range(n):
        angle = np.arctan2(i-offset, j-offset)
        dist = 1 - np.sqrt((i-offset)**2 + (j-offset)**2) / offset / np.sqrt(2)
        colorwheel[i, j, :] = np.array([.5+.5*np.sin(angle), .5+.5*np.sin(angle+np.pi/2), dist]) * 255
# setup figure
fig, axs = plt.subplots(len(poss)+2, len(imgs)*3, figsize=((len(imgs)*3)*3+1, (len(poss)+2)*3))
# plot original images
for i, (img1, img2) in enumerate(imgs):
    axs[0, 3*i].imshow(img1)
    axs[0, 3*i].axis('off')
    axs[0, 3*i].set_title('(Transferred) Image')
    axs[0, 3*i+1].imshow(colorwheel)
    axs[0, 3*i+1].axis('off')
    axs[0, 3*i+1].set_title('(Transferred) Colorwheel')
    axs[0, 3*i+2].axis('off')
    axs[0, 3*i+2].set_title('Error between transferred and target')
# plot flipped images
for i, (img1, img2) in enumerate(imgs):
    axs[1, 3*i].imshow(img2)
    axs[1, 3*i].axis('off')
    axs[1, 3*i+1].imshow(colorwheel[:, :])
    axs[1, 3*i+1].axis('off')
    axs[1, 3*i+2].axis('off')
# plot transferred images
for i, (img1, img2) in enumerate(imgs):
    for j, pos in enumerate(poss):
        img = np.array(img1)
        similarities = reprs1[i][j].cosine_similarity(reprs2[i][j])
        transferred_img = np.zeros_like(img)
        transferred_colorwheel = np.zeros_like(img)
        m = similarities.shape[0]
        s = n // m
        for k in range(m):
            for l in range(m):
                argmax = similarities[:,:,k, l].flatten().argmax()
                k_, l_ = argmax // m, argmax % m
                transferred_img[k*s:(k+1)*s, l*s:(l+1)*s] = img[k_*s:(k_+1)*s, l_*s:(l_+1)*s]
                transferred_colorwheel[k*s:(k+1)*s, l*s:(l+1)*s] = colorwheel[k_*s:(k_+1)*s, l_*s:(l_+1)*s]

        indices = similarities.view(-1, m, m).argmax(dim=0)
        k_, l_ = indices // m, indices % m
        k, l = torch.meshgrid(torch.arange(m), torch.arange(m), indexing='ij')
        errors = ((k - k_)**2 + (l - l_)**2)**.5
        all_distances = torch.cdist(*[torch.stack([k.flatten(), l.flatten()], dim=1).float()]*2)
        percentiles = (all_distances < errors.flatten().unsqueeze(1)).float().mean(dim=1).reshape(errors.shape)

        axs[j+2, 3*i].imshow(transferred_img)
        axs[j+2, 3*i].axis('off')
        axs[j+2, 3*i+1].imshow(transferred_colorwheel)
        axs[j+2, 3*i+1].axis('off')
        axs[j+2, 3*i+2].imshow(np.where(percentiles == 0, np.nan, percentiles), cmap='YlOrRd', interpolation='nearest')
        axs[j+2, 3*i+2].axis('off')

for i, pos in enumerate(['source image', 'target image'] + [' & '.join(pos) for pos in poss]):
    axs[i, 0].text(-0.1, 0.5, pos, va='center', ha='right', transform=axs[i, 0].transAxes)
plt.tight_layout()


In [None]:
raise Exception('This is the end')

---

# OLD

---

In [None]:
img1 = PIL.Image.open('cat_going_left.png')
# img2 = PIL.ImageOps.invert(img1)
img2 = PIL.Image.open('alligator in a street.jpg').resize(img1.size)
print(img1.size)

In [None]:
poss = [
    ['mid_block'],
    ['up_blocks[0]'],
    ['up_blocks[1]'],
    ['up_blocks[2]'],
]
reprs1 = [sd.img2repr(img1, pos, step=50, seed=42) for pos in poss]
reprs2 = [sd.img2repr(img2, pos, step=50, seed=42) for pos in poss]

In [None]:
plt.figure(figsize=(9, 4.5))
plt.suptitle('Original images')
plt.subplot(1, 2, 1)
plt.imshow(img1)
plt.axis('off')
plt.subplot(1, 2, 2)
plt.imshow(img2)
plt.axis('off')
plt.tight_layout()
plt.show()

for i, pos in enumerate(poss):
    r1 = reprs1[i]
    r2 = reprs2[i]
    sim = r1.cosine_similarity(r2)

    n = sim.shape[0]
    # calculate transferred images
    img_arr = np.array(img1)
    colorcircle = np.zeros_like(img_arr)
    for i in range(n):
        for j in range(n):
            s = img1.size[0] // n
            angle = np.arctan2(i - n/2 + .5, j - n/2 + .5)
            dist = 1 - np.sqrt((i - n/2 + .5)**2 + (j - n/2 + .5)**2) / (n/2) / np.sqrt(2)
            colorcircle[i*s:(i+1)*s, j*s:(j+1)*s] = np.array([.5+.5*np.sin(angle+0), .5+.5*np.sin(angle+np.pi/2), dist]) * 255
    transferred_img = np.zeros_like(img1)
    transferred_colorcircle = np.zeros_like(img1)
    for i in range(n):
        for j in range(n):
            s = img1.size[0] // n
            argmax = sim[:,:,i, j].flatten().argmax()
            k, l = argmax // n, argmax % n
            transferred_img[i*s:(i+1)*s, j*s:(j+1)*s] = img_arr[k*s:(k+1)*s, l*s:(l+1)*s]
            transferred_colorcircle[i*s:(i+1)*s, j*s:(j+1)*s] = colorcircle[k*s:(k+1)*s, l*s:(l+1)*s]
    dist_compared_to_flipped = np.zeros((n, n))
    for i in range(n):
        for j in range(n):
            argmax = sim[:,:,i, j].flatten().argmax()
            k, l = argmax // n, argmax % n
            dist_compared_to_flipped[i, j] = np.sqrt((i - k)**2 + (j - l)**2)

    # plot
    plt.figure(figsize=(9, 12))
    plt.suptitle(f'{pos} (shape {n}x{n})')
    plt.subplot(3, 2, 1)
    plt.imshow(img1)
    plt.axis('off')
    plt.subplot(3, 2, 2)
    plt.imshow(transferred_img)
    plt.axis('off')
    plt.subplot(3, 2, 3)
    plt.imshow(colorcircle)
    plt.axis('off')
    plt.subplot(3, 2, 4)
    plt.imshow(transferred_colorcircle)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    # plt.imshow(dist_compared_to_flipped, cmap='viridis', interpolation='nearest')
    # plt.axis('off')
    # plt.colorbar()
    # plt.tight_layout()
    # plt.show()