In [1]:
import os
os.chdir('../../')

In [2]:
import torch 
import torch.nn as nn
import torch.nn.functional as F

In [3]:
%load_ext autoreload
%autoreload 2

# Test `MultiheadAttention`

In [4]:
from models.transformer.attention import MultiheadAttention

In [29]:
batch_size = 3
temporal_dim = 4
spatial_dim = 5
token_dim = 16
num_heads = 8

In [30]:
x = torch.randn((batch_size, temporal_dim, spatial_dim, token_dim))

x.shape

torch.Size([3, 4, 5, 16])

In [28]:
multihead_attention = MultiheadAttention(token_dim, num_heads)
multihead_attention_t = MultiheadAttention(token_dim, num_heads, spatial=False)

In [20]:
y = multihead_attention(x)

assert y.shape == (batch_size, temporal_dim, spatial_dim, token_dim)

y.shape

torch.Size([3, 4, 5, 16])

In [34]:
y = multihead_attention_t(x)

assert y.shape == (batch_size, temporal_dim, spatial_dim, token_dim), f"Got: {y.shape}"

y.shape

torch.Size([3, 4, 5, 16])

# `Test LateTemporalTokenizer`

In [12]:
from models.tokenizer.late_temporal_tokenizer import LateTemporalTokenizer

In [16]:
batch_size = 3
temporal_dim = 4
spatial_dim = 5
H = 8
W = 8
C = 4
num_frames = 10
token_dim = 16
num_heads = 8

In [21]:
x = torch.randn((batch_size, num_frames, H * W, C))

x.shape

torch.Size([3, 10, 64, 4])

In [23]:
late_temporal_tokenizer = LateTemporalTokenizer(C, spatial_dim, temporal_dim, token_dim)

In [24]:
y = late_temporal_tokenizer(x)

assert y.shape == (batch_size, temporal_dim, spatial_dim, token_dim)

y.shape

torch.Size([3, 4, 5, 16])

# Test `TransformerEncoder`

In [26]:
from models.transformer.encoder import TransformerEncoder

In [27]:
batch_size = 3
temporal_dim = 4
spatial_dim = 5
token_dim = 16
num_heads = 8

In [35]:
x = torch.randn((batch_size, temporal_dim, spatial_dim, token_dim))

x.shape

torch.Size([3, 4, 5, 16])

In [36]:
encoder = TransformerEncoder(token_dim, spatial=True, num_heads=num_heads)
encoder_t = TransformerEncoder(token_dim, spatial=False, num_heads=num_heads)

In [37]:
y = encoder(x)

assert y.shape == (batch_size, temporal_dim, spatial_dim, token_dim)

y.shape

torch.Size([3, 4, 5, 16])

In [38]:
y = encoder_t(x)

assert y.shape == (batch_size, temporal_dim, spatial_dim, token_dim)

y.shape

torch.Size([3, 4, 5, 16])

# Test `DividedTransformer`

In [41]:
from models.transformer.layer import DividedTransformer

In [42]:
batch_size = 3
temporal_dim = 4
spatial_dim = 5
token_dim = 16
num_heads = 8
layers_1 = [0, 1]
layers_2 = [0, 0, 0, 1, 1, 1]
layers_3 = [0, 1, 0, 1, 0, 1]
layers_4 = [0, 0, 0, 0]

In [43]:
x = torch.randn((batch_size, temporal_dim, spatial_dim, token_dim))

x.shape

torch.Size([3, 4, 5, 16])

In [44]:
t_1 = DividedTransformer(token_dim, layers_1)
t_2 = DividedTransformer(token_dim, layers_2)
t_3 = DividedTransformer(token_dim, layers_3)
t_4 = DividedTransformer(token_dim, layers_4)

In [46]:
y_1 = t_1(x)
y_2 = t_2(x)
y_3 = t_3(x)
y_4 = t_4(x)

assert y_1.shape == (batch_size, temporal_dim, spatial_dim, token_dim)
assert y_2.shape == (batch_size, temporal_dim, spatial_dim, token_dim)
assert y_3.shape == (batch_size, temporal_dim, spatial_dim, token_dim)
assert y_4.shape == (batch_size, temporal_dim, spatial_dim, token_dim)

y_1.shape

torch.Size([3, 4, 5, 16])

# Test `DividedVideoTransformer`

In [49]:
from models.video_transformer import DividedVideoTransformer

In [56]:
batch_size = 3
temporal_dim = 4
spatial_dim = 5
H = 32
W = 32
C = 3
num_frames = 10
token_dim = 16
num_heads = 8
layers = [0, 1]
num_classes = 5

In [62]:
x = torch.randn((batch_size, num_frames, C, H, W))

x.shape

torch.Size([3, 10, 3, 32, 32])

In [65]:
divided_transformer = DividedVideoTransformer(spatial_dim, temporal_dim, token_dim, num_classes=num_classes)

In [66]:
y = divided_transformer(x)

assert y.shape == (batch_size, num_classes)

y.shape

torch.Size([3, 5])