# Semantic Correspondence

In [None]:
from sdhelper import SD
import torch
import numpy as np
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import random
from PIL import Image, ImageDraw

# Generate Images

In [None]:
def create_circular_cutout_with_gradient(image):
    # Open the input image
    width, height = image.size
    radius = min(width, height) // 2
    # Create a mask for the circular cutout
    mask = Image.new("L", (width, height), 0)
    draw = ImageDraw.Draw(mask)
    draw.ellipse((width//2 - radius, height//2 - radius, width//2 + radius, height//2 + radius), fill=255)

    # Create a gradient for the internal transparency
    gradient = Image.new("L", (width, height), 0)
    for y in range(height):
        for x in range(width):
            distance = ((x - width//2)**2 + (y - height//2)**2)**0.5
            if distance < radius:
                alpha = int(255 * radius/200 * (1-distance/radius))
                if alpha < 0:
                    alpha = 0
                gradient.putpixel((x, y), alpha)

    # Combine the mask and the gradient
    combined_mask = Image.composite(mask, gradient, gradient)

    # Apply the combined mask to the image
    image.putalpha(combined_mask)

    # Save the result
    return image


if False:
    sd = SD('sdxl-lightning-4step')
    cat_image = sd('a photo of a cat sitting on a mat').result_image
    background = sd('a blurred photo of a room').result_image

    # Example usage
    cat_image_crop = create_circular_cutout_with_gradient(cat_image)
    new_img = background.copy().convert("RGBA")
    new_img.paste(cat_image_crop.resize((512,512)), (0, 0), cat_image_crop.resize((512,512)))
    # new_img.save('cat_on_bg.png')

# Evaluate Images

In [None]:
sd = SD()

In [None]:
# load images
# assuming size of images is 1024x1024
foregrounds = [
    PIL.Image.open("cat.png"),
    PIL.Image.open("cat2.png"),
]
background = PIL.Image.open("cat_bg.png")

In [None]:
keypoints = [{
    'nose': (516, 430),
    'mouth': (519, 480),
    'left_eye': (445, 340),
    'right_eye': (590, 340),
    'left_ear': (335, 142),
    'right_ear': (690, 130),
    'left_front_paw': (470, 850),
    'right_front_paw': (630, 850),
},{
    'nose': (532, 342),
    'mouth': (528, 395),
    'left_eye': (475, 285),
    'right_eye': (591, 292),
    'left_ear': (410, 130),
    'right_ear': (680, 150),
    'left_front_paw': (410, 635),
    'right_front_paw': (550, 770),
}]
plt.figure(figsize=(5*len(foregrounds), 5))
for i, (foreground, kps) in enumerate(zip(foregrounds, keypoints), 1):
    plt.subplot(1, len(foregrounds), i)
    plt.imshow(foreground)
    for name, (x, y) in kps.items():
        plt.scatter(x, y, label=name)
# plt.legend()
plt.show()

In [None]:
# define image/kp transformation functions

def place_fg_on_bg(fg: PIL.Image.Image, bg: PIL.Image.Image, x: int, y: int, rotation: float, scale: float) -> PIL.Image.Image:
    result = bg.copy()
    fg = fg.rotate(rotation, expand=True)
    fg = fg.resize((int(fg.width * scale), int(fg.height * scale)))
    result.paste(fg, (x - fg.width // 2, y - fg.height // 2), fg)
    return result

def transform_keypoints(keypoints: dict[str, tuple[int, int]], x: int, y: int, rotation: float, scale: float) -> dict[str, tuple[int, int]]:
    result = {}
    rotation = -np.deg2rad(rotation)
    for name, (kx, ky) in keypoints.items():
        kx, ky = kx - 512, ky - 512
        kx, ky = kx * np.cos(rotation) - ky * np.sin(rotation), kx * np.sin(rotation) + ky * np.cos(rotation)
        kx *= scale
        ky *= scale
        kx += x
        ky += y
        result[name] = (kx, ky)
    return result

i = 0
test_img = place_fg_on_bg(foregrounds[i], background, 432, 600, 420, 0.7)
test_kp = transform_keypoints(keypoints[i], 432, 600, 420, 0.7)

plt.imshow(test_img)
for name, (x, y) in test_kp.items():
    plt.scatter(x, y, label=name)
# plt.legend()
plt.show()

In [None]:
# semantic correspondence

def sc(x_a: int, y_a: int, rot_a: float, scale_a: float, x_b: int, y_b: int, rot_b: float, scale_b: float, plot=False):
    img_size = 512
    args_a = (x_a, y_a, rot_a, scale_a)
    args_b = (x_b, y_b, rot_b, scale_b)
    img_a = place_fg_on_bg(foregrounds[0], background, *args_a)
    img_b = place_fg_on_bg(foregrounds[1], background, *args_b)
    repr_a = sd.img2repr(img_a.resize((img_size,img_size)), ['up_blocks[1]'], step=100).concat()
    repr_b = sd.img2repr(img_b.resize((img_size,img_size)), ['up_blocks[1]'], step=100).concat()
    repr_size = repr_a.shape[1]
    kp_a = transform_keypoints(keypoints[0], *args_a)
    kp_b = transform_keypoints(keypoints[1], *args_b)
    kp_pred = {}
    for name, (x, y) in kp_a.items():
        tmp_a = repr_a[:, int(y*repr_size//1024), int(x*repr_size//1024), None, None]
        similarities = torch.cosine_similarity(tmp_a, repr_b, dim=0)
        argmax = torch.argmax(similarities)
        pred_x = (argmax % repr_size + 0.5) / repr_size * 1024
        pred_y = (argmax // repr_size + 0.5) / repr_size * 1024
        kp_pred[name] = (pred_x, pred_y)

    differences = (np.array(list(kp_pred.values())) - np.array(list(kp_b.values()))) / 1024
    errors = np.linalg.norm(differences, axis=1)

    if plot:
        print(f"Max    relative error: {np.max(errors):6.2%} of img, or {np.max(errors)/scale_b:6.2%} of trg obj")
        print(f"Mean   relative error: {np.mean(errors):6.2%} of img, or {np.mean(errors)/scale_b:6.2%} of trg obj")
        print(f"Median relative error: {np.median(errors):6.2%} of img, or {np.median(errors)/scale_b:6.2%} of trg obj")
        print(f'Min    relative error: {np.min(errors):6.2%} of img, or {np.min(errors)/scale_b:6.2%} of trg obj')

        plt.figure(figsize=(9, 3))
        plt.subplot(1, 2, 1)
        plt.imshow(img_a)
        for name, (x, y) in kp_a.items():
            plt.scatter(x, y, label=name)
        plt.axis('off')
        plt.subplot(1, 2, 2)
        plt.imshow(img_b)
        for name, (x, y) in kp_pred.items():
            plt.scatter(x, y, label=name)
        plt.axis('off')
        plt.show()

    return kp_pred, errors

sc(432, 600, 0, 0.8, 432, 600, 0, 0.8, plot=True);
sc(500, 500, -20, 0.7, 400, 555, 100, 0.6, plot=True);

# SC over positions

In [None]:
scale = 0.25
n = 16 - 4  # number of representations minus start+end
samples = 4

start = scale * 512
end = 1024 - scale * 512
matrix_max = np.zeros((n, n))
matrix_median = np.zeros((n, n))
matrix_mean = np.zeros((n, n))
matrix_min = np.zeros((n, n))
tq = tqdm(total=n*n)
for i, x in enumerate(np.linspace(start, end, n)):
    for j, y in enumerate(np.linspace(start, end, n)):
        tq.update(1)
        for _ in range(samples):
            kp_pred, errors = sc(int(x), int(y), 0, scale, 512, 512, 0, scale)
            matrix_max[i, j] += np.max(errors) / samples
            matrix_median[i, j] += np.median(errors) / samples
            matrix_mean[i, j] += np.mean(errors) / samples
            matrix_min[i, j] += np.min(errors) / samples

for name, matrix in zip(['max', 'median', 'mean', 'min'], [matrix_max, matrix_median, matrix_mean, matrix_min]):
    plt.imshow(matrix, interpolation='nearest', vmin=0)
    plt.colorbar()
    plt.title(f'{name} relative error')
    plt.show()

In [None]:
# plot failure cases
failure_positions = {np.unravel_index(x.argmax(), x.shape) for x in [matrix_max, matrix_median, matrix_mean, matrix_min]}
for x, y in failure_positions:
    sc(int(np.linspace(start, end, n)[x]), int(np.linspace(start, end, n)[y]), 0, scale, 512, 512, 0, scale, plot=True)

# Statistics for sc with random transformations

In [None]:
# calculate sc errors for random transformations
results = []
for i in trange(500):
    scale_a = 0.5 #random.uniform(0.1, 1)
    x_a = random.randint(int(scale_a*512), 1024-int(scale_a*512))
    y_a = random.randint(int(scale_a*512), 1024-int(scale_a*512))
    rot_a = random.uniform(-180, 180)
    scale_b = 0.5 #random.uniform(0.1, 1)
    x_b = random.randint(int(scale_b*512), 1024-int(scale_b*512))
    y_b = random.randint(int(scale_b*512), 1024-int(scale_b*512))
    rot_b = random.uniform(-180, 180)
    args = (x_a, y_a, rot_a, scale_a, x_b, y_b, rot_b, scale_b)
    results.append((args, *sc(*args)))


In [None]:
# calculate extra attributes
for i, ((x_a, y_a, rot_a, scale_a, x_b, y_b, rot_b, scale_b, *_), kp_pred, errors) in enumerate(results):
    position_difference = np.linalg.norm(np.array([x_a, y_a]) - np.array([x_b, y_b]))
    rotation_difference = abs(rot_a - rot_b) % 360
    if rotation_difference > 180:
        rotation_difference = 360 - rotation_difference
    fg_scale_factor = np.log2(scale_b / scale_a)
    results[i] = ((x_a, y_a, rot_a, scale_a, x_b, y_b, rot_b, scale_b, position_difference, rotation_difference, fg_scale_factor), kp_pred, errors)

In [None]:
# plot calculated errors
for i, name in enumerate(
    [f'{a} {b}' for a in ['source','target'] for b in ['x position [center,px]', 'y position [center,px]', 'rotation [°]', 'foreground object scale']]
    + ['position difference [px]', 'rotation difference [°]', 'foreground scale factor [log2]']
):
    # skip if all values are the same
    if len(set([a[i] for a, _, _ in results])) == 1:
        continue
    # create bins
    bins = [[] for _ in range(10)]
    min = np.min([a[i] for a, _, _ in results])
    max = np.max([a[i] for a, _, _ in results])
    for args, _, err in results:
        bins[int((args[i] - min) / (max - min + 1e-6) * 10)].append(err)
    # plot
    plt.bar(range(10), [np.max(b, axis=1).mean() for b in bins])#, yerr=[np.std(b) for b in bins])
    plt.bar(range(10), [np.mean(b) for b in bins])
    plt.bar(range(10), [np.min(b, axis=1).mean() for b in bins])
    for i, b in enumerate(bins):
        plt.text(i, 0, f'{len(b)}', ha='center', va='bottom', color='white')
    plt.xlabel(name)
    plt.ylabel('Mean error distance [fraction of img size]')
    plt.xticks(range(10), [f'{min + i*(max-min)/10:.{2 if "scale" in name else 0}f}' for i in range(10)], rotation=45)
    plt.legend(['Max KP', 'Mean KP', 'Min KP'])
    plt.title(f'Error over {name.split("[")[0].strip()}')
    plt.show()