In [1]:
import numpy as np
from sdhelper import SD
from PIL import Image
import torch
from tqdm.autonotebook import tqdm, trange
import matplotlib.pyplot as plt
from datasets import load_dataset
import torch
import random

torch.set_float32_matmul_precision('high')  # for better performance (got a warning without this during torch compile)


In [10]:
# data = load_dataset("0jl/NYUv2", trust_remote_code=True, split="train")
# up1_anomalies = np.load("../other/data_labeler/high_norm_anomalies_nyuv2_norm_step50_seed42.npy")

data = load_dataset("JonasLoos/imagenet_subset", split="train")
up1_anomalies = np.load("../data/data_labeler/high_norm_anomalies_imagenet_subset_step50_seed42_heavy_only.npy")
convin_anomalies = np.load("../data/data_labeler/high_norm_anomalies_imagenet_subset_step50_seed42_conv_in.npy")


In [11]:
all_blocks_separated = [[
        'conv_in',
    ],[
        'down_blocks[0].resnets[0]',
        'down_blocks[0].attentions[0]',
        'down_blocks[0].resnets[1]',
        'down_blocks[0].attentions[1]',
        'down_blocks[0].downsamplers[0]',
    ],[
        'down_blocks[1].resnets[0]',
        'down_blocks[1].attentions[0]',
        'down_blocks[1].resnets[1]',
        'down_blocks[1].attentions[1]',
        'down_blocks[1].downsamplers[0]',
    ],[
        'down_blocks[2].resnets[0]',
        'down_blocks[2].attentions[0]',
        'down_blocks[2].resnets[1]',
        'down_blocks[2].attentions[1]',
        'down_blocks[2].downsamplers[0]',
    ],[
        'down_blocks[3].resnets[0]',
        'down_blocks[3].resnets[1]',
    ],[
        'mid_block.resnets[0]',
        'mid_block.attentions[0]',
        'mid_block.resnets[1]',
    ],[
        'up_blocks[0].resnets[0]',
        'up_blocks[0].resnets[1]',
        'up_blocks[0].upsamplers[0]',
    ],[
        'up_blocks[1].resnets[0]',
        'up_blocks[1].attentions[0]',
        'up_blocks[1].resnets[1]',
        'up_blocks[1].attentions[1]',
        'up_blocks[1].resnets[2]',
        'up_blocks[1].attentions[2]',
        'up_blocks[1].upsamplers[0]',
    ],[
        'up_blocks[2].resnets[0]',
        'up_blocks[2].attentions[0]',
        'up_blocks[2].resnets[1]',
        'up_blocks[2].attentions[1]',
        'up_blocks[2].resnets[2]',
        'up_blocks[2].attentions[2]',
        'up_blocks[2].upsamplers[0]',
    ],[
        'up_blocks[3].resnets[0]',
        'up_blocks[3].attentions[0]',
        'up_blocks[3].resnets[1]',
        'up_blocks[3].attentions[1]',
        'up_blocks[3].resnets[2]',
        'up_blocks[3].attentions[2]',
    ],[
        'conv_out',
    ]
]
all_blocks = [b for blocks_list in all_blocks_separated for b in blocks_list]

In [None]:
sd = SD()

In [None]:
# blocks = [
#     'up_blocks[0].upsamplers[0]',
#     'up_blocks[1].resnets[0]',
#     'up_blocks[1].attentions[0]',
#     'up_blocks[1].resnets[1]',
#     'up_blocks[1].attentions[1]',
#     'up_blocks[1].resnets[2]',
#     'up_blocks[1].attentions[2]',
#     'up_blocks[1].upsamplers[0]',
#     'up_blocks[2].resnets[0]',
#     'up_blocks[2].attentions[0]',
#     'up_blocks[2].resnets[1]',
#     'up_blocks[2].attentions[1]',
#     'up_blocks[2].resnets[2]',
#     'up_blocks[2].attentions[2]',
#     'up_blocks[2].upsamplers[0]',
#     'up_blocks[3].resnets[0]',
#     'up_blocks[3].attentions[0]',
#     'up_blocks[3].resnets[1]',
#     'up_blocks[3].attentions[1]',
#     'up_blocks[3].resnets[2]',
#     'up_blocks[3].attentions[2]',
# ]
blocks = all_blocks
representations_raw = sd.img2repr([x['image'] for x in data], extract_positions=blocks, step=50, seed=42)

In [None]:
up1_anomaly_norms = np.zeros((len(blocks),len(up1_anomalies)))
convin_anomaly_norms = np.zeros((len(blocks),len(convin_anomalies)))
corner_norms = np.zeros((len(blocks),len(representations_raw)))
border_norms = np.zeros((len(blocks),len(representations_raw)))
mean_norms = np.zeros((len(blocks),len(representations_raw)))

for j, tmp in enumerate(tqdm(up1_anomalies)):
    img_idx, w_idx, h_idx = tmp.tolist()
    for i, block in enumerate(blocks):
        repr = representations_raw[img_idx][block].squeeze(0).to(dtype=torch.float32)
        features, h, w = repr.shape
        norms = repr.norm(dim=0)

        # up1 anomaly
        h_up1 = representations_raw[0]['up_blocks[1].upsamplers[0]'].shape[2]
        scale = h / h_up1
        h_idx_scaled = int(h_idx*scale)
        w_idx_scaled = int(w_idx*scale)
        offset = int(2*scale)
        if offset < 1: offset = 1
        reprs_anomaly = repr[:,h_idx_scaled:h_idx_scaled+offset, w_idx_scaled:w_idx_scaled+offset]
        # rel_norm = reprs_anomaly.norm(dim=0)[0,0] / norms.mean()
        rel_norm = reprs_anomaly.norm(dim=0).mean() / norms.mean()
        up1_anomaly_norms[i,j] = rel_norm.item()

for j, tmp in enumerate(tqdm(convin_anomalies)):
    img_idx, w_idx, h_idx = tmp.tolist()
    for i, block in enumerate(blocks):
        repr = representations_raw[img_idx][block].squeeze(0).to(dtype=torch.float32)
        features, h, w = repr.shape
        norms = repr.norm(dim=0)

        # convin anomaly
        h_convin = representations_raw[0]['conv_in'].shape[2]
        scale = h / h_convin
        h_idx_scaled = int(h_idx*scale)
        w_idx_scaled = int(w_idx*scale)
        reprs_anomaly = repr[:,h_idx_scaled, w_idx_scaled]
        rel_norm = reprs_anomaly.norm(dim=0) / norms.mean()
        convin_anomaly_norms[i,j] = rel_norm.item()

for j in trange(len(representations_raw)):
    for i, block in enumerate(blocks):
        repr = representations_raw[j][block].squeeze(0).to(dtype=torch.float32)
        features, h, w = repr.shape
        norms = repr.norm(dim=0)

        # mean norm
        mean_norms[i,j] = norms.mean().item()

        # corner norm
        corner_norms[i,j] = norms[[0,0,1,1], [0,1,0,1]].mean().item() / norms.mean()

        # border norm
        border_norms[i,j] = torch.cat([norms[0,:], norms[-1,:], norms[:,0], norms[:,-1]]).mean().item() / norms.mean()

In [None]:

fig, ax1 = plt.subplots(figsize=(10,4))

# up[1] anomaly
plt.fill_between(range(len(blocks)), 
                 up1_anomaly_norms.mean(axis=1) - up1_anomaly_norms.std(axis=1),
                 up1_anomaly_norms.mean(axis=1) + up1_anomaly_norms.std(axis=1), 
                 alpha=0.2)
plt.plot(range(len(blocks)), up1_anomaly_norms.mean(axis=1), label='up[1] anomalies')

# conv-in anomaly
plt.fill_between(range(len(blocks)),
                 convin_anomaly_norms.mean(axis=1) - convin_anomaly_norms.std(axis=1),
                 convin_anomaly_norms.mean(axis=1) + convin_anomaly_norms.std(axis=1),
                 alpha=0.2) 
plt.plot(range(len(blocks)), convin_anomaly_norms.mean(axis=1), label='conv-in anomalies')

# corner norm
# plt.fill_between(range(len(blocks)),
#                  corner_norms.mean(axis=1) - corner_norms.std(axis=1),
#                  corner_norms.mean(axis=1) + corner_norms.std(axis=1),
#                  alpha=0.2, label='corner') 
# plt.plot(range(len(blocks)), corner_norms.mean(axis=1))

# border norm
# plt.fill_between(range(len(blocks)),
#                  border_norms.mean(axis=1) - border_norms.std(axis=1),
#                  border_norms.mean(axis=1) + border_norms.std(axis=1),
#                  alpha=0.2, label='border') 
# plt.plot(range(len(blocks)), border_norms.mean(axis=1))

# mean norm
plt.plot(range(len(blocks)), np.ones(len(blocks)), color='black', linestyle='--', label='all')

plt.title("Relative norm of anomaly over layers")
plt.ylabel("mean norm relative to all")

# plot x ticks
x = np.arange(len(blocks))
ticks = ['attn' if 'attentions' in block else 'res' if 'resnets' in block else 'down' if 'downsamplers' in block else 'up' if 'upsamplers' in block else 'conv' if 'conv' in block else '?' for block in blocks]
ax1.set_xticks(x)
ax1.set_xticklabels(ticks, rotation=90)

# compute main blocks names and positions
main_blocks = []
main_block_positions = []
tmp = 0
for block_list in all_blocks_separated:
    if 'mid' in block_list[0]:
        name = 'mid'
    elif 'conv' in block_list[0]:
        name = block_list[0][5:]
    else:
        a, b, *_ = block_list[0].split('[')
        name = a.replace('_blocks','') + b.split(']')[0]
    main_blocks.append(name)
    main_block_positions.append(tmp)
    tmp += len(block_list)

# lines between main blocks
for p in main_block_positions[1:]:
    ax1.axvline(x=p-0.5, color='black', linestyle='--', c='lightgray')
ax_x3 = ax1.secondary_xaxis(location=0)
ax_x3.set_xticks([p-0.5 for p in main_block_positions[1:]], labels=[])
ax_x3.tick_params(axis='x', length=34, width=1.5, color='lightgray')

ax_x2 = ax1.secondary_xaxis(location=0)
ax_x2.set_xticks([p+len(bl)/2-0.5 for p, bl in zip(main_block_positions, all_blocks_separated)], labels=[f'\n\n\n{b}' for b in main_blocks], ha='center')
ax_x2.tick_params(length=0)

ax1.legend()
ax1.set_yscale('log')
ax1.set_yticks([0.5, 1.0, 2.0, 4.0])
ax1.set_yticklabels([f'{x:.1f}' for x in ax1.get_yticks()])
plt.tight_layout()
plt.show()

img_idx, w_idx, h_idx = random.choice(convin_anomalies).tolist()
plt.imshow(representations_raw[img_idx]['conv_in'].squeeze(0).to(dtype=torch.float32).norm(dim=0))
plt.scatter(w_idx, h_idx, color='red')
plt.axis('off')
plt.show()