In [1]:
import torch
import numpy as np
from models import mar
from models.vae import AutoencoderKL
from torchvision.utils import save_image
from util import download
from PIL import Image
from IPython.display import display
torch.set_grad_enabled(False)
device = "cuda" if torch.cuda.is_available() else "cpu"
if device == "cpu":
    print("GPU not found. Using CPU instead.")

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# model_type = "mar_huge" #@param ["mar_base", "mar_large", "mar_huge"]
model_type = "mar_base" #@param ["mar_base", "mar_large", "mar_huge"]
num_sampling_steps_diffloss = 100 #@param {type:"slider", min:1, max:1000, step:1}
if model_type == "mar_base":
  download.download_pretrained_marb(overwrite=False)
  diffloss_d = 6
  diffloss_w = 1024
elif model_type == "mar_large":
  download.download_pretrained_marl(overwrite=False)
  diffloss_d = 8
  diffloss_w = 1280
elif model_type == "mar_huge":
  download.download_pretrained_marh(overwrite=False)
  diffloss_d = 12
  diffloss_w = 1536
else:
  raise NotImplementedError
model = mar.__dict__[model_type](
  buffer_size=64,
  diffloss_d=diffloss_d,
  diffloss_w=diffloss_w,
  num_sampling_steps=str(num_sampling_steps_diffloss)
).to(device)
state_dict = torch.load("pretrained_models/mar/{}/checkpoint-last.pth".format(model_type))["model_ema"]
model.load_state_dict(state_dict)
model.eval() # important!
vae = AutoencoderKL(embed_dim=16, ch_mult=(1, 1, 2, 2, 4), ckpt_path="pretrained_models/vae/kl16.ckpt").cuda().eval()
     

Working with z of shape (1, 16, 16, 16) = 4096 dimensions.
Loading pre-trained KL-VAE
Missing keys:
[]
Unexpected keys:
[]
Restored from pretrained_models/vae/kl16.ckpt


In [4]:
# Set user inputs:
seed = 0 #@param {type:"number"}
torch.manual_seed(seed)
np.random.seed(seed)
num_ar_steps = 64 #@param {type:"slider", min:1, max:256, step:1}
cfg_scale = 4 #@param {type:"slider", min:1, max:10, step:0.1}
cfg_schedule = "constant" #@param ["linear", "constant"]
temperature = 1.0 #@param {type:"slider", min:0.9, max:1.1, step:0.01}

# class_labels = 207, 360, 388, 113, 355, 980, 323, 979 #@param {type:"raw"}
# samples_per_row = 4 #@param {type:"number"}

class_labels = [207] #@param {type:"raw"}
samples_per_row = 1 #@param {type:"number"}

with torch.cuda.amp.autocast():
  sampled_tokens = model.sample_tokens(
      bsz=len(class_labels), num_iter=num_ar_steps,
      cfg=cfg_scale, cfg_schedule=cfg_schedule,
      labels=torch.Tensor(class_labels).long().cuda(),
      temperature=temperature, progress=True)
  print(sampled_tokens.shape)
  sampled_images = vae.decode(sampled_tokens / 0.2325)

# Save and display images:
save_image(sampled_images, "sample.png", nrow=int(samples_per_row), normalize=True, value_range=(-1, 1))
samples = Image.open("sample.png")
display(samples)

  0%|          | 0/64 [00:00<?, ?it/s]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
*** NameError: name 'num_itr' is not defined
64
64
tensor([[ True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
          True,  True,  True,  True,  True,  True,  True,  True,  True,  True,
  

  2%|▏         | 1/64 [05:18<5:34:10, 318.27s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
torch.Size([1, 16])
torch.Size([1, 16])
tensor([[False, False, False, False, False, False, False, False, False, False,
         False, False, False, False, False, False, False, False, False, False,
         Fals

  3%|▎         | 2/64 [07:16<3:27:30, 200.81s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
torch.Size([1, 16])
torch.Size([1, 16])


  5%|▍         | 3/64 [07:42<2:02:46, 120.76s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


  6%|▋         | 4/64 [07:45<1:14:20, 74.35s/it] 

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


  8%|▊         | 5/64 [07:46<47:08, 47.94s/it]  

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
torch.Size([1, 16])
torch.Size([1, 16])


  9%|▉         | 6/64 [07:54<33:11, 34.34s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 11%|█         | 7/64 [07:55<22:17, 23.46s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 12%|█▎        | 8/64 [07:56<15:17, 16.39s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 14%|█▍        | 9/64 [07:57<10:35, 11.55s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
torch.Size([1, 16])
torch.Size([1, 16])


 16%|█▌        | 10/64 [08:08<10:12, 11.35s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 17%|█▋        | 11/64 [08:09<07:15,  8.21s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 19%|█▉        | 12/64 [08:10<05:10,  5.97s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 20%|██        | 13/64 [08:13<04:13,  4.97s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m


 22%|██▏       | 14/64 [08:15<03:18,  3.97s/it]

> [0;32m/export/home/visual_tokenizer_sfr_intern/mar/models/mar.py[0m(324)[0;36msample_tokens[0;34m()[0m
[0;32m    322 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    323 [0;31m[0;34m[0m[0m
[0m[0;32m--> 324 [0;31m            [0mcur_tokens[0m[0;34m[[0m[0mmask_to_pred[0m[0;34m.[0m[0mnonzero[0m[0;34m([0m[0mas_tuple[0m[0;34m=[0m[0;32mTrue[0m[0;34m)[0m[0;34m][0m [0;34m=[0m [0msampled_token_latent[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    325 [0;31m            [0mtokens[0m [0;34m=[0m [0mcur_tokens[0m[0;34m.[0m[0mclone[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    326 [0;31m[0;34m[0m[0m
[0m
torch.Size([3, 16])
torch.Size([6, 768])


 22%|██▏       | 14/64 [11:22<40:36, 48.73s/it]


In [7]:
vae
# reconstruct with vae:
reconstructed = vae(sampled_images)

AutoencoderKL(
  (encoder): Encoder(
    (conv_in): Conv2d(3, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down): ModuleList(
      (0-1): 2 x Module(
        (block): ModuleList(
          (0-1): 2 x ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
            (norm2): GroupNorm(32, 128, eps=1e-06, affine=True)
            (dropout): Dropout(p=0.0, inplace=False)
            (conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          )
        )
        (attn): ModuleList()
        (downsample): Downsample(
          (conv): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2))
        )
      )
      (2): Module(
        (block): ModuleList(
          (0): ResnetBlock(
            (norm1): GroupNorm(32, 128, eps=1e-06, affine=True)
            (conv1): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
         