In [1]:
import sys, os
import numpy as np
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torch.nn.functional as F
import torch
import adaptive_tokenizers
from utils import misc

In [12]:
# For all token visualizations in the paper, we used alit_small_vqgan_quantized_latents.pth model (attention maps from the 6th layer of the decoder).
# Interesting token-object binding visualizations are on images with more than one object, eg: coco images (even though the model is trained only on Imagenet100)
# The last iteration attention maps are most crisp and sharp, and have best token-object binding. This is because the presence of more tokens (or memory) in the last iteration, allows some tokens to specialize and attend to objects.
# Try visualizing the few coco images we have attached in assets/ directory.

args = {
    'image_path': 'assets/custom_images/coco/val2017_000000000632.jpg',
    'device': 'cuda:1',
    'input_size': 256,
    'model': 'alit_small',
    'base_tokenizer': 'vqgan',
    'ckpt': 'adaptive_tokenizers/pretrained_models/imagenet100/alit_small_vqgan_quantized_latents.pth',
    'quantize_latent': True, # 1D quantized models have better object-token binding.
    'num_layers': 8, # 8 for small, 12 for base, 16 for large
    'layer_to_visualize': 6,
    'iter_to_visualize': 7 # 0<=iter_to_visualize<=7
}
args = misc.Args(**args)

image = Image.open(args.image_path).convert("RGB")
img = np.asarray(image)
transform_val = transforms.Compose([
    transforms.Resize(args.input_size, interpolation=transforms.InterpolationMode.BICUBIC, antialias=True),
    transforms.CenterCrop(args.input_size),
    transforms.ToTensor()
])
image_tensor = transform_val(image).to(args.device)[None]

plt.imshow(image_tensor[0].permute([1,2,0]).cpu().numpy())
plt.show()

base_tokenizer_args = {
    "id": args.base_tokenizer,
    "is_requires_grad": False
}
adaptive_tokenizer = adaptive_tokenizers.__dict__[args.model](
    base_tokenizer_args=base_tokenizer_args, quantize_latent=args.quantize_latent, 
    train_stage="full_finetuning", visualize_decoder_attn_weights=True)

adaptive_tokenizer.to(args.device)
checkpoint = torch.load(args.ckpt, map_location='cpu')
adaptive_tokenizer.load_state_dict(checkpoint['ema'], strict=True)
adaptive_tokenizer.eval()

with torch.no_grad():
    # Automatic sample minimum length representation for the image.
    # Currently, we support only "Reconstruction Loss < Threshold" as automatic Token Selection Criteria (TSC).
    reconstruction_loss_threshold = 0.05
    _, _, all_logs = adaptive_tokenizer.encode(image_tensor, return_min_length_embedding=False, token_selection_criteria="reconstruction_loss", threshold=reconstruction_loss_threshold, return_embedding_type="latent_tokens")
    
    num_tokens = 32
    start_tokens = 0
    
    for iter in range(8):

        fig, axes = plt.subplots(4, 8, figsize=(20, 10))
        plt.subplots_adjust(wspace=0, hspace=0)
        for idx in range(start_tokens, start_tokens + num_tokens):
            attn_weight = np.concatenate(all_logs[iter]["decoded_attn_weights_{}".format(iter)])
            attn_weight_vis = attn_weight[args.layer_to_visualize, 1:257, 256 + idx].reshape(16, 16)
            attn_weight_vis = torch.tensor(attn_weight_vis).unsqueeze(0).unsqueeze(0)  # Shape (1, 1, 16, 16)
            attn_weight_vis = F.interpolate(attn_weight_vis, size=(256, 256), mode='bilinear')
            attn_weight_vis = attn_weight_vis.squeeze()  # Shape (256, 256)
            unnormalized_high_attention_mask = attn_weight_vis

            row = (idx - start_tokens) // 8  # 4 rows for 32 tokens
            col = (idx - start_tokens) % 8   # 8 columns for 32 tokens

            min_val, max_val = unnormalized_high_attention_mask.min(), unnormalized_high_attention_mask.max() 
            # max_val = 0.03 # for vae based model, setting optimal max_val leads to better visualization
            high_attention_mask = (unnormalized_high_attention_mask - min_val) / (max_val - min_val)
            high_attention_mask[high_attention_mask>1]=1
            im = axes[row, col].imshow(high_attention_mask, cmap='jet', aspect='auto')
            axes[row, col].imshow(img, alpha=0.4, aspect='auto')
            axes[row, col].axis('off')
        
        plt.show(fig)