# MaskGIT Demo

## Clone Repository

In [None]:
is_colab = True
if is_colab:
    !git clone https://github.com/valeoai/Halton-MaskGIT.git
    %cd Halton-MaskGIT
    %pip install omegaconf>=2.0.0 einops>=0.3.0 webdataset>=2.0 huggingface_hub clean-fid torch-fidelity torchmetrics

In [None]:
import random
import matplotlib.pyplot as plt
import torch
import torchvision.utils as vutils
from huggingface_hub import hf_hub_download

import torchvision.transforms as T
from PIL import Image
import numpy as np
import torch.nn.functional as F

from Utils.utils import load_args_from_file
from Utils.viz import show_images_grid
from Trainer.cls_trainer import MaskGIT
from Sampler.halton_sampler import HaltonSampler
from Sampler.confidence_sampler import ConfidenceSampler

## Download Pretrained models

# MaskGIT initialisation

In [None]:
config_path = "Config/base_cls2img.yaml"        # Path to your config file
args = load_args_from_file(config_path)
args.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Overide the args parameters here to selec a different network
args.vit_size = "large"   # "tiny", "small", "base", "large"
args.img_size = 384       # 256 or 384
args.compile = False      # compile is faster
args.dtype = "float32"    # bfloat16 is faster 
args.resume = True
args.vit_folder = f"./saved_networks/ImageNet_{args.img_size}_{args.vit_size}.pth"

if is_colab:
    hf_hub_download(repo_id="llvictorll/Halton-Maskgit", filename=f"ImageNet_{args.img_size}_{args.vit_size}.pth", local_dir="./saved_networks")
    hf_hub_download(repo_id="FoundationVision/LlamaGen",  filename="vq_ds16_c2i.pt",  local_dir="./saved_networks")
    
maskgit = MaskGIT(args)

In [None]:
def viz(x, nrow=10, pad=2, size=(18, 18)):
    """
    Visualize a grid of images.

    Args:
        x (torch.Tensor): Input images to visualize.
        nrow (int): Number of images in each row of the grid.
        pad (int): Padding between the images in the grid.
        size (tuple): Size of the visualization figure.

    """
    nb_img = len(x)
    min_norm = x.min(-1)[0].min(-1)[0].min(-1)[0].view(-1, 1, 1, 1)
    max_norm = x.max(-1)[0].max(-1)[0].max(-1)[0].view(-1, 1, 1, 1)
    x = (x - min_norm) / (max_norm - min_norm)

    x = vutils.make_grid(x.float().cpu(), nrow=nrow, padding=pad, normalize=False)
    plt.figure(figsize = size)
    plt.axis('off')
    plt.imshow(x.permute(1, 2, 0))
    plt.show()

def decoding_viz(gen_code, mask, maskgit):
    """
    Visualize the decoding process of generated images with associated masks.

    Args:
        gen_code (torch.Tensor): Generated code for decoding.
        mask (torch.Tensor): Mask used for decoding.
        maskgit (MaskGIT): MaskGIT instance.
    """
    start = torch.FloatTensor([1, 1, 1]).view(1, 3, 1, 1).expand(1, 3, maskgit.input_size, maskgit.input_size) * 0.8
    end = torch.FloatTensor([0.01953125, 0.30078125, 0.08203125]).view(1, 3, 1, 1).expand(1, 3, maskgit.input_size, maskgit.input_size) * 1.4
    code = torch.stack((gen_code), dim=0).squeeze()
    mask = torch.stack((mask), dim=0).view(-1, 1, maskgit.input_size, maskgit.input_size).cpu()

    with torch.no_grad():
        x = maskgit.ae.decode_code(torch.clamp(code, 0, maskgit.args.mask_value))

    binary_mask = (1-mask) * start + mask * end
    binary_mask = vutils.make_grid(binary_mask, nrow=len(gen_code), padding=1, pad_value=0.4, normalize=False)
    binary_mask = binary_mask.permute(1, 2, 0)

    plt.figure(figsize = (18, 2))
    plt.gca().invert_yaxis()
    plt.pcolormesh(binary_mask, edgecolors='w', linewidth=.5)
    plt.axis('off')
    plt.show()

    viz(x, nrow=len(gen_code))

# MaskGIT Sampling With the Halton Scheduler

In [None]:
sm_temp_min = 1.1   # Minimum softmax temperature for sampling.
sm_temp_max = 1.1   # Maximum softmax temperature for sampling.
top_k=-1            # If > 0, applies top-k sampling for token selection. 
temp_pow = 1        # Exponent for temperature scheduling.
w=2                 # CFG weight
sched_pow=2         # Power factor for the progressive unmasking schedule.
step=32             # Number of steps to sample an image
temp_warmup=1       # Number of initial steps where temperature is reduced.
randomize=True      # If True, applies random shifts to the Halton sequence for diverse sampling.

sampler = HaltonSampler(sm_temp_min=sm_temp_min, sm_temp_max=sm_temp_max, temp_pow=temp_pow, temp_warmup=temp_warmup, w=w, sched_pow=sched_pow, step=step, randomize=randomize, top_k=top_k)
# goldfish, chicken, tiger cat, hourglass, ship, dog, race car, airliner, teddy bear
labels = torch.LongTensor([1, 7, 282, 604, 724, 179, 681, 850]).to(args.device)

# Generate sample
gen_sample, gen_code, l_mask = sampler(maskgit, nb_sample=len(labels), labels=labels)
viz(gen_sample, nrow=10, size=(18, 18))

In [None]:
nb_class = 8
nb_sample = 2
nb_row = 8
# Generate all ImageNet Classes
for l in range(0, 1000, nb_class):
    labels = [l+i for i in range(nb_class)] * nb_sample
    labels = torch.LongTensor(labels).to(args.device)
    # Generate sample
    gen_sample, gen_code, l_mask = sampler(maskgit, nb_sample=len(labels), labels=labels)
    x = viz(gen_sample, nrow=nb_row, size=(18, 18))
    