### 🚀 For an interactive experience, head over to our [demo platform](https://var.vision/demo) and dive right in! 🌟

In [1]:
################## 1. Download checkpoints and build models
import os
import os.path as osp
import torch, torchvision
import random
import numpy as np
import PIL.Image as PImage, PIL.ImageDraw as PImageDraw
setattr(torch.nn.Linear, 'reset_parameters', lambda self: None)     # disable default parameter init for faster speed
setattr(torch.nn.LayerNorm, 'reset_parameters', lambda self: None)  # disable default parameter init for faster speed
from models import VQVAE, build_vae_var

MODEL_DEPTH = 16    # TODO: =====> please specify MODEL_DEPTH <=====
assert MODEL_DEPTH in {16, 20, 24, 30}


# download checkpoint
hf_home = 'https://huggingface.co/FoundationVision/var/resolve/main'
vae_ckpt, var_ckpt = 'vae_ch160v4096z32.pth', f'var_d{MODEL_DEPTH}.pth'
if not osp.exists(vae_ckpt): os.system(f'wget {hf_home}/{vae_ckpt}')
if not osp.exists(var_ckpt): os.system(f'wget {hf_home}/{var_ckpt}')

# build vae, var
patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
if 'vae' not in globals() or 'var' not in globals():
    vae, var = build_vae_var(
        V=4096, Cvae=32, ch=160, share_quant_resi=4,    # hard-coded VQVAE hyperparameters
        device=device, patch_nums=patch_nums,
        num_classes=1000, depth=MODEL_DEPTH, shared_aln=False,
    )

# load checkpoints
vae.load_state_dict(torch.load(vae_ckpt, map_location='cpu'), strict=True)
var.load_state_dict(torch.load(var_ckpt, map_location='cpu'), strict=True)
vae.eval(), var.eval()
for p in vae.parameters(): p.requires_grad_(False)
for p in var.parameters(): p.requires_grad_(False)
print(f'prepare finished.')

  from .autonotebook import tqdm as notebook_tqdm



[constructor]  ==== flash_if_available=True (0/16), fused_if_available=True (fusing_add_ln=0/16, fusing_mlp=0/16) ==== 
    [VAR config ] embed_dim=1024, num_heads=16, depth=16, mlp_ratio=4.0
    [drop ratios ] drop_rate=0.0, attn_drop_rate=0.0, drop_path_rate=0.0666667 (tensor([0.0000, 0.0044, 0.0089, 0.0133, 0.0178, 0.0222, 0.0267, 0.0311, 0.0356,
        0.0400, 0.0444, 0.0489, 0.0533, 0.0578, 0.0622, 0.0667]))

[init_weights] VAR with init_std=0.0180422
prepare finished.


In [40]:
############################# 2. Sample with classifier-free guidance

# set args
seed = 2137 #@param {type:"number"}
torch.manual_seed(seed)
num_sampling_steps = 250 #@param {type:"slider", min:0, max:1000, step:1} # IRRELEVANT -- variable not used
cfg = 4 #@param {type:"slider", min:1, max:10, step:0.1}
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}
more_smooth = False # True for more smooth output

# seed
torch.manual_seed(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# run faster
tf32 = True
torch.backends.cudnn.allow_tf32 = bool(tf32)
torch.backends.cuda.matmul.allow_tf32 = bool(tf32)
torch.set_float32_matmul_precision('high' if tf32 else 'highest')

# sample
B = len(class_labels)
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)
with torch.inference_mode():
    with torch.autocast('cuda', enabled=True, dtype=torch.float16, cache_enabled=True):    # using bfloat16 can be faster
        recon_B3HW = var.autoregressive_infer_cfg(B=B, label_B=label_B, cfg=cfg, top_k=900, top_p=0.95, g_seed=seed, more_smooth=more_smooth)

level 0 x: x.shape=torch.Size([16, 1, 1024])

level 1 x: x.shape=torch.Size([16, 4, 1024])

level 2 x: x.shape=torch.Size([16, 9, 1024])

level 3 x: x.shape=torch.Size([16, 16, 1024])

level 4 x: x.shape=torch.Size([16, 25, 1024])

level 5 x: x.shape=torch.Size([16, 36, 1024])

level 6 x: x.shape=torch.Size([16, 64, 1024])

level 7 x: x.shape=torch.Size([16, 100, 1024])

level 8 x: x.shape=torch.Size([16, 169, 1024])

level 9 x: x.shape=torch.Size([16, 256, 1024])



In [41]:

chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [6]:
# chw = torchvision.utils.make_grid(recon_B3HW, nrow=8, padding=0, pad_value=1.0)
# chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
# chw = PImage.fromarray(chw.astype(np.uint8))
# chw.show()
# recon_B3HW.

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found


kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [7]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
# image_pyramid = vae.img_to_idxBl(optimized_image, patch_nums)

# image_pyramid

In [8]:
# reconstructed = vae.idxBl_to_img(image_pyramid, same_shape=False)

In [7]:
# reconstructed[6].shape

### Encoding the input 

In [8]:
recon_B3HW.shape

torch.Size([8, 3, 256, 256])

In [9]:
device

'cuda'

In [10]:
# x = recon_B3HW
# x = torch.rand(8, 3, 256, 256, device=device)

x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))

# image = vae.fhat_to_img(f)
# f = vae.quant_conv(vae.encoder(image))

# image = vae.fhat_to_img(f)
# f = vae.quant_conv(vae.encoder(image))

# image = vae.fhat_to_img(f)
# f = vae.quant_conv(vae.encoder(image))

image = vae.fhat_to_img(f).clamp_(0, 1)

# x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
# f = vae.quant_conv(vae.encoder(x))
# image = vae.fhat_to_img(f).add_(1).mul_(0.5)

chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

# ls_f_hat_BChw = vae.quantize.f_to_idxBl_or_fhat(f, to_fhat=False, v_patch_nums=patch_nums)

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [11]:
x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
f = vae.quant_conv(vae.encoder(x))
f_hat, _, _ = vae.quantize(f)


image = vae.fhat_to_img(f_hat)

# x = recon_B3HW.type(torch.float32).to(device)#.add_(-0.5).mul_(2)
# f = vae.quant_conv(vae.encoder(x))
# image = vae.fhat_to_img(f).add_(1).mul_(0.5)

chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()


kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [12]:
# vae.quantize.forward(f)
# assert False    

In [13]:
from torch.nn import functional as F

patch_nums = (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)
last_patch_id = -1

quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

residual = f - f_hat

In [14]:
print(f"{len(quant_pyramid)=})")
if quant_pyramid:
    print(f"{quant_pyramid[0].shape=}")
# assert False

len(quant_pyramid)=0)


### Using the pyramid as the input for transformer

In [15]:
predicted_residual = var.autoregressive_single_step_prediction(
    quant_pyramid=quant_pyramid, 
    f_hat=f_hat, 
    label_B=label_B,) 

current_level=-1, next_level=0
next_token_map lvl current_level=-1 next_level=0: next_token_map.shape=torch.Size([16, 1, 1024])



In [None]:
f.shape

torch.Size([8, 32, 16, 16])

In [None]:
predicted_residual.shape

torch.Size([8, 32, 16, 16])

### Optimisation POC in latent space

In [3]:
from torch.nn import functional as F


def get_downsampling_regularisation_loss(f_residual, patch_size):
    downsampled = F.interpolate(f_residual, size=(patch_size, patch_size), mode='area', align_corners=False)

    return downsampled.mean() # L1 regularisation


def get_next_step_loss(quant_pyramid, f_residual, f_hat, label_B):
    predicted_residual = var.autoregressive_single_step_prediction(
        quant_pyramid=quant_pyramid, 
        f_hat=f_hat, 
        label_B=label_B, 
        cfg=cfg, 
        top_k=900, 
        top_p=0.95)
    
    return F.mse_loss(predicted_residual, f_residual)

def get_optimisation_loss(f, last_patch_id):
    quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

    f_residual = f - f_hat

    next_step_loss = get_next_step_loss(quant_pyramid, f_residual, f_hat, label_B)

    regularisation_losses = [
        get_downsampling_regularisation_loss(f_residual, patch_size)
        for patch_size in patch_nums[:last_patch_id]
    ]

    return next_step_loss + sum(regularisation_losses)

In [17]:

f = torch.rand([2, 32, 16, 16], requires_grad=True, device=device)

class_labels = (22, 562)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.1)

steps_per_coarsness = 1

for coarsness_step in range(-1, len((patch_nums)) - 1):
    for i in range(steps_per_coarsness):
        optimizer.zero_grad()
        loss = get_optimisation_loss(f, coarsness_step)
        patch_weight = patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
        step_weight = (steps_per_coarsness - i) /steps_per_coarsness
        
        loss*= patch_weight * step_weight

        if i % 500 == 0:
            print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

        
        loss.backward()
        optimizer.step()


current_level=-1, next_level=0
next_token_map lvl current_level=-1 next_level=0: next_token_map.shape=torch.Size([4, 1, 1024])

coarsness_step: -1, loss: 0.39253973960876465
current_level=0, next_level=1
next_token_map lvl current_level=0 next_level=1: next_token_map.shape=torch.Size([4, 4, 1024])

coarsness_step: 0, loss: 0.0009515469428151846
current_level=1, next_level=2
next_token_map lvl current_level=1 next_level=2: next_token_map.shape=torch.Size([4, 9, 1024])

coarsness_step: 1, loss: 0.002446998842060566
current_level=2, next_level=3
next_token_map lvl current_level=2 next_level=3: next_token_map.shape=torch.Size([4, 16, 1024])

coarsness_step: 2, loss: 0.004130854737013578
current_level=3, next_level=4
next_token_map lvl current_level=3 next_level=4: next_token_map.shape=torch.Size([4, 25, 1024])

coarsness_step: 3, loss: 0.006683871150016785
current_level=4, next_level=5
next_token_map lvl current_level=4 next_level=5: next_token_map.shape=torch.Size([4, 36, 1024])

coarsnes

In [None]:
assert False

AssertionError: 

kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"
kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [18]:
var.patch_nums

(1, 2, 3, 4, 5, 6, 8, 10, 13, 16)

In [19]:
image = vae.fhat_to_img(f).add_(1).mul_(0.5)

In [20]:
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


## Random coarsness step


In [40]:
def get_optimisation_loss(f, last_patch_id, f_hat=None, quant_pyramid=None):
    if f_hat is None or quant_pyramid is None:
        quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

    f_residual = f - f_hat

#     predicted_residual = var.autoregressive_single_step_prediction(
#         quant_pyramid=quant_pyramid, 
#         f_hat=f_hat, 
#         label_B=label_B, 
#         cfg=cfg, 
#         top_k=900, 
#         top_p=0.95)
    
#     # TODO add stability loss 
#     # Ensure f_residual scaled down to the "current" patch size should be close to zero (?)

#     return F.l1_loss(predicted_residual, f_residual)


In [6]:
from torch.nn import functional as F


def get_downsampling_regularisation_loss(f_residual, patch_size):
    downsampled = F.interpolate(f_residual, size=(patch_size, patch_size), mode='area')

    return F.mse_loss(downsampled, torch.zeros_like(downsampled)) # L1 regularisation


def get_next_step_loss(quant_pyramid, f_residual, f_hat, label_B):
    predicted_residual = var.autoregressive_single_step_prediction(
        quant_pyramid=quant_pyramid, 
        f_hat=f_hat, 
        label_B=label_B, 
        cfg=cfg, 
        top_k=900, 
        top_p=0.95)
    
    return F.mse_loss(predicted_residual, f_residual)

def get_optimisation_loss(f, last_patch_id, f_hat=None, quant_pyramid=None, regularisation_weight=0):
    if f_hat is None or quant_pyramid is None:
        quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, last_patch_id)

    f_residual = f - f_hat

    next_step_loss = get_next_step_loss(quant_pyramid, f_residual, f_hat, label_B)

    total_loss = next_step_loss


    if regularisation_weight > 0:
        regularisation_losses = [
            get_downsampling_regularisation_loss(f_residual, patch_size)
            for patch_size in patch_nums[:last_patch_id]
        ]

        total_loss += sum(regularisation_losses) * regularisation_weight

        
    return total_loss

In [86]:
from tqdm import tqdm
# f = torch.randn([4, 32, 16, 16], requires_grad=True, device=device) 
# class_labels = (980, 437, 22, 562)  #@param {type:"raw"}


img = recon_B3HW.float().clone().clamp_(0, 1).to(device)
f = torch.tensor(vae.quant_conv(vae.encoder(img).detach()), requires_grad=True, device=device)
class_labels = (980, 980, 437, 437, 22, 22, 562, 562)  #@param {type:"raw"}

label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.01)
total_steps = 1000

min_coarsness = len(patch_nums) - 2 #-1 #len(patch_nums) - 4
max_coarsness = len(patch_nums) - 2 



# quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, max_coarsness)
quant_pyramid, f_hat = None, None

for i in tqdm(range(total_steps), total=total_steps):
    coarsness_step = np.random.randint(min_coarsness, max_coarsness + 1)

    loss = get_optimisation_loss(f, coarsness_step, f_hat, quant_pyramid)

    patch_grad_scale = patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
    loss*= patch_grad_scale
    
    step_weight = (total_steps - i) / total_steps
    loss*= step_weight

    level_ratio = (coarsness_step + 2)/ (len(patch_nums) + 1) 
    level_weight = np.sqrt(level_ratio)
    loss*= level_weight

    # print(f'{level_ratio=}, {level_weight=}, {step_weight=}, {patch_grad_scale=}, {loss.item()=}')

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

f = f.detach()

    # if i % 10 == 0:
    #     print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

  f = torch.tensor(vae.quant_conv(vae.encoder(img).detach()), requires_grad=True, device=device)
100%|██████████| 1000/1000 [00:36<00:00, 27.25it/s]


In [87]:
image = vae.fhat_to_img(f.detach()).clamp_(0, 1)
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


In [None]:
len(patch_nums)

10

In [50]:
print("patch_sizes = ", patch_nums)


patch_sizes =  (1, 2, 3, 4, 5, 6, 8, 10, 13, 16)


# Fixed F-hat incremetal

In [47]:

f = torch.zeros([2, 32, 16, 16], requires_grad=True, device=device)

class_labels = (22, 562)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.05)

steps_per_coarsness = 500

for coarsness_level in range(-1, len((patch_nums)) - 1):
    losses = []

    quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, coarsness_level)

    f_clone = f.clone().detach()
    
    for i in range(steps_per_coarsness):
        optimizer.zero_grad()
        loss = get_optimisation_loss(f, coarsness_level, f_clone, quant_pyramid, regularisation_weight=0)
        #patch_weight = 1 #patch_nums[coarsness_level]**2 / patch_nums[len(patch_nums)-1]**2
        #step_weight = (steps_per_coarsness - i) /steps_per_coarsness
        
        #loss*= patch_weight * step_weight
        losses.append(loss.item())

        loss.backward()
        optimizer.step()

    print(f'coarsness_level: {coarsness_level}, loss: {np.mean(losses)}')


coarsness_level: -1, loss: 0.0006090999025363968
coarsness_level: 0, loss: 0.00017674557398780655
coarsness_level: 1, loss: 7.816978732463561e-05
coarsness_level: 2, loss: 0.000136359025294917
coarsness_level: 3, loss: 0.00022669045838961596
coarsness_level: 4, loss: 0.00021566336376152353
coarsness_level: 5, loss: 0.0003899050076426731
coarsness_level: 6, loss: 0.0004408489482989985
coarsness_level: 7, loss: 0.0005911831812500716
coarsness_level: 8, loss: 0.001166960721561788


In [48]:
image = vae.fhat_to_img(f.detach()).clamp_(0, 1)
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


# All at once optimization

In [38]:

from tqdm import tqdm

f = torch.zeros([2, 32, 16, 16], requires_grad=True, device=device)

class_labels = (980, 437)  #@param {type:"raw"}
label_B: torch.LongTensor = torch.tensor(class_labels, device=device)

optimizer = torch.optim.Adam([f], lr=0.05)

total_steps = 2000

for i in tqdm(range(total_steps), total=total_steps):
    loss = 0
    optimizer.zero_grad()

    final_patch = int(i/total_steps * len(patch_nums)) 

    for coarsness_level in range(-1, final_patch):
        patch_weight = patch_nums[coarsness_level+1]**2 / patch_nums[len(patch_nums)-1]**2
        loss += get_optimisation_loss(f, coarsness_level, regularisation_weight=0) * patch_weight
        

    if i % 100 == 0:
        print(f'loss: {loss.item()}')

    #step_weight = (total_steps - i) / total_steps



    #loss*= step_weight
    loss.backward()
    optimizer.step()


# for coarsness_step in range(-1, len((patch_nums)) - 1):
#     for i in range(steps_per_coarsness):
#         optimizer.zero_grad()
#         loss = get_optimisation_loss(f, coarsness_step)
#         patch_weight = patch_nums[coarsness_step]**2 / patch_nums[len(patch_nums)-1]**2
#         step_weight = (steps_per_coarsness - i) /steps_per_coarsness
        
#         loss*= patch_weight * step_weight

#         if i % 500 == 0:
#             print(f'coarsness_step: {coarsness_step}, loss: {loss.item()}')

        
#         loss.backward()
#         optimizer.step()


  0%|          | 0/2000 [00:00<?, ?it/s]

  2%|▎         | 50/2000 [00:00<00:07, 252.99it/s]

loss: 0.0002106762840412557


  7%|▋         | 136/2000 [00:00<00:06, 274.59it/s]

loss: 2.1594459553853085e-09


 11%|█         | 221/2000 [00:00<00:07, 223.78it/s]

loss: 0.00039762482629157603


 16%|█▌        | 320/2000 [00:01<00:11, 151.66it/s]

loss: 7.952936721267179e-05


 21%|██        | 412/2000 [00:02<00:12, 127.24it/s]

loss: 0.0012372417841106653


 26%|██▌       | 510/2000 [00:03<00:15, 93.53it/s] 

loss: 0.0016484935767948627


 30%|███       | 610/2000 [00:04<00:16, 84.09it/s]

loss: 0.00352624268271029


 36%|███▌      | 712/2000 [00:05<00:18, 68.97it/s]

loss: 0.003063205862417817


 41%|████      | 811/2000 [00:07<00:19, 62.23it/s]

loss: 0.004616558086127043


 45%|████▌     | 909/2000 [00:09<00:20, 54.39it/s]

loss: 0.006852190941572189


 50%|█████     | 1005/2000 [00:11<00:19, 51.24it/s]

loss: 0.007261977531015873


 55%|█████▌    | 1106/2000 [00:13<00:20, 44.59it/s]

loss: 0.012157270684838295


 60%|██████    | 1206/2000 [00:15<00:18, 42.31it/s]

loss: 0.016481220722198486


 65%|██████▌   | 1305/2000 [00:18<00:18, 37.69it/s]

loss: 0.015166565775871277


 70%|███████   | 1405/2000 [00:20<00:16, 35.47it/s]

loss: 0.02793954312801361


 75%|███████▌  | 1505/2000 [00:24<00:15, 31.21it/s]

loss: 0.025798745453357697


 80%|████████  | 1605/2000 [00:27<00:13, 28.57it/s]

loss: 0.04659782350063324


 85%|████████▌ | 1704/2000 [00:31<00:12, 23.94it/s]

loss: 0.04217388853430748


 90%|█████████ | 1803/2000 [00:35<00:08, 22.06it/s]

loss: 0.10323582589626312


 95%|█████████▌| 1904/2000 [00:40<00:05, 18.89it/s]

loss: 0.14244210720062256


100%|██████████| 2000/2000 [00:45<00:00, 43.48it/s]


In [39]:
image = vae.fhat_to_img(f.detach()).clamp_(0, 1)
chw = torchvision.utils.make_grid(image, nrow=8, padding=0, pad_value=1.0)
chw = chw.permute(1, 2, 0).mul_(255).cpu().numpy()
chw = PImage.fromarray(chw.astype(np.uint8))
chw.show()

kf.service.services: KApplicationTrader: mimeType "x-scheme-handler/file" not found
kf.i18n.kuit: "Unknown subcue ':whatsthis,' in UI marker in context {@info:whatsthis, %1 the action's text}."
org.kde.kdegraphics.gwenview.lib: Unresolved raw mime type  "image/x-samsung-srw"


# VAR prediction


In [None]:
for coarsness_level in range(-1, len((patch_nums)) - 1):
    losses = []

    quant_pyramid, f_hat = vae.quantize.f_to_quant_pyramid_and_f_hat(f, patch_nums, coarsness_level)
    