# Semantic Correspondence

semantic correspondence with degradations on the target image, like color changes, texture overlay, blurring, etc.

## Setup

In [1]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable
import random
from collections import defaultdict
from typing import Any
import torch.nn.functional as F

In [7]:
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
metadata_keys = ['src_bndbox','trg_bndbox','category','viewpoint_variation','scale_variation',]
metadata = [{'i': i, 'src_kp': src_kp, 'trg_kp': trg_kp, 'src_size': pair['src_img'].size, 'trg_size': pair['trg_img'].size} | {k: pair[k] for k in metadata_keys} for i, pair in enumerate(tqdm(pairs)) for src_kp, trg_kp in zip(pair['src_kps'], pair['trg_kps'])]

## Calculate semantic correspondence

In [26]:
sd = SD('SD1.5', disable_progress_bar=True)
all_blocks = sd.available_extract_positions

In [10]:
interpolation_steps = np.linspace(0, 1, 5)
correct = {block: [[] for _ in interpolation_steps] for block in sd.available_extract_positions}
positions = {block: [[] for _ in interpolation_steps] for block in sd.available_extract_positions}

In [None]:
blocks = sd.available_extract_positions[:6]
blocks

In [None]:
def rgb2gbr(arr: np.ndarray, interpolation: float):
    return arr[:,:,::-1] * interpolation + arr * (1-interpolation)

grass_image = np.array(PIL.Image.open('../data/grass.png').convert('RGB'))
def add_grass(img: np.ndarray, interpolation: float):
    tmp = 1 - interpolation/2
    return img * tmp + grass_image[:img.shape[0],:img.shape[1],:] * (1-tmp)

def edges_only(arr: np.ndarray, interpolation: float):
    gray_tensor = torch.from_numpy(np.dot(arr[...,:3], [0.2989, 0.5870, 0.1140]))[None,None,:,:]  # convert to grayscale
    
    # Sobel kernels
    kx = torch.tensor([[[[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]]]], dtype=torch.float64)
    ky = torch.tensor([[[[-1, -2, -1], [0, 0, 0], [1, 2, 1]]]], dtype=torch.float64)
    
    dx = F.conv2d(gray_tensor, kx, padding=1)
    dy = F.conv2d(gray_tensor, ky, padding=1)
    
    edges = torch.sqrt(dx**2 + dy**2)[0,0].numpy()
    edges **= .5  # increase sensitivity
    edges = (edges / edges.max() * 255).astype(np.uint8)  # normalize
    return np.stack([edges] * 3, axis=-1) * interpolation + arr * (1-interpolation)

def blur(arr: np.ndarray, interpolation: float):
    return np.array(PIL.Image.fromarray(arr).filter(PIL.ImageFilter.GaussianBlur(radius=16 * interpolation)))

conversion_function = blur

# precalculate representations
transform_img = lambda img: img.resize(tuple(np.array(img.size) * 512 // max(img.size)))
representations = [[] for _ in interpolation_steps]
for x in tqdm(data, desc='Calculating representations'):
    prompt = ''# x['name'].split('/')[0]
    img = np.array(transform_img(x['img']))
    img = [conversion_function(img, i).astype(np.uint8) for i in interpolation_steps]
    rs = sd.img2repr(img, blocks, 50, prompt=prompt, seed=42)
    rs = [r.apply(lambda x: x / torch.norm(x, dim=0, keepdim=True)) for r in rs]  # normalize
    for i, r in enumerate(rs):
        representations[i].append(r)

In [22]:
del sd
torch.cuda.empty_cache()

In [None]:
# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)

for block in tqdm(blocks, desc='Blocks'):
    representations_concat = [[r.at(block).concat().cuda() for r in rs] for rs in representations]
    for x in (t:=tqdm(pairs, desc=f'Calculating SC for {block}')):
        if x['src_data_index'] >= len(representations_concat[0]): continue
        a = representations_concat[0][x['src_data_index']]
        tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
        sxs = [sx*a.shape[2]//x['src_img'].size[0] for sx, sy in x['src_kps']]
        sys = [sy*a.shape[1]//x['src_img'].size[1] for sx, sy in x['src_kps']]
        for j, modified_repr in enumerate(representations_concat):
            if x['trg_data_index'] >= len(modified_repr): continue
            b = modified_repr[x['trg_data_index']]
            argmaxes = (b[:,None,:,:] * a[:, sys, sxs, None, None]).sum(0).flatten(1).argmax(1).cpu()
            for (argmax,[tx,ty]) in zip(argmaxes, x['trg_kps']):
                y_max, x_max = np.unravel_index(argmax, b.shape[1:])
                x_max_pixel = x_max * x['trg_img'].size[0] / b.shape[2]
                y_max_pixel = y_max * x['trg_img'].size[1] / b.shape[1]
                relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
                correct[block][j].append(relative_distance < 0.1)
                positions[block][j].append((x_max_pixel, y_max_pixel))

    del representations_concat
    torch.cuda.empty_cache()

In [None]:
for block, values in correct.items():
    print(block, np.mean(values, axis=1))


In [27]:
pcks = np.array([np.mean(x, axis=1) for x in correct.values()])
np.save(f'sc_pck_blur_on_trg_{sd.model_name}.npy', pcks)

In [None]:
pcks.shape

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(8, 6))
colors = plt.cm.rainbow(np.linspace(0, 1, len(all_blocks)))
for block_idx, (block, color) in enumerate(zip(all_blocks, colors)):
    axs[0].plot(interpolation_steps, pcks[block_idx], label=block, color=color)
axs[0].legend(bbox_to_anchor=(1.01, 0.5), loc='center left')
axs[0].set_xticklabels([])
axs[0].set_ylabel('PCK@$0.1_{bbox}$')

colors = plt.cm.rainbow(np.linspace(0, 1, len(all_blocks)))
for block_idx, (block, color) in enumerate(zip(all_blocks, colors)):
    init_acc = pcks[block_idx, 0]
    acc_change = (pcks[block_idx] - init_acc) / init_acc
    axs[1].plot(interpolation_steps, acc_change, label=block, color=color)
axs[1].set_xlabel('Interpolation step')
axs[1].set_ylabel('Rel. Change in PCK@$0.1_{bbox}$')


plt.tight_layout()
plt.show()


## Visualise semantic correspondence

In [None]:
# plot PCK over category
def plot_pck_over_category(correct, metadata):
    categories = list(set(x['category'] for x in metadata))
    bins = [[] for _ in categories]
    for category, c in zip([x['category'] for x in metadata], correct):
        bins[categories.index(category)].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.bar(range(len(categories)), pcks)
    plt.xticks(range(len(categories)), pairs.features['category_id'].names, rotation=90)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom', fontsize=8)
        plt.text(i+0.07, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white', fontsize=8)
    plt.xlabel('Category')
    plt.ylabel('PCK@$0.1_{bbox}$')
    # plt.title('PCK@$0.1_{bbox}$ over category')

plt.figure(figsize=(len(interpolation_steps)*4, len(correct)*4))
for i, (block, correct_block) in enumerate(correct.items()):
    for j, (interpolation_step, correct_interpolation_step) in enumerate(zip(interpolation_steps, correct_block)):
        plt.subplot(len(correct), len(interpolation_steps), i*len(interpolation_steps) + j + 1)
        plt.title(f'{block} {interpolation_step:.2f}')
        plot_pck_over_category(correct_interpolation_step, metadata)
plt.tight_layout()
plt.show()

In [None]:
# plot PCK over viewpoint variation
def plot_pck_over_viewpoint_variation(correct, metadata):
    bins = [[], [], []]
    for viewpoint_variation, c in zip([x['viewpoint_variation'] for x in metadata], correct):
        bins[viewpoint_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Viewpoint variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over viewpoint variation')
    plt.show()

plot_pck_over_viewpoint_variation(correct, metadata)

In [None]:
# plot PCK over bounding box scale variation
def plot_pck_over_scale_variation(correct, metadata):
    bins = [[], [], []]
    for scale_variation, c in zip([x['scale_variation'] for x in metadata], correct):
        bins[scale_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Scale variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over scale variation')
    plt.show()

plot_pck_over_scale_variation(correct, metadata)

In [None]:
# plot PCK over src and trg bounding box size
def plot_pck_over_bbox_size(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    for name, sizes in [('src', src_sizes), ('trg', trg_sizes)]:
        bins = [[] for _ in range(10)]
        min_size = sizes.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
        max_size = sizes.max() ** .5 + 1
        for size, c in zip(sizes, correct):
            idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(10), pcks)
        plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box size')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box size')
        plt.show()

plot_pck_over_bbox_size(correct, metadata)

In [None]:
# plot PCK over difference in bounding box sizes
def plot_pck_over_bbox_size_diff(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    diffs = np.abs(src_sizes - trg_sizes)
    bins = [[] for _ in range(10)]
    min_size = diffs.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
    max_size = diffs.max() ** .5 + 1
    for size, c in zip(diffs, correct):
        idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
        bins[idx].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(10), pcks)
    plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Difference in bounding box sizes')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over difference in bounding box sizes')
    plt.show()

plot_pck_over_bbox_size_diff(correct, metadata)

In [None]:
# plot PCK over bouding box aspect ratio
def plot_pck_over_bbox_aspect_ratio(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_ratios = (src_shapes[:,2] - src_shapes[:,0]) / (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_ratios = (trg_shapes[:,2] - trg_shapes[:,0]) / (trg_shapes[:,3] - trg_shapes[:,1])
    for name, ratios in [('src', src_ratios), ('trg', trg_ratios)]:
        bin_names = ['< 1:4', '1:4 - 1:2', '1:2 - 1:1', '1:1 - 2:1', '2:1 - 4:1', '> 4:1']
        bins = [[] for _ in bin_names]
        min_ratio = ratios.min()
        max_ratio = ratios.max()+1
        for ratio, c in zip(ratios, correct):
            idx = int(np.log2(ratio) + 2)
            if idx < 0: idx = 0
            if idx >= len(bins): idx = len(bins)-1
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(len(bin_names)), pcks)
        plt.xticks(range(len(bin_names)), bin_names, rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box aspect ratio')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box aspect ratio')
        plt.show()

plot_pck_over_bbox_aspect_ratio(correct, metadata)

In [None]:
# plot PCK over src and trg keypoint positions
def plot_pck_over_kp_position(correct, metadata, scale_factor=2):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        min_x = kps[:,0].min() // scale_factor
        max_x = kps[:,0].max() // scale_factor
        min_y = kps[:,1].min() // scale_factor
        max_y = kps[:,1].max() // scale_factor
        matrix = np.zeros((max_y - min_y + 1, max_x - min_x + 1, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y//scale_factor-min_y, x//scale_factor-min_x, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]

        # plot count
        plt.figure()
        plt.imshow(matrix[:,:,1], cmap='viridis', interpolation='nearest')
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'Count of keypoints over {name} keypoint position')
        plt.show()

        # plot pck matrix
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 20)

In [None]:
# plot PCK over src and trg keypoint positions
# might be useful to see if certain pixels in the regions are ignored
def plot_pck_over_kp_position(correct, metadata, size=32):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        matrix = np.zeros((size, size, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y%size, x%size, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x (mod {size})')
        plt.ylabel(f'{name} keypoint y (mod {size})')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 16)

In [None]:
def plot_failures(get_distance: Callable[[bool,Any,int,int], int|None], normalize_values: Callable[[int], float], normalize_distance: Callable[[int], float] = lambda x: x):
    count_over_dist = defaultdict(int)
    for c, m, (x,y) in zip(correct, metadata, positions):
        dist = get_distance(c, m, x, y)
        if dist is None: continue
        count_over_dist[dist] += 1
    counts_total = sum(val for key, val in count_over_dist.items())
    x = []
    y = []
    for distance, count in sorted(count_over_dist.items()):
        y.append(normalize_values(count) / counts_total)
        x.append(normalize_distance(distance+1))
    plt.scatter(x, y, s=2, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Fraction of failed keypoints per pixel')

# both
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: np.hypot(m['src_kp'][0]-x, m['src_kp'][1]-y)**2//25 if not c else None, lambda x: x/np.pi/25, lambda x: (x*25)**.5)
plot_failures(lambda c, m, x, y: np.hypot(m['src_kp'][0]-x, m['src_kp'][1]-y)**2//25, lambda x: x/np.pi/25, lambda x: (x*25)**.5)
plt.legend(['Failed KP', 'All KP'])
plt.title('Distance of failed keypoint guess to source keypoint')
plt.show()

# x
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][0] - x)//1 if not c else None, lambda x: x/2)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][0] - x)//1, lambda x: x/2)
plt.legend(['Failed KP', 'All KP'])
plt.title('X-Distance of failed keypoint guess to source keypoint')
plt.show()

# y
plt.xlim(0, 200)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][1] - y)//1 if not c else None, lambda x: x/2)
plot_failures(lambda c, m, x, y: abs(m['src_kp'][1] - y)//1, lambda x: x/2)
plt.legend(['Failed KP', 'All KP'])
plt.title('Y-Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# PCK over keypoint distance
n = 50
distances = {i: [] for i in range(n)}
counts = {i: 0 for i in range(n)}
for c, m, (x,y) in zip(correct, metadata, positions):
    max_size = max(max(m['src_size']), max(m['trg_size']))
    distance = np.hypot((m['src_kp'][0] - m['trg_kp'][0]) / max_size, (m['src_kp'][1] - m['trg_kp'][1]) / max_size)
    key = int(distance*n)
    if key > n-1:
        key = n-1
    distances[key].append(c)
    counts[key] += 1

plt.bar(counts.keys(), counts.values(), color='lightblue')
plt.ylabel('Number of keypoints', color='lightblue')
plt.xlabel('Distance between keypoints (normalized)')
plt.xticks(range(0,n,n//10), [f'{i/n:.1f}' for i in range(0,n,n//10)])
plt.twinx()
plt.plot(distances.keys(), [np.mean(d)*100 for d in distances.values()], 'red')
plt.ylabel('PCK@$0.1_{bbox}$ (%)', color='red')
plt.show()

### visualize random failures

In [None]:
# visualize random failures
i = random.randint(0, len(metadata))
while correct[i]:
    i = random.randint(0, len(metadata))
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pairs[metadata[i]['i']]['src_img'])
plt.scatter(metadata[i]['src_kp'][0], metadata[i]['src_kp'][1], c='g', marker='x', s=200)
plt.subplot(1, 2, 2)
plt.imshow(pairs[metadata[i]['i']]['trg_img'])
plt.scatter(metadata[i]['trg_kp'][0], metadata[i]['trg_kp'][1], c='g', marker='x', s=200)
plt.scatter(positions[i][0], positions[i][1], c='r', marker='x', s=200)
relative_distance = ((positions[i][0] - metadata[i]['trg_kp'][0])**2 + (positions[i][1] - metadata[i]['trg_kp'][1])**2) ** 0.5 / max(metadata[i]['trg_bndbox'][2] - metadata[i]['trg_bndbox'][0], metadata[i]['trg_bndbox'][3] - metadata[i]['trg_bndbox'][1])
plt.title(f'PCK@0.1_bbox: {relative_distance:.2f}')
plt.show()