# Semantic Correspondence

Simple semantic correspondence with visualisations based on metadata

## Setup

In [None]:
from sdhelper import SD
import torch
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable

In [None]:
sd = SD('SD2.1')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
metadata_keys = ['src_bndbox','trg_bndbox','category','viewpoint_variation','scale_variation',]
metadata = [{'src_kp': src_kp, 'trg_kp': trg_kp} | {k: pair[k] for k in metadata_keys} for pair in tqdm(pairs) for src_kp, trg_kp in zip(pair['src_kps'], pair['trg_kps'])]

## Calculate semantic correspondence

In [None]:
def compute_pca_basis(data, n_components):
    # Center the data
    data_mean = torch.mean(data, dim=0)
    data_centered = data - data_mean

    # Compute covariance matrix
    covariance_matrix = torch.mm(data_centered.T, data_centered) / (data_centered.size(0) - 1)

    # Compute eigenvalues and eigenvectors using torch.linalg.eigh
    eigenvalues, eigenvectors = torch.linalg.eigh(covariance_matrix)

    # Sort eigenvalues and eigenvectors in descending order
    sorted_indices = torch.argsort(eigenvalues, descending=True)
    sorted_eigenvectors = eigenvectors[:, sorted_indices]

    # Select the top n_components
    pca_basis = sorted_eigenvectors[:, :n_components]

    return pca_basis, data_mean

def transform_data(data, pca_basis, data_mean):
    data_centered = data - data_mean
    transformed_data = torch.mm(data_centered, pca_basis)
    return transformed_data


In [None]:
def sc(transform_img: Callable = lambda x: x, n_components=500):
    # precalculate representations
    representations = []
    for x in tqdm(data, desc='Calculating representations'):
        r = sd.img2repr(transform_img(x['img']), ['up_blocks[0]','up_blocks[1].attentions[2]'], 100, prompt=x['name'].split('/')[0])
        r = r.apply(lambda x: x / torch.norm(x, dim=0, keepdim=True))  # normalize
        r = r.concat().float()
        representations.append(r)

    # pca transform
    print('calculating PCA basis')
    num_features = representations[0].shape[0]
    pca_basis, data_mean = compute_pca_basis(torch.cat([r.view((num_features,-1)) for r in representations], dim=1).T, n_components)
    print('transforming representations with PCA')
    representations_pca = [transform_data(r.view((num_features,-1)).T, pca_basis, data_mean).T.reshape((-1,*r.shape[1:])) for r in representations]

    # calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
    correct = []
    for x in (t:=tqdm(pairs, desc='Calculating SC')):
        a = representations_pca[x['src_data_index']]
        b = representations_pca[x['trg_data_index']]
        tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
        for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
            src_repr = a[:, sy*a.shape[1]//x['src_img'].size[1], sx*a.shape[2]//x['src_img'].size[0]]
            cossim = (b * src_repr[:,None,None]).sum(dim=0)
            y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
            x_max_pixel = x_max * x['trg_img'].size[0] / b.shape[2]
            y_max_pixel = y_max * x['trg_img'].size[1] / b.shape[1]
            relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
            correct.append(relative_distance < 0.1)
        if len(correct) % 100 == 0:
            t.set_postfix(pck=np.mean(correct)*100)

    return correct

correct = sc(lambda img: img.resize(np.array(img.size) * 768 // max(img.size)), 256)

In [None]:
np.mean(correct)*100

## Visualise semantic correspondence

In [None]:
# plot PCK over category
def plot_pck_over_category(correct, metadata):
    categories = list(set(x['category'] for x in metadata))
    bins = [[] for _ in categories]
    for category, c in zip([x['category'] for x in metadata], correct):
        bins[categories.index(category)].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(len(categories)), pcks)
    plt.xticks(range(len(categories)), pairs.features['category'].names, rotation=90)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom', fontsize=8)
        plt.text(i+0.07, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white', fontsize=8)
    plt.xlabel('Category')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over category')
    plt.show()

plot_pck_over_category(correct, metadata)

In [None]:
# plot PCK over viewpoint variation
def plot_pck_over_viewpoint_variation(correct, metadata):
    bins = [[], [], []]
    for viewpoint_variation, c in zip([x['viewpoint_variation'] for x in metadata], correct):
        bins[viewpoint_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Viewpoint variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over viewpoint variation')
    plt.show()

plot_pck_over_viewpoint_variation(correct, metadata)

In [None]:
# plot PCK over bounding box scale variation
def plot_pck_over_scale_variation(correct, metadata):
    bins = [[], [], []]
    for scale_variation, c in zip([x['scale_variation'] for x in metadata], correct):
        bins[scale_variation].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(3), pcks)
    plt.xticks(range(3), ['low', 'medium', 'high'])
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Scale variation')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over scale variation')
    plt.show()

plot_pck_over_scale_variation(correct, metadata)

In [None]:
# plot PCK over src and trg bounding box size
def plot_pck_over_bbox_size(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    for name, sizes in [('src', src_sizes), ('trg', trg_sizes)]:
        bins = [[] for _ in range(10)]
        min_size = sizes.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
        max_size = sizes.max() ** .5 + 1
        for size, c in zip(sizes, correct):
            idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(10), pcks)
        plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box size')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box size')
        plt.show()

plot_pck_over_bbox_size(correct, metadata)

In [None]:
# plot PCK over difference in bounding box sizes
def plot_pck_over_bbox_size_diff(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_sizes = (src_shapes[:,2] - src_shapes[:,0]) * (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_sizes = (trg_shapes[:,2] - trg_shapes[:,0]) * (trg_shapes[:,3] - trg_shapes[:,1])
    diffs = np.abs(src_sizes - trg_sizes)
    bins = [[] for _ in range(10)]
    min_size = diffs.min() ** .5  # sqrt for more useful bins (relative to side length instead of area)
    max_size = diffs.max() ** .5 + 1
    for size, c in zip(diffs, correct):
        idx = int((size**.5 - min_size) / (max_size - min_size) * 10)
        bins[idx].append(c)
    pcks = [np.mean(b)*100 for b in bins]
    plt.figure()
    plt.bar(range(10), pcks)
    plt.xticks(range(10), [f'{(min_size + i*(max_size - min_size)/10)**2/1000:.0f}k-{(min_size + (i+1)*(max_size - min_size)/10)**2/1000:.0f}k px' for i in range(10)], rotation=45)
    for i, pck in enumerate(pcks):
        plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
        plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
    plt.xlabel('Difference in bounding box sizes')
    plt.ylabel('PCK@$0.1_{bbox}$')
    plt.title('PCK@$0.1_{bbox}$ over difference in bounding box sizes')
    plt.show()

plot_pck_over_bbox_size_diff(correct, metadata)

In [None]:
# plot PCK over bouding box aspect ratio
def plot_pck_over_bbox_aspect_ratio(correct, metadata):
    src_shapes = np.array([x['src_bndbox'] for x in metadata])
    src_ratios = (src_shapes[:,2] - src_shapes[:,0]) / (src_shapes[:,3] - src_shapes[:,1])
    trg_shapes = np.array([x['trg_bndbox'] for x in metadata])
    trg_ratios = (trg_shapes[:,2] - trg_shapes[:,0]) / (trg_shapes[:,3] - trg_shapes[:,1])
    for name, ratios in [('src', src_ratios), ('trg', trg_ratios)]:
        bin_names = ['< 1:4', '1:4 - 1:2', '1:2 - 1:1', '1:1 - 2:1', '2:1 - 4:1', '> 4:1']
        bins = [[] for _ in bin_names]
        min_ratio = ratios.min()
        max_ratio = ratios.max()+1
        for ratio, c in zip(ratios, correct):
            idx = int(np.log2(ratio) + 2)
            if idx < 0: idx = 0
            if idx >= len(bins): idx = len(bins)-1
            bins[idx].append(c)
        pcks = [np.mean(b)*100 for b in bins]
        plt.figure()
        plt.bar(range(len(bin_names)), pcks)
        plt.xticks(range(len(bin_names)), bin_names, rotation=45)
        for i, pck in enumerate(pcks):
            plt.text(i, pck, f'{pck:.2f}', ha='center', va='bottom')
            plt.text(i, 3, f'{len(bins[i])}', ha='center', va='bottom', rotation=90, color='white')
        plt.xlabel(f'{name} bounding box aspect ratio')
        plt.ylabel('PCK@$0.1_{bbox}$')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} bounding box aspect ratio')
        plt.show()

plot_pck_over_bbox_aspect_ratio(correct, metadata)

In [None]:
# plot PCK over src and trg keypoint positions
def plot_pck_over_kp_position(correct, metadata, scale_factor=2):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        min_x = kps[:,0].min() // scale_factor
        max_x = kps[:,0].max() // scale_factor
        min_y = kps[:,1].min() // scale_factor
        max_y = kps[:,1].max() // scale_factor
        matrix = np.zeros((max_y - min_y + 1, max_x - min_x + 1, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y//scale_factor-min_y, x//scale_factor-min_x, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]

        # plot count
        plt.figure()
        plt.imshow(matrix[:,:,1], cmap='viridis', interpolation='nearest')
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'Count of keypoints over {name} keypoint position')
        plt.show()

        # plot pck matrix
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x')
        plt.ylabel(f'{name} keypoint y')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 20)

In [None]:
# plot PCK over src and trg keypoint positions
# might be useful to see if certain pixels in the regions are ignored
def plot_pck_over_kp_position(correct, metadata, size=32):
    src_kps = np.array([x['src_kp'] for x in metadata])
    trg_kps = np.array([x['trg_kp'] for x in metadata])
    for name, kps in [('src', src_kps), ('trg', trg_kps)]:
        matrix = np.zeros((size, size, 2))
        for (x, y), c in zip(kps, correct):
            matrix[y%size, x%size, :] += c, 1
        pck_matrix = matrix[:,:,0] / matrix[:,:,1]
        plt.figure()
        plt.imshow(pck_matrix, cmap='viridis', interpolation='nearest', vmin=0, vmax=1)
        plt.colorbar()
        plt.xlabel(f'{name} keypoint x (mod {size})')
        plt.ylabel(f'{name} keypoint y (mod {size})')
        plt.title(f'PCK@$0.1_{{bbox}}$ over {name} keypoint position')
        plt.show()

plot_pck_over_kp_position(correct, metadata, 16)