In [3]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataclasses import dataclass   
from typing import List,Tuple,Optional,Callable

In [None]:
# dataclasses block will be used to store configuration like heads , channels,kernel 
# typing to specify what type of data each type should hold 


In [4]:
@dataclass
class MSBlockConfig:
    num_heads:int
    input_channels:int
    output_channels:int
    kernel_q:List[int]
    kernel_kv:List[int]
    stride_q:List[int]
    stride_kv:List[int]
    
    

In [5]:
def _prod(s:List[int])->int:
  product=1
  for v in s:
   product*=v;
  return product

def _unsqueeze(x: torch.Tensor,target_dim:int,expand_dim:int):
    tensor_dim=x.dim()
    if tensor_dim==target_dim-1:
        x=x.unsqueeze(expand_dim)

    elif tensor_dim!=target_dim:
        raise ValueError(f"unsupported input dimension {x.shape}")
    return x,tensor_dim

def _squeeze(x:torch.tensor,target_dim:int,expand_dim:int,tensor_dim:int):
    if tensor_dim==target_dim-1:
        x=x.squeeze(expand_dim)
    return x   

    


In [None]:
##  expand_dim is the index where new dimension will get added
## squeeze and unsqueeze are used inside the pool block 
## squeeze and unsqueeze are paired operations first we unnsqueze then pool and squeeze it back 

In [6]:
class Pool(nn.Module):
    def __init__(
        self,
        pool: nn.Module,
        norm: Optional[nn.Module],
        activation: Optional[nn.Module] = None,
        norm_before_pool: bool = False,
    ) -> None:

        super().__init__()
        self.pool = pool

        layers = []
        if norm is not None:
            layers.append(norm)
        if activation is not None:
            layers.append(activation)

    
        if len(layers) > 0:
            self.norm_act = nn.Sequential(*layers)
        else:
            self.norm_act = None

        self.norm_before_pool = norm_before_pool


In [None]:
## super().__init__()
## Initializes nn.Module

In [7]:
def forward(self,x:torch.Tensor,thw: Tuple[int,int,int]):
    x,tensor_dim=_unsqueeze(x,4,1)
    class_token,x=torch.tensor_split(x,indices=(1,),dim=2)
    x=x.transpose(2,3)
    B,N,C=x.shape[:3]
    x=x.reshape((B*N,C)+ thw)

    if self.norm_before_pool and self.norm_act is not None:
        x=self.norm_act(x)
    x, thw_new = self.pool(x, thw)
    T, H, W = thw_new
    x=x.reshape(B,N,C,-1).transpose(2,3)
    x=torch.cat((class_token,x),dim=2)

    if not self.norm_before_pool and self.norm_act is not None:
        x=self.norm_act(x)

    x=_squeeze(x,4,1,tensor_dim)

    return x,(T,H,W)

In [None]:
## 4 dimension input is received by attention and popling converts it into 5 dimension


In [8]:
class MultiscaleAttention(nn.Module):
    def __init__(
        self,
        input_size: List[int],
        embed_dim: int,
        output_dim: int,
        num_heads: int,
        kernel_q: List[int],
        kernel_kv: List[int],
        stride_q: List[int],
        stride_kv: List[int],
        residual_pool: bool,
        residual_with_cls_embed: bool,
        rel_pos_embed: bool,
        dropout: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()

        self.embed_dim = embed_dim
        self.output_dim = output_dim
        self.num_heads = num_heads

        self.head_dim = output_dim // num_heads
        self.scaler = 1.0 / math.sqrt(self.head_dim)

        self.residual_pool = residual_pool
        self.residual_with_cls_embed = residual_with_cls_embed

        self.qkv = nn.Linear(embed_dim, 3 * output_dim)

        layers: List[nn.Module] = [nn.Linear(output_dim, output_dim)]
        if dropout > 0.0:
            layers.append(nn.Dropout(dropout, inplace=True))
        self.project = nn.Sequential(*layers)

        #q pooling
        self.pool_q: Optional[nn.Module] = None
        if _prod(kernel_q) > 1 or _prod(stride_q) > 1:
            padding_q = [k // 2 for k in kernel_q]
            self.pool_q = Pool(
                nn.Conv3d(
                    self.head_dim,
                    self.head_dim,
                    kernel_q,
                    stride=stride_q,
                    padding=padding_q,
                    groups=self.head_dim,
                    bias=False,
                ),
                norm_layer(self.head_dim),
            )

        #k,v pooling
        self.pool_k: Optional[nn.Module] = None
        self.pool_v: Optional[nn.Module] = None
        if _prod(kernel_kv) > 1 or _prod(stride_kv) > 1:
            padding_kv = [k // 2 for k in kernel_kv]
            self.pool_k = Pool(
                nn.Conv3d(
                    self.head_dim,
                    self.head_dim,
                    kernel_kv,
                    stride=stride_kv,
                    padding=padding_kv,
                    groups=self.head_dim,
                    bias=False,
                ),
                norm_layer(self.head_dim),
            )
            self.pool_v = Pool(
                nn.Conv3d(
                    self.head_dim,
                    self.head_dim,
                    kernel_kv,
                    stride=stride_kv,
                    padding=padding_kv,
                    groups=self.head_dim,
                    bias=False,
                ),
                norm_layer(self.head_dim),
            )

        # relposembedng
        self.rel_pos_h: Optional[nn.Parameter] = None
        self.rel_pos_w: Optional[nn.Parameter] = None
        self.rel_pos_t: Optional[nn.Parameter] = None

        if rel_pos_embed:
            size = max(input_size[1:])
            q_size = size // stride_q[1] if len(stride_q) > 0 else size
            kv_size = size // stride_kv[1] if len(stride_kv) > 0 else size

            spatial_dim = 2 * max(q_size, kv_size) - 1
            temporal_dim = 2 * input_size[0] - 1

            self.rel_pos_h = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
            self.rel_pos_w = nn.Parameter(torch.zeros(spatial_dim, self.head_dim))
            self.rel_pos_t = nn.Parameter(torch.zeros(temporal_dim, self.head_dim))

            nn.init.trunc_normal_(self.rel_pos_h, std=0.02)
            nn.init.trunc_normal_(self.rel_pos_w, std=0.02)
            nn.init.trunc_normal_(self.rel_pos_t, std=0.02)


In [9]:
class MultiscaleBlock(nn.Module):
    def __init__(
        self,
        input_size: List[int],
        cnf: MSBlockConfig,
        residual_pool: bool,
        residual_with_cls_embed: bool,
        rel_pos_embed: bool,
        proj_after_attn: bool,
        dropout: float = 0.0,
        stochastic_depth_prob: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:
        super().__init__()

        self.proj_after_attn = proj_after_attn

        # Skip connection pooling
        self.pool_skip: Optional[nn.Module] = None
        if _prod(cnf.stride_q) > 1:
            kernel_skip = [s + 1 if s > 1 else s for s in cnf.stride_q]
            padding_skip = [k // 2 for k in kernel_skip]
            self.pool_skip = Pool(
                nn.MaxPool3d(kernel_skip, stride=cnf.stride_q, padding=padding_skip),
                None,
            )

        attn_dim = cnf.output_channels if proj_after_attn else cnf.input_channels

        self.norm1 = norm_layer(cnf.input_channels)
        self.norm2 = norm_layer(attn_dim)

        self.needs_transposal = isinstance(self.norm1, nn.BatchNorm1d)

        self.attn = MultiscaleAttention(
            input_size,
            cnf.input_channels,
            attn_dim,
            cnf.num_heads,
            kernel_q=cnf.kernel_q,
            kernel_kv=cnf.kernel_kv,
            stride_q=cnf.stride_q,
            stride_kv=cnf.stride_kv,
            residual_pool=residual_pool,
            residual_with_cls_embed=residual_with_cls_embed,
            rel_pos_embed=rel_pos_embed,
            dropout=dropout,
            norm_layer=norm_layer,
        )

        self.mlp = MLP(
            attn_dim,
            [4 * attn_dim, cnf.output_channels],
            activation_layer=nn.GELU,
            dropout=dropout,
            inplace=None,
        )

        self.stochastic_depth = StochasticDepth(stochastic_depth_prob, "row")

        self.project: Optional[nn.Module] = None
        if cnf.input_channels != cnf.output_channels:
            self.project = nn.Linear(cnf.input_channels, cnf.output_channels)


In [10]:
def forward(self, x: torch.Tensor, thw: Tuple[int, int, int]):
    
    if self.needs_transposal:       # before attention
        x_norm1 = self.norm1(x.transpose(1, 2)).transpose(1, 2)
    else:
        x_norm1 = self.norm1(x)

    
    x_attn, thw_new = self.attn(x_norm1, thw)

    
    if self.project is not None and self.proj_after_attn:   # project before attention
        x_res = self.project(x_norm1)
    else:
        x_res = x

   
    if self.pool_skip is not None:
        x_skip = self.pool_skip(x, thw)[0]  ## to skip the cls token from pooling
    else:
        x_skip = x

    
    x = x_skip + self.stochastic_depth(x_attn)

    
    if self.needs_transposal:             # before mlp 
        x_norm2 = self.norm2(x.transpose(1, 2)).transpose(1, 2)
    else:
        x_norm2 = self.norm2(x)

    
    if self.project is not None and not self.proj_after_attn:  ## project before mlp after attention
        x_res = self.project(x_norm2)
    else:
        x_res = x

    x = x_res + self.stochastic_depth(self.mlp(x_norm2))

    return x, thw_new


In [11]:
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size: int, spatial_size: Tuple[int, int], temporal_size: int, rel_pos_embed: bool) -> None:
        super().__init__()

        self.spatial_size = spatial_size
        self.temporal_size = temporal_size

        
        self.class_token = nn.Parameter(torch.zeros(embed_size))

        
        self.spatial_pos: Optional[nn.Parameter] = None
        self.temporal_pos: Optional[nn.Parameter] = None
        self.class_pos: Optional[nn.Parameter] = None

        if not rel_pos_embed:
            self.spatial_pos = nn.Parameter(torch.zeros(spatial_size[0] * spatial_size[1], embed_size))
            self.temporal_pos = nn.Parameter(torch.zeros(temporal_size, embed_size))
            self.class_pos = nn.Parameter(torch.zeros(embed_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        
        class_token = self.class_token.expand(x.size(0), -1).unsqueeze(1)   # (b,c)
        x = torch.cat((class_token, x), dim=1)  # added cls token 

        
        if self.spatial_pos is not None and self.temporal_pos is not None and self.class_pos is not None:
            hw_size, embed_size = self.spatial_pos.shape

            
            pos_embedding = torch.repeat_interleave(self.temporal_pos, hw_size, dim=0)

            
            pos_embedding += (
                self.spatial_pos.unsqueeze(0)
                .expand(self.temporal_size, -1, -1)
                .reshape(-1, embed_size)
            )

            
            pos_embedding = torch.cat((self.class_pos.unsqueeze(0), pos_embedding), dim=0).unsqueeze(0)  ## adding  cls token in positional embeddings
  
            x = x + pos_embedding

        return x


In [22]:
class MViT(nn.Module):
    def __init__(
        self,
        input_channels: int,               
        input_size: List[int],            
        embed_dim: int,                   
        block_setting: List[MSBlockConfig],
        num_classes: int,
        rel_pos_embed: bool = True,
        dropout: float = 0.0,
        norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
    ) -> None:

        super().__init__()

      
        self.patch_embed = nn.Conv3d( 
            in_channels=input_channels, 
            out_channels=embed_dim,
            kernel_size=(3, 4, 4),
            stride=(2, 4, 4),
            padding=(1, 1, 1),  
            bias=False
        )

        T, H, W = input_size

        # no . of vectors to be made in t , h , w 
        self.embed_T = (T + 2*1 - 3) // 2 + 1    
        self.embed_H = (H + 2*1 - 4) // 4 + 1    
        self.embed_W = (W + 2*1 - 4) // 4 + 1    
        self.embed_dim = embed_dim   // product of all three to tell the number of tokens 

        

        self.pos_encoding = PositionalEncoding(
            embed_size=embed_dim,
            spatial_size=(self.embed_H, self.embed_W),
            temporal_size=self.embed_T,
            rel_pos_embed=rel_pos_embed,
        )

       
        self.blocks = nn.ModuleList()
        current_size = [self.embed_T, self.embed_H, self.embed_W]  

        for cnf in block_setting:

            block = MultiscaleBlock(
                input_size=current_size,
                cnf=cnf,
                residual_pool=True,
                residual_with_cls_embed=True,
                rel_pos_embed=rel_pos_embed,
                proj_after_attn=True,
                dropout=dropout,
                stochastic_depth_prob=0.0,
                norm_layer=norm_layer,
            )

            self.blocks.append(block)

            
            t, h, w = current_size
            if len(cnf.stride_q) > 0:
                t = t // cnf.stride_q[0]
                h = h // cnf.stride_q[1]
                w = w // cnf.stride_q[2]

            current_size = [t, h, w]

       
        self.final_size = current_size

        self.norm=norm_layer(embed_dim)
        self.head=nn.Sequential(
            nn.Dropout(dropout),
            nn.Linear(embed_dim,num_classes)  # 96 dimensionss to 34 class classification 
        )


        def forward(self,x: torch.Tensor) -> torch.Tensor:

            x=self.patch_embed(x)      ## converting videos into patches 

            B,C,T,H,W =x.shape
            x=x.reshape(B,C,T*H*W).transpose(1,2)

            x=self.pos_encoding(x)

            thw=(T,H,W)

            for block in self.blocks:
                x,thw = block(x,thw)

            x=self.norm(x)

            cls_token=x[:,0]

            out=self.head(cls_token)

            return out
        
            
