### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟

In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

  from .autonotebook import tqdm as notebook_tqdm



[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [2]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 2137 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1} # IRRELEVANT -- variable not used
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

next_token_map: next_token_map.shape=torch.Size([16, 1, 1024])
 next_token_map.dtype=torch.float32
 next_token_map.device=device(type='cuda', index=0)

[autoregressive_infer_cfg] f_hat.shape=torch.Size([8, 32, 16, 16])
, next_token_map.shape=torch.Size([16, 1, 1024])

level 0 x: x.shape=torch.Size([16, 1, 1024])

level 1 x: x.shape=torch.Size([16, 4, 1024])

level 2 x: x.shape=torch.Size([16, 9, 1024])

level 3 x: x.shape=torch.Size([16, 16, 1024])

level 4 x: x.shape=torch.Size([16, 25, 1024])

level 5 x: x.shape=torch.Size([16, 36, 1024])

level 6 x: x.shape=torch.Size([16, 64, 1024])

level 7 x: x.shape=torch.Size([16, 100, 1024])

level 8 x: x.shape=torch.Size([16, 169, 1024])

level 9 x: x.shape=torch.Size([16, 256, 1024])



  return F.conv2d(input, weight, bias, self.stride,


In [3]:

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

In [4]:
# chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()
# recon_B3HW.

In [5]:
# optimized_image = torch.rand_like(recon_B3HW, requires_grad=True)
# image_pyramid = vae.img_to_idxBl(optimized_image, patch_nums)

# image_pyramid

In [6]:
# reconstructed = vae.idxBl_to_img(image_pyramid, same_shape=False)

In [7]:
# reconstructed[6].shape

### Encoding the input 

In [8]:
recon_B3HW.shape

torch.Size([8, 3, 256, 256])

In [9]:
device

'cuda'

In [10]:
# x = recon_B3HW
# x = torch.rand(8, 3, 256, 256, device=device)

x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))

image = vae.fhat_to_img(f)
f = vae.quant_conv(vae.encoder(image))

# image = vae.fhat_to_img(f)
# f = vae.quant_conv(vae.encoder(image))

# image = vae.fhat_to_img(f)
# f = vae.quant_conv(vae.encoder(image))

image = vae.fhat_to_img(f)

# x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
# f = vae.quant_conv(vae.encoder(x))
# image = vae.fhat_to_img(f).add_(1).mul_(0.5)

chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

# ls_f_hat_BChw = vae.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=patch_nums)

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found


In [11]:
from torch.nn import functional as F

patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
last_patch_id = -1

quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

residual = f - f_hat

In [12]:
print(f"{len(quant_pyramid)=})")
if quant_pyramid:
    print(f"{quant_pyramid[0].shape=}")
# assert False

len(quant_pyramid)=0)


### Using the pyramid as the input for transformer

In [13]:
predicted_residual = var.autoregressive_single_step_prediction(
    quant_pyramid=quant_pyramid, 
    f_hat=f_hat, 
    label_B=label_B,) 

current_level=-1, next_level=0
next_token_map lvl current_level=-1 next_level=0: next_token_map.shape=torch.Size([16, 1, 1024])



In [14]:
f.shape

torch.Size([8, 32, 16, 16])

In [15]:
predicted_residual.shape

torch.Size([8, 32, 16, 16])

### Optimisation POC in latent space

In [16]:
def get_optimisation_loss(f, last_patch_id):
    quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

    f_residual = f - f_hat

    predicted_residual = var.autoregressive_single_step_prediction(
        quant_pyramid=quant_pyramid, 
        f_hat=f_hat, 
        label_B=label_B, 
        cfg=cfg, 
        top_k=900, 
        top_p=0.95)

    return F.mse_loss(predicted_residual, f_residual)


In [17]:
patch_nums

(1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [18]:

f = torch.rand([2, 32, 16, 16], requires_grad=True, device=device)

class_labels = (22, 562)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.1)

steps_per_coarsness = 1

for coarsness_step in range(-1, len((patch_nums)) - 1):
    for i in range(steps_per_coarsness):
        optimizer.zero_grad()
        loss = get_optimisation_loss(f, coarsness_step)
        patch_weight = patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
        step_weight = (steps_per_coarsness - i) /steps_per_coarsness
        
        loss*= patch_weight * step_weight

        if i % 500 == 0:
            print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

        
        loss.backward()
        optimizer.step()


current_level=-1, next_level=0
next_token_map lvl current_level=-1 next_level=0: next_token_map.shape=torch.Size([4, 1, 1024])

coarsness_step: -1, loss: 0.39253973960876465
current_level=0, next_level=1
next_token_map lvl current_level=0 next_level=1: next_token_map.shape=torch.Size([4, 4, 1024])

coarsness_step: 0, loss: 0.0009515469428151846
current_level=1, next_level=2
next_token_map lvl current_level=1 next_level=2: next_token_map.shape=torch.Size([4, 9, 1024])

coarsness_step: 1, loss: 0.002446998842060566
current_level=2, next_level=3
next_token_map lvl current_level=2 next_level=3: next_token_map.shape=torch.Size([4, 16, 1024])

coarsness_step: 2, loss: 0.004130854737013578
current_level=3, next_level=4
next_token_map lvl current_level=3 next_level=4: next_token_map.shape=torch.Size([4, 25, 1024])

coarsness_step: 3, loss: 0.006683871150016785
current_level=4, next_level=5
next_token_map lvl current_level=4 next_level=5: next_token_map.shape=torch.Size([4, 36, 1024])

coarsnes

In [19]:
assert False

AssertionError: 

kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [None]:
var.patch_nums

(1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [None]:
image = vae.fhat_to_img(f).add_(1).mul_(0.5)

In [None]:
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


## Random coarsness step


In [32]:
from tqdm import tqdm
# f = torch.randn([4, 32, 16, 16], requires_grad=True, device=device) 
# class_labels = (980, 437, 22, 562)  #@param {type:"raw"}


img = recon_B3HW.float()
f = torch.tensor(vae.quant_conv(vae.encoder(img)), requires_grad=True, device=device)
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.00001)
total_steps = 1

min_coarsness = len(patch_nums) - 2
max_coarsness = len(patch_nums) - 1

for i in tqdm(range(total_steps), total=total_steps):
    coarsness_step = np.random.randint(min_coarsness, max_coarsness)

    loss = get_optimisation_loss(f, coarsness_step)

    patch_grad_scale = patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
    loss*= patch_grad_scale
    
    step_weight = (total_steps - i) / total_steps
    loss*= step_weight

    level_ratio = (coarsness_step + 2)/ (len(patch_nums) + 1) 
    level_weight = np.sqrt(level_ratio)
    loss*= level_weight

    print(f'{level_ratio=}, {level_weight=}, {step_weight=}, {patch_grad_scale=}, {loss.item()=}')

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

f = f.detach()

    # if i % 10 == 0:
    #     print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

  f = torch.tensor(vae.quant_conv(vae.encoder(img)), requires_grad=True, device=device)
100%|██████████| 1/1 [00:00<00:00, 25.00it/s]

current_level=8, next_level=9
next_token_map lvl current_level=8 next_level=9: next_token_map.shape=torch.Size([16, 256, 1024])

level_ratio=0.9090909090909091, level_weight=0.9534625892455924, step_weight=1.0, patch_grad_scale=0.66015625, loss.item()=0.2066161185503006





In [30]:
image = vae.fhat_to_img(f.detach())#.add_(1).mul_(0.5)
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [24]:
len(patch_nums)

10