### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟

In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')


[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [2]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1} # IRRELEVANT -- variable not used
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
# with torch.inference_mode():
#     with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
#         recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

# chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()


In [3]:
# chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()
# recon_B3HW.

In [4]:
# optimized_image = torch.rand_like(recon_B3HW, requires_grad=True)
# image_pyramid = vae.img_to_idxBl(optimized_image, patch_nums)

# image_pyramid

In [5]:
# reconstructed = vae.idxBl_to_img(image_pyramid, same_shape=False)

In [6]:
# reconstructed[6].shape

### Encoding the input 

In [7]:
# x = recon_B3HW
x = torch.rand(8, 3, 256, 256)

f = vae.quant_conv(vae.encoder(x))

# ls_f_hat_BChw = vae.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=patch_nums)

In [8]:
from torch.nn import functional as F

patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
coarsness_step = 4

quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, coarsness_step)

residual = f - f_hat

In [9]:
len(quant_pyramid)

4

### Using the pyramid as the input for transformer

In [12]:
predicted_residual = var.autoregressive_single_step_prediction(
    quant_pyramid=quant_pyramid, 
    f_hat=f_hat, 
    label_B=label_B, 
    cfg=cfg, 
    top_k=900, 
    top_p=0.95)

In [13]:
f.shape

torch.Size([8, 32, 16, 16])

In [60]:
predicted_residual.shape

torch.Size([8, 32, 16, 16])

### Optimisation POC in latent space

In [14]:
def get_optimisation_loss(f, coarsness_step):
    quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, coarsness_step)

    f_residual = f - f_hat

    predicted_residual = var.autoregressive_single_step_prediction(
        quant_pyramid=quant_pyramid, 
        f_hat=f_hat, 
        label_B=label_B, 
        cfg=cfg, 
        top_k=900, 
        top_p=0.95)

    return F.mse_loss(predicted_residual, f_residual)


In [57]:

f = torch.rand([2, 32, 16, 16], requires_grad=True)

class_labels = (22, 562)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.1)

steps_per_coarsness = 100

for coarsness_step in range(1, len((patch_nums)) - 1):
    for i in range(steps_per_coarsness):
        optimizer.zero_grad()
        loss = get_optimisation_loss(f, coarsness_step)
        loss*= patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
        loss*= (steps_per_coarsness - i) / steps_per_coarsness

        loss.backward()
        optimizer.step()

        if i % 10 == 0:
            print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

coarsness_step: 1, loss: 0.004371508955955505
coarsness_step: 1, loss: 0.0003481769817881286
coarsness_step: 1, loss: 0.00036811825702898204
coarsness_step: 1, loss: 0.00014163559535518289
coarsness_step: 1, loss: 0.00015670272114221007
coarsness_step: 1, loss: 0.00011510938202263787
coarsness_step: 1, loss: 9.264543768949807e-05
coarsness_step: 1, loss: 6.400592974387109e-05
coarsness_step: 1, loss: 3.6390963941812515e-05
coarsness_step: 1, loss: 1.6867006706888787e-05
coarsness_step: 2, loss: 0.0005777326296083629
coarsness_step: 2, loss: 0.0005680759786628187
coarsness_step: 2, loss: 0.0006122048944234848
coarsness_step: 2, loss: 0.0005135508254170418
coarsness_step: 2, loss: 0.0003322355914860964
coarsness_step: 2, loss: 0.0001490399445174262
coarsness_step: 2, loss: 0.0001059482601704076
coarsness_step: 2, loss: 5.8833087678067386e-05
coarsness_step: 2, loss: 4.348806032794528e-05
coarsness_step: 2, loss: 1.7639913494349457e-05
coarsness_step: 3, loss: 0.00029077508952468634
coars

In [58]:
image = vae.fhat_to_img(f).add_(1).mul_(0.5)

In [59]:
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

Python(9189) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(9190) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.


## Random coarsness step


In [47]:
from tqdm import tqdm
f = torch.randn([8, 32, 16, 16], requires_grad=True) 

class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.1)
total_steps = 400

for i in tqdm(range(total_steps), total=total_steps):
    coarsness_step = np.random.randint(1, len((patch_nums)) - 1)

    optimizer.zero_grad()
    loss = get_optimisation_loss(f, coarsness_step)

    loss*= patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
    loss*= (total_steps - i) / total_steps # linearly decrease the loss weight
    loss.backward()
    optimizer.step()

    if i % 10 == 0:
        print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

  0%|          | 1/400 [00:05<34:39,  5.21s/it]

coarsness_step: 4, loss: 0.09834317117929459


  3%|▎         | 11/400 [03:48<7:03:56, 65.39s/it]

coarsness_step: 7, loss: 0.10095315426588058


  5%|▌         | 21/400 [04:00<19:12,  3.04s/it]  

coarsness_step: 5, loss: 0.015652334317564964


  8%|▊         | 31/400 [16:47<5:45:27, 56.17s/it]  

coarsness_step: 5, loss: 0.009880476631224155


 10%|█         | 41/400 [17:23<26:52,  4.49s/it]  

coarsness_step: 4, loss: 0.006207524798810482


 13%|█▎        | 51/400 [17:39<08:37,  1.48s/it]

coarsness_step: 2, loss: 0.0025906700175255537


 15%|█▌        | 61/400 [18:07<17:46,  3.14s/it]

coarsness_step: 7, loss: 0.018752446398139


 18%|█▊        | 71/400 [18:24<09:02,  1.65s/it]

coarsness_step: 5, loss: 0.0070991660468280315


 20%|██        | 81/400 [18:39<09:29,  1.78s/it]

coarsness_step: 6, loss: 0.00950113870203495


 23%|██▎       | 91/400 [18:58<04:56,  1.04it/s]

coarsness_step: 1, loss: 0.0011801017681136727


 25%|██▌       | 101/400 [19:17<13:33,  2.72s/it]

coarsness_step: 7, loss: 0.02104043960571289


 28%|██▊       | 111/400 [19:31<05:30,  1.14s/it]

coarsness_step: 3, loss: 0.00307333841919899


 30%|███       | 121/400 [19:51<11:12,  2.41s/it]

coarsness_step: 8, loss: 0.03673836961388588


 33%|███▎      | 131/400 [20:00<03:04,  1.45it/s]

coarsness_step: 5, loss: 0.0057668788358569145


 35%|███▌      | 141/400 [20:12<03:28,  1.24it/s]

coarsness_step: 2, loss: 0.002143984194844961


 38%|███▊      | 151/400 [20:27<04:23,  1.06s/it]

coarsness_step: 5, loss: 0.0038930713199079037


 40%|████      | 161/400 [20:51<07:53,  1.98s/it]

coarsness_step: 7, loss: 0.012571165338158607


 43%|████▎     | 171/400 [21:13<10:25,  2.73s/it]

coarsness_step: 8, loss: 0.027296559885144234


 45%|████▌     | 181/400 [21:26<03:53,  1.07s/it]

coarsness_step: 1, loss: 0.002047841204330325


 48%|████▊     | 191/400 [21:35<01:43,  2.02it/s]

coarsness_step: 2, loss: 0.003128730459138751


 50%|█████     | 201/400 [21:46<03:29,  1.05s/it]

coarsness_step: 1, loss: 0.0016092625446617603


 53%|█████▎    | 211/400 [21:59<02:30,  1.25it/s]

coarsness_step: 4, loss: 0.0026412273291498423


 55%|█████▌    | 221/400 [22:08<01:54,  1.56it/s]

coarsness_step: 2, loss: 0.0027278035413473845


 58%|█████▊    | 231/400 [22:21<04:43,  1.68s/it]

coarsness_step: 6, loss: 0.004148900043219328


 60%|██████    | 241/400 [22:35<03:48,  1.44s/it]

coarsness_step: 1, loss: 0.0016260698903352022


 63%|██████▎   | 251/400 [22:52<02:56,  1.18s/it]

coarsness_step: 1, loss: 0.0018258620984852314


 65%|██████▌   | 261/400 [23:07<03:01,  1.31s/it]

coarsness_step: 6, loss: 0.007293058093637228


 68%|██████▊   | 271/400 [23:27<05:12,  2.42s/it]

coarsness_step: 8, loss: 0.017124123871326447


 70%|███████   | 281/400 [23:45<03:44,  1.88s/it]

coarsness_step: 6, loss: 0.005782019346952438


 73%|███████▎  | 291/400 [23:58<02:24,  1.33s/it]

coarsness_step: 6, loss: 0.004763308446854353


 75%|███████▌  | 301/400 [24:14<03:01,  1.83s/it]

coarsness_step: 5, loss: 0.002259922446683049


 78%|███████▊  | 311/400 [24:26<01:22,  1.08it/s]

coarsness_step: 4, loss: 0.0020377703476697206


 80%|████████  | 321/400 [24:41<02:18,  1.76s/it]

coarsness_step: 6, loss: 0.0020869560539722443


 83%|████████▎ | 331/400 [24:54<02:33,  2.22s/it]

coarsness_step: 8, loss: 0.009092291817069054


 85%|████████▌ | 341/400 [25:07<00:56,  1.04it/s]

coarsness_step: 1, loss: 0.0007735262042842805


 88%|████████▊ | 351/400 [25:20<01:04,  1.32s/it]

coarsness_step: 6, loss: 0.0013377207797020674


 90%|█████████ | 361/400 [25:35<00:54,  1.39s/it]

coarsness_step: 6, loss: 0.0010890368139371276


 93%|█████████▎| 371/400 [25:58<01:17,  2.67s/it]

coarsness_step: 7, loss: 0.0014432725729420781


 95%|█████████▌| 381/400 [26:15<00:23,  1.26s/it]

coarsness_step: 4, loss: 0.0003595100424718112


 98%|█████████▊| 391/400 [26:28<00:09,  1.05s/it]

coarsness_step: 4, loss: 0.00018257004558108747


100%|██████████| 400/400 [26:41<00:00,  4.00s/it]


In [48]:
image = vae.fhat_to_img(f).add_(1).mul_(0.5)

In [49]:
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

Python(92571) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
Python(92584) MallocStackLogging: can't turn off malloc stack logging because it was not enabled.
