In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import matplotlib.pyplot as plt

from svtr.model import custom_layers, custom_blocks
from svtr.model import model

### Create some dummy data

In [None]:
img_shape = (3, 32, 200)  # channels first
image_batch = torch.rand(size=(5, *img_shape))
plt.imshow(image_batch[0].permute(1, 2, 0));

### Test individual components of the model

In [None]:
# patch encoding
patch_embedding = custom_blocks.PatchEmbedding(img_shape, hdim1=32, hdim2=64)

x0 = patch_embedding(image_batch)
x0.shape

In [None]:
# positional embedding
pos_embedding = torch.nn.Embedding(num_embeddings=patch_embedding.nr_patches, embedding_dim=patch_embedding.hdim2)

emb_indices = torch.arange(0, patch_embedding.nr_patches, dtype=torch.int32)
x1 = x0 + pos_embedding(emb_indices)
x1.shape

In [None]:
# multi-head attention layer
mha_local = custom_layers.WindowedMultiheadAttention(embed_dim=64, num_heads=4, mixing_type='local', in_hw=[patch_embedding.out_h, patch_embedding.out_w], window_shape=[7, 11])
mha_global = custom_layers.WindowedMultiheadAttention(embed_dim=64, num_heads=4, mixing_type='global', in_hw=[patch_embedding.out_h, patch_embedding.out_w])

In [None]:
x2_loc = mha_local(x1)
print(x2_loc.shape)
x2_glob = mha_global(x1)
print(x2_glob.shape)

In [None]:
# mixing blocks
mixing_block = custom_blocks.MixingBlock(embed_dim=64, 
                                         num_heads=4, 
                                         mixing_type='local', 
                                         window_shape=[7, 11], 
                                         in_hw=[patch_embedding.out_h, patch_embedding.out_w], 
                                         mlp_hidden_dim_factor=4, 
                                         attn_dropout=0.5, 
                                         linear_dropout=0.5, 
                                         act=torch.nn.GELU)
x3 = mixing_block(x1)
x3.shape

In [None]:
# mixing blocks + merging
stage_merging = custom_blocks.MixingBlocksMerging(embed_dim=64, 
                                                  out_dim=128, 
                                                  num_heads=4, 
                                                  mixing_type_list=['local', 'global', 'local', 'global'], 
                                                  window_shape=[7, 11], 
                                                  in_hw=[patch_embedding.out_h, patch_embedding.out_w], 
                                                  mlp_hidden_dim_factor=4, 
                                                  attn_dropout=0.5, 
                                                  linear_dropout=0.5, 
                                                  act=torch.nn.GELU)
x4 = stage_merging(x1)
x4.shape

In [None]:
# mixing blocks + combining
stage_combining = custom_blocks.MixingBlocksCombining(embed_dim=128, 
                                                      out_dim=192, 
                                                      num_heads=4, 
                                                      mixing_type_list=['local', 'global', 'local', 'global'], 
                                                      window_shape=[7, 11], 
                                                      in_hw=[stage_merging.out_h, stage_merging.out_w], 
                                                      mlp_hidden_dim_factor=4, 
                                                      attn_dropout=0.5, 
                                                      linear_dropout=0.5, 
                                                      act=torch.nn.GELU)
x5 = stage_combining(x4)
x5.shape

### Test full model

In [None]:
svtr = model.SVTR(architecture='tiny', img_shape=[3,32,200])

out = svtr(image_batch)
out.shape