## Image Tokenization
Usage examples for the Quadtree image tokenizer and the vanilla ViT tokenizer.

The tokenizers prepare input images to be used as input for a standard Transformer model: \
they pass the patch pixels through an encoding layer, add sinusoidal position embeddings \
based on patch locations, and prepend the cls_token.

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from torch import nn

from mixed_res.patch_scorers.random_patch_scorer import RandomPatchScorer
from mixed_res.quadtree_impl.quadtree_z_curve import ZCurveQuadtreeRunner
from mixed_res.tokenization.patch_embed import FlatPatchEmbed, PatchEmbed
from mixed_res.tokenization.tokenizers import QuadtreeTokenizer, VanillaTokenizer

In [3]:
device = "cuda" if torch.cuda.is_available() else "cpu"
image_size = 256
channels = 3
min_patch_size = 16
max_patch_size = 64
quadtree_num_patches = 100
batch_size = 5
embed_dim = 384

images = torch.randn(batch_size, channels, image_size, image_size, device=device)
cls_token = nn.Parameter(torch.randn(embed_dim)).to(device)

### Tokenize images with a Quadtree tokenizer

In [4]:
# These will probably be initialized inside your ViT's __init__ method
patch_embed = FlatPatchEmbed(img_size=image_size, patch_size=min_patch_size, embed_dim=embed_dim).to(device)
quadtree_runner = ZCurveQuadtreeRunner(quadtree_num_patches, min_patch_size, max_patch_size)
patch_scorer = RandomPatchScorer()
quadtree_tokenizer = QuadtreeTokenizer(patch_embed, cls_token, quadtree_runner, patch_scorer)

# put this in your forward method
token_embeds = quadtree_tokenizer.tokenize(images)
token_embeds.shape  # [batch_size, 1 + num_patches, embed_dim]

torch.Size([5, 101, 384])

### Tokenize images with a vanilla ViT tokenizer

In [5]:
# These will probably be initialized inside your ViT's __init__ method
patch_embed = PatchEmbed(img_size=image_size, patch_size=min_patch_size, embed_dim=embed_dim).to(device)
vanilla_tokenizer = VanillaTokenizer(patch_embed, cls_token)

# put this in your forward method
token_embeds = vanilla_tokenizer.tokenize(images)
token_embeds.shape  # [batch_size, 1 + (image_size / patch_size)**2, embed_dim]

torch.Size([5, 257, 384])