# Vision Transformers for Edge Devices

MobileViT core step for rearranging convolution2d activation maps into (B, T, C) format for use with transformers.

The key takeaway here is that rather than merging the patches into the channel dimension i.e `(B, C, H, W) -> (B, (h*w), (pw*ph*C))` to form (B,T,C) for standard ViTs. 

MobileViTs merge the patches into the batch dimension i.e `(B, C, H, W) -> (B*pw*ph, (h*w), C)` to form (B,T,C). This allows each patch to attend to other patches using the `C` data obtained from convolutional layers rather than ViTs which create their `C` data from the raw pixel values. 

**It's a powerful 2 step process where it parallelises the patch to patch attention across the entire batch dimension and allows the model to learn from the convolutional features rather than raw pixel values.**

$$X_{G}(p)=\text{Transformer}(X_{U}(p))$$

Here, the transformer is applied independently for each $p\in\{1,\cdot\cdot\cdot,P\}$. By reshaping to an effective batch size of $B \times P$, standard PyTorch transformer layers can process all $P$ relative pixel locations across all $B$ images entirely in parallel.

In [24]:
import torch
from einops import rearrange

H, W = 128, 128
C = 64
B = 8

x = torch.randn(B, C, H, W)
print(f"Original tensor = {x.shape}")  # (B, C, H, W)

print("---" * 10)

# UNFOLD
# Merge the patches into the batch dimension, and rearrange to (B, T, C) format for use with transformers.
ph, pw = 4, 4
# re-arrange from (B, C, (h ph), (w pw)) -> (B, C, N, P)
tx = rearrange(x, 'b c (h ph) (w pw) -> b c (h w) (ph pw)', ph=ph, pw=pw)
tx = rearrange(tx, 'B C N P -> (B P) N C')  # (B, T, C) where P is the patch size (ph*pw), and N is the number of patches (H//ph * W//pw) and C is the number of channels.
print(f"Unfolded tensor = {tx.shape}")   # (BP, N, C) -> (B*pw*ph, h*w, C) -> (B, T, C)
# NB: For plain ViT, we would just do rearrange(x, 'b c (h ph) (w pw) -> b (h w) (ph pw c)', ph=ph, pw=pw) to get (B, T, C) directly without merging the patch dimension into the batch dimension.


# FOLD
# re-arrange from (BP, N, C) OR (B, T, C) -> (B, C, (h ph), (w pw))
x2 = rearrange(tx, '(b p) (h w) c -> b c (h w) p', h=H//ph, w=W//pw, p=ph*pw)
x2 = rearrange(x2, 'b c (h w) (ph pw) -> b c (h ph) (w pw)', ph=ph, pw=pw, h=H//ph)
print(f"Folded tensor = {x2.shape}")  # (B, C, H, W)

print("---" * 10)

# The above was done for clarity, but we can also do it in one step:
# UNFOLD: Single-step rearrange to (B*P, N, C)
# B*P = b * ph * pw
# N = h * w (where h = H//ph and w = W//pw)
tx = rearrange(x, 'b c (h ph) (w pw) -> (b ph pw) (h w) c', ph=ph, pw=pw)
print(f"Unfolded tensor = {tx.shape}") 

# FOLD: Single-step reverse rearrange back to (B, C, H, W)
x2 = rearrange(tx, '(b ph pw) (h w) c -> b c (h ph) (w pw)', b=B, ph=ph, pw=pw, h=H//ph, w=W//pw)
print(f"Folded tensor = {x2.shape}")

# Sanity check to ensure the math maps perfectly
assert torch.allclose(x, x2), "Folded tensor does not match the original!"

Original tensor = torch.Size([8, 64, 128, 128])
------------------------------
Unfolded tensor = torch.Size([128, 1024, 64])
Folded tensor = torch.Size([8, 64, 128, 128])
------------------------------
Unfolded tensor = torch.Size([128, 1024, 64])
Folded tensor = torch.Size([8, 64, 128, 128])
