In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
class TimeSFormerConfig():
  def __init__(self):
    self.in_channels=3
    self.embed_dim=768
    self.img_size=256
    self.patch_size=32
    self.num_patch=(self.img_size//self.patch_size)**2
    self.dropout=0.1
    self.num_heads=4
    self.num_frames=10
    self.transformer_units=1
    self.num_class=10

In [None]:
class PatchEmbedding(nn.Module):
  def __init__(self,config:TimeSFormerConfig):
    super().__init__()
    self.proj=nn.Conv2d(in_channels=config.in_channels,out_channels=config.embed_dim,kernel_size=config.patch_size,stride=config.patch_size)
  def forward(self,x):
    B,C,T,H,W=x.shape
    x=x.transpose(1,2).reshape(B*T,C,H,W)  #B*T,C,H,W
    x=self.proj(x).flatten(2)   #B*T,E_D,N
    _,E_D,N=x.shape
    x=x.reshape(B,T,N,E_D)
    return x


In [None]:
class FF(nn.Module):
  def __init__(self,config:TimeSFormerConfig):
    super().__init__()
    self.fc1=nn.Linear(config.embed_dim,4*config.embed_dim)
    self.fc2=nn.Linear(4*config.embed_dim,config.embed_dim)
    self.dropout=nn.Dropout(config.dropout)
  def forward(self,x):
    return self.dropout(self.fc2(self.dropout(F.gelu(self.fc1(x)))))


In [None]:
"""For one block: apply temporal attention then spatial attention (or vice versa).
   Temporal attention: attend across T for each patch (shape: B*N, T, D)
   Spatial attention: attend across patches for each frame (shape: B*T, N, D)
"""

'For one block: apply temporal attention then spatial attention (or vice versa).\n   Temporal attention: attend across T for each patch (shape: B*N, T, D)\n   Spatial attention: attend across patches for each frame (shape: B*T, N, D)\n'

In [None]:
class DivideSpaceTimeBlock(nn.Module):
  def __init__(self,config:TimeSFormerConfig):
    super().__init__()
    self.ln1=nn.LayerNorm(config.embed_dim)
    self.ln2=nn.LayerNorm(config.embed_dim)
    self.ln3=nn.LayerNorm(config.embed_dim)
    self.temp_attn=nn.MultiheadAttention(embed_dim=config.embed_dim,num_heads=config.num_heads,dropout=config.dropout,batch_first=True)
    self.spat_attn=nn.MultiheadAttention(embed_dim=config.embed_dim,num_heads=config.num_heads,dropout=config.dropout,batch_first=True)
    self.ff=FF(config)
    self.dropout=nn.Dropout(config.dropout)
  def forward(self,x,cls_token):
    B,T,N,D=x.shape

    #Temporal Attn

    xt=x.transpose(1,2).contiguous().view(B*N,T,D)
    xt=self.ln1(xt)

    temp_attn_out,_=self.temp_attn(xt,xt,xt)
    xt=self.dropout(temp_attn_out)
    xt=xt+temp_attn_out

    xt=xt.view(B,N,T,D).transpose(1,2)

    x=x+xt

    #Spatial Attn

    xs=x.view(B*T,N,D)
    cls_rep=cls_token.unsqueeze(1).expand(B,1,D).repeat(1,T,1).view(B*T,1,D)
    seq=torch.cat((cls_rep,xs),dim=1)
    seq_norm=self.ln2(seq)
    attn_spat_out,_=self.spat_attn(seq_norm,seq_norm,seq_norm)
    seq=seq+self.dropout(attn_spat_out)

    cls_token=seq[:,0,:].view(B,T,D)
    patch_out=seq[:,1:,:].view(B,T,N,D)
    x=patch_out+x

    x_flat=x.view(B*T*N,D)
    x_norm=self.ln3(x_flat)
    x=self.ff(x).view(B,T,N,D)

    return x,cls_token



In [None]:
class TimeSFormer(nn.Module):
  def __init__(self,config:TimeSFormerConfig):
    super().__init__()
    self.patch_embed=PatchEmbedding(config)
    self.cls_token=nn.Parameter(torch.randn(1,config.embed_dim))
    self.pos_emb=nn.Parameter(torch.randn(1,config.num_frames,config.embed_dim))
    self.dropout=nn.Dropout(config.dropout)
    self.blocks=nn.ModuleList([DivideSpaceTimeBlock(config) for _ in range(config.transformer_units)])
    self.ln=nn.LayerNorm(config.embed_dim)
    self.head=nn.Linear(config.embed_dim,config.num_class)
    self.__init_weights()

  def __init_weights(self):
    nn.init.trunc_normal_(self.pos_emb,std=0.2)
    nn.init.trunc_normal_(self.cls_token,std=0.2)


  def forward(self,x):
    B,C,T,H,W=x.shape
    x=self.patch_embed(x)
    B,T,N,D=x.shape
    cls_token=self.cls_token.expand(B,D)
    print(cls_token.shape)

    x=x+self.pos_emb.unsqueeze(2)
    x=self.dropout(x)

    for block in self.blocks:
      x,cls_per_frame=block(x,cls_token)
      cls_token=cls_per_frame.mean(dim=1,keepdim=True)
      cls_token=cls_token.squeeze(1)


    cls=self.ln(cls_token)
    out=self.head(cls)
    return out


In [None]:
model=TimeSFormer(TimeSFormerConfig())
t=torch.randn([32,3,10,256,256])
out=model(t)
out.shape

torch.Size([32, 768])


torch.Size([32, 10])