In [2]:
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch
import adaptive_tokenizers
from utils import misc

# set arguments accordingly
args = {
    'image_path': 'assets/custom_images/birds/000000.png',
    'device': 'cuda:0',
    'input_size': 256,
    'model': 'alit_small',
    'base_tokenizer': 'vae',
    'ckpt': 'adaptive_tokenizers/pretrained_models/imagenet100/alit_small_vae_continuous_latents.pth',
    'quantize_latent': False
}
args = misc.Args(**args)


image = Image.open(args.image_path).convert("RGB")
transform_val = transforms.Compose([
    transforms.Resize(args.input_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor()
])
image_tensor = transform_val(image).to(args.device)[None]

plt.imshow(image_tensor[0].permute([1,2,0]).cpu().numpy())
plt.show()

base_tokenizer_args = {
    "id": args.base_tokenizer,
    "is_requires_grad": False
}
adaptive_tokenizer = adaptive_tokenizers.__dict__[args.model](
    base_tokenizer_args=base_tokenizer_args, quantize_latent=args.quantize_latent, 
    train_stage="full_finetuning")

adaptive_tokenizer.to(args.device)
checkpoint = torch.load(args.ckpt, map_location='cpu')
adaptive_tokenizer.load_state_dict(checkpoint['ema'], strict=True)
adaptive_tokenizer.eval()

with torch.no_grad():
    # Automatic sample minimum length representation for the image.
    # Currently, we support only "Reconstruction Loss < Threshold" as automatic Token Selection Criteria (TSC).
    reconstruction_loss_threshold = 0.05
    min_length_embed, min_length_reconstruction = adaptive_tokenizer.encode(image_tensor, return_min_length_embedding=True, token_selection_criteria="reconstruction_loss", threshold=reconstruction_loss_threshold, return_embedding_type="latent_tokens")
    print("Sampling Minimum Length Encoding...")
    print("The minimum token count under the provided token selcetion criteria (reconstruction loss < {}) is: {}".format(reconstruction_loss_threshold, min_length_embed.shape[1]))
    plt.imshow(np.clip(min_length_reconstruction[0].permute(1,2,0).cpu().numpy(), 0., 1.))
    plt.show()

    # To sample, all / multiple represntations per image. Longer represntations are learned with more processing and more memory.
    print("Sampling All Length Encodings...")
    all_embeds, all_reconstructions = adaptive_tokenizer.encode(image_tensor, return_min_length_embedding=False)
    for embed, reconstruction_image in zip(all_embeds, all_reconstructions):
        plt.imshow(np.clip(reconstruction_image[0].permute(1,2,0).cpu().numpy(), 0., 1.))
        plt.show()