In [1]:
# Install required libraries
%pip install huggingface_hub requests
import time
import os
import os.path as osp
import torch
import torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
from models import VQVAE, build_vae_var

Note: you may need to restart the kernel to use updated packages.


  from .autonotebook import tqdm as notebook_tqdm
    PyTorch 2.7.0+cu126 with CUDA 1206 (you have 2.7.0+cpu)
    Python  3.10.11 (you have 3.10.0)
  Please reinstall xformers (see https://github.com/facebookresearch/xformers#installing-xformers)
  Memory-efficient attention, SwiGLU, sparse and more won't be available.
  Set XFORMERS_MORE_DETAILS=1 for more details


In [2]:
# Disable default parameter initialization for faster speed
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)

In [3]:
# Model depth configuration
MODEL_DEPTH = 16  # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}, "Invalid MODEL_DEPTH value!"

# Define checkpoint URLs
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt = 'vae_ch160v4096z32.pth'
var_ckpt = f'var_d{MODEL_DEPTH}.pth'

# Function to download files if not already present
def download_file(url, filename):
    if not osp.exists(filename):
        print(f"Downloading {filename}...")
        response = requests.get(url, stream=True)
        response.raise_for_status()
        with open(filename, 'wb') as f:
            for chunk in response.iter_content(chunk_size=8192):
                f.write(chunk)
        print(f"Downloaded {filename}")

# Download checkpoints
os.makedirs("checkpoints", exist_ok=True)
vae_ckpt_path = osp.join("checkpoints", vae_ckpt)
var_ckpt_path = osp.join("checkpoints", var_ckpt)
download_file(f"{hf_home}/{vae_ckpt}", vae_ckpt)
download_file(f"{hf_home}/{var_ckpt}", var_ckpt)

In [4]:
# Auto-select device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Build models
vae, var = build_vae_var(
    V=4096,
    Cvae=32,
    ch=160,
    share_quant_resi=4,
    device=device,
    patch_nums=(1, 2, 3, 4, 5, 6, 8, 10, 13, 16),
    num_classes=1000,
    depth=MODEL_DEPTH,
    shared_aln=False,
)

Using device: cpu

[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422


In [5]:
# Load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)

# Load VAR model checkpoint with strict=False to allow mismatched parameters
var_checkpoint = torch.load(var_ckpt, map_location='cpu')

# Filter out mismatched keys
filtered_checkpoint = {k: v for k, v in var_checkpoint.items() if k in var.state_dict() and var.state_dict()[k].shape == v.shape}

# Load the filtered checkpoint
var.load_state_dict(filtered_checkpoint, strict=False)

<All keys matched successfully>

In [6]:
# Log skipped parameters
missing_keys, unexpected_keys = var.load_state_dict(filtered_checkpoint, strict=False)

if missing_keys:
    print(f"Missing keys in the checkpoint: {missing_keys}")
    
if unexpected_keys:
    print(f"Unexpected keys in the checkpoint: {unexpected_keys}")

In [7]:
# Handle mismatched parameters by initializing them with default values
for key in missing_keys:
    if key in var.state_dict():
        print(f"Initializing missing key: {key}")
        var.state_dict()[key].copy_(torch.zeros_like(var.state_dict()[key]))

In [8]:
# Set models to evaluation mode and freeze parameters
vae.eval()
var.eval()

for param in vae.parameters():
    param.requires_grad_(False)
for param in var.parameters():
    param.requires_grad_(False)

print("Model preparation finished.")

Model preparation finished.


In [9]:
# Set arguments
seed = 0  
num_sampling_steps = 250  
cfg = 4
class_labels = (43, 113, 134)  
more_smooth = False 
num_diffusion_steps = 50
noise_scale = 0.1

In [10]:
# Seed for reproducibility
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

In [11]:
# Enable TensorFloat-32 for faster computation
tf32 = True 
torch.backends.cudnn.allow_tf32 = tf32
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

In [12]:
# Sampling process setup
B = len(class_labels)
label_B = torch.tensor(class_labels, device=device)

In [13]:
start = time.time()

# Run the hybrid model
with torch.inference_mode():
    recon_B3HW = var.autoregressive_infer_cfg(
        B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth
    )

end = time.time()

elapsed_time = end - start

print(f"Inference Time: {elapsed_time:.2f} seconds")

Inference Time: 294.96 seconds


In [14]:
# Visualize the results
grid = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
grid = grid.permute(1, 2, 0).mul_(255).cpu().numpy()
PImage.fromarray(grid.astype(np.uint8)).show()

In [15]:
# Improved pseudo diffusion model
class DiffusionModel:
    def __init__(self, device):
        self.device = device

    def refine(self, structure, num_steps=num_diffusion_steps, noise_scale=noise_scale):
        # Normalize to [0, 1]
        structure = torch.clamp(structure, 0, 1)

        # Improved refinement loop for better clarity
        noisy = structure.clone()
        for step in range(num_steps):
            noise = noise_scale * torch.randn_like(noisy)
            noisy = torch.clamp(noisy + noise, 0, 1)  # Add slight noise
            alpha = 0.5 * (1 - step / num_steps)  # Gradually reduce noise influence
            noisy = noisy.lerp(structure, alpha)  # Gradually refine towards the original structure

        return noisy

In [16]:
# Define the hybrid AR-Diffusion architecture
class HybridARDiffusion:
    def __init__(self, ar_model, diffusion_model, device):
        self.ar_model = ar_model
        self.diffusion_model = diffusion_model
        self.device = device

    def generate(self, class_labels, cfg, num_diffusion_steps, top_k=900, top_p=0.95, noise_scale=noise_scale):
        # Step 1: Generate coarse structure using AR model
        with torch.inference_mode():
            structure = self.ar_model.autoregressive_infer_cfg(
                B=B, label_B=label_B, cfg=cfg, top_k=top_k, top_p=top_p
            )

        # Step 2: Refine the structure using the diffusion model
        refined_output = self.diffusion_model.refine(structure, num_steps=num_diffusion_steps, noise_scale=noise_scale)
        return refined_output

In [17]:
# Instantiate the hybrid model
hybrid_model = HybridARDiffusion(var, DiffusionModel(device), device)

In [18]:
# Generate refined images and calculate inference time
start = time.time()

refined_images = hybrid_model.generate(class_labels, cfg, num_diffusion_steps=num_diffusion_steps, noise_scale=noise_scale)

end = time.time()

elapsed_time = end - start

print(f"Inference Time: {elapsed_time:.2f} seconds")

Inference Time: 343.67 seconds


In [19]:
# Visualize the results
grid = torchvision.utils.make_grid(refined_images, nrow=8, padding=0, pad_value=1.0)
grid = grid.permute(1, 2, 0).mul(255).byte().cpu().numpy()
PImage.fromarray(grid).show()