# Visualize the h-space pixels norms with the highest/lowest values

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from sdhelper import SD
import datasets
from tqdm.notebook import tqdm, trange

In [None]:
# load model
sd = SD('SDXL-Turbo')
sd.pipeline.set_progress_bar_config(disable=True)
labels = datasets.load_dataset('imagenet-1k')['train'].features['label'].names

In [None]:
# config
p_norm = 2
n = 100

# initialize lists
norms_list = []
norm_indices_list = []
images_list = []
prompts = []

# calculate norms
for i in trange(n):
    random_label = np.random.choice(labels)
    result = sd(random_label, steps=2, extract_positions=['mid_block'])
    representations = result.representations['mid_block']
    norms = representations[-1].norm(p=p_norm, dim=0).detach().cpu().numpy()
    norms_sort_indices = norms.flatten().argsort().argsort().reshape(norms.shape)
    norms_list.append(norms)
    norm_indices_list.append(norms_sort_indices)
    images_list.append(result.result_image)
    prompts.append(random_label)

# calculate averages
avg_norms = np.mean(norms_list, axis=0)
avg_norm_indices = np.mean(norm_indices_list, axis=0)

In [None]:
# grid/image plot of average norms and average norm indices

plt.imshow(avg_norms)
plt.title(f'Mean L{p_norm} ({n} samples, {sd.model_name})')
plt.axis('off')
plt.colorbar()
plt.show()
plt.imshow(avg_norm_indices)
plt.title(f'Mean L{p_norm} index ({n} samples, {sd.model_name})')
plt.axis('off')
plt.colorbar()
plt.show()

In [None]:
# barplot of the top 10 average norms

idx = avg_norm_indices.flatten().argsort()[-10:][::-1]
plt.bar(np.arange(10), avg_norms.flatten()[idx], tick_label=np.array([f'({i},{j})' for i in range(16) for j in range(16)])[idx])
plt.title(f'Top 10 L{p_norm} ({n} samples, {sd.model_name})')
plt.show()

In [None]:
# Plot the image patches corresponding to the top and bottom 10 average norms

n = int(len(norms_list[0].flatten())**.5)  # number of h-space pixels per row/column (assuming square image)
m = 512 // n  # size of each image patch corresponding to a single h-space pixel

for i in range(10):
    idx_bot = norm_indices_list[i].flatten().argsort()[:10]
    idx_top = norm_indices_list[i].flatten().argsort()[-10:][::-1]
    img = np.array(images_list[i])
    plt.figure(figsize=(18, 5))
    plt.suptitle(f'Patches with highest (top) and lowest (bottom) h-space norm for prompt "{prompts[i]}" ({sd.model_name})')
    for num1, ij in enumerate(idx_top):
        i = ij // n
        j = ij % n
        patch = img[i*m:i*m+m, j*m:j*m+m,:]
        plt.subplot(2, 10, num1+1)
        plt.imshow(patch)
        plt.title(f'({i},{j})')
        plt.axis('off')
    for num2, ij in enumerate(idx_bot):
        i = ij // n
        j = ij % n
        patch = img[i*m:i*m+m, j*m:j*m+m,:]
        plt.subplot(2, 10, num1+num2+2)
        plt.imshow(patch)
        plt.title(f'({i},{j})')
        plt.axis('off')

    plt.tight_layout()
    plt.show()