In [1]:
from sdhelper import SD
import torch
import torch.nn.functional as F
import numpy as np
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import random
from PIL import Image, ImageDraw, ImageOps
import json
import pickle
import pandas as pd
from pathlib import Path

In [2]:
image_names = [
    'cutouts/cat_going_left_cutout',
    'cutouts/cat_going_right_cutout',
    'cutouts/cow_cutout',
    'cutouts/dog_cutout',
]

num_shifts = 16
background_size = 512
foreground_size = 128

cutout_images = [PIL.Image.open(f"{name}.png") for name in image_names]

background_images = [
    PIL.Image.new('RGB', (512, 512), color=(255, 255, 255)),
]

labels = [
    'left ear',
    'right ear',
    'left eye',
    'right eye',
    'nose',
    'left front paw',
    'right front paw',
    'left rear paw',
    'right rear paw',
]

raw_positions = [(json.load(open(f"{name}.json"))+[None]*len(labels))[:len(labels)] for name in image_names]


foregrounds = [
    ImageOps.pad(img, (foreground_size, foreground_size), color=(0,0,0,0))
    for img in cutout_images
]

backgrounds = [
    ImageOps.fit(img, (background_size, background_size))
    for img in background_images
]


In [None]:
combinations: list[tuple] = []
for i, foreground in enumerate(foregrounds):
    for j, background in enumerate(backgrounds):
        for k in range(num_shifts):
            for l in range(num_shifts):
                shift_offset = num_shifts // 2 - foreground_size // 2
                x = k * background_size // num_shifts + shift_offset
                y = l * background_size // num_shifts + shift_offset
                composite = background.copy()
                composite.paste(foreground, (x, y), foreground)
                kps = []
                for m, label in enumerate(labels):
                    pos = raw_positions[i][m]
                    if pos is not  None: 
                        cutout_size = cutout_images[i].size
                        kp_x = (pos['x'] + (max(cutout_size) - cutout_size[0]) / 2) * foreground_size / max(cutout_size) + x
                        kp_y = (pos['y'] + (max(cutout_size) - cutout_size[1]) / 2) * foreground_size / max(cutout_size) + y
                        if 0 <= kp_x < background_size and 0 <= kp_y < background_size:
                            kps.append((kp_x, kp_y))
                        else:
                            kps.append(None)
                    else:
                        kps.append(None)
                combinations.append((
                    (i,j,k,l,False),
                    composite,
                    kps,
                ))
                # append mirrored
                combinations.append((
                    (i,j,k,l,True),
                    ImageOps.mirror(composite),
                    [None if k is None else (background_size-k[0], k[1]) for k in kps],
                ))

len(combinations)

In [None]:
i = random.randint(0, len(combinations)-1)
plt.imshow(combinations[i][1])
kps = np.array([k for k in combinations[i][2] if k is not None])
plt.scatter(kps[:,0], kps[:,1], c='red')
plt.axis('off')
plt.show()

In [None]:
sd = SD()

In [None]:
representations = sd.img2repr([x[1] for x in combinations], extract_positions=['up_blocks[1]'], step=50, seed=42, batch_size=20)


In [None]:
combinations_shuffled = random.sample(combinations, len(combinations))
print(len(combinations_shuffled))

In [None]:
# count all valid pairs
count = 0
for i in tqdm(random.sample(range(len(combinations)), len(combinations))):
    (_,_,_,_,mirrored1), _, kps1 = combinations[i]
    for j in random.sample(range(len(combinations)), len(combinations)):
        (_,_,_,_,mirrored2), _, kps2 = combinations[j]
        if mirrored1 != mirrored2 or i == j: continue  # ignore different mirrorization and same image comparisons
        kp1_idxs = [idx for idx, kp1 in enumerate(kps1) if kp1 is not None]
        if len(kp1_idxs) == 0: continue
        for idx, kp1 in zip(kp1_idxs, kps1):
            kp2 = kps2[idx]
            if kp2 is None: continue
            count += 1

print(count)

In [None]:
def get_random_plot_data():
    while True:
        i = random.randint(0, len(combinations)-1)
        j = random.randint(0, len(combinations)-1)
        (_,_,_,_,mirrored1), img1, kps1 = combinations[i]
        (_,_,_,_,mirrored2), img2, kps2 = combinations[j]
        kps = [(kp1, kp2) for kp1, kp2 in zip(kps1, kps2) if kp1 is not None and kp2 is not None]
        if mirrored1 != mirrored2 or i == j or len(kps) == 0: continue
        else: break
    kp1, kp2 = random.choice(kps)
    r1 = representations[i].concat()
    r2 = representations[j].concat()
    plot_data = {
        'r1': r1,
        'r2': r2,
        'img1': img1,
        'img2': img2,
        'kp1': kp1,
        'kp2': kp2,
    }
    return plot_data

def plot_example(plot_data):
    r1, r2, img1, img2, kp1, kp2 = plot_data.values()
    n = r1.shape[1]
    kp1_x = int(kp1[0]) * n // background_size
    kp1_y = int(kp1[1]) * n // background_size
    best_idx = F.cosine_similarity(r1[:,kp1_y,kp1_x,None,None], r2[:,:,:], dim=0).flatten().argmax()
    y_trg_base, x_trg_base = np.unravel_index(best_idx, (n,n))
    x_trg = (x_trg_base + 0.5) * background_size // n
    y_trg = (y_trg_base + 0.5) * background_size // n
    fig, axs = plt.subplots(1,3, figsize=(9,3))
    axs[0].imshow(img1)
    axs[0].set_title('source image')
    axs[0].scatter(*kp1, c='red', marker='x', alpha=0.8, label='source keypoint')
    axs[0].set_xticks([])
    axs[0].set_yticks([])
    axs[0].legend()
    axs[1].imshow(img2)
    axs[1].set_title('target image')
    axs[1].scatter(*kp2, c='red', marker='x', alpha=0.8, label='target keypoint')
    axs[1].scatter(x_trg, y_trg, c='green', marker='+', alpha=0.8, label='predicted target')
    axs[1].legend()
    axs[1].set_xticks([])
    axs[1].set_yticks([])
    axs[2].imshow(F.cosine_similarity(r1[:,kp1_y,kp1_x,None,None], r2[:,:,:], dim=0).cpu().view(n,n))
    axs[2].set_title('target similiarities')
    axs[2].axis('off')
    plt.show()

plot_data = get_random_plot_data()
pickle.dump(plot_data, open('sc_errors_over_position_example_plot_data.pkl', 'wb'))
plot_example(plot_data)

In [None]:
def plot_example():
    for i in tqdm(random.sample(range(len(combinations)), len(combinations))):
        (_,_,_,_,mirrored1), _, kps1 = combinations[i]
        r1 = representations[i].concat().cuda()
        n = r1.shape[1]
        for j in random.sample(range(len(combinations)), len(combinations)):
            (_,_,_,_,mirrored2), _, kps2 = combinations[j]
            if mirrored1 != mirrored2 or i == j: continue  # ignore different mirrorization and same image comparisons
            r2 = representations[j].concat().cuda()
            kp1_idxs = [idx for idx, kp1 in enumerate(kps1) if kp1 is not None]
            x_src_bases = [int(kps1[i][0]) * n // background_size for i in kp1_idxs]
            y_src_bases = [int(kps1[i][1]) * n // background_size for i in kp1_idxs]
            if len(x_src_bases) == 0: continue
            best_indices = F.cosine_similarity(r1[:,y_src_bases,x_src_bases,None,None], r2[:,None,:,:], dim=0).flatten(1,2).argmax(dim=1).cpu()
            for idx, x_src_base, y_src_base, best_idx in zip(kp1_idxs, x_src_bases, y_src_bases, best_indices):
                kp1, kp2 = kps1[idx], kps2[idx]
                if kp2 is None: continue
                y_trg_base, x_trg_base = np.unravel_index(best_idx, (n,n))
                x_trg = (x_trg_base + 0.5) * background_size // n
                y_trg = (y_trg_base + 0.5) * background_size // n
                if random.random() < 0.01:
                    fig, axs = plt.subplots(1,4, figsize=(16,4))
                    axs[0].imshow(combinations[i][1])
                    print(kp1, kp2)
                    axs[0].scatter(*kp1, c='red')
                    axs[0].set_title('source')
                    axs[0].axis('off')
                    axs[1].imshow(combinations[j][1])
                    axs[1].scatter(*kp2, c='red')
                    axs[1].set_title('target')
                    axs[1].scatter(x_trg, y_trg, c='green')
                    axs[1].axis('off')
                    axs[2].imshow(r1.norm(dim=0).cpu())
                    axs[2].scatter(x_src_base, y_src_base, c='red')
                    axs[2].set_title('source norm')
                    axs[2].axis('off')
                    axs[3].imshow(F.cosine_similarity(r1[:,y_src_base,x_src_base,None,None], r2, dim=0).cpu().view(n,n))
                    axs[3].scatter(x_trg_base, y_trg_base, c='green')
                    axs[3].set_title('target similiarities')
                    axs[3].axis('off')
                    plt.show()
                    return

plot_example()


In [None]:
path_name = Path(f'sc_predictions_artificial_position_dataset_with_mirror_{sd.model_name}.pkl')
if path_name.exists():
    predictions = pickle.load(open(path_name, 'rb'))
else:
    predictions_list = []
    try:
        for i in tqdm(random.sample(range(len(combinations)), len(combinations))):
            (_,_,_,_,mirrored1), _, kps1 = combinations[i]
            r1 = representations[i].concat().cuda()
            n = r1.shape[1]
            for j in random.sample(range(len(combinations)), len(combinations)):
                (_,_,_,_,mirrored2), _, kps2 = combinations[j]
                if mirrored1 != mirrored2 or i == j: continue  # ignore different mirrorization and same image comparisons
                r2 = representations[j].concat().cuda()
                kp1_idxs = [idx for idx, kp1 in enumerate(kps1) if kp1 is not None]
                x_src_bases = [int(kps1[i][0]) * n // background_size for i in kp1_idxs]
                y_src_bases = [int(kps1[i][1]) * n // background_size for i in kp1_idxs]
                if len(x_src_bases) == 0: continue
                best_indices = F.cosine_similarity(r1[:,y_src_bases,x_src_bases,None,None], r2[:,None,:,:], dim=0).flatten(1,2).argmax(dim=1).cpu()
                for idx, x_src_base, y_src_base, best_idx in zip(kp1_idxs, x_src_bases, y_src_bases, best_indices):
                    kp1, kp2 = kps1[idx], kps2[idx]
                    if kp2 is None: continue
                    y_trg_base, x_trg_base = np.unravel_index(best_idx, (n,n))
                    x_trg = (x_trg_base + 0.5) * background_size // n
                    y_trg = (y_trg_base + 0.5) * background_size // n
                    distance = ((x_trg - kp2[0])**2 + (y_trg - kp2[1])**2)**0.5
                    correct = distance < background_size / n
                    predictions_list.append((i, j, *kp1, *kp2, x_trg, y_trg, distance, correct))
    except KeyboardInterrupt:
        pass
    finally:
        predictions = pd.DataFrame(predictions_list, columns=['i','j','x1','y1','x2','y2','x_trg','y_trg','distance','correct'])
        pickle.dump(predictions, open(path_name, 'wb'))


In [None]:
print(f'Mean distance: {predictions.distance.mean():.2f}')
print(f'Std distance: {predictions.distance.std():.2f}')
print(f'Mean correct: {predictions.correct.mean():.2f}')
print(f'Num correct: {sum(predictions.correct)}')



In [None]:
# plot random example prediction

rand = random.randint(0, len(predictions)-1)
p = predictions.iloc[rand]
fig, axs = plt.subplots(1,2, figsize=(8,4))
axs[0].imshow(combinations[p.i][1])
axs[0].scatter(p.x1, p.y1, c='red')
axs[0].set_title('source')
axs[0].axis('off')
axs[1].imshow(combinations[p.j][1])
axs[1].scatter(p.x2, p.y2, c='red')
axs[1].set_title('target')
axs[1].scatter(p.x_trg, p.y_trg, c='green')
axs[1].axis('off')
plt.show()

print(f'distance: {predictions.distance.iloc[rand]:.2f}')

plt.hist(predictions.distance, bins=100)
plt.title('distance histogram')
plt.show()



In [None]:
# errors over position

xs = []
ys = []
for p in tqdm(predictions.itertuples(), total=len(predictions)):
    if not p.correct:
        xs.append(p.x_trg)
        ys.append(p.y_trg)

fig, ax = plt.subplots(figsize=(10,8))
ax.hist2d(xs, ys, bins=32, cmap='viridis', range=[[0,512],[0,512]])
ax.scatter(0,0, c='red', marker='+', s=100)
ax.set_aspect('equal')

plt.colorbar(ax.collections[0], label='count')
plt.title('error over position')
plt.xlabel('x')
plt.ylabel('y')
plt.show()


In [None]:
# error relative to random keypoint

xs = []
ys = []
for p in tqdm(predictions.itertuples(), total=len(predictions)):
    if not p.correct:
        xs.append(random.randint(0,512)-p.x_trg)
        ys.append(random.randint(0,512)-p.y_trg)

fig, ax = plt.subplots(figsize=(10,8))
ax.hist2d(xs, ys, bins=32, cmap='viridis', range=[[-512,512],[-512,512]])
ax.scatter(0,0, c='red', marker='+', s=100)
ax.set_aspect('equal')

plt.colorbar(ax.collections[0])
plt.title('error over position relative to random keypoint')
plt.show()


# Plot mean error

In [None]:
# test

w = 64
k, l = torch.meshgrid(torch.arange(w)*512/w, torch.arange(w)*512/w, indexing='ij')
all_distances = torch.cdist(*[torch.stack([k.flatten(), l.flatten()], dim=1).float()]*2).reshape(w,w,-1)

test_distances = torch.zeros((w*2, w*2))
test_counts = torch.zeros((w*2, w*2))
test_error_counts = torch.zeros((w*2, w*2))
for p in tqdm(predictions.itertuples(), total=len(predictions)):
    dx = p.x1-p.x_trg
    dy = p.y1-p.y_trg
    test_distances[int(dx*w//512+w), int(dy*w//512+w)] += p.distance
    test_counts[int(dx*w//512+w), int(dy*w//512+w)] += 1
    test_error_counts[int(dx*w//512+w), int(dy*w//512+w)] += not p.correct




In [15]:
# save for thesis

torch.save(test_distances, f'sc_errors_by_position_distances_{sd.model_name}.pt')
torch.save(test_counts, f'sc_errors_by_position_counts_{sd.model_name}.pt')
torch.save(test_error_counts, f'sc_errors_by_position_error_counts_{sd.model_name}.pt')


In [None]:
fig, ax = plt.subplots(figsize=(12,10))
tmp = torch.full_like(test_distances, np.nan)
tmp[15:-15, 15:-15] = (test_distances / test_counts)[15:-15, 15:-15]
tmp = tmp.reshape(2*w//4,4,2*w//4,4).mean(dim=(1,3))
ax.imshow(tmp)
ax.set_aspect('equal')
ax.set_xlabel('x')
ax.set_ylabel('y')
ticks = np.linspace(-512, 512, 9)
tick_positions = np.linspace(0, len(tmp), 9)
ax.set_xticks(tick_positions)
ax.set_xticklabels([str(int(x)) for x in ticks])
ax.set_yticks(tick_positions)
ax.set_yticklabels([str(int(y)) for y in ticks])


ax.set_title('error over position relative to source keypoint')
plt.colorbar(ax.images[0], label='mean distance')
plt.show()


In [None]:
fig, ax = plt.subplots(figsize=(12,10))
tmp = torch.full_like(test_error_counts, np.nan)
tmp[15:-15, 15:-15] = (test_error_counts / test_counts)[15:-15, 15:-15]
tmp = tmp.reshape(2*w//4,4,2*w//4,4).mean(dim=(1,3))
ax.imshow(tmp)
ax.set_aspect('equal')
ax.set_xlabel('x')
ax.set_ylabel('y')
ticks = np.linspace(-512, 512, 9)
tick_positions = np.linspace(0, len(tmp), 9)
ax.set_xticks(tick_positions)
ax.set_xticklabels([str(int(x)) for x in ticks])
ax.set_yticks(tick_positions)
ax.set_yticklabels([str(int(y)) for y in ticks])


ax.set_title('error over position relative to source keypoint')
plt.colorbar(ax.images[0], label='error rate')
plt.show()


In [None]:
import matplotlib.colors
fig, ax = plt.subplots(figsize=(12,10))
ax.imshow(test_counts, norm=matplotlib.colors.LogNorm())
ax.set_aspect('equal')
ax.set_xlabel('x')
ax.set_ylabel('y')
ticks = np.linspace(-512, 512, 9)
tick_positions = np.linspace(0, len(tmp), 9)
ax.set_xticks(tick_positions)
ax.set_xticklabels([str(int(x)) for x in ticks])
ax.set_yticks(tick_positions)
ax.set_yticklabels([str(int(y)) for y in ticks])


ax.set_title('error over position relative to source keypoint')
plt.colorbar(ax.images[0], label='mean distance')
plt.show()
