# Semantic Correspondence - embedding due to positional embedding

Semantic correspondence on SPair-71k and tests on the influence of the source kp position on failures

## Setup

In [1]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable
import random
from collections import defaultdict
from typing import Any

In [None]:
sd = SD('SD1.5')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
metadata_keys = ['src_bndbox','trg_bndbox','category','viewpoint_variation','scale_variation',]
metadata = [{'i': i, 'src_kp': src_kp, 'trg_kp': trg_kp, 'src_size': pair['src_img'].size, 'trg_size': pair['trg_img'].size} | {k: pair[k] for k in metadata_keys} for i, pair in enumerate(tqdm(pairs)) for src_kp, trg_kp in zip(pair['src_kps'], pair['trg_kps'])]

## Calculate semantic correspondence

In [4]:
def expand_and_resize(x: PIL.Image.Image, size, border_pad=True):
    n, m = x.size
    s = max(n, m)
    r = PIL.Image.new('RGB', (s, s))
    r.paste(x, ((s-n)//2, (s-m)//2))
    if border_pad:
        # pad with border
        if n > m:
            r.paste(x.crop((0, 0, n, 1)).resize((n,(s-m)//2)), (0, 0))
            r.paste(x.crop((0, m-1, n, m)).resize((n,(s-m)//2)), (0, m+(s-m)//2))
        elif m > n:
            r.paste(x.crop((0, 0, 1, m)).resize(((s-n)//2,m)), (0, 0))
            r.paste(x.crop((n-1, 0, n, m)).resize(((s-n)//2,m)), (n+(s-n)//2, 0))
    return r.resize((size, size))

In [None]:
images_expanded = [expand_and_resize(x['img'], 512, True) for x in tqdm(data, desc='Expanding images')]
prompts = [x['name'].split('/')[0] for x in data]
representations_raw = sd.img2repr(images_expanded, ['up_blocks[1]'], 50, seed=42)
representations = [r.apply(lambda x: x / torch.norm(x, dim=0, keepdim=True)) for r in tqdm(representations_raw, desc='Normalizing representations')]

In [None]:
# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
correct = []
positions = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = representations[x['src_data_index']].concat()
    b = representations[x['trg_data_index']].concat()
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct.append(relative_distance < 0.1)
        positions.append((x_max_pixel, y_max_pixel))
    if len(correct) % 100 == 0:
        t.set_postfix(pck=np.mean(correct)*100)

In [None]:
np.mean(correct)*100

## Evaluate

In [None]:
def plot_failures(f: Callable[['Any'], tuple[int,int]]):
    fail_distances = defaultdict(int)
    for tmp, (x,y) in zip(metadata, positions):
        x0, y0 = f(tmp)
        dist = np.hypot(x0-x, y0-y)
        fail_distances[dist**2//25] += 1
    fail_distances = sorted(fail_distances.items())
    x = []
    y = []
    for tmp, count in fail_distances:
        # if ((tmp+1)*25)**.5 < 20:
        #     print(f'until {((tmp+1)*25)**.5:2.0f} pixels: {count} keypoints')
        x.append(((tmp+1)*25)**.5)
        y.append(count/np.pi/25)
    plt.scatter(x, y, s=2, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Number of failed keypoints per pixel')

plt.xlim(0, 200)
plot_failures(lambda x: (x['src_kp'][0], x['src_kp'][1]))
plot_failures(lambda x: (random.randint(100,400), random.randint(100,300)))
plt.legend(['Source KP', 'Random KP'])
plt.title('Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
def plot_failures(f: Callable[[Any], tuple[int,int]], only_fails=False):
    fail_distances = defaultdict(int)
    for c, tmp, (x,y) in zip(correct, metadata, positions):
        if only_fails and c:
            continue
        x0, y0 = f(tmp)
        dist = np.hypot(x0-x, y0-y)
        fail_distances[dist**2//25] += 1
    fail_distances = sorted(fail_distances.items())
    num_fails = sum(val for key, val in fail_distances)
    x = []
    y = []
    for tmp, count in fail_distances:
        # if ((tmp+1)*25)**.5 < 20:
        #     print(f'until {((tmp+1)*25)**.5:2.0f} pixels: {count} keypoints')
        x.append(((tmp+1)*25)**.5)
        y.append(count/np.pi/25 / num_fails)
    plt.scatter(x, y, s=2, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Normalized number of failed keypoints per pixel')

plt.xlim(0, 200)
plot_failures(lambda x: (x['src_kp'][0], x['src_kp'][1]), only_fails=True)
plot_failures(lambda x: (x['src_kp'][0], x['src_kp'][1]))
# plot_failures(lambda x: (random.randint(100,400), random.randint(100,300)))
plt.legend(['Failed KP', 'All KP'])
plt.title('Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# failure distance
def plot_failures(f: Callable[[Any,Any,Any], float], only_fails=False):
    fail_distances = defaultdict(int)
    for c, tmp, (x,y) in zip(correct, metadata, positions):
        if only_fails and c:
            continue
        fail_distances[f(tmp, x, y)//1] += 1
    num_fails = sum(val for key, val in fail_distances.items())
    x = []
    y = []
    for tmp, count in sorted(fail_distances.items()):
        x.append((tmp+1))
        y.append(count/2 / num_fails)
    plt.scatter(x, np.array(y), s=3, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Normalized number of failed keypoints per row/column')

# x
plt.xlim(0, 200)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][0] - x), only_fails=True)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][0] - x))
plt.legend(['Failed KP', 'All KP'])
plt.title('X-Distance of failed keypoint guess to source keypoint')
plt.show()

# y
plt.xlim(0, 200)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][1] - y), only_fails=True)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][1] - y))
plt.legend(['Failed KP', 'All KP'])
plt.title('Y-Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# PCK over keypoint distance
n = 50
distances = {i: [] for i in range(n)}
counts = {i: 0 for i in range(n)}
for c, m, (x,y) in zip(correct, metadata, positions):
    max_size = max(max(m['src_size']), max(m['trg_size']))
    dist = np.hypot((m['src_kp'][0] - m['trg_kp'][0]) / max_size, (m['src_kp'][1] - m['trg_kp'][1]) / max_size)
    key = int(dist*n)
    if key > n-1:
        key = n-1
    distances[key].append(c)
    counts[key] += 1

plt.bar(counts.keys(), counts.values(), color='lightblue')
plt.ylabel('Number of keypoints', color='lightblue')
plt.xlabel('Distance between keypoints (normalized)')
plt.xticks(range(0,n,n//10), [f'{i/n:.1f}' for i in range(0,n,n//10)])
plt.twinx()
plt.plot(distances.keys(), [np.mean(d)*100 for d in distances.values()], 'red')
plt.ylabel('PCK@$0.1_{bbox}$ (%)', color='red')
plt.show()

### visualize random failures

In [None]:
# visualize random failures
i = random.randint(0, len(metadata))
while correct[i]:
    i = random.randint(0, len(metadata))
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pairs[metadata[i]['i']]['src_img'])
plt.scatter(metadata[i]['src_kp'][0], metadata[i]['src_kp'][1], c='g', marker='x', s=200)
plt.subplot(1, 2, 2)
plt.imshow(pairs[metadata[i]['i']]['trg_img'])
plt.scatter(metadata[i]['trg_kp'][0], metadata[i]['trg_kp'][1], c='g', marker='x', s=200)
plt.scatter(positions[i][0], positions[i][1], c='r', marker='x', s=200)
relative_distance = ((positions[i][0] - metadata[i]['trg_kp'][0])**2 + (positions[i][1] - metadata[i]['trg_kp'][1])**2) ** 0.5 / max(metadata[i]['trg_bndbox'][2] - metadata[i]['trg_bndbox'][0], metadata[i]['trg_bndbox'][3] - metadata[i]['trg_bndbox'][1])
plt.title(f'PCK@0.1_bbox: {relative_distance:.2f}')
plt.show()