# Evaluate flip semantic/dense correspondence performance

WIP: is the normalization factor good?

In [None]:
from pathlib import Path
from sdhelper import SD
import torch
from tqdm.autonotebook import tqdm, trange
import PIL.Image
import PIL.ImageOps
import matplotlib.pyplot as plt
import numpy as np


In [None]:
img_path = Path('../../random_images_flux')
images = [PIL.Image.open(img) for img in list(img_path.glob('*.jpg'))]
sd = SD('sd15', disable_progress_bar=True)
reprs = [sd.img2repr(img, extract_positions=sd.available_extract_positions, step=50) for img in tqdm(images, desc='extracting representations')]
reprs_flipped = [sd.img2repr(PIL.ImageOps.mirror(img), extract_positions=sd.available_extract_positions, step=50) for img in tqdm(images, desc='extracting representations')]
del sd

In [None]:
# Calculate error as in show_flip_sc_fails.ipynb and average it over each image

num_positions = len(reprs[0].data)
fig = plt.figure(figsize=(12, 3*num_positions))

gs = fig.add_gridspec(num_positions, 2, width_ratios=[2, 1], hspace=0.4, wspace=0.3)
errors = np.zeros((len(images), num_positions))

for idx, p in enumerate(tqdm(reprs[0].data, desc='calculating errors')):
    n = reprs[0].data[p].shape[-1]

    for i, (repr, repr_flipped) in enumerate(zip(reprs, reprs_flipped)):
        similarities = repr.at(p).cosine_similarity(repr_flipped.at(p))

        indices = similarities.view(-1, n, n).argmax(dim=0)
        k_, l_ = indices // n, indices % n
        l_ = n - 1 - l_  # flip

        k, l = torch.meshgrid(torch.arange(n), torch.arange(n), indexing='ij')
        dist = ((k - k_)**2 + (l - l_)**2).sqrt()

        max_possible_dist = ((n-1)**2 + (n-1)**2)**0.5
        errors[i, idx] = (dist / max_possible_dist).mean()

    # Error distribution plot
    ax_dist = fig.add_subplot(gs[idx, 0])
    ax_dist.hist(errors[:, idx], bins=20, color='skyblue', edgecolor='black')
    ax_dist.set_title(f'Error Distribution for {p}')
    ax_dist.set_xlabel('Error')
    ax_dist.set_ylabel('Frequency')

    # Violin plot
    ax_violin = fig.add_subplot(gs[idx, 1])
    ax_violin.violinplot(errors[:, idx], showmeans=True, showextrema=True, showmedians=True)
    ax_violin.set_title(f'Error Violin Plot for {p}')
    ax_violin.set_ylabel('Error')

plt.tight_layout()
plt.show()

In [None]:
def plot_error_visualization(image, repr, repr_flipped):
    # Calculate similarities and indices
    similarities = repr.cosine_similarity(repr_flipped)
    n = similarities.shape[-1]
    indices = similarities.view(-1, n, n).argmax(dim=0)
    k_, l_ = indices // n, indices % n
    l_ = n - 1 - l_  # flip

    # Calculate distance
    k, l = torch.meshgrid(torch.arange(n), torch.arange(n), indexing='ij')
    dist = ((k - k_)**2 + (l - l_)**2).sqrt()

    # Create error normalization helper matrix
    tmp = torch.arange(-n, n).unsqueeze(0)**2
    error_normalization_matrix = torch.cumsum(torch.cumsum((tmp + tmp.T).sqrt(), dim=1), dim=0)

    # Normalize distance
    x_start, y_start = n - k - 1, n-l - 1
    x_end, y_end = x_start + n, y_start + n
    normalizer = (error_normalization_matrix[x_end,y_end] + error_normalization_matrix[x_start,y_start] - 
                  error_normalization_matrix[x_end,y_start] - error_normalization_matrix[x_start,y_end]) / n
    normalized_dist = dist / normalizer

    # Create the plot
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(7.5, 3))
    ax1.imshow(image)
    ax1.axis('off')
    ax1.set_title('Original Image')
    im = ax2.imshow(normalized_dist, cmap='YlOrRd', interpolation='nearest', aspect='equal')
    ax2.axis('off')
    ax2.set_title('Error Visualization')
    fig.colorbar(im, ax=ax2, label='Normalized Error')

    plt.tight_layout()

    # Ensure both subplots have the same size and are well-aligned
    ax1_pos = ax1.get_position()
    ax2_pos = ax2.get_position()
    ax2.set_position([ax2_pos.x0, ax1_pos.y0, ax2_pos.width, ax1_pos.height])

    plt.show()

# plot highest and lowest error images
pos_idx = list(reprs[0].data.keys()).index('up_blocks[1]')
for error_type, error_func in [("highest", np.argmax), ("lowest", np.argmin)]:
    error_index = error_func(errors[:, pos_idx]).item()
    print(f"Image with {error_type} average error (id: {error_index}): {errors[error_index, pos_idx]:.4f}")
    plot_error_visualization(images[error_index], reprs[error_index], reprs_flipped[error_index])


In [None]:
def plot_normalizer(n):
    k, l = torch.meshgrid(torch.arange(n), torch.arange(n), indexing='ij')

    # Create error normalization helper matrix
    tmp = torch.arange(-n, n).unsqueeze(0)**2
    error_normalization_matrix = torch.cumsum(torch.cumsum((tmp + tmp.T).sqrt(), dim=1), dim=0)

    # Normalize distance
    x_start, y_start = n - k - 1, n-l - 1
    x_end, y_end = x_start + n, y_start + n
    normalizer = (error_normalization_matrix[x_end,y_end] + error_normalization_matrix[x_start,y_start] - 
                    error_normalization_matrix[x_end,y_start] - error_normalization_matrix[x_start,y_end]) / n

    # Plot the normalizer
    plt.figure(figsize=(10, 8))
    plt.imshow(normalizer.numpy(), cmap='plasma')
    plt.colorbar(label='Normalization Factor')
    plt.title('Normalizer')
    plt.xlabel('X')
    plt.ylabel('Y')
    plt.show()

plot_normalizer(reprs[0].data['up_blocks[1]'].shape[-1])