In [1]:
import numpy as np
import matplotlib.pyplot as plt
from sdhelper import SD
from PIL import Image
from datasets import load_dataset
import torch
from tqdm.autonotebook import tqdm, trange

In [2]:
data_path = "high_norm_anomalies_nyuv2_norm_step50_seed42.npy"
seed = 42


In [None]:
dataset = load_dataset("0jl/NYUv2", split="train", trust_remote_code=True)
dataset[0]['image']

In [None]:
data = np.load(data_path)
data[:5]

In [None]:
counts = np.zeros((len(dataset),), dtype=np.int32)
for i, x, y in data:
    counts[i] += 1

bars = np.zeros(counts.max()+1)
for x in counts:
    bars[x] += 1

plt.bar(np.arange(len(bars)), bars)
plt.title("Number of anomalies per image")
# plt.yscale("log")
plt.xlabel("Number of anomalies")
plt.ylabel("Number of images")
plt.show()

print(f'{1-bars[0]/len(dataset):.2%} of images have anomalies')

## Anomaly similarities

In [None]:
sd = SD()
representations = sd.img2repr([x['image'] for x in dataset], extract_positions=['up_blocks[1]'], step=50, seed=seed)
representations = torch.stack([r['up_blocks[1]'].squeeze(0) for r in representations]).to(dtype=torch.float32)
norms = torch.linalg.norm(representations, dim=1)

In [None]:
# norm histogram
norms_of_selected = norms[data[:, 0], data[:, 2], data[:, 1]]
print(f'norms of selected: min {norms_of_selected.min():.2f}, max {norms_of_selected.max():.2f}, mean {norms_of_selected.mean():.2f}')

# mean norm of 2x2 patches around selected anomalies 
reprs_of_patches = torch.concat([
    representations[data[:, 0], :, data[:, 2]+0, data[:, 1]+0],
    representations[data[:, 0], :, data[:, 2]+0, data[:, 1]+1],
    representations[data[:, 0], :, data[:, 2]+1, data[:, 1]+0],
    representations[data[:, 0], :, data[:, 2]+1, data[:, 1]+1],
], dim=0)
norms_of_patches = torch.linalg.norm(reprs_of_patches, dim=1)
print(f'norms 4x4 patches: min {norms_of_patches.min():.2f}, max {norms_of_patches.max():.2f}, mean {norms_of_patches.mean():.2f}')
print(f'norms of all:      min {norms.min():.2f}, max {norms.max():.2f}, mean {norms.mean():.2f}')


In [49]:
torch.save(reprs_of_patches.mean(dim=0), 'high_norm_anomalies_nyuv2_step50_seed42_reprs_of_patches_mean.pt')


In [None]:
# selected similarity
reprs_of_selected = representations[data[:, 0], :, data[:, 2], data[:, 1]]
similarities = torch.cosine_similarity(reprs_of_selected[:, None], reprs_of_selected[None, :], dim=2)
print(f'mean similarity of all selected: {similarities.mean():.4f}')


# all similarity (random subset of 1000)
all_reprs = representations.permute(0, 2, 3, 1).flatten(0,2)[torch.randperm(representations.shape[0])[:1000]]
similarities = torch.cosine_similarity(all_reprs[:, None], all_reprs[None, :], dim=2)
print(f'mean similarity of random subset of all: {similarities.mean():.4f}')

In [None]:
n = 50

# cosine similarity
cosine_similarity = lambda x: torch.cosine_similarity(x[:, None], x[None, :], dim=2)

# euclidean distance
def euclidean_similarity(x):
    distance = ((x[:, None, :] - x[None, :, :])**2).mean(dim=-1)**.5
    return 1-distance/distance.max()

similarity_measure = cosine_similarity

# single tokens
for pos_name, (dx, dy) in {'top-left': (0, 0), 'top-right': (0, 1), 'bottom-left': (1, 0), 'bottom-right': (1, 1)}.items():
    tmp_reprs = representations[data[:, 0], :, data[:, 2]+dx, data[:, 1]+dy]
    tmp_reprs_sorted = tmp_reprs[torch.argsort(norms_of_selected, descending=True)]
    similarities_selected = similarity_measure(tmp_reprs_sorted)
    mean_similarities_selected = np.zeros(n)
    for i in range(n):
        top_i = int(len(tmp_reprs_sorted)*(i+1)/n)
        # normalize while accounting for self-similarity (1s on diagonal)
        mean_similarities_selected[i] = (similarities_selected[:top_i, :top_i].sum() - top_i) / (top_i*(top_i-1))
    plt.plot(np.linspace(1, 100, n), mean_similarities_selected, label=pos_name)

# patches
reprs_of_patches_sorted = reprs_of_patches[torch.argsort(norms_of_patches, descending=True)]
similarities_patches = similarity_measure(reprs_of_patches_sorted)
mean_similarities_patches = np.zeros(n)
for i in range(n):
    top_i = int(len(reprs_of_patches)*(i+1)/n)
    mean_similarities_patches[i] = (similarities_patches[:top_i, :top_i].sum() - top_i) / (top_i*(top_i-1))
plt.plot(np.linspace(1, 100, n), mean_similarities_patches, label="patches")

# all (subset)
all_reprs_sorted = all_reprs[torch.argsort(all_reprs.norm(dim=1), descending=True)]
similarities_all = similarity_measure(all_reprs_sorted)
mean_similarities_all = np.zeros(n)
for i in range(n):
    top_i = int(len(all_reprs)*(i+1)/n)
    mean_similarities_all[i] = (similarities_all[:top_i, :top_i].sum() - top_i) / (top_i*(top_i-1))
plt.plot(np.linspace(1, 100, n), mean_similarities_all, label="all (random subset)")

plt.title("Mean similarity of top k% by norm")
plt.xlabel("Top k% by norm")
plt.ylabel("Mean similarity")
plt.legend()
plt.show()


In [None]:
print(f'Mean patch similarity: {(similarities_patches.sum() - len(similarities_patches)) / (len(similarities_patches)*(len(similarities_patches)-1)):.4f}')


In [None]:
# similarity to average repr

bins = 100

repr_means = {
    'patch': reprs_of_patches.mean(dim=0),
    'selected': reprs_of_selected.mean(dim=0),
}


for repr_name, repr_mean in repr_means.items():
    plt.figure(figsize=(12, 6))

    # all (subset)
    all_similarities_to_mean = torch.cosine_similarity(repr_mean[None, :], all_reprs, dim=1)
    all_hist = torch.histc(all_similarities_to_mean, bins=bins, min=-0.2, max=1.0)
    # all_hist /= all_hist.sum() / bins  # normalization to avg bin size = 1
    plt.bar(np.linspace(-0.2, 1.0, bins), all_hist.numpy(), width=0.012, label='all (subset)', alpha=0.6, color='blue')

    # patches
    patch_similarities_to_mean = torch.cosine_similarity(repr_mean[None, :], reprs_of_patches, dim=1)
    patch_hist = torch.histc(patch_similarities_to_mean, bins=bins, min=-0.2, max=1.0)
    # patch_hist /= patch_hist.sum() / bins  # normalization to avg bin size = 1
    plt.bar(np.linspace(-0.2, 1.0, bins), patch_hist.numpy(), width=0.012, label='patches', alpha=0.6, color='purple')

    # selected
    selected_similarities_to_mean = torch.cosine_similarity(repr_mean[None, :], reprs_of_selected, dim=1)
    selected_hist = torch.histc(selected_similarities_to_mean, bins=bins, min=-0.2, max=1.0)
    # selected_hist /= selected_hist.sum() / bins  # normalization to avg bin size = 1
    plt.bar(np.linspace(-0.2, 1.0, bins), selected_hist.numpy(), width=0.012, label='selected', alpha=0.6, color='orange')

    # scatter plots for better visibility
    plt.scatter(np.linspace(-0.2, 1.0, bins), all_hist.numpy(), c='blue', s=5)
    plt.scatter(np.linspace(-0.2, 1.0, bins), patch_hist.numpy(), c='purple', s=5)
    plt.scatter(np.linspace(-0.2, 1.0, bins), selected_hist.numpy(), c='orange', s=5)

    plt.title(f"Cosine Similarity to average anomaly {repr_name} repr")
    plt.xlabel("Similarity")
    plt.ylabel("Count")
    plt.legend()
    plt.show()


## Other


In [None]:
# histogram of anomaly norms
plt.hist(representations.norm(dim=1).flatten(), bins=80, density=True, label='all', alpha=0.5, range=(0, 1200))
plt.hist(norms_of_patches, bins=80, density=True, label='patches', alpha=0.5, range=(0, 1200))
plt.hist(norms_of_selected, bins=80, density=True, label='selected', alpha=0.5, range=(0, 1200))
plt.title("Histogram of anomaly norms")
plt.xlabel("Token Norm")
plt.ylabel("Probability Density")
plt.legend()
plt.show()


In [None]:
plt.imshow(representations[0].norm(dim=0))
plt.colorbar()
plt.show()

In [None]:
# heatmap of anomaly positions
heatmap = torch.zeros(representations.shape[2:])
for d in data:
    heatmap[d[2], d[1]] += 1
plt.imshow(heatmap, cmap='hot', interpolation='nearest')
plt.colorbar()
plt.title("Heatmap of selected anomaly positions")
plt.show()