# SC over blocks and noise step

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.autonotebook import tqdm
import datasets

In [None]:
model = 'SDXL'
blocks = ['conv_in','down_blocks[0]','down_blocks[1]','down_blocks[2]','mid_block','up_blocks[0]','up_blocks[1]','up_blocks[2]','conv_out']
# blocks = ['conv_in','down_blocks[0]','down_blocks[1]','down_blocks[2]','down_blocks[3]','mid_block','up_blocks[0]','up_blocks[1]','up_blocks[2]','up_blocks[3]','conv_out']
noise_steps = [0, 10, 25, 50, 75, 100, 150, 200, 300, 500, 800]
base_path = Path(f'PATH_TO_RESULTS/step_over_blocks_{model.replace(".", "")}/')

data = np.zeros((len(blocks), len(noise_steps), 88328, 12), dtype=int)
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    for j, noise_step in enumerate(noise_steps):
        data_path = base_path / f'{model}-{block}-expand_and_resize-512-{noise_step}.npy'
        if not data_path.exists():
            print(data_path.name, 'missing')
            continue
        tmp = np.load(data_path)
        data[i,j,:,:] = tmp
        # for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
        #     pass


In [3]:
pck = data[:,:,:,0].mean(axis=-1)

In [None]:
data.shape

In [None]:
best_block_idx = pck[:,0].argmax()
ranking = pck[best_block_idx,:].argsort()[::-1]
print(f'noise step ranking for best block {blocks[best_block_idx]}:')
for i, idx in enumerate(ranking):
    print(f'{i+1:2}. {noise_steps[idx]:3}: {pck[best_block_idx,idx]:6.2%}')

print()
print(f'noise step ranking for average over all blocks:')
ranking = pck.mean(axis=0).argsort()[::-1]
for i, idx in enumerate(ranking):
    print(f'{i+1:2}. {noise_steps[idx]:3}: {pck.mean(axis=0)[idx]:6.2%}')


In [None]:
block_names = [x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in blocks]

colors = plt.cm.viridis(np.linspace(0.0, 0.9, len(noise_steps)))
for i, (x, color) in enumerate(zip(pck.T, colors)):
    plt.plot(block_names, x*100, color=color, label=noise_steps[i], marker='o', alpha=0.7)
    plt.xticks(rotation=45, ha='right')
plt.xlabel('block')
plt.ylabel('PCK [%]')
plt.ylim(0)
plt.legend()
plt.show()

In [7]:
np.save(f'sc_pck_spair_over_blocks_noise_{model.replace(".", "")}.npy', pck)

In [None]:
block_names = [x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in blocks]
colors = plt.cm.rainbow(np.linspace(0.0, 1.0, len(blocks)))
for i in range(len(blocks)):
    plt.plot(noise_steps, pck[i,:]*100, color=colors[i], label=block_names[i], marker='o', alpha=0.7)
plt.xlabel('noise step')
plt.ylabel('PCK [%]')
plt.ylim(0)
plt.legend()
plt.show()


# Position bias

In [None]:
# histogram of sample count over predicted position relative to source position
# not really relevant by itself, potentially in combination with the error rate

lim = 128
bins = 16

fig, axs = plt.subplots(len(blocks), len(noise_steps), figsize=(len(blocks)*1, len(noise_steps)*1))
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    axs[i,0].text(-0.1, 0.5, block_names[i], ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]
        dx = tmp[:,1] - tmp[:,5]
        dy = tmp[:,2] - tmp[:,6]

        hist, *_ = np.histogram2d(dx, dy, bins=bins, range=[[-lim,lim],[-lim,lim]], density=True)
        axs[i,j].imshow(hist, origin='lower', extent=(-lim, lim, -lim, lim), interpolation='nearest')
        if i != len(blocks)-1: axs[i,j].set_xticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if i == len(blocks)-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)
plt.show()

In [None]:
# error rate over prediction position relative to source position

lim = 256
downscale = 4
lim_scaled_rel = (512-lim)//downscale

error_rates = np.zeros((len(blocks), len(noise_steps), 2*512//downscale, 2*512//downscale))
n, m, *_ = error_rates.shape
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]
        error_counts = np.zeros((2*512, 2*512), dtype=int)
        total_counts = np.zeros((2*512, 2*512), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            error_counts[sx-pred_x+512, sy-pred_y+512] += not correct
            total_counts[sx-pred_x+512, sy-pred_y+512] += 1

        error_count_scaled = error_counts.reshape(len(error_counts)//downscale, downscale, len(error_counts)//downscale, downscale).mean(axis=(1,3))
        total_count_scaled = total_counts.reshape(len(total_counts)//downscale, downscale, len(total_counts)//downscale, downscale).mean(axis=(1,3))
        error_rates[i,j,:,:] = error_count_scaled / total_count_scaled

# plot
fig, axs = plt.subplots(n, m, figsize=(m*1, n*1))
for i, block_name in enumerate(block_names):
    axs[i,0].text(-0.8, 0.5, block_name, ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        axs[i,j].imshow(error_rates[i,j,lim_scaled_rel:-lim_scaled_rel, lim_scaled_rel:-lim_scaled_rel], origin='lower', extent=(-lim, lim, -lim, lim))
        axs[i,j].tick_params(labelsize=8)
        if i != n-1: axs[i,j].set_xticks([])
        else: axs[i,j].set_xticks([-200, 0, 200])
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].set_yticks([-200, 0, 200])
        if i == n-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)

# x/y-labels
axs[0,0].text(-1.8, -n/2, 'block | error rate', ha='right', va='center', transform=axs[0,0].transAxes, fontsize=12, rotation=90)
axs[-1,0].text(m/2+1, -1.0, 'noise step | distance [px]', ha='center', va='top', transform=axs[-1,0].transAxes, fontsize=12)

plt.show()
# np.save('sc_errors_by_relative_position_spair_maps_SD1.5.npy', error_rates)


In [None]:
# error rate over prediction position relative to source position

lim = 128
downscale = 16
lim_scaled_rel = (512-lim)//downscale

# compute error rates
error_rates = np.zeros((len(blocks), len(noise_steps), int((2**.5*512)**.5)))
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]

        error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
        total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
            error_counts[int(dist)] += not correct
            total_counts[int(dist)] += 1

        error_rates[i,j,:] = error_counts / total_counts

# plot
fig, axs = plt.subplots(len(blocks), len(noise_steps), figsize=(len(noise_steps)*1, len(blocks)*1))
for i, block in enumerate(blocks):
    axs[i,0].text(-0.8, 0.5, block_names[i], ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        error_rate = error_rates[i,j,:]

        axs[i,j].plot(np.arange(len(error_rate))**2, error_rate)
        axs[i,j].set_xlim(-5, 105)
        if i != len(blocks)-1: axs[i,j].set_xticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if i == len(blocks)-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)

# set same ylim for each row (relative to the third lowest value)
for i in range(len(blocks)):
    ymin = 1 - (1 - sorted(ax.get_ylim()[0] for ax in axs[i,:])[2]) * 1.1
    for ax in axs[i,:]:
        ax.set_ylim(ymin, 1 + (1 - ymin) * 0.1)

# x/y-labels
axs[0,0].text(-1.8, -len(blocks)/2, 'block | error rate', ha='right', va='center', transform=axs[0,0].transAxes, fontsize=12, rotation=90)
axs[-1,0].text(len(noise_steps)/2, -1.0, 'noise step | relative distance [px]', ha='center', va='top', transform=axs[-1,0].transAxes, fontsize=12)

plt.show()
# np.save('sc_errors_by_relative_position_spair_lines_SD1.5.npy', error_rates)


In [None]:
noise_index = 3

block_names = [x.replace('_blocks', '').replace('_block', '').replace('_', '-') for x in blocks]
colors = plt.cm.rainbow(np.linspace(0.0, 1.0, len(blocks)))
fig, ax = plt.subplots()
for i, block in enumerate(blocks):
    tmp = data[i, noise_index]
    error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
    total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
    for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
        dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
        error_counts[int(dist)] += not correct
        total_counts[int(dist)] += 1
    error_rate = error_counts / total_counts
    ax.plot(np.arange(len(error_rate))**2, error_rate, label=block_names[i], alpha=0.7, color=colors[i])
    ax.set_xlim(-5, 55)

    # print error rate stats
    print(block_names[i])
    lowest_error_idx = error_rate[:-1].argmin()
    print(f'min error rate at {lowest_error_idx**2} px')
    for i in range(lowest_error_idx+1):
        print(f'  {i**2} <= dist < {(i+1)**2} px: error rate {error_rate[i]:.2%}, count {total_counts[i]}')
    extra_errors = sum((error_rate[i]-error_rate[lowest_error_idx])*total_counts[i] for i in range(lowest_error_idx))
    print(f'-> extra errors: {extra_errors} ({extra_errors/total_counts.sum():.4%})')
    print()
ax.legend(loc='upper right', fontsize=8)
plt.show()


In [None]:
tmp = data[7, 3]  # up_blocks[3], noise_step 50
error_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
total_counts = np.zeros(int((2**.5*512)**.5), dtype=int)
for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
    dist = ((sx-pred_x)**2 + (sy-pred_y)**2)**.25
    error_counts[int(dist)] += not correct
    total_counts[int(dist)] += 1
error_rate = error_counts / total_counts

plt.plot(np.arange(len(error_rate))**2, error_rate)
plt.show()

lowest_error_idx = error_rate[:-1].argmin()
print(f'min error rate at {lowest_error_idx**2} px')
for i in range(lowest_error_idx+1):
    print(f'  {i**2} <= dist < {(i+1)**2} px: error rate {error_rate[i]:.2%}, count {total_counts[i]}')
extra_errors = sum((error_rate[i]-error_rate[lowest_error_idx])*total_counts[i] for i in range(lowest_error_idx))
print(f'-> extra errors: {extra_errors} ({extra_errors/total_counts.sum():.4%})')
print()

In [None]:
# error rate over prediction position relative to target position 

lim = 128
downscale = 16
lim_scaled_rel = (512-lim)//downscale

fig, axs = plt.subplots(len(blocks), len(noise_steps), figsize=(len(blocks)*1, len(noise_steps)*1))
for i, block in enumerate(tqdm(blocks, desc='blocks')):
    axs[i,0].text(-0.8, 0.5, block_names[i], ha='right', va='center', transform=axs[i,0].transAxes)
    for j, noise_step in enumerate(noise_steps):
        tmp = data[i, j]

        error_counts = np.zeros((2*512, 2*512), dtype=int)
        total_counts = np.zeros((2*512, 2*512), dtype=int)
        for correct, sx, sy, tx, ty, pred_x, pred_y, sn, sm, tn, tm, category_id in tmp:
            error_counts[tx-pred_x+512, ty-pred_y+512] += not correct
            total_counts[tx-pred_x+512, ty-pred_y+512] += 1

        error_count_scaled = error_counts.reshape(len(error_counts)//downscale, downscale, len(error_counts)//downscale, downscale).mean(axis=(1,3))
        total_count_scaled = total_counts.reshape(len(total_counts)//downscale, downscale, len(total_counts)//downscale, downscale).mean(axis=(1,3))
        error_rate_scaled = error_count_scaled / total_count_scaled

        axs[i,j].imshow(error_rate_scaled[lim_scaled_rel:-lim_scaled_rel, lim_scaled_rel:-lim_scaled_rel], origin='lower', extent=(-lim, lim, -lim, lim))
        if i != len(blocks)-1: axs[i,j].set_xticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if j != 0: axs[i,j].set_yticks([])
        else: axs[i,j].tick_params(labelsize=8)
        if i == len(blocks)-1: axs[-1,j].text(0.5, -0.5, noise_steps[j], ha='center', va='top', transform=axs[-1,j].transAxes)
plt.show()

# Texture bias

In [15]:
spair_data = datasets.load_dataset('0jl/SPair-71k', 'data', split='train', trust_remote_code=True)
spair_pairs = datasets.load_dataset('0jl/SPair-71k', 'pairs', split='test', trust_remote_code=True)

In [None]:
spair_images = [x['img'] for x in tqdm(spair_data)]


In [None]:
spair_images_max_x = 0
spair_images_max_y = 0
for img in spair_images:
    spair_images_max_x = max(spair_images_max_x, img.size[0])
    spair_images_max_y = max(spair_images_max_y, img.size[1])

spair_images_data = np.zeros((len(spair_images), spair_images_max_y, spair_images_max_x, 3), dtype=np.uint8)
for i, img in enumerate(tqdm(spair_images)):
    spair_images_data[i, :img.size[1], :img.size[0], :] = np.array(img)

In [None]:
spair_kps_list = []
for pair in tqdm(spair_pairs):
    for (sx, sy), (tx, ty) in zip(pair['src_kps'], pair['trg_kps']):
        spair_kps_list.append((pair['src_data_index'], sx, sy, pair['trg_data_index'], tx, ty))

spair_kps = np.array(spair_kps_list)

In [None]:
data.shape, spair_kps.shape, spair_images_data.shape


In [None]:
src_colors = spair_images_data[*spair_kps[:,[0,2,1]].T]
trg_colors = spair_images_data[*spair_kps[:,[3,5,4]].T]

src_colors.shape, trg_colors.shape


In [None]:
color_diffs = np.abs(src_colors - trg_colors).mean(axis=-1)
fig, axs = plt.subplots(len(blocks), 1, figsize=(4,2*len(blocks)))
for i in range(len(blocks)):
    correct_at_50 = data[i,3,:,0]
    axs[i].hist(color_diffs[correct_at_50==1], bins=np.linspace(0, 256, 257), density=True, label='correct')
    axs[i].hist(color_diffs[correct_at_50==0], bins=np.linspace(0, 256, 257), density=True, alpha=0.5, label='wrong')
    axs[i].legend()
    axs[i].set_title(block_names[i])
plt.tight_layout()
plt.show()

# if there is some color bias in some layers, we maybe would expect to be more correct for smaller color differences
# this doesn't seem to be the case

In [None]:
d=7
spair_images_data_padded = np.zeros((spair_images_data.shape[0], spair_images_data.shape[1]+2*d, spair_images_data.shape[2]+2*d, 3), dtype=spair_images_data.dtype)
spair_images_data_padded[:,d:-d,d:-d,:] = spair_images_data
src_regions = np.array([spair_images_data_padded[i, y:y+2*d+1, x:x+2*d+1] for i,x,y in spair_kps[:,[0,1,2]]])
trg_regions = np.array([spair_images_data_padded[i, y:y+2*d+1, x:x+2*d+1] for i,x,y in spair_kps[:,[3,4,5]]])

src_regions.shape, trg_regions.shape

In [None]:
texture_diffs = np.abs(src_regions - trg_regions).mean(axis=-1).mean(axis=-1).mean(axis=-1)
fig, axs = plt.subplots(len(blocks), 1, figsize=(4,2*len(blocks)))
for i in range(len(blocks)):
    correct_at_50 = data[i,3,:,0]
    axs[i].hist(color_diffs[correct_at_50==1], bins=np.linspace(0, 256, 257), density=True, label='correct')
    axs[i].hist(color_diffs[correct_at_50==0], bins=np.linspace(0, 256, 257), density=True, alpha=0.5, label='wrong')
    axs[i].legend()
    axs[i].set_title(block_names[i])
plt.tight_layout()
plt.show()

# if there is some texture bias in some layers, we maybe would expect to be more correct for smaller texture differences
# this doesn't seem to be the case