### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟

In [1]:
################## 1. Download checkpoints and build models
import os
if os.path.exists('/content/VAR'): os.chdir('/content/VAR')
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

# we recommend using imagenet-512-d36 model to do the in-painting & out-painting & class-condition editing task
MODEL_DEPTH = 24    # TODO: =====> please specify MODEL_DEPTH <=====

assert MODEL_DEPTH in {16, 20, 24, 30, 36}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'

# --- MODIFICATION START ---
checkpoint_dir = '/Users/gilliam/Desktop/493G1/VAR/VideoVAR/' # Your local directory

vae_ckpt_name = 'vae_ch160v4096z32.pth'
var_ckpt_name = f'var_d{MODEL_DEPTH}.pth' # This will correctly become 'var_d36.pth' if MODEL_DEPTH is 36

# Use these variables with the correct local paths
local_vae_ckpt_path = osp.join(checkpoint_dir, vae_ckpt_name)
local_var_ckpt_path = osp.join(checkpoint_dir, var_ckpt_name)

# Optional: Keep the download logic if the files might not be there
if not osp.exists(local_vae_ckpt_path):
    print(f"{local_vae_ckpt_path} not found. Downloading {vae_ckpt_name} to {checkpoint_dir}...")
    os.makedirs(checkpoint_dir, exist_ok=True) # Ensure directory exists
    os.system(f'wget {hf_home}/{vae_ckpt_name} -O {local_vae_ckpt_path}')
if not osp.exists(local_var_ckpt_path):
    print(f"{local_var_ckpt_path} not found. Downloading {var_ckpt_name} to {checkpoint_dir}...")
    os.makedirs(checkpoint_dir, exist_ok=True) # Ensure directory exists
    os.system(f'wget {hf_home}/{var_ckpt_name} -O {local_var_ckpt_path}')
# --- MODIFICATION END ---

# build vae, var
FOR_512_px = MODEL_DEPTH == 36
if FOR_512_px:
    patch_nums = (1, 2, 3, 4, 6, 9, 13, 18, 24, 32)
else:
    patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
vae, var = build_vae_var(
    V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
    device=device, patch_nums=patch_nums,
    num_classes=1000, depth=MODEL_DEPTH, shared_aln=FOR_512_px,
)

# load checkpoints
# --- ENSURE YOU USE THE CORRECT PATH VARIABLES HERE ---
vae.load_state_dict(torch.load(local_vae_ckpt_path, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(local_var_ckpt_path, map_location='cpu'), strict=True)
# ---

vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'preparation finished.')


[constructor]  ==== flash_if_available=True (0/24), fused_if_available=True (fusing_add_ln=0/24, fusing_mlp=0/24) ==== 
    [VAR config ] embed_dim=1536, num_heads=24, depth=24, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.1 (tensor([0.0000, 0.0043, 0.0087, 0.0130, 0.0174, 0.0217, 0.0261, 0.0304, 0.0348,
        0.0391, 0.0435, 0.0478, 0.0522, 0.0565, 0.0609, 0.0652, 0.0696, 0.0739,
        0.0783, 0.0826, 0.0870, 0.0913, 0.0957, 0.1000]))

[init_weights] VAR with init_std=0.0147314
preparation finished.


In [None]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1}
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 980, 980, 980, 980, 980, 980)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()
