# Semantic Correspondence

Simple semantic correspondence with visualisations based on metadata

## Setup

In [1]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable
import random
from collections import defaultdict
from typing import Any

In [None]:
sd = SD('SD15')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
metadata_keys = ['src_bndbox','trg_bndbox','category','viewpoint_variation','scale_variation',]
metadata = [{'i': i, 'src_kp': src_kp, 'trg_kp': trg_kp, 'src_size': pair['src_img'].size, 'trg_size': pair['trg_img'].size} | {k: pair[k] for k in metadata_keys} for i, pair in enumerate(tqdm(pairs)) for src_kp, trg_kp in zip(pair['src_kps'], pair['trg_kps'])]

## Calculate semantic correspondence

In [None]:
transform_img = lambda img: img.resize(tuple(np.array(img.size) * 512 // max(img.size)))

# precalculate representations
representations = []
for x in tqdm(data, desc='Calculating representations'):
    prompt = '' #'a photo of a ' + x['name'].split('/')[0]
    r = sd.img2repr(transform_img(x['img']), ['up_blocks[1]'], 50, prompt=prompt, seed=42)
    r = r.apply(lambda x: x / torch.norm(x, dim=0, keepdim=True))  # normalize
    representations.append(r)

In [None]:
# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
correct = []
positions = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = representations[x['src_data_index']].concat()
    b = representations[x['trg_data_index']].concat()
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, sy*a.shape[1]//x['src_img'].size[1], sx*a.shape[2]//x['src_img'].size[0]]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = (x_max+.5) * x['trg_img'].size[0] / b.shape[2]
        y_max_pixel = (y_max+.5) * x['trg_img'].size[1] / b.shape[1]
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct.append(relative_distance < 0.1)
        positions.append((x_max_pixel, y_max_pixel))
    if len(correct) % 100 == 0:
        t.set_postfix(pck=np.mean(correct)*100)

In [None]:
np.mean(correct)*100

## Visualise semantic correspondence

In [None]:
# plot PCK over category
def plot_pck_over_category(correct, metadata):
    categories = list(set(x['category'] for x in metadata))
    bins = [[] for _ in categories]
    for category, c in zip([x['category'] for x in metadata], correct):
        bins[categories.index(category)].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(len(categories)), pcks)
    plt.xticks(range(len(categories)), pairs.features['category_id'].names, rotation=90)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom', fontsize=8)
        plt.text(i+0.07, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white', fontsize=8)
    plt.xlabel('Category')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over category')
    plt.show()

plot_pck_over_category(correct, metadata)

In [None]:
# plot PCK over viewpoint variation
def plot_pck_over_viewpoint_variation(correct, metadata):
    bins = [[], [], []]
    for viewpoint_variation, c in zip([x['viewpoint_variation'] for x in metadata], correct):
        bins[viewpoint_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Viewpoint variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over viewpoint variation')
    plt.show()

plot_pck_over_viewpoint_variation(correct, metadata)

In [None]:
# plot PCK over bounding box scale variation
def plot_pck_over_scale_variation(correct, metadata):
    bins = [[], [], []]
    for scale_variation, c in zip([x['scale_variation'] for x in metadata], correct):
        bins[scale_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Scale variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over scale variation')
    plt.show()

plot_pck_over_scale_variation(correct, metadata)

In [None]:
# plot PCK over src and trg bounding box size
def plot_pck_over_bbox_size(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    for name, sizes in [('src', src_sizes), ('trg', trg_sizes)]:
        bins = [[] for _ in range(10)]
        min_size = sizes.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
        max_size = sizes.max() ** .5 + 1
        for size, c in zip(sizes, correct):
            idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(10), pcks)
        plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box size')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box size')
        plt.show()

plot_pck_over_bbox_size(correct, metadata)

In [None]:
# plot PCK over difference in bounding box sizes
def plot_pck_over_bbox_size_diff(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    diffs = np.abs(src_sizes - trg_sizes)
    bins = [[] for _ in range(10)]
    min_size = diffs.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
    max_size = diffs.max() ** .5 + 1
    for size, c in zip(diffs, correct):
        idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
        bins[idx].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(10), pcks)
    plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Difference in bounding box sizes')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over difference in bounding box sizes')
    plt.show()

plot_pck_over_bbox_size_diff(correct, metadata)

In [None]:
# plot PCK over bouding box aspect ratio
def plot_pck_over_bbox_aspect_ratio(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_ratios = (src_shapes[:,2] - src_shapes[:,0]) / (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_ratios = (trg_shapes[:,2] - trg_shapes[:,0]) / (trg_shapes[:,3] - trg_shapes[:,1])
    for name, ratios in [('src', src_ratios), ('trg', trg_ratios)]:
        bin_names = ['< 1:4', '1:4 - 1:2', '1:2 - 1:1', '1:1 - 2:1', '2:1 - 4:1', '> 4:1']
        bins = [[] for _ in bin_names]
        min_ratio = ratios.min()
        max_ratio = ratios.max()+1
        for ratio, c in zip(ratios, correct):
            idx = int(np.log2(ratio) + 2)
            if idx < 0: idx = 0
            if idx >= len(bins): idx = len(bins)-1
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(len(bin_names)), pcks)
        plt.xticks(range(len(bin_names)), bin_names, rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box aspect ratio')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box aspect ratio')
        plt.show()

plot_pck_over_bbox_aspect_ratio(correct, metadata)

In [None]:
# plot PCK over src and trg keypoint positions
def plot_pck_over_kp_position(correct, metadata, scale_factor=2):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        min_x = kps[:,0].min() // scale_factor
        max_x = kps[:,0].max() // scale_factor
        min_y = kps[:,1].min() // scale_factor
        max_y = kps[:,1].max() // scale_factor
        matrix = np.zeros((max_y - min_y + 1, max_x - min_x + 1, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y//scale_factor-min_y, x//scale_factor-min_x, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]

        # plot count
        plt.figure()
        plt.imshow(matrix[:,:,1], cmap='viridis', interpolation='nearest')
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'Count of keypoints over {name} keypoint position')
        plt.show()

        # plot pck matrix
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 20)

In [None]:
# plot PCK over src and trg keypoint positions
# might be useful to see if certain pixels in the regions are ignored
def plot_pck_over_kp_position(correct, metadata, size=32):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        matrix = np.zeros((size, size, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y%size, x%size, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x (mod {size})')
        plt.ylabel(f'{name} keypoint y (mod {size})')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 16)

In [None]:
def plot_failures(get_distance: Callable[[bool,Any,int,int], int|None], normalize_values: Callable[[int], float], normalize_distance: Callable[[int], float] = lambda x: x):
    count_over_dist = defaultdict(int)
    for c, m, (x,y) in zip(correct, metadata, positions):
        dist = get_distance(c, m, x, y)
        if dist is None: continue
        count_over_dist[dist] += 1
    counts_total = sum(val for key, val in count_over_dist.items())
    x = []
    y = []
    for distance, count in sorted(count_over_dist.items()):
        y.append(normalize_values(count) / counts_total)
        x.append(normalize_distance(distance+1))
    plt.scatter(x, y, s=2, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Fraction of failed keypoints per pixel')

# both
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: np.hypot(m['src_kp'][0]-x, m['src_kp'][1]-y)**2//25 if not c else None, lambda x: x/np.pi/25, lambda x: (x*25)**.5)
plot_failures(lambda c, m, x, y: np.hypot(m['src_kp'][0]-x, m['src_kp'][1]-y)**2//25, lambda x: x/np.pi/25, lambda x: (x*25)**.5)
plt.legend(['Failed KP', 'All KP'])
plt.title('Distance of failed keypoint guess to source keypoint')
plt.show()

# x
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][0] - x)//1 if not c else None, lambda x: x/2)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][0] - x)//1, lambda x: x/2)
plt.legend(['Failed KP', 'All KP'])
plt.title('X-Distance of failed keypoint guess to source keypoint')
plt.show()

# y
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][1] - y)//1 if not c else None, lambda x: x/2)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][1] - y)//1, lambda x: x/2)
plt.legend(['Failed KP', 'All KP'])
plt.title('Y-Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# PCK over keypoint distance
n = 50
distances = {i: [] for i in range(n)}
counts = {i: 0 for i in range(n)}
for c, m, (x,y) in zip(correct, metadata, positions):
    max_size = max(max(m['src_size']), max(m['trg_size']))
    distance = np.hypot((m['src_kp'][0] - m['trg_kp'][0]) / max_size, (m['src_kp'][1] - m['trg_kp'][1]) / max_size)
    key = int(distance*n)
    if key > n-1:
        key = n-1
    distances[key].append(c)
    counts[key] += 1

plt.bar(counts.keys(), counts.values(), color='lightblue')
plt.ylabel('Number of keypoints', color='lightblue')
plt.xlabel('Distance between keypoints (normalized)')
plt.xticks(range(0,n,n//10), [f'{i/n:.1f}' for i in range(0,n,n//10)])
plt.twinx()
plt.plot(distances.keys(), [np.mean(d)*100 for d in distances.values()], 'red')
plt.ylabel('PCK@$0.1_{bbox}$ (%)', color='red')
plt.show()

In [None]:
# correlate representation magnitude with PCK

magnitudes = [r.concat().abs().mean(0) for r in tqdm(representations)]

def get_magnitues(m, p, c):
    src_mag = magnitudes[pairs[m['i']]['src_data_index']]
    src_x = m['src_kp'][0] * src_mag.shape[1] // m['src_size'][0]
    src_y = m['src_kp'][1] * src_mag.shape[0] // m['src_size'][1]
    src_mag_single = src_mag[src_y, src_x]
    trg_mag = magnitudes[pairs[m['i']]['trg_data_index']]
    trg_x = m['trg_kp'][0] * trg_mag.shape[1] // m['trg_size'][0]
    trg_y = m['trg_kp'][1] * trg_mag.shape[0] // m['trg_size'][1]
    trg_mag_single = trg_mag[trg_y, trg_x]
    pred_x = int(p[0] * trg_mag.shape[1] // m['trg_size'][0])
    pred_y = int(p[1] * trg_mag.shape[0] // m['trg_size'][1])
    pred_mag_single = trg_mag[pred_y, pred_x]
    return src_mag_single, trg_mag_single, pred_mag_single, c

def get_all_magnitudes():
    from concurrent.futures import ProcessPoolExecutor, as_completed
    import multiprocessing
    cpu_count = multiprocessing.cpu_count()
    print(f'Using {cpu_count} cores')
    with ProcessPoolExecutor(max_workers=cpu_count) as executor:
        futures = [executor.submit(get_magnitues, *args) for args in zip(metadata, positions, correct)]
        results = []
        for future in tqdm(as_completed(futures), total=len(futures)):
            results.append(future.result())
    return np.array(results)

magnitudes_per_keypoint = get_all_magnitudes()

In [None]:
# magnitude-pck correlation
print(f'Correlation between source keypoint magnitude and PCK: {np.corrcoef(magnitudes_per_keypoint[:,0], magnitudes_per_keypoint[:,3])[0,1]}')
print(f'Correlation between target keypoint magnitude and PCK: {np.corrcoef(magnitudes_per_keypoint[:,1], magnitudes_per_keypoint[:,3])[0,1]}')
print(f'Correlation between predicted keypoint magnitude and PCK: {np.corrcoef(magnitudes_per_keypoint[:,2], magnitudes_per_keypoint[:,3])[0,1]}')

In [None]:
def plot_pck_over_magnitudes(magnitudes, name, num_bars=20):
    min = magnitudes.min()
    max = magnitudes.max()
    bins = np.linspace(min, max, num_bars)
    counts = np.zeros(num_bars)
    pcks = np.zeros(num_bars)
    for src, c in zip(magnitudes, correct):
        idx = int(num_bars * (src - min) / (max - min))
        if idx == num_bars: idx = num_bars - 1
        counts[idx] += 1
        pcks[idx] += c
    pcks /= counts + 1e-6
    plt.bar(bins, pcks, width=(max-min)/num_bars)
    for i, pck in enumerate(pcks):
        # plt.text(bins[i], pck, f'{pck:.2f}'[1:], ha='center', va='bottom')
        plt.text(bins[i], 0.0, f' {counts[i]: 6.0f}', ha='center', va='bottom', rotation=90, color='white', fontfamily='monospace')
    plt.xlabel(f'Mean abs. {name} KP Repr. Magnitude')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over representation magnitude')
    plt.show()

plot_pck_over_magnitudes(magnitudes_per_keypoint[:,0], 'Source')
plot_pck_over_magnitudes(magnitudes_per_keypoint[:,1], 'Target')
plot_pck_over_magnitudes(magnitudes_per_keypoint[:,2], 'Predicted')

### visualize random failures

In [None]:
# visualize random failures
i = random.randint(0, len(metadata))
while correct[i]:
    i = random.randint(0, len(metadata))
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pairs[metadata[i]['i']]['src_img'])
plt.scatter(metadata[i]['src_kp'][0], metadata[i]['src_kp'][1], c='g', marker='x', s=200)
plt.subplot(1, 2, 2)
plt.imshow(pairs[metadata[i]['i']]['trg_img'])
plt.scatter(metadata[i]['trg_kp'][0], metadata[i]['trg_kp'][1], c='g', marker='x', s=200)
plt.scatter(positions[i][0], positions[i][1], c='r', marker='x', s=200)
relative_distance = ((positions[i][0] - metadata[i]['trg_kp'][0])**2 + (positions[i][1] - metadata[i]['trg_kp'][1])**2) ** 0.5 / max(metadata[i]['trg_bndbox'][2] - metadata[i]['trg_bndbox'][0], metadata[i]['trg_bndbox'][3] - metadata[i]['trg_bndbox'][1])
plt.title(f'PCK@0.1_bbox: {relative_distance:.2f}')
plt.show()