# Overall

In order to train the SinVAR, normally, it includes to train the VAE first. However, given the time-horizon for the current project, 
and to stay with the original VAR paper, I will use the pretrained VQVAE as they have used. However, I will look into fine-tune the VQVAE if possible. 

Then, after the VQVAE, we have to train the VAR.

In [1]:
import torch
import torch.nn as nn
from typing import Tuple

In [2]:
from vqvae import VQVAE
from var import VAR


def build_vae_var(
        # Shared args
        device, patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),  # 10 steps by default
        # VQVAE args
        V=4096, Cvae=32, ch=160, share_quant_resi=4,
        # VAR args
        depth=16, attn_l2_norm=True,
        init_head=0.02, init_std=-1,  # init_std < 0: automated
) -> Tuple[VQVAE, VAR]:
    heads = depth
    width = depth * 64
    dpr = 0.1 * depth / 24

    # disable built-in initialization for speed
    for clz in (nn.Linear, nn.LayerNorm, nn.BatchNorm2d, nn.SyncBatchNorm, nn.Conv1d, nn.Conv2d, nn.ConvTranspose1d,
                nn.ConvTranspose2d):
        setattr(clz, 'reset_parameters', lambda self: None)

    # build models
    vae_local = VQVAE(vocab_size=V, z_channels=Cvae, ch=ch, test_mode=True, share_quant_resi=share_quant_resi,
                      v_patch_nums=patch_nums).to(device)
    var_wo_ddp = VAR(
        vae_local=vae_local,
        depth=depth, embed_dim=width, num_heads=heads, drop_rate=0., attn_drop_rate=0., drop_path_rate=dpr,
        norm_eps=1e-6,
        attn_l2_norm=attn_l2_norm,
        patch_nums=patch_nums,
    ).to(device)
    var_wo_ddp.init_weights(init_head=init_head, init_std=init_std)

    return vae_local, var_wo_ddp

## VQVAE

- Loading from the Hugging Face
- Trying to Reconstruct on Random Crop

In [3]:
vae_var_config = {
    # Shared
    "device": torch.device("mps" if torch.mps.is_available() else "cpu"),
    "patch_nums" : (1, 2, 3, 4, 5, 6, 8, 10, 13, 16),

    # VAR config (customizable for your setup)
    "depth": 16,  # VAR transformer depth
    "attn_l2_norm": True,  

    # Initialisation options (irrelevant for non-adaptive setup)
    "init_head": 0.02,
    "init_std": -1,  # Use default init scheme
}

In [4]:
vae_local, var_wo_ddp = build_vae_var(
    V=4096, Cvae=32, ch=160, share_quant_resi=4,  # hard-coded VQVAE hyperparameters
    device=vae_var_config['device'],
    depth=vae_var_config['depth'], attn_l2_norm=vae_var_config['attn_l2_norm'],
    init_head=vae_var_config['init_head'], init_std=vae_var_config['init_std'],
)


[constructor]  ==== UNCONDITIONAL VAR: using SelfAttnBlock (16 blocks, no class label) ====
    [VAR config ] embed_dim=1024, num_heads=16, depth=16
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422


In [5]:
# download checkpoint for VQVAE
import os
import os.path as osp

hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt = 'vae_ch160v4096z32.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')

--2025-04-19 18:40:03--  https://huggingface.co/FoundationVision/var/resolve/main/vae_ch160v4096z32.pth
Resolving huggingface.co (huggingface.co)... 18.67.93.22, 18.67.93.63, 18.67.93.102, ...
Connecting to huggingface.co (huggingface.co)|18.67.93.22|:443... connected.
HTTP request sent, awaiting response... 302 Found
Location: https://cdn-lfs-us-1.hf.co/repos/27/33/2733ebd65833f8330005ce942c9195c84a3d385ac604d32ebe5d6d9a79385456/7c3ec27ae28a3f87055e83211ea8cc8558bd1985d7b51742d074fb4c2fcf186c?response-content-disposition=inline%3B+filename*%3DUTF-8%27%27vae_ch160v4096z32.pth%3B+filename%3D%22vae_ch160v4096z32.pth%22%3B&Expires=1745055603&Policy=eyJTdGF0ZW1lbnQiOlt7IkNvbmRpdGlvbiI6eyJEYXRlTGVzc1RoYW4iOnsiQVdTOkVwb2NoVGltZSI6MTc0NTA1NTYwM319LCJSZXNvdXJjZSI6Imh0dHBzOi8vY2RuLWxmcy11cy0xLmhmLmNvL3JlcG9zLzI3LzMzLzI3MzNlYmQ2NTgzM2Y4MzMwMDA1Y2U5NDJjOTE5NWM4NGEzZDM4NWFjNjA0ZDMyZWJlNWQ2ZDlhNzkzODU0NTYvN2MzZWMyN2FlMjhhM2Y4NzA1NWU4MzIxMWVhOGNjODU1OGJkMTk4NWQ3YjUxNzQyZDA3NGZiNGMyZmNmMTg2Yz9yZXNwb2