In [None]:
%load_ext autoreload
%autoreload 2
from autoseg.config import read_config
from autoseg.models import Model
import torch

In [None]:
config = read_config("autoseg/examples/unetr")

In [None]:
model = Model(config)
model.load()

In [None]:
print(f"Num params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}")

In [None]:
from autoseg.datasets import GunpowderZarrDataset
from torch.utils.data import DataLoader
from autoseg.datasets.utils import multisample_collate as collate

In [None]:
dataset = GunpowderZarrDataset(
  config=config["pipeline"],
  input_image_shape=config["model"]["input_image_shape"],
  output_image_shape=config["model"]["output_image_shape"],
)

In [None]:
dataloader = DataLoader(
    dataset=dataset,
    collate_fn=collate,
    batch_size=config["training"]["train_dataloader"]["batch_size"],
    pin_memory=False,
)

In [None]:
dataloader_it = iter(dataloader)

In [None]:
sample_image = next(dataloader_it)

In [None]:
sample_image[0].shape

In [None]:
raw_img = sample_image[0][0][0][24]

In [None]:
d3_raw_img = np.sum(sample_image[0][0][0][16:32], axis=0) / 16
d3_raw_img.shape
plt.imshow(d3_raw_img, cmap="gray")

In [None]:
import matplotlib.pyplot as plt
plt.imshow(raw_img, cmap="gray")

In [None]:
raw = torch.tensor(sample_image[0])
raw = raw.to("cuda")
raw.shape

In [None]:
208*208*48/(16*16*16)

In [None]:
model = model.to("cuda")

In [None]:
layer = model.model.unet.transformer.layer[0]

In [None]:
layer

In [None]:
patch_embeddings = model.model.unet.transformer.embeddings.patch_embeddings(raw)

In [None]:
patch_embeddings.flatten(2).shape

In [None]:
patch_embeddings.shape

In [None]:
def to_np(t):
  return t.detach().cpu().numpy()

In [None]:
patch_embeddings = to_np(patch_embeddings)

In [None]:
13*16

In [None]:
# batch, emb_channel, z
plt.imshow(patch_embeddings[0,39,1])

In [None]:
def patch_to_original_size(patch):
  return patch.repeat(16,axis=0).repeat(16,axis=1)

In [None]:
from itertools import cycle

In [None]:
values = cycle([0,1])
overlay = np.array([[next(values) for i in range(13)] for j in range(13)])

In [None]:
plt.imshow(overlay)

In [None]:
#plt.imshow(patch_to_original_size(2*patch_embeddings[0,39,1]) + raw_img,cmap="gray")
plt.imshow(patch_to_original_size(overlay) + 12*d3_raw_img,cmap="gray", vmin=d3_raw_img.min()*12, vmax=d3_raw_img.max()*12+1)
#plt.imshow(patch_to_original_size(overlay) + 4*raw_img,cmap="gray")

In [None]:
plt.imshow(d3_raw_img[0*16:2*16,10*16:12*16], cmap="gray", vmin=d3_raw_img.min(), vmax=d3_raw_img.max())

In [None]:
#num_heads * emb_dim
12*(768/12)
# (507, 768) * (768)

In [None]:
embeddings = model.model.unet.transformer.embeddings(raw)

In [None]:
embeddings.shape

In [None]:
first_layer_out = layer.attn(layer.attention_norm(embeddings), return_raw_scores=True)


In [None]:
attn_scores = first_layer_out[2]

In [None]:
attn_scores = to_np(attn_scores)

In [None]:
attn_scores.shape

In [None]:
print(attn_scores.min(), attn_scores.max(), attn_scores.std())

In [None]:
first_map_attn = attn_scores[0,s:s+c,0]

In [None]:
torch.unflatten(torch.tensor(first_map_attn),dim=1,sizes=(3,13,13)).numpy().shape

In [None]:
def attn_for_patch_i(attn,x,y,z):
  i = z*13*13+y*13+x
  p_attn = attn_scores[0,s:s+c,i]
  return np.transpose(torch.unflatten(torch.tensor(p_attn),dim=1,sizes=(3,13,13)).numpy(), (0, 2, 3, 1))

In [None]:
attn_for_patch_i(attn_scores, 52).shape

In [None]:
import numpy as np

In [None]:
z=2
x=6
y=6
#print("Mean attention", attn_for_patch_i(attn_scores, x,y,z)[0,:,:,1].mean())
plt.imshow(attn_for_patch_i(attn_scores, x,y,z)[0,:,:,1], cmap="hot", vmin=-2,vmax=2)

In [None]:
xs = 4
xe = 8 
ys = 8 
ye = 12
tot_attn = None
ct = 0
for x in range(xs,xe+1):
  for y in range(ys,ye+1):
    ct += 1
    if tot_attn is None:
      tot_attn = attn_for_patch_i(attn_scores, x,y,1)
    else:
      tot_attn += attn_for_patch_i(attn_scores, x,y,1)
tot_attn /= ct

In [None]:
plt.imshow(d3_raw_img[ys*16:ye*16,xs*16:xe*16], cmap="gray", vmin=d3_raw_img.min(), vmax=d3_raw_img.max())

In [None]:
# tot std: 0.08
# std only white patch: 0.28
# std white/black patch 0.13
print(tot_attn.min(), tot_attn.max(), tot_attn.std())

In [None]:
plt.imshow(tot_attn[0,:,:,1], cmap="hot", vmin=-2,vmax=2)