# Find Anomalies by looking at the representation norms

In [None]:
from sdhelper import SD
from pathlib import Path
from PIL import Image
import torch
import matplotlib.pyplot as plt
from tqdm.autonotebook import tqdm, trange
from matplotlib.colors import LogNorm
import numpy as np
from datasets import load_dataset
from collections import Counter, defaultdict


In [None]:
sd = SD()

In [None]:
# dataset_path = Path('../random_images_flux/')
# dataset = [Image.open(p) for p in sorted(dataset_path.glob('*.jpg'))]
dataset = [x['image'] for x in load_dataset('JonasLoos/imagenet_subset', split='train')]
len(dataset)

In [None]:
representations_noise_0 = sd.img2repr(dataset, extract_positions=sd.available_extract_positions, step=0)
representations_noise_50 = sd.img2repr(dataset, extract_positions=sd.available_extract_positions, step=50)
representations_noise_200 = sd.img2repr(dataset, extract_positions=sd.available_extract_positions, step=200)

## Highest Norms

In [None]:
pos = 'up_blocks[1]'
top_n = 20
representations = representations_noise_50

# Calculate norms for
norms = torch.stack([r[pos].squeeze(0) for r in representations]).norm(dim=1)

# Get the indices of the top 5 images with highest norms
top_n_indices = norms.flatten().argsort(descending=True)[:top_n]

# Plot the top 5 images
fig, axes = plt.subplots(int(np.ceil(top_n/5)), 5, figsize=(5*2, int(np.ceil(top_n/5))*2+1))
axes = axes.flatten()
fig.suptitle(f"Top {top_n} images with highest norms in {pos}")

for i, idx in enumerate(top_n_indices):
    img_idx = idx // norms.shape[1]**2
    top_image = dataset[img_idx]
    
    axes[i].imshow(top_image, extent=(0, 1, 0, 1))
    axes[i].set_title(f"{img_idx} - {norms.flatten()[idx]:.2f}")
    
    tmp = idx % norms.shape[1]**2
    axes[i].plot(tmp % norms.shape[1] / norms.shape[1], 
                 1 - tmp // norms.shape[1] / norms.shape[1], 
                 'ro', markersize=10, markeredgecolor='white')
    axes[i].axis('off')

plt.tight_layout()
plt.show()


In [None]:
# plot high norm spot sizes

pos = 'up_blocks[1]'
representations = representations_noise_50
norm_threshold = 0.9

plt.figure(figsize=(10, 5))
reprs_tmp = torch.stack([r[pos].squeeze(0) for r in representations]).to('cuda')
norms = reprs_tmp.norm(dim=1).cpu()

# get all connected components of high norm tokens
for norm_threshold in [0.7, 0.8, 0.9, 0.95, 0.98]:
    counter = Counter()
    images = defaultdict(list)
    for i in trange(len(representations)):
        highest_norm = norms[i].max()
        argmax = norms[i].argmax().item()
        n = norms.shape[1]
        k,l = argmax//n, argmax%n
        size = 0
        neighbors = [(k,l)]
        visited = set([(k,l)])
        while neighbors:
            k,l = neighbors.pop()
            if norms[i, k, l] < highest_norm * norm_threshold:
                continue
            size += 1
            for dk, dl in [(-1,0), (1,0), (0,-1), (0,1)]:
                k2, l2 = k+dk, l+dl
                if 0 <= k2 < n and 0 <= l2 < n:
                    if (k2, l2) not in visited:
                        visited.add((k2, l2))
                        neighbors.append((k2, l2))
        counter[size] += 1
        images[size].append(i)
    plt.bar(counter.keys(), counter.values(), alpha=0.6, label=f'norm threshold {norm_threshold}')
    # plt.yscale('log')
    plt.xlim(0, 50)
plt.title(f'High norm spot sizes for {pos} with norm threshold {norm_threshold}')
plt.xlabel('Spot size')
plt.ylabel('Frequency')
plt.legend()
plt.show()


In [None]:
import random

# print first image with norm
for key, value in sorted(images.items()):
    fig, axs = plt.subplots(1, 3, figsize=(2*2+1, 2+1), width_ratios=[1, 1, 0.1])
    i = random.choice(value)
    fig.suptitle(f'High norm spot of size {key} ({len(value)} images)', fontsize=16)
    axs[0].imshow(dataset[i])
    axs[0].axis('off')
    im = axs[1].imshow(norms[i])
    axs[1].axis('off')
    
    # Make colorbar as high as imshow
    cbar = fig.colorbar(im, cax=axs[2], label='Norm')
    cbar.ax.set_box_aspect(10)

    plt.tight_layout()
    plt.show()


In [None]:
# norm-norm scatter plot

pos = 'up_blocks[1]'


axs = plt.subplots(1, 2, figsize=(8, 5), width_ratios=[2, 1])[1]
representations = representations_noise_50

reprs_tmp = torch.stack([r[pos].squeeze(0).flatten(start_dim=1).T for r in representations]).to('cuda')
norms = reprs_tmp.norm(dim=2).cpu()

# Plot (selection of) all norms
x, y = [], []
indices = torch.randint(reprs_tmp.shape[0]*reprs_tmp.shape[1], (10000,))
for i in tqdm(indices):
    others = reprs_tmp[torch.arange(reprs_tmp.shape[0])!=i//reprs_tmp.shape[1]].flatten(0,1)
    argmax = torch.cosine_similarity(reprs_tmp.flatten(0,1)[i][None, :], others, dim=1).argmax().cpu()
    x.append(norms.flatten()[i] / norms[i//norms.shape[1]].max())
    y.append((norms.flatten()[argmax] / norms[argmax//norms.shape[1],:].max().item()))

axs[0].scatter(x, y, alpha=0.25, s=2)
axs[1].hist(y, bins=50, orientation='horizontal', density=True)


# Plot highest norms
x, y = [], []
for i in trange(len(reprs_tmp)):
    j = norms[i].argmax()
    others = reprs_tmp[torch.arange(reprs_tmp.shape[0])!=i].flatten(0,1)
    argmax = torch.cosine_similarity(reprs_tmp[i][j][None, :], others, dim=1).argmax().cpu()
    x.append(1)
    y.append(norms.flatten()[argmax] / norms[argmax//norms.shape[1],:].max().item())

axs[0].scatter(x, y, alpha=0.25, s=2)
axs[1].hist(y, bins=50, orientation='horizontal', density=True, alpha=0.5)


# configure plot
axs[0].set_xlabel('Relative Norm of first')
axs[0].set_ylabel('Relative Norm of second')
# axs[0].set_yscale('log')
# axs[0].set_xscale('log')
axs[0].set_ylim(0, 1.05)
axs[0].set_xlim(0, 1.05)
axs[0].set_title('Norms of tokens and their most similar matches')
axs[0].legend(['all', 'highest norm only'])

axs[1].set_xlabel('Relative Frequency')
axs[1].set_yticks([])
axs[1].set_ylim(0, 1.05)
axs[1].set_title('Distribution of norms')

plt.tight_layout()
plt.show()



In [None]:
# plot norms of tokens that match with highest norm tokens

pos = 'up_blocks[1]'
representations = representations_noise_50

x, y = [], []
reprs_tmp = torch.stack([r[pos].squeeze(0).flatten(start_dim=1).T for r in representations])
norms = reprs_tmp.norm(dim=2).cpu()
for i in trange(len(reprs_tmp)):
    j = norms[i].argmax()
    others = reprs_tmp[torch.arange(reprs_tmp.shape[0])!=i].flatten(0,1)
    argmax = torch.cosine_similarity(reprs_tmp[i][j][None, :], others, dim=1).argmax().cpu()
    x.append(norms[i, j])
    y.append(norms.flatten()[argmax])

plt.figure(figsize=(5, 5))
plt.scatter(x, y)
plt.xlabel('Norm of first')
plt.ylabel('Norm of second')
plt.yscale('log')
plt.xscale('log')
plt.show()


# Norms vs Similarity

In [None]:
n = 100

histograms: dict[str, list[torch.Tensor|None]] = {pos: [None, None, None] for pos in sd.available_extract_positions}
for pos in tqdm(sd.available_extract_positions):
    for i, representations in enumerate([representations_noise_0, representations_noise_50, representations_noise_200]):
        # calculate norms
        norms = torch.stack([r[pos].squeeze(0).flatten(start_dim=1).T for r in representations]).norm(dim=2).cpu()
        x_min, x_max = -0.5, 1.0
        y_min, y_max = norms.min().item(), norms.max().item()

        # calculate histogram
        # do this in a loop to avoid memory issues
        histogram = torch.zeros(n, n)
        for r1, n1 in zip(tqdm(representations), norms):
            for r2, n2 in zip(representations, norms):
                if (np.random.random() > 0.01 and pos in ['conv_in', 'up_blocks[2]', 'up_blocks[3]', 'conv_out']) or (np.random.random() > 0.1 and pos in ['down_blocks[0]', 'up_blocks[1]']):
                    continue
                # TODO: correct normalization of histogram
                data = torch.stack([
                    r1.at(pos).cosine_similarity(r2.at(pos)).flatten().cpu(),
                    (n1[None,:]+n2[:,None]).flatten()/2
                ], dim=1)
                histogram += torch.histogramdd(data, bins=n, range=(x_min, x_max, y_min, y_max)).hist.T.flip(0)
        histograms[pos][i] = histogram

In [None]:
for pos in sd.available_extract_positions:

    plt.figure(figsize=(15, 4))  # Increased width to accommodate colorbar
    for i, (representations, noise_level) in enumerate([(representations_noise_0, 0), (representations_noise_50, 50), (representations_noise_200, 200)]):
        # calculate norms
        norms = torch.stack([r[pos].squeeze(0).flatten(start_dim=1).T for r in representations]).norm(dim=2).cpu()
        x_min, x_max = -0.5, 1.0
        y_min, y_max = norms.min().item(), norms.max().item()

        histogram = histograms[pos][i]
        if histogram is None: continue
        # plot histogram
        plt.subplot(1, 3, i+1)
        plt.imshow(histogram/histogram.sum(), norm=LogNorm(vmin=1e-12, vmax=1e-2), cmap='YlOrRd', extent=(x_min, x_max, y_min, y_max), aspect='auto')
        plt.ylabel('Average Norm')
        plt.xlabel('Cosine Similarity')
        plt.title(f'{pos} - Noise: {noise_level}')
    sm = plt.cm.ScalarMappable(cmap='YlOrRd', norm=LogNorm(vmin=1e-12, vmax=1e-2))
    sm.set_array([])
    fig = plt.gcf()
    cbar_ax = fig.add_axes([0.92, 0.15, 0.015, 0.75])  # [left, bottom, width, height]
    cbar = plt.colorbar(sm, cax=cbar_ax)
    cbar.set_label('Normalized Frequency')
    plt.tight_layout(rect=[0, 0, 0.9, 1])  # Adjust layout to make room for colorbar
    plt.show()

In [None]:
# torch.save(torch.stack([torch.stack(h) for h in histograms.values()]), '../histograms_norm_cossim_sd15.pt')

In [None]:
y_limits = torch.zeros(len(sd.available_extract_positions), 3, 2)
for pos_i, pos in enumerate(sd.available_extract_positions):
    for i, (representations, noise_level) in enumerate([(representations_noise_0, 0), (representations_noise_50, 50), (representations_noise_200, 200)]):
        norms = torch.stack([r[pos].squeeze(0).flatten(start_dim=1).T for r in representations]).norm(dim=2).cpu()
        y_min, y_max = norms.min().item(), norms.max().item()
        y_limits[pos_i, i, 0] = y_min
        y_limits[pos_i, i, 1] = y_max
# torch.save(y_limits, '../y_limits_norm_cossim_sd15.pt')