## Visualize h-space pixel norms using histograms

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from sdhelper import SD
from tqdm.notebook import tqdm, trange

In [None]:
p_norm = 2
samples = 20
# for model in ['SD-Turbo', 'SDXL-Turbo', 'SD1.5']:#, 'SD2.1']:
for model in ['SDXL-Turbo']:
    sd = SD(model)
    sd.pipeline.set_progress_bar_config(disable=True)
    print(model)
    all_norms = []
    for i in trange(samples):
        norms = []
        result = sd('a cat', seed=i, extract_positions=['mid_block'])
        for step, repr in enumerate(result.representations['mid_block']):
            x = repr.to(dtype=torch.float32)
            norm = x.norm(p_norm, 0)
            norms.append(norm)
        all_norms.append(torch.stack(norms))
    all_norms = torch.stack(all_norms).cpu().numpy().transpose(1, 0, 2, 3)
    
    # Set up colormap
    colormap = plt.cm.rainbow
    color_steps = np.linspace(0, 1, len(all_norms))
    color_iterator = iter(colormap(color_steps))

    # plot
    for step, norm in enumerate(all_norms):
        plt.hist(norm.flatten(), bins=30, alpha=.7, label=f'step {step}', color=next(color_iterator))
    plt.title(f'{model} h-space L{p_norm} norm distribution ({samples} samples)')
    plt.xlabel(f'L{p_norm} norm')
    plt.ylabel('count')
    if len(all_norms) < 10:
        plt.legend()
    else:
        # plot only every 10th legend entry
        handles, labels = plt.gca().get_legend_handles_labels()
        plt.legend(handles[::10], labels[::10])
    plt.show()