In [28]:
import torch
from torch import nn
import numpy as np
from typing import Callable, List, Optional
from timm.models.layers import DropPath

In [75]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads: int = 8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_mask=None):
        super().__init__()
        
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(in_features=dim, out_features=dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

    def forward(self, x): 
        B, N, C = x.shape # B = Batch size, N = Number of Tokens, C = Channels or Feature Dimensions
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) # (B, N, C) -> (B, N, 3, num_heads, head_dim) 
        qkv = (qkv.permute(2, 0, 3, 1, 4)) # (B, N, 3, num_heads, head_dim) -> (3, B, num_heads, N, head_dim)
        q = qkv[0] # (B, num_heads, N, head_dim)
        k = qkv[1] # (B, num_heads, N, head_dim)
        v = qkv[2] # (B, num_heads, N, head_dim)

        attn = q @ k.transpose(-2, -1) * self.scale  # Attention scores: (B, num_heads, N, N)
        attn = attn.softmax(dim=-1) # Softmax across tokens (last dim, N)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1,2) # (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
        x = x.reshape(B, N, C) # (B, N, num_heads, head_dim) -> (B, N, C)
        x = self.proj_drop(x)
        return x
    
attention = Attention(dim = 768)
x = torch.rand(6, 49, 768)
attention(x).shape

torch.Size([6, 49, 768])

In [3]:
class MLP(nn.Module):
    def __init__(self, in_features, out_features = None, hidden_features = None, act_layer = nn.GELU, drop = 0.0):
        super().__init__()

        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)
    
    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

x = torch.rand(3, 224)
mlp = MLP(in_features=224, hidden_features=5, out_features=200)
mlp(x).shape

torch.Size([3, 200])

In [81]:
class MultiheadAttention(nn.MultiheadAttention):
    def forward(self, x: torch.Tensor, attn_mask: torch.Tensor):
        return super().forward(x, x, x, attn_mask=attn_mask)[0]

attn_layer = MultiheadAttention(embed_dim=768, num_heads=8)
x = torch.rand(49, 5, 768)

attn_mask = torch.triu(torch.ones(49, 49), diagonal=1)
attn_mask = attn_mask.masked_fill(attn_mask == 1, float('-inf'))
attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))

output = attn_layer(x, attn_mask)
output.shape


torch.Size([49, 5, 768])

In [None]:
class BlockWithMasking(nn.Module):
    def __init__(self, dim, attn_target: Callable, mlp_ratio: int = 4, act_layer: Callable = nn.GELU, 
                 norm_layer: Callable = nn.LayerNorm, ffn_dropout_rate: float = 0.0, drop_path: float = 0.0, 
                 layer_scale_type = None, layer_scale_init_value: float = 1e-4):
        
        super().__init__()
        assert isinstance(attn_layer, nn.Module), f"attn_target should be a Callable. Otherwise attn_target is shared across blocks!"
        
        self.attn = attn_target
        if drop_path > 0.0: self.drop_path = DropPath(drop_prob=drop_path)
        else: self.drop_path = nn.Identity()

        self.norm1 = norm_layer(dim)
        mlp_hidden_dim = int(mlp_ratio * dim)
        self.mlp = MLP(
            in_features=dim,
            hidden_features=mlp_hidden_dim,
            out_features=dim,

            act_layer = act_layer,
            drop = ffn_dropout_rate
        )

        self.norm2 = norm_layer(dim)

        self.layer_scale_type = layer_scale_type
        if self.layer_scale_type is not None:
            assert self.layer_scale_type in ["per_channel", "scalar"], f"Found layer_scale_type to be {self.layer_scale_type}; should be either `per_channel`, `scalar`"
            if self.layer_scale_type == "per_channel":
                gamma_shape = [1, 1, dim]
            else: gamma_shape = [1, 1, 1]
            self.layer_scale_gamma1 = nn.Parameter(
                torch.ones(size = gamma_shape) * layer_scale_init_value,
                requires_grad=True
            )
            self.layer_scale_gamma2 = nn.Parameter(
                torch.ones(size = gamma_shape) * layer_scale_init_value,
                requires_grad=True
            )
        else: 
            self.layer_scale_gamma1 = nn.Identity()
            self.layer_scale_gamma2 = nn.Identity()

    def forward(self, x: torch.Tensor, attn_mask:torch.Tensor):
        x = self.norm1(x)
        if isinstance(self.attn, MultiheadAttention): # torch's MultiheadAttention processes (N, B, C)
            x = self.attn(x.permute(1,0,2), attn_mask)
            x = x.permute(1,0,2)
        else: 
            x = self.attn(x, attn_mask)

        x = x + self.drop_path(x)
        x = (x * self.layer_scale_gamma1)

        x = self.norm2(x)
        x = self.mlp(x)
        x = self.drop_path(x)
        x = x * self.layer_scale_gamma2
        
        return x

dim = 768
num_heads = 8
tokens = 49
batch_size = 2
x = torch.rand(batch_size, tokens, dim)
attn_target = MultiheadAttention(embed_dim=dim, num_heads=num_heads)

blockwithmaskig = BlockWithMasking(dim = dim, attn_target=attn_target, layer_scale_type="per_channel")

attn_layer = MultiheadAttention(embed_dim = dim, num_heads = num_heads)
attn_mask = torch.triu(torch.ones(tokens, tokens), diagonal=1)
attn_mask = attn_mask.masked_fill(attn_mask == 1, float('-inf'))
attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))

output = blockwithmaskig(x, attn_mask)
output.shape


torch.Size([2, 49, 768])

In [103]:
output.permute(1,0,2).shape

torch.Size([2, 49, 768])

In [84]:
attn_layer = MultiheadAttention(embed_dim=8, num_heads=8)
x = torch.rand(4, 1, 8)

attn_mask = torch.triu(torch.ones(4, 4), diagonal=1)
attn_mask = attn_mask.masked_fill(attn_mask == 1, float('-inf'))
attn_mask = attn_mask.masked_fill(attn_mask == 0, float(0.0))
output = attn_layer(x, attn_mask)
output

tensor([[[ 0.2907, -0.1710,  0.0714,  0.2417, -0.2110, -0.4061, -0.2756,
           0.0812]],

        [[ 0.1078, -0.1417,  0.0392,  0.2496, -0.2377, -0.4312, -0.1947,
           0.0795]],

        [[ 0.0292, -0.2555, -0.0110,  0.2481, -0.2067, -0.4454, -0.1864,
           0.0585]],

        [[ 0.0208, -0.2861,  0.0100,  0.2736, -0.2639, -0.4570, -0.2389,
           0.0675]]], grad_fn=<ViewBackward0>)