# Semantic Correspondence with position subtraction

train a position estimator and then optimize the extracted SD representations such that the position isn't represented anymore

## Setup

In [None]:
from sdhelper import SD
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import datasets
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
from typing import Callable
import random
from collections import defaultdict


In [None]:
sd = SD('SD15')
data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
metadata_keys = ['src_bndbox','trg_bndbox','category','viewpoint_variation','scale_variation',]
metadata = [{'i': i, 'src_kp': src_kp, 'trg_kp': trg_kp, 'src_size': pair['src_img'].size, 'trg_size': pair['trg_img'].size} | {k: pair[k] for k in metadata_keys} for i, pair in enumerate(tqdm(pairs)) for src_kp, trg_kp in zip(pair['src_kps'], pair['trg_kps'])]

## Calculate semantic correspondence

In [None]:
p = 'up_blocks[1]'
img_size = 512

In [None]:
# precalculate representations

def expand_and_resize(x: PIL.Image.Image, size, border_pad=True):
    n, m = x.size
    s = max(n, m)
    r = PIL.Image.new('RGB', (s, s))
    r.paste(x, ((s-n)//2, (s-m)//2))
    if border_pad:
        # pad with border
        if n > m:
            r.paste(x.crop((0, 0, n, 1)).resize((n,(s-m)//2)), (0, 0))
            r.paste(x.crop((0, m-1, n, m)).resize((n,(s-m)//2)), (0, m+(s-m)//2))
        elif m > n:
            r.paste(x.crop((0, 0, 1, m)).resize(((s-n)//2,m)), (0, 0))
            r.paste(x.crop((n-1, 0, n, m)).resize(((s-n)//2,m)), (n+(s-n)//2, 0))
    return r.resize((size, size))

transform_img = lambda img: expand_and_resize(img, img_size, True)

representations = []
for x in tqdm(data, desc='Calculating representations'):
    r = sd.img2repr(transform_img(x['img']), [p], 100, prompt=x['name'].split('/')[0])
    r = r.apply(lambda x: x / torch.norm(x, dim=0, keepdim=True))  # normalize
    representations.append(r)

In [None]:
_, num_channels, H, W = representations[0][p].shape
assert H == W

class PositionClassifier(nn.Module):
    def __init__(self, num_channels: int, size: int):
        super().__init__()
        self.layer_x = nn.Linear(num_channels, size)
        self.layer_y = nn.Linear(num_channels, size)

    def forward(self, repr):
        x = self.layer_x(repr)
        y = self.layer_y(repr)
        return x, y
    
device = 'cuda'
model = PositionClassifier(num_channels, H).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [None]:
# plot training progress
from trainplot.trainplot import TrainPlotPlotlyExperimental as TrainPlot
tp = TrainPlot(threaded=True)
# tp.fig.update_yaxes(type="log");

In [None]:
# train position classifier
batch_size = 32
model.train()
for epoch in trange(20):
    indices = np.random.permutation(len(representations))
    for i in range(0, len(representations), batch_size):
        batch = [representations[p] for p in indices[i:i+batch_size]]
        input = torch.stack([x[p].squeeze(0).flatten(1,2).T for x in batch]).to(device=device, dtype=torch.float32)
        # current shape of x: [batch_size, num_channels, H, W]
        y_x = torch.arange(W, device=device).repeat(H).expand(len(batch), -1).flatten()
        y_y = torch.arange(H, device=device).repeat_interleave(W).expand(len(batch), -1).flatten()
        pred_x, pred_y = model(input)
        loss = F.cross_entropy(pred_x.flatten(0,1), y_x) + F.cross_entropy(pred_y.flatten(0,1), y_y)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        with torch.no_grad():
            tp(
                accuracy_x = (pred_x.argmax(dim=2).flatten() == y_x).float().mean().cpu().item(),
                accuracy_y = (pred_y.argmax(dim=2).flatten() == y_y).float().mean().cpu().item(),
            )

In [None]:
# remove positional information from representations
# using dynamic learning rate, which increases if convergence is slow
representations_unpositioned = []
model.eval()
for r in tqdm(representations):
    r = r[p].squeeze(0).flatten(1,2).T.unsqueeze(0).to(device=device, dtype=torch.float32)
    loss, prev_loss = 1, 1
    lr = 1.
    for i in range(50):
        r = r.detach().requires_grad_()
        x, y = model(r)
        position_loss = F.mse_loss(F.softmax(x, dim=-1), torch.full_like(x, 1/W)) + F.mse_loss(F.softmax(y, dim=-1), torch.full_like(y, 1/H))
        nochange_loss = F.mse_loss(r, torch.zeros_like(r))
        loss = position_loss + 0.0 * nochange_loss
        if prev_loss / loss < 1.05:  # if loss decreases slowly, increase lr
            lr *= 2
        if prev_loss < loss:  # if loss increases, reset lr
            lr = 1.
        prev_loss = loss
        loss.backward()
        r = r - lr * r.grad
        # print(loss.item(), lr)
    r = r.detach().to(device='cpu', dtype=torch.float16).squeeze(0).T.reshape(num_channels, H, W)
    representations_unpositioned.append(r)

In [None]:
# renormalize representations_unpositioned
representations_unpositioned = torch.stack(representations_unpositioned)
representations_unpositioned = representations_unpositioned / representations_unpositioned.pow(2).sum(dim=1, keepdim=True).sqrt()

In [None]:
# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
correct = []
positions = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = representations_unpositioned[x['src_data_index']]
    b = representations_unpositioned[x['trg_data_index']]
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct.append(relative_distance < 0.1)
        positions.append((x_max_pixel, y_max_pixel))
    if len(correct) % 100 == 0:
        t.set_postfix(pck=np.mean(correct)*100)

In [None]:
np.mean(correct)*100

In [None]:
# reference

# calculate percentage of correct keypoints at 10% of the bounding box (PCK@0.1_bbox)
correct_reference = []
positions_reference = []
for x in (t:=tqdm(pairs, desc='Calculating SC')):
    a = representations[x['src_data_index']].concat()
    b = representations[x['trg_data_index']].concat()
    tbb_max = max(x['trg_bndbox'][2] - x['trg_bndbox'][0], x['trg_bndbox'][3] - x['trg_bndbox'][1])
    for ([sx, sy],[tx,ty]) in zip(x['src_kps'], x['trg_kps']):
        src_repr = a[:, 
            int((sy + (max(x['src_img'].size) - x['src_img'].size[1])/2) * a.shape[1] / max(x['src_img'].size)),
            int((sx + (max(x['src_img'].size) - x['src_img'].size[0])/2) * a.shape[2] / max(x['src_img'].size)),
        ]
        cossim = (b * src_repr[:,None,None]).sum(dim=0)
        y_max, x_max = np.unravel_index(cossim.argmax().cpu(), cossim.shape)
        x_max_pixel = x_max / b.shape[2] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[0]) / 2
        y_max_pixel = y_max / b.shape[1] * max(x['trg_img'].size) - (max(x['trg_img'].size) - x['trg_img'].size[1]) / 2
        relative_distance = ((x_max_pixel - tx)**2 + (y_max_pixel - ty)**2) ** 0.5 / tbb_max
        correct_reference.append(relative_distance < 0.1)
        positions_reference.append((x_max_pixel, y_max_pixel))
    if len(correct_reference) % 100 == 0:
        t.set_postfix(pck=np.mean(correct_reference)*100)

In [None]:
np.mean(correct_reference)*100

## Evaluate

In [None]:
def plot_failures(f: Callable[['Any'], tuple[int,int]], only_fails=False):
    fail_distances = defaultdict(int)
    for c, tmp, (x,y) in zip(correct, metadata, positions):
        if only_fails and c:
            continue
        x0, y0 = f(tmp)
        dist = np.hypot(x0-x, y0-y)
        fail_distances[dist**2//25] += 1
    fail_distances = sorted(fail_distances.items())
    num_fails = sum(val for key, val in fail_distances)
    x = []
    y = []
    for tmp, count in fail_distances:
        # if ((tmp+1)*25)**.5 < 20:
        #     print(f'until {((tmp+1)*25)**.5:2.0f} pixels: {count} keypoints')
        x.append(((tmp+1)*25)**.5)
        y.append(count/np.pi/25 / num_fails)
    plt.scatter(x, y, s=2, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Normalized number of failed keypoints per pixel')

plt.xlim(0, 200)
plot_failures(lambda x: (x['src_kp'][0], x['src_kp'][1]), only_fails=True)
plot_failures(lambda x: (x['src_kp'][0], x['src_kp'][1]))
# plot_failures(lambda x: (random.randint(100,400), random.randint(100,300)))
plt.legend(['Failed KP', 'All KP'])
plt.title('Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# failure distance
def plot_failures(f: Callable[['Any'], tuple[int,int]], only_fails=False):
    fail_distances = defaultdict(int)
    for c, tmp, (x,y) in zip(correct, metadata, positions):
        if only_fails and c:
            continue
        fail_distances[f(tmp, x, y)//1] += 1
    num_fails = sum(val for key, val in fail_distances.items())
    x = []
    y = []
    for tmp, count in sorted(fail_distances.items()):
        x.append((tmp+1))
        y.append(count/2 / num_fails)
    plt.scatter(x, np.array(y), s=3, alpha=0.5)
    plt.xlabel('Distance to source keypoint (pixels)')
    plt.ylabel('Normalized number of failed keypoints per row/column')

# x
plt.xlim(0, 200)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][0] - x), only_fails=True)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][0] - x))
plt.legend(['Failed KP', 'All KP'])
plt.title('X-Distance of failed keypoint guess to source keypoint')
plt.show()

# y
plt.xlim(0, 200)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][1] - y), only_fails=True)
plot_failures(lambda tmp, x, y: abs(tmp['src_kp'][1] - y))
plt.legend(['Failed KP', 'All KP'])
plt.title('Y-Distance of failed keypoint guess to source keypoint')
plt.show()

In [None]:
# PCK over keypoint distance
n = 50
distances = {i: [] for i in range(n)}
counts = {i: 0 for i in range(n)}
for c, m, (x,y) in zip(correct, metadata, positions):
    max_size = max(max(m['src_size']), max(m['trg_size']))
    dist = np.hypot((m['src_kp'][0] - m['trg_kp'][0]) / max_size, (m['src_kp'][1] - m['trg_kp'][1]) / max_size)
    key = int(dist*n)
    if key > n-1:
        key = n-1
    distances[key].append(c)
    counts[key] += 1

plt.bar(counts.keys(), counts.values(), color='lightblue')
plt.ylabel('Number of keypoints', color='lightblue')
plt.xlabel('Distance between keypoints (normalized)')
plt.xticks(range(0,n,n//10), [f'{i/n:.1f}' for i in range(0,n,n//10)])
plt.twinx()
plt.plot(distances.keys(), [np.mean(d)*100 for d in distances.values()], 'red')
plt.ylabel('PCK@$0.1_{bbox}$ (%)', color='red')
plt.show()

### visualize random failures

In [None]:
# visualize random failures
i = random.randint(0, len(metadata))
while correct[i]:
    i = random.randint(0, len(metadata))
plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(pairs[metadata[i]['i']]['src_img'])
plt.scatter(metadata[i]['src_kp'][0], metadata[i]['src_kp'][1], c='g', marker='x', s=200)
plt.subplot(1, 2, 2)
plt.imshow(pairs[metadata[i]['i']]['trg_img'])
plt.scatter(metadata[i]['trg_kp'][0], metadata[i]['trg_kp'][1], c='g', marker='x', s=200)
plt.scatter(positions[i][0], positions[i][1], c='r', marker='x', s=200)
relative_distance = ((positions[i][0] - metadata[i]['trg_kp'][0])**2 + (positions[i][1] - metadata[i]['trg_kp'][1])**2) ** 0.5 / max(metadata[i]['trg_bndbox'][2] - metadata[i]['trg_bndbox'][0], metadata[i]['trg_bndbox'][3] - metadata[i]['trg_bndbox'][1])
plt.title(f'PCK@0.1_bbox: {relative_distance:.2f}')
plt.show()