The goal here is to build a FSQ-VAE which builds latent vectors for spatio-temporal "tublets" described in the ViViT. Sequences of tublet latent vectors are to then be modeled by a Transformer Decoder

![](https://i.imgur.com/9G7QTfV.png)

In [1]:
import torch
from torch import nn

In [2]:
# I want to take a video tensor of shape (B, C, T, H, W) and split it into
# patches of shape (t, p, p) aka "tubelets", but instead of performing a single
# Conv3d operation on them such that the tublet dims are the kernel dims, I want
# to have tublets be the input to a small VAE so that the resulting latents
# can be mapped back to pixel space

# video dimensions
B, C, T, H, W = 4, 3, 64, 256, 256

# patch dim
t, p = 8, 16

assert T % t == 0
assert H % p == 0
assert W % p == 0

n_t, n_h, n_w = T // t, H // p, W // p

vid = torch.randn(B, C, T, H, W)

In [None]:
# (B, C, T, H, W) -> (B, C, n_t, t, n_h, p, n_w, p)
tubelets = vid.reshape(B, C, n_t, t, n_h, p, n_w, p)
print(tubelets.shape)

In [None]:
# (B, C, n_t, t, n_h, p, n_w, p) -> (B, n_t, n_h, n_w, C, t, p, p)
tubelets = tubelets.permute(0, 2, 4, 6, 1, 3, 5, 7).contiguous()
print(tubelets.shape)

In [None]:
# (B, n_t, n_h, n_w, C, t, p, p) -> ((B * n_t * n_h * n_w), C, t, p, p)
tubelets = tubelets.reshape(-1, C, t, p, p)
print(tubelets.shape)

In [None]:
conv3d_1 = nn.Conv3d(in_channels=C, out_channels=32, kernel_size=3, stride=2, padding=1)
out1 = conv3d_1(tubelets)
print(out1.shape)

In [None]:
conv3d_2 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out2 = conv3d_2(out1)
print(out2.shape)

In [None]:
conv3d_3 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=2, padding=1)
out3 = conv3d_3(out2)
print(out3.shape)

In [None]:
conv3d_4 = nn.Conv3d(in_channels=32, out_channels=32, kernel_size=3, stride=(1, 2, 2), padding=1)
out4 = conv3d_4(out3)
print(out4.shape)