### Self-Attention

In [23]:
import torch.nn as nn
import torch

class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out_kq, d_out_v):
        super().__init__()
        self.d_out_kq = d_out_kq
        self.W_query = nn.Linear(d_in, d_out_kq)
        self.W_key   = nn.Linear(d_in, d_out_kq)
        self.W_value = nn.Linear(d_in, d_out_v)

    def forward(self, x):
        keys, queries, values = self.W_key(x), self.W_query(x), self.W_value(x)

        attn_scores = queries @ keys.transpose(-2,-1)  # unnormalized attention weights
        # no mask
        attn_weights = torch.softmax(
            attn_scores / self.d_out_kq**0.5, dim=-1
        )

        # # causal mask
        # seq_len = attn_weights.shape[0] # seq_len x seq_len
        # mask = torch.triu(torch.ones(seq_len, seq_len), diagonal=1)
        # masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
        # attn_weights = torch.softmax(masked / self.d_out_kq**0.5, dim=1)

        context_vec = attn_weights @ values
        return context_vec

In [30]:
embedded_sentence = torch.randn((2, 6, 3)) #b, seq_len, n_dim
# reduce d_out_v from 4 to 1, because we have 4 heads
d_in, d_out_kq, d_out_v = 3, 2, 4
sa = SelfAttention(d_in, d_out_kq, d_out_v)

### Multi-head-Attention

In [31]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out_kq, d_out_v, num_heads):
        super().__init__()
        self.heads = nn.ModuleList(
            [SelfAttention(d_in, d_out_kq//num_heads, d_out_v//num_heads)
             for _ in range(num_heads)]
        )
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [34]:
mha = MultiHeadAttention(d_in, d_out_kq, d_out_v, num_heads=2)
context_vecs = mha(embedded_sentence)


### Temporal Mixing

consists of 2 layers - conv3d and temporal attention applied after every spatial layer

![image.png](../imgs/vid.png)

In [1]:
import torch
import torch.nn as nn
from einops import rearrange
import torch.nn.functional as F
from einops.layers.torch import Rearrange

In [None]:
class Conv3DLayer(nn.Module):
    def __init__(self, in_dim, out_dim, n_frames):
        super().__init__()

        self.to_3d = Rearrange('(b t) c h w -> b c t h w', t=n_frames)
        self.to_2d = Rearrange('b c t h w -> (b t) c h w')

        k, p = (3, 1, 1), (1, 0, 0)
        self.block1 = nn.Sequential(
            nn.GroupNorm(32, in_dim),
            nn.SiLU(),
            nn.Conv3d(in_dim, out_dim, kernel_size=k, stride=1, padding=p)
        )
        self.block2 = nn.Sequential(
            nn.GroupNorm(32, out_dim),
            nn.SiLU(),
            nn.Conv3d(out_dim, out_dim, kernel_size=k, stride=1, padding=p)
        )
        self.alpha = nn.Parameter(torch.ones(1))

    def forward(self, x):
        h = self.to_3d(x)

        h = self.block1(h)
        h = self.block2(h)

        h = self.to_2d(h)

        with torch.no_grad():
            self.alpha.clamp_(0, 1)
        out = self.alpha * x + (1 - self.alpha) * h
        return out

In [77]:
class TemporalAttentionLayer(nn.Module):
    def __init__(self, dim=64, n_frames=5, n_heads=8, kv_dim=None):
        super().__init__()
        self.n_frames = n_frames
        self.alpha = nn.Parameter(torch.ones(1))
        self.mha = MultiHeadAttention(d_in=dim, d_out_kq=dim, d_out_v=dim, num_heads=n_heads)

    def forward(self, q, kv=None, mask=None):
        skip = q
        bt, c, h, w = q.shape
        q = rearrange(q, '(b t) c h w -> (b h w) t c', t=self.n_frames)
        out = self.mha(q)
        out = rearrange(out, '(b h w) t c -> (b t) c h w', h=h, w=w)
        with torch.no_grad():
            self.alpha.clamp_(0, 1)
        out = self.alpha * skip + (1 - self.alpha) * out
        return out

In [80]:
t = TemporalAttentionLayer(dim=2*32*32, n_frames=5)
# bt, c, h, w
q = torch.randn(2*5,64,32,32)
out = t(q)

torch.Size([10, 64, 32, 32])
torch.Size([2048, 5, 64])
torch.Size([10, 64, 32, 32])
