# Norm Distribution maps for different norms and different layers

In [None]:
from sdhelper import SD
import numpy as np
from tqdm.autonotebook import tqdm, trange
from matplotlib import pyplot as plt
import PIL.Image
import PIL.ImageOps
import torch
from pathlib import Path
import matplotlib.colors


In [None]:
sd = SD()

In [None]:
dataset_path = Path('../random_images_flux/')
dataset = [PIL.Image.open(p) for p in dataset_path.glob('*.jpg')]
len(dataset)

In [None]:
reprs = sd.img2repr(dataset, extract_positions=sd.available_extract_positions, step=50)

In [None]:
# plot norm map
for norm_type in [1, 2, np.inf]:
    norms = {p: torch.stack([torch.linalg.norm(r[p][0], ord=norm_type, dim=0) for r in reprs]).mean(dim=0) for p in sd.available_extract_positions}
    plt.figure(figsize=(3*len(sd.available_extract_positions), 10))
    for i, p in enumerate(sd.available_extract_positions):
        plt.subplot(1, len(sd.available_extract_positions), i+1)
        plt.title(p)
        plt.imshow(norms[p], cmap='gray')
        plt.axis('off')
    plt.show()

In [None]:
# norm over position like in "Vision Transformers Need Registers", fig. 4a
# TODO: could use gridspec to align the colorbars and twinx to have a native y-axis

# config
bins = 200

# plot
plt.figure(figsize=(3*5, 5))
for i, (norm_type, norm_name, max_norm) in enumerate([(1, 'L1', 11), (2, 'L2', 8), (np.inf, 'L$\infty$', 6)]):
    norms = torch.stack([torch.histogram(torch.stack([torch.linalg.norm(r[p][0], ord=norm_type, dim=0) for r in reprs]).flatten().float().log(), bins=bins, range=(0, max_norm))[0].flip(0) / reprs[0][p].shape[2]**2 / len(reprs) for p in sd.available_extract_positions])
    plt.subplot(1, 3, i+1)
    plt.title(f'{norm_name} Norm Distribution over blocks')
    plt.imshow(norms.T, cmap='YlOrRd', norm=matplotlib.colors.LogNorm(), aspect=len(sd.available_extract_positions)/bins, interpolation='nearest')
    plt.colorbar(extend='min')
    plt.xlabel('Position')
    plt.ylabel(f'{norm_name} Norm')
    plt.xticks(ticks=range(len(sd.available_extract_positions)), labels=sd.available_extract_positions, rotation=90, ha='center')
    plt.yticks(ticks=range(bins-1,-1,-bins//10), labels=[f'{torch.exp(torch.linspace(0, max_norm, 10)).numpy()[i]:.2e}' for i in range(10)])

plt.tight_layout()
plt.show()

In [None]:
# simple histogram plot

bins = 30

fig, axes = plt.subplots(len(sd.available_extract_positions), 3, figsize=(3*5, len(sd.available_extract_positions)*5))
norm_types = [(1, 'L1', 11), (2, 'L2', 8), (np.inf, 'L$\infty$', 6)]

for i, (norm_type, norm_name, max_norm) in enumerate(norm_types):
    for j, p in enumerate(sd.available_extract_positions):
        norms = torch.stack([torch.linalg.norm(r[p][0], ord=norm_type, dim=0).flatten() for r in reprs])
        ax = axes[j, i]
        
        # Use torch.histogram instead of plt.hist
        hist = torch.histogram(norms, bins=bins, range=(norms.min().item(), norms.max().item()))
        ax.bar(hist.bin_edges[:-1], hist.hist, width=hist.bin_edges[1]-hist.bin_edges[0], edgecolor='black', align='edge')

        ax.set_yscale('log')
        # ax.set_xscale('log')

        # Add row/column labels (extract positions / norm types)
        if i == 0: ax.set_ylabel(f'{p}\n(Norm Frequency)', fontsize=12)
        if j == 0: ax.set_title(f'{norm_name} Norm', fontsize=12)

plt.tight_layout()
plt.show()
