In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from varoptimizer.models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

  from .autonotebook import tqdm as notebook_tqdm
--2024-07-17 17:44:28--  https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth
Resolving huggingface.co (huggingface.co)... 3.160.150.119, 3.160.150.7, 3.160.150.2, ...
Connecting to huggingface.co (huggingface.co)|3.160.150.119|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.huggingface.co/repos/27/33/2733ebd65833f8330005ce942c9195c84a3d385ac604d32ebe5d6d9a79385456/7c3ec27ae28a3f87055e83211ea8cc8558bd1985d7b51742d074fb4c2fcf186c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27vae_ch160v4096z32.pth%3B+filename%3D%22vae_ch160v4096z32.pth%22%3B&Expires=1721490268&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTcyMTQ5MDI2OH19LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmh1Z2dpbmdmYWNlLmNvL3JlcG9zLzI3LzMzLzI3MzNlYmQ2NTgzM2Y4MzMwMDA1Y2U5NDJjOTE5NWM4NGEzZDM4NWFjNjA0ZDMyZWJlNWQ2ZDlhNzkzODU0NTYvN2MzZWMyN2FlMjhhM2Y4


[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [2]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 412 #2137 #@param {type:"number"}
torch.manual_seed(seed)

cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        # recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=1, top_p=0.15, g_seed=seed, more_smooth=more_smooth)
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

level 0 x: x.shape=torch.Size([16, 1, 1024])

level 1 x: x.shape=torch.Size([16, 4, 1024])

level 2 x: x.shape=torch.Size([16, 9, 1024])

level 3 x: x.shape=torch.Size([16, 16, 1024])

level 4 x: x.shape=torch.Size([16, 25, 1024])

level 5 x: x.shape=torch.Size([16, 36, 1024])

level 6 x: x.shape=torch.Size([16, 64, 1024])

level 7 x: x.shape=torch.Size([16, 100, 1024])

level 8 x: x.shape=torch.Size([16, 169, 1024])

level 9 x: x.shape=torch.Size([16, 256, 1024])



In [3]:

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

In [4]:
# x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
# f = vae.quant_conv(vae.encoder(x))
# f_hat = torch.zeros_like(f)

# for next_level in range(len(patch_nums)):
#     current_level = next_level - 1

#     print(f"{next_level=}, patch_nums[next_level]={patch_nums[next_level]}")

#     h = var.autoregressive_single_step_prediction(
#         current_level=current_level, 
#         f_hat=f_hat, 
#         label_B=label_B,
#     ) 

#     f_hat = f_hat + h


# image = vae.fhat_to_img(f_hat.detach()).clamp_(0, 1)

# chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()

# Manual sample

In [5]:
def visualise(f_hat):
    image = vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5)

    chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
    chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
    chw = PImage.fromarray(chw.astype(np.uint8))
    chw.show()

In [6]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
f_hat = torch.zeros_like(f)

class_labels = (980, 980, 437, 437, 22, 22, 562, 562) 
cfg = 4
images = []
token_maps = []

for b in var.blocks: b.attn.kv_caching(False)

f_hat = torch.zeros_like(f)
batch_size = len(class_labels)

label_B = torch.tensor(class_labels, device=device)
step_token_map = var.get_initial_token_map(label_B)

full_token_map = step_token_map.clone()


class_conditioning = var.get_class_conditioning(label_B)

for map_size_index, patch_num, in enumerate(patch_nums):
    # print(f"{map_size_index=}, {patch_num=}")

    if map_size_index != 0:
        step_token_map = var.prepare_token_map(f_hat, map_size_index=map_size_index)
        full_token_map = torch.cat([full_token_map, step_token_map], dim=1)

    level_ratio = map_size_index / (len(patch_nums) - 1)

    logits = var.token_map_to_relevant_logits(full_token_map.clone(), class_conditioning, map_size_index=map_size_index, masked=True)

    indices = var.logits_to_indices(logits, cfg=cfg, batch_size=batch_size, level_ratio=level_ratio)
    h = var.indices_to_h(indices, batch_size=batch_size, level_ratio=level_ratio, patch_num=patch_num)

    f_hat = f_hat + h

    images.append(vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5))

for b in var.blocks: b.attn.kv_caching(False)

# visualise(f_hat)

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


# Now let's make it a single_step inference

We need to make the whole full_token_map at a single step

In [7]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
f_hat = torch.zeros_like(f)

class_labels = (980, 980, 437, 437, 22, 22, 562, 562) 
cfg = 5
top_k=50
top_p=0.4

for b in var.blocks: b.attn.kv_caching(False)

f_hat = torch.zeros_like(f)
batch_size = len(class_labels)

label_B = torch.tensor(class_labels, device=device)
class_conditioning = var.get_class_conditioning(label_B)


def get_full_token_map(map_size_index, label_B, f_hat, last_grad_only=True):
    full_token_map = var.get_initial_token_map(label_B)
    
    if map_size_index == 0:
        return full_token_map
    
    with torch.no_grad() if last_grad_only else torch.enable_grad():
        for i in range(1, map_size_index): # for each level up to the second to last level
            step_token_map = var.prepare_token_map(f_hat, map_size_index=i)
            full_token_map = torch.cat([full_token_map, step_token_map], dim=1)

    # for the last level we need grad anyway
    last_step_token_map = var.prepare_token_map(f_hat, map_size_index=map_size_index) 
    full_token_map = torch.cat([full_token_map, last_step_token_map], dim=1)

    return full_token_map


for map_size_index, patch_num, in enumerate(patch_nums):
    full_token_map = get_full_token_map(map_size_index, label_B, f_hat, last_grad_only=True)

    level_ratio = map_size_index / (len(patch_nums) - 1)

    logits = var.token_map_to_relevant_logits(full_token_map.clone(), class_conditioning, map_size_index=map_size_index, masked=True)

    indices = var.logits_to_indices(logits, cfg=cfg, batch_size=batch_size, level_ratio=level_ratio, top_k=top_k, top_p=top_p)
    h = var.indices_to_h(indices, batch_size=batch_size, level_ratio=level_ratio, patch_num=patch_num)

    f_hat = f_hat + h


for b in var.blocks: b.attn.kv_caching(False)

visualise(f_hat)

In [8]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
f_hat = torch.zeros_like(f)

class_labels = (980, 980, 437, 437, 22, 22, 562, 562) 
cfg = 4

f_hat = torch.zeros_like(f)
label_B = torch.tensor(class_labels, device=device)

for map_size_index, patch_num, in enumerate(patch_nums):
    f_hat += var.predict_single_step_residual(f_hat, label_B, map_size_index=map_size_index, cfg=cfg)    

# visualise(f_hat)

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found


# And now the optimization

In [9]:
map_size_index = -1 # last patch num that will be included

quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, map_size_index)

token_map = vae.quantize.limited_quant_pyramid_to_var_input(quant_pyramid) # omits the class conditioning
# len(quant_pyramid)

# vae.quantize.limited_quant_pyramid_to_var_input
# var.predict_single_step_from_quant_pyramid(quant_pyramid=quant_pyramid, label_B=label_B, cfg=cfg).shape

In [10]:
var.quant_pyramid_to_var_inputs(quant_pyramid, label_B, batch_size=8).shape

torch.Size([16, 1, 1024])

In [11]:
# var.predict_single_step_from_quant_pyramid(quant_pyramid=quant_pyramid, label_B=label_B, cfg=cfg).shape
var.predict_single_step_from_quant_pyramid(quant_pyramid=[], label_B=label_B, cfg=cfg, batch_size=8).shape

torch.Size([8, 32, 16, 16])

In [12]:
for map_size_index in range(len(patch_nums)):
    current_length = sum([
        patch_size * patch_size
        for patch_size
        in patch_nums[:map_size_index + 1]
    ])  

    patch_size = patch_nums[map_size_index]
    print(f"{map_size_index=}, {patch_size}x{patch_size}, {current_length=}")
    


map_size_index=0, 1x1, current_length=1
map_size_index=1, 2x2, current_length=5
map_size_index=2, 3x3, current_length=14
map_size_index=3, 4x4, current_length=30
map_size_index=4, 5x5, current_length=55
map_size_index=5, 6x6, current_length=91
map_size_index=6, 8x8, current_length=155
map_size_index=7, 10x10, current_length=255
map_size_index=8, 13x13, current_length=424
map_size_index=9, 16x16, current_length=680


In [13]:
from torch.nn import functional as F

def downsample_residual(residual, patch_size):
    if patch_size == patch_nums[-1]:
        return residual

    return F.interpolate(residual, size=(patch_size, patch_size), mode='area')


def get_optimization_loss(f, predicted_map_size_index, batch_size=8, cfg=4, top_k=900, top_p=0.95, scale_by_inv_area=False, seed=None):
    f_hat_map_size_index = predicted_map_size_index - 1
    predicted_map_size = patch_nums[predicted_map_size_index]

    rng = None

    if seed is not None:
        rng = torch.Generator(device='cuda')
        rng.manual_seed(seed)

    with torch.no_grad():
        quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, f_hat_map_size_index) 
        predicted_residual = var.predict_single_step_from_quant_pyramid(quant_pyramid=quant_pyramid, label_B=label_B, cfg=cfg, batch_size=batch_size, top_k=top_k, top_p=top_p, rng=rng)

    real_residual = f - f_hat

    loss = F.mse_loss(
        downsample_residual(predicted_residual, predicted_map_size), 
        downsample_residual(real_residual, predicted_map_size)
    )

    if scale_by_inv_area:
        loss /= predicted_map_size * predicted_map_size


    return loss



In [14]:
from tqdm import tqdm
# f = torch.randn([4, 32, 16, 16], requires_grad=True, device=device) 
# class_labels = (980, 437, 22, 562)  #@param {type:"raw"}


img = recon_B3HW.float().clone().clamp_(0, 1).to(device)
f = torch.tensor(vae.quant_conv(vae.encoder(img).detach()), requires_grad=True, device=device)
f = torch.zeros_like(f, requires_grad=True, device=device)
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

cfg = 4
top_k = 500 #700
top_p = 0.25 #0.95 #0.95
lr = 0.2
scale_by_inv_area = False


label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=lr)
total_steps = 500

min_next_map_size_index = 0 #len(patch_nums) - 2 #-1 #len(patch_nums) - 4
max_next_map_size_index = len(patch_nums) - 1 


for i in tqdm(range(total_steps), total=total_steps):
    next_map_size_index = np.random.randint(min_next_map_size_index, max_next_map_size_index + 1)

    loss = get_optimization_loss(f, next_map_size_index, batch_size=8, cfg=cfg, top_k=top_k, top_p=top_p, scale_by_inv_area=scale_by_inv_area)

    
    # patch_weight_bias = .001
    # patch_grad_scale = (patch_nums[len(patch_nums)-1]**2 + patch_weight_bias)  / (patch_nums[next_map_size_index]**2 + patch_weight_bias)
    # loss*= np.sqrt(patch_grad_scale)
    
    # step_weight = (total_steps - i) / total_steps
    # loss*= step_weight

    # level_ratio = (next_map_size_index + 1)/ (len(patch_nums)) 
    # level_weight = np.power(level_ratio, 1/4)
    # loss*= level_weight

    loss.backward()
    optimizer.step()
    optimizer.zero_grad()

f = f.detach()
visualise(f)

  f = torch.tensor(vae.quant_conv(vae.encoder(img).detach()), requires_grad=True, device=device)
  0%|          | 0/500 [00:00<?, ?it/s]kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
100%|██████████| 500/500 [00:18<00:00, 27.01it/s]


In [15]:

from tqdm import tqdm

# f = torch.zeros_like(f, requires_grad=True, device=device)


class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

initial_noise = 0.1
cfg = 7
top_k = 600 #700
top_p = 0.65 #0.95 #0.95
lr = 0.02
steps_per_size = 250
seed = 42

scale_by_inv_area = False

label_B: torch.LongTensor = torch.tensor(class_labels, device=device)


f_initial = torch.randn_like(f, device=device) * initial_noise # std=0.5
f = f_initial.clone().detach().requires_grad_(True)

optimizer = torch.optim.Adam([f], lr=lr)

for max_next_map_size_index in range(len(patch_nums)):

    for i in tqdm(range(steps_per_size), total=steps_per_size):
        
        optimizer.zero_grad()

        # max_next_map_size_index = int(np.sqrt(i/total_steps) * len(patch_nums)) 
        
        loss = 0
        loss += get_optimization_loss(f, max_next_map_size_index, batch_size=8, cfg=cfg, top_k=top_k, top_p=top_p, scale_by_inv_area=scale_by_inv_area, seed=seed)

        random_size_index = np.random.randint(0, max_next_map_size_index + 1)
        loss += get_optimization_loss(f, random_size_index, batch_size=8, cfg=cfg, top_k=top_k, top_p=top_p, scale_by_inv_area=scale_by_inv_area, seed=seed)    

        # if i % 100 == 0:
        #     print(f'loss: {loss.item()}')

        step_weight = np.power((total_steps - i + 1) / total_steps, 1/4)

        loss*= step_weight
        

        #loss*= step_weight
        loss.backward()
        optimizer.step()


visualise(f)

  0%|          | 0/250 [00:00<?, ?it/s]

 10%|█         | 26/250 [00:00<00:01, 126.07it/s]kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
 26%|██▌       | 65/250 [00:00<00:01, 126.77it/s]kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
100%|██████████| 250/250 [00:01<00:00, 125.77it/s]
100%|██████████| 250/250 [00:02<00:00, 114.45it/s]
100%|██████████| 250/250 [00:02<00:00, 97.46it/s]
100%|██████████| 250/250 [00:03<00:00, 78.31it/s]
100%|██████████| 250/250 [00:04<00:00, 59.03it/s]
100%|██████████| 250/250 [00:06<00:00, 39.57it/s]
100%|██████████| 250/250 [00:09<00:00, 25.81it/s]
100%|██████████| 250/250 [00:13<00:00, 19.19it/s]
100%|██████████| 250/250 [00:22<00:00, 11.36it/s]
100%|██████████| 250/250 [00:38<00:00,  6.55it/s]


In [16]:
f.std()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


tensor(0.3930, device='cuda:0', grad_fn=<StdBackward0>)

In [17]:


assert False

AssertionError: 

In [None]:
torch.cat(token_maps, dim=1).shape

torch.Size([16, 91, 1024])

In [None]:
for b in var.blocks: b.attn.kv_caching(False)

self=var
map_size_index = map_size_id

current_length = sum([
    patch_size * patch_size
    for patch_size
    in self.patch_nums[:map_size_index + 1]
])

full_token_map = torch.cat(token_maps, dim=1).clone()

output_length = self.patch_nums[map_size_index] ** 2
attn_bias = self.attn_bias_for_masking[:, :, :current_length, :current_length]

kv_pairs = []

for b in self.blocks:
    token_map = b(x=full_token_map, cond_BD=class_conditioning, attn_bias=attn_bias)#, attn_bias=attn_bias)
    kv_pairs.append({
        'key': b.attn.cached_k,
        'value': b.attn.cached_v
    })

self.last_kv_pairs = kv_pairs

output_length = self.patch_nums[map_size_index] ** 2

relevant_output_tokens = token_map[:, -output_length:]

logits = self.get_logits(relevant_output_tokens, class_conditioning)








In [None]:
logits.shape

torch.Size([16, 36, 4096])

In [None]:
patch_nums[5]

6

In [None]:
image = vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5)

chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

In [None]:
f_hat.min(), f_hat.max()

(tensor(-3.0807, device='cuda:0'), tensor(2.9216, device='cuda:0'))

In [None]:
images_stack = torch.cat(images, dim=0)


chw = torchvision.utils.make_grid(images_stack, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [None]:
token_maps[3].shape

torch.Size([16, 16, 1024])

In [None]:
var.attn_bias_for_masking[0][0][:10,:10]

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

In [None]:
def get_attn_bias_for_masking(var, map_size_index):
    input_length = var.patch_nums[map_size_index] ** 2
    attn_bias = var.attn_bias_for_masking[:, :, :input_length, :input_length]
    return attn_bias


In [None]:
get_attn_bias_for_masking(var, 1)[0][0][:10,:10]

tensor([[0., -inf, -inf, -inf],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')

In [None]:
token_maps[-1].shape

torch.Size([16, 256, 1024])

In [None]:
patch_nums

(1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [None]:
kv_pairs[2][0]['key'].shape

# 16 x 16 x length x 64

torch.Size([16, 16, 14, 64])

In [None]:

for map_size_index, patch_num in enumerate(patch_nums):
    current_length = sum([
                patch_size * patch_size
                for patch_size
                in patch_nums[:map_size_index + 1]
            ])
    print(f"map_size_index={map_size_index}, current_length={current_length}")

map_size_index=0, current_length=1
map_size_index=1, current_length=5
map_size_index=2, current_length=14
map_size_index=3, current_length=30
map_size_index=4, current_length=55
map_size_index=5, current_length=91
map_size_index=6, current_length=155
map_size_index=7, current_length=255
map_size_index=8, current_length=424
map_size_index=9, current_length=680


kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.coreaddons: Expected a KPluginFactory, got a KIOPluginForMetaData
