In [3]:
import torch
from torch import nn

In [35]:
class Attention(nn.Module):
    def __init__(self, dim, num_heads: int = 8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
        super().__init__()
        
        self.num_heads = num_heads
        head_dim = dim // num_heads

        self.scale = qk_scale or head_dim**-0.5

        self.qkv = nn.Linear(in_features=dim, out_features=dim*3, bias=qkv_bias)
        self.attn_drop = nn.Dropout(0.0)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(0.0)

    def forward(self, x): 
        B, N, C = x.shape # B = Batch size, N = Number of Tokens, C = Channels or Feature Dimensions
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) # (B, N, C) -> (B, N, 3, num_heads, head_dim) 
        qkv = (qkv.permute(2, 0, 3, 1, 4)) # (B, N, 3, num_heads, head_dim) -> (3, B, num_heads, N, head_dim)
        q = qkv[0] # (B, num_heads, N, head_dim)
        k = qkv[1] # (B, num_heads, N, head_dim)
        v = qkv[2] # (B, num_heads, N, head_dim)

        attn = q @ k.transpose(-2, -1) * self.scale  # Attention scores: (B, num_heads, N, N)
        attn = attn.softmax(dim=-1) # Softmax across tokens (last dim, N)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1,2) # (B, num_heads, N, head_dim) -> (B, N, num_heads, head_dim)
        x = x.reshape(B, N, C) # (B, N, num_heads, head_dim) -> (B, N, C)
        x = self.proj_drop(x)
        return x
attention = Attention(dim = 768)
x = torch.rand(6, 49, 768)
attention(x).shape

torch.Size([6, 49, 768])