# SC over blocks and noise step

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.autonotebook import tqdm
import datasets

In [None]:
models = ['SD1.5', 'SD2.1', 'SD-Turbo', 'SDXL', 'SDXL-Turbo']
resolutions = [256, 512, 768, 1024]
blocks = ['conv_in','down_blocks[0]','down_blocks[1]','down_blocks[2]','down_blocks[3]','mid_block','up_blocks[0]','up_blocks[1]','up_blocks[2]','up_blocks[3]','conv_out']
block_names = np.array([x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in blocks])
sdxl_block_indices = np.array([i for i, b in enumerate(blocks) if '3' not in b])
base_path = Path(f'PATH_TO_RESULTS/step50_vary_all_else/')

data = np.zeros((len(models), len(resolutions), len(blocks), 88328, 12), dtype=int)
for i, model in enumerate(tqdm(models, desc='models')):
    for j, resolution in enumerate(resolutions):
        for k, block in enumerate(blocks):
            data_path = base_path / f'{model}-{block}-expand_and_resize-{resolution}-50.npy'
            if not data_path.exists():
                print(data_path.name, 'missing')
                continue
            tmp = np.load(data_path)
            data[i,j,k,:,:] = tmp
            # for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            #     pass


In [3]:
pck = data[...,0].mean(axis=-1)

In [None]:
data.shape

In [None]:
fig, axs = plt.subplots(1, len(models), figsize=(len(models)*2.5, 1*2.5))
for i, model in enumerate(models):
    colors = plt.cm.viridis(np.linspace(0.0, 0.9, len(resolutions)))
    for j in range(len(resolutions)):
        x = block_names[sdxl_block_indices] if 'SDXL' in model else block_names
        y_idx = sdxl_block_indices if 'SDXL' in model else slice(None)
        axs[i].plot(x, pck[i,j,y_idx]*100, color=colors[j], label=resolutions[j], marker='o', markersize=2, alpha=0.7)
    axs[i].tick_params(labelsize=10, labelrotation=90)
    # axs[i].set_xlabel('block')
    axs[i].set_ylim(0)
    axs[i].set_title(model)
    axs[i].grid(axis='y', linestyle='--', alpha=0.3)
    if i == 0:
        axs[i].set_ylabel('PCK [%]')
    else:
        axs[i].set_yticklabels([])

# Get and set global min/max y values across all subplots
y_min = min(ax.get_ylim()[0] for ax in axs)
y_max = max(ax.get_ylim()[1] for ax in axs)
for ax in axs:
    ax.set_ylim(y_min, y_max)

axs[-1].legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
plt.tight_layout()
plt.show()

In [6]:
np.save(f'sc_pck_spair_step50_vary_all_else.npy', pck)

# Position bias

In [None]:
# error rate over prediction position relative to source position

model_idx = 0
lim = 0.5
downscale = 4

n = len(blocks)
m = len(resolutions)
error_rates = []
for i, resolution in enumerate(tqdm(resolutions, desc='resolutions')):
    error_rates.append(np.zeros((n, 2*resolution//downscale, 2*resolution//downscale)))
    for j, block in enumerate(blocks):
        tmp = data[model_idx, i, j]
        error_counts = np.zeros((2*resolution, 2*resolution), dtype=int)
        total_counts = np.zeros((2*resolution, 2*resolution), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            error_counts[sx-pred_x+resolution, sy-pred_y+resolution] += not correct
            total_counts[sx-pred_x+resolution, sy-pred_y+resolution] += 1

        error_count_scaled = error_counts.reshape(len(error_counts)//downscale, downscale, len(error_counts)//downscale, downscale).mean(axis=(1,3))
        total_count_scaled = total_counts.reshape(len(total_counts)//downscale, downscale, len(total_counts)//downscale, downscale).mean(axis=(1,3))
        error_rates[i][j,:,:] = error_count_scaled / total_count_scaled

# plot
fig, axs = plt.subplots(n, m, figsize=(m*1, n*1))
for i, block_name in enumerate(block_names):
    axs[i,0].text(-0.8, 0.5, block_name, ha='right', va='center', transform=axs[i,0].transAxes)
    for j, resolution in enumerate(resolutions):
        lim_scaled_rel = int((1-lim)*resolution)//downscale
        axs[i,j].imshow(error_rates[j][i,lim_scaled_rel:-lim_scaled_rel, lim_scaled_rel:-lim_scaled_rel], origin='lower', extent=(-resolution, resolution, -resolution, resolution))
        axs[i,j].tick_params(labelsize=8)
        if i != n-1: axs[i,j].set_xticks([])
        else: axs[i,j].set_xticks([-resolution/2, 0, resolution/2])
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].set_yticks([-resolution/2, 0, resolution/2])
        if i == n-1: axs[-1,j].text(0.5, -0.5, resolutions[j], ha='center', va='top', transform=axs[-1,j].transAxes)

# x/y-labels
axs[0,0].text(-1.8, -n/2, 'block | error rate', ha='right', va='center', transform=axs[0,0].transAxes, fontsize=12, rotation=90)
axs[-1,0].text(m/2+1, -1.0, 'noise step | distance [px]', ha='center', va='top', transform=axs[-1,0].transAxes, fontsize=12)

plt.show()
# np.save('sc_errors_by_relative_position_spair_maps_SD1.5.npy', error_rates)


In [None]:
# TODO: adjust

# error rate over prediction position relative to source position

lim = 128
downscale = 16
lim_scaled_rel = (512-lim)//downscale

# compute error rates
error_rates = np.zeros((len(blocks), len(noise_steps), int((2**.5*512)**.5)))
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]

        error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
        total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
            error_counts[int(dist)] += not correct
            total_counts[int(dist)] += 1

        error_rates[i,j,:] = error_counts / total_counts

# plot
fig, axs = plt.subplots(len(blocks), len(noise_steps), figsize=(len(noise_steps)*1, len(blocks)*1))
for i, block in enumerate(blocks):
    axs[i,0].text(-0.8, 0.5, block_names[i], ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        error_rate = error_rates[i,j,:]

        axs[i,j].plot(np.arange(len(error_rate))**2, error_rate)
        axs[i,j].set_xlim(-5, 105)
        if i != len(blocks)-1: axs[i,j].set_xticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if i == len(blocks)-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)

# set same ylim for each row (relative to the third lowest value)
for i in range(len(blocks)):
    ymin = 1 - (1 - sorted(ax.get_ylim()[0] for ax in axs[i,:])[2]) * 1.1
    for ax in axs[i,:]:
        ax.set_ylim(ymin, 1 + (1 - ymin) * 0.1)

# x/y-labels
axs[0,0].text(-1.8, -len(blocks)/2, 'block | error rate', ha='right', va='center', transform=axs[0,0].transAxes, fontsize=12, rotation=90)
axs[-1,0].text(len(noise_steps)/2, -1.0, 'noise step | relative distance [px]', ha='center', va='top', transform=axs[-1,0].transAxes, fontsize=12)

plt.show()
# np.save('sc_errors_by_relative_position_spair_lines_SD1.5.npy', error_rates)


In [None]:
# TODO: adjust

noise_index = 3

block_names = [x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in blocks]
colors = plt.cm.rainbow(np.linspace(0.0, 1.0, len(blocks)))
fig, ax = plt.subplots()
for i, block in enumerate(blocks):
    tmp = data[i, noise_index]
    error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
    total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
    for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
        dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
        error_counts[int(dist)] += not correct
        total_counts[int(dist)] += 1
    error_rate = error_counts / total_counts
    ax.plot(np.arange(len(error_rate))**2, error_rate, label=block_names[i], alpha=0.7, color=colors[i])
    ax.set_xlim(-5, 55)

    # print error rate stats
    print(block_names[i])
    lowest_error_idx = error_rate[:-1].argmin()
    print(f'min error rate at {lowest_error_idx**2} px')
    for i in range(lowest_error_idx+1):
        print(f'  {i**2} <= dist < {(i+1)**2} px: error rate {error_rate[i]:.2%}, count {total_counts[i]}')
    extra_errors = sum((error_rate[i]-error_rate[lowest_error_idx])*total_counts[i] for i in range(lowest_error_idx))
    print(f'-> extra errors: {extra_errors} ({extra_errors/total_counts.sum():.4%})')
    print()
ax.legend(loc='upper right', fontsize=8)
plt.show()


In [None]:
# TODO: adjust

tmp = data[7, 3]  # up_blocks[3], noise_step 50
error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
    dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
    error_counts[int(dist)] += not correct
    total_counts[int(dist)] += 1
error_rate = error_counts / total_counts

plt.plot(np.arange(len(error_rate))**2, error_rate)
plt.show()

lowest_error_idx = error_rate[:-1].argmin()
print(f'min error rate at {lowest_error_idx**2} px')
for i in range(lowest_error_idx+1):
    print(f'  {i**2} <= dist < {(i+1)**2} px: error rate {error_rate[i]:.2%}, count {total_counts[i]}')
extra_errors = sum((error_rate[i]-error_rate[lowest_error_idx])*total_counts[i] for i in range(lowest_error_idx))
print(f'-> extra errors: {extra_errors} ({extra_errors/total_counts.sum():.4%})')
print()

In [None]:
# TODO: adjust

# error rate over prediction position relative to target position 

lim = 128
downscale = 16
lim_scaled_rel = (512-lim)//downscale

fig, axs = plt.subplots(len(blocks), len(noise_steps), figsize=(len(blocks)*1, len(noise_steps)*1))
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    axs[i,0].text(-0.8, 0.5, block_names[i], ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]

        error_counts = np.zeros((2*512, 2*512), dtype=int)
        total_counts = np.zeros((2*512, 2*512), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            error_counts[tx-pred_x+512, ty-pred_y+512] += not correct
            total_counts[tx-pred_x+512, ty-pred_y+512] += 1

        error_count_scaled = error_counts.reshape(len(error_counts)//downscale, downscale, len(error_counts)//downscale, downscale).mean(axis=(1,3))
        total_count_scaled = total_counts.reshape(len(total_counts)//downscale, downscale, len(total_counts)//downscale, downscale).mean(axis=(1,3))
        error_rate_scaled = error_count_scaled / total_count_scaled

        axs[i,j].imshow(error_rate_scaled[lim_scaled_rel:-lim_scaled_rel, lim_scaled_rel:-lim_scaled_rel], origin='lower', extent=(-lim, lim, -lim, lim))
        if i != len(blocks)-1: axs[i,j].set_xticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if i == len(blocks)-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)
plt.show()

# Texture bias

In [15]:
spair_data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
spair_pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
spair_images = [x['img'] for x in tqdm(spair_data)]


In [None]:
spair_images_max_x = 0
spair_images_max_y = 0
for img in spair_images:
    spair_images_max_x = max(spair_images_max_x, img.size[0])
    spair_images_max_y = max(spair_images_max_y, img.size[1])

spair_images_data = np.zeros((len(spair_images), spair_images_max_y, spair_images_max_x, 3), dtype=np.uint8)
for i, img in enumerate(tqdm(spair_images)):
    spair_images_data[i, :img.size[1], :img.size[0], :] = np.array(img)

In [None]:
spair_kps_list = []
for pair in tqdm(spair_pairs):
    for (sx, sy), (tx, ty) in zip(pair['src_kps'], pair['trg_kps']):
        spair_kps_list.append((pair['src_data_index'], sx, sy, pair['trg_data_index'], tx, ty))

spair_kps = np.array(spair_kps_list)

In [None]:
data.shape, spair_kps.shape, spair_images_data.shape


In [None]:
# TODO: adjust

src_colors = spair_images_data[*spair_kps[:,[0,2,1]].T]
trg_colors = spair_images_data[*spair_kps[:,[3,5,4]].T]

src_colors.shape, trg_colors.shape


In [None]:
# TODO: adjust

color_diffs = np.abs(src_colors - trg_colors).mean(axis=-1)
fig, axs = plt.subplots(len(blocks), 1, figsize=(4,2*len(blocks)))
for i in range(len(blocks)):
    correct_at_50 = data[i,3,:,0]
    axs[i].hist(color_diffs[correct_at_50==1], bins=np.linspace(0, 256, 257), density=True, label='correct')
    axs[i].hist(color_diffs[correct_at_50==0], bins=np.linspace(0, 256, 257), density=True, alpha=0.5, label='wrong')
    axs[i].legend()
    axs[i].set_title(block_names[i])
plt.tight_layout()
plt.show()

# if there is some color bias in some layers, we maybe would expect to be more correct for smaller color differences
# this doesn't seem to be the case

In [None]:
# TODO: adjust

d=7
spair_images_data_padded = np.zeros((spair_images_data.shape[0], spair_images_data.shape[1]+2*d, spair_images_data.shape[2]+2*d, 3), dtype=spair_images_data.dtype)
spair_images_data_padded[:,d:-d,d:-d,:] = spair_images_data
src_regions = np.array([spair_images_data_padded[i, y:y+2*d+1, x:x+2*d+1] for i,x,y in spair_kps[:,[0,1,2]]])
trg_regions = np.array([spair_images_data_padded[i, y:y+2*d+1, x:x+2*d+1] for i,x,y in spair_kps[:,[3,4,5]]])

src_regions.shape, trg_regions.shape

In [None]:
# TODO: adjust

texture_diffs = np.abs(src_regions - trg_regions).mean(axis=-1).mean(axis=-1).mean(axis=-1)
fig, axs = plt.subplots(len(blocks), 1, figsize=(4,2*len(blocks)))
for i in range(len(blocks)):
    correct_at_50 = data[i,3,:,0]
    axs[i].hist(color_diffs[correct_at_50==1], bins=np.linspace(0, 256, 257), density=True, label='correct')
    axs[i].hist(color_diffs[correct_at_50==0], bins=np.linspace(0, 256, 257), density=True, alpha=0.5, label='wrong')
    axs[i].legend()
    axs[i].set_title(block_names[i])
plt.tight_layout()
plt.show()

# if there is some texture bias in some layers, we maybe would expect to be more correct for smaller texture differences
# this doesn't seem to be the case