In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

  from .autonotebook import tqdm as notebook_tqdm



[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [2]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 1 #2137 #@param {type:"number"}
torch.manual_seed(seed)

cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

level 0 x: x.shape=torch.Size([16, 1, 1024])

level 1 x: x.shape=torch.Size([16, 4, 1024])

level 2 x: x.shape=torch.Size([16, 9, 1024])

level 3 x: x.shape=torch.Size([16, 16, 1024])

level 4 x: x.shape=torch.Size([16, 25, 1024])

level 5 x: x.shape=torch.Size([16, 36, 1024])

level 6 x: x.shape=torch.Size([16, 64, 1024])

level 7 x: x.shape=torch.Size([16, 100, 1024])

level 8 x: x.shape=torch.Size([16, 169, 1024])

level 9 x: x.shape=torch.Size([16, 256, 1024])



  return F.conv2d(input, weight, bias, self.stride,


In [3]:

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

In [4]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
f_hat = torch.zeros_like(f)

for next_level in range(len(patch_nums)):
    current_level = next_level - 1

    print(f"{next_level=}, patch_nums[next_level]={patch_nums[next_level]}")

    h = var.autoregressive_single_step_prediction(
        current_level=current_level, 
        f_hat=f_hat, 
        label_B=label_B,
    ) 

    f_hat = f_hat + h


image = vae.fhat_to_img(f_hat.detach()).clamp_(0, 1)

# chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()

next_level=0, patch_nums[next_level]=1
next_level=1, patch_nums[next_level]=2
current_level=0 cur_L: cur_L=0
next_level=2, patch_nums[next_level]=3
current_level=1 cur_L: cur_L=1
next_level=3, patch_nums[next_level]=4
current_level=2 cur_L: cur_L=5
next_level=4, patch_nums[next_level]=5
current_level=3 cur_L: cur_L=14
next_level=5, patch_nums[next_level]=6
current_level=4 cur_L: cur_L=30
next_level=6, patch_nums[next_level]=8
current_level=5 cur_L: cur_L=55
next_level=7, patch_nums[next_level]=10
current_level=6 cur_L: cur_L=91
next_level=8, patch_nums[next_level]=13
current_level=7 cur_L: cur_L=155
next_level=9, patch_nums[next_level]=16
current_level=8 cur_L: cur_L=255


# Manual sample

In [5]:
def visualise(f_hat):
    image = vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5)

    chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
    chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
    chw = PImage.fromarray(chw.astype(np.uint8))
    chw.show()

In [19]:
class_labels = (980, 980, 437, 437, 22, 22, 562, 562) 
cfg = 4
images = []
token_maps = []

kv_pairs = []
attention_maps = []

f_hat = torch.zeros_like(f)
batch_size = len(class_labels)

label_B = torch.tensor(class_labels, device=device)
token_map = var.get_initial_token_map(label_B)


class_conditioning = var.get_class_conditioning(label_B)

for b in var.blocks: b.attn.kv_caching(False)

for map_size_index, patch_num, in enumerate(patch_nums):

    
    # if map_size_index > 7:
    #     for b in var.blocks: b.attn.kv_caching(False)

    # if map_size_index % 3 == 0:
    #     visualise(f_hat)

    if map_size_index != 0:
        token_map = var.prepare_token_map(f_hat, map_size_index=map_size_index)

    level_ratio = map_size_index / (len(patch_nums) - 1)

    print(f"{map_size_index=}, {level_ratio=}")

    token_maps.append(token_map)
    # for b in var.blocks: b.attn.kv_caching(True)

    logits = var.token_map_to_logits(token_map, class_conditioning, map_size_index=map_size_index)
    kv_pairs.append(var.last_kv_pairs)
    indices = var.logits_to_indices(logits, cfg=cfg, batch_size=batch_size, level_ratio=level_ratio) # TODO: remove rng
    h = var.indices_to_h(indices, batch_size=batch_size, level_ratio=level_ratio, patch_num=patch_num)

    f_hat = f_hat + h

    images.append(vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5))

for b in var.blocks: b.attn.kv_caching(False)

visualise(f_hat)

map_size_index=0, level_ratio=0.0
map_size_index=1, level_ratio=0.1111111111111111
map_size_index=2, level_ratio=0.2222222222222222
map_size_index=3, level_ratio=0.3333333333333333
map_size_index=4, level_ratio=0.4444444444444444
map_size_index=5, level_ratio=0.5555555555555556
map_size_index=6, level_ratio=0.6666666666666666
map_size_index=7, level_ratio=0.7777777777777778
map_size_index=8, level_ratio=0.8888888888888888
map_size_index=9, level_ratio=1.0


kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.coreaddons: Expected a KPluginFactory, got a KIOPluginForMetaData


In [7]:
image = vae.fhat_to_img(f_hat.detach()).add_(1).mul_(0.5)

chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

In [8]:
f_hat.min(), f_hat.max()

(tensor(-3.0807, device='cuda:0'), tensor(2.9216, device='cuda:0'))

In [9]:
images_stack = torch.cat(images, dim=0)


chw = torchvision.utils.make_grid(images_stack, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [10]:
token_maps[3].shape

torch.Size([16, 16, 1024])

In [11]:
var.attn_bias_for_masking[0][0][:10,:10]

tensor([[0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0.]], device='cuda:0')

In [12]:
def get_attn_bias_for_masking(var, map_size_index):
    input_length = var.patch_nums[map_size_index] ** 2
    attn_bias = var.attn_bias_for_masking[:, :, :input_length, :input_length]
    return attn_bias


In [13]:
get_attn_bias_for_masking(var, 1)[0][0][:10,:10]

tensor([[0., -inf, -inf, -inf],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], device='cuda:0')

In [14]:
token_maps[-1].shape

torch.Size([16, 256, 1024])

In [15]:
patch_nums

(1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [16]:
kv_pairs[2][0]['key'].shape

# 16 x 16 x length x 64

torch.Size([16, 16, 14, 64])

In [17]:

for map_size_index, patch_num in enumerate(patch_nums):
    current_length = sum([
                patch_size * patch_size
                for patch_size
                in patch_nums[:map_size_index + 1]
            ])
    print(f"map_size_index={map_size_index}, current_length={current_length}")

map_size_index=0, current_length=1
map_size_index=1, current_length=5
map_size_index=2, current_length=14
map_size_index=3, current_length=30
map_size_index=4, current_length=55
map_size_index=5, current_length=91
map_size_index=6, current_length=155
map_size_index=7, current_length=255
map_size_index=8, current_length=424
map_size_index=9, current_length=680


kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.coreaddons: Expected a KPluginFactory, got a KIOPluginForMetaData
