<a href="https://colab.research.google.com/github/Shaobin675/Path_in_ML_model_training/blob/main/103_Multi_Head_Attention_DecayingAttention_in_Transformers.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class DecayingAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        '''
        This is a Multi-Head Attention implementation.
        It includes a linear layer (self.qkv) to project the input into multiple heads,
        allowing the model to learn different decay patterns simultaneously.
        '''
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads

        #The linear layers allow the model to learn that certain features are good for searching (Query),
        #while others are good for being found (Key), and others are good for content (Value).
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        # Optional: learnable decay scale
        self.gamma = nn.Parameter(torch.ones(num_heads, 1, 1))

    def forward(self, x):
        B, L, C = x.shape

        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2] # [B, num_heads, L, head_dim]

        # 1. Compute Dot Product: q @ k^T
        # Scaling by sqrt(d_k) is standard for stability
        scores = (q @ k.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # 2. Create the Decay Bias b = |i - j|
        indices = torch.arange(L, device=x.device)
        # Broadcast indices to create a [L, L] distance matrix
        distance_matrix = torch.abs(indices.view(L, 1) - indices.view(1, L))

        # 3. Apply Negative Bias (Decay)
        # scores: [B, heads, L, L], distance_matrix: [L, L]
        # We multiply by gamma to control decay strength
        decay_bias = -self.gamma * distance_matrix.unsqueeze(0)
        scores = scores + decay_bias

        # 4. Softmax and Value aggregation
        attn = F.softmax(scores, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, L, C)
        return out


In [None]:
# Assume the model has a hidden dimension of 512
batch_size = 32
seq_len = 100
d_model = 512
num_heads = 8

# Create the module instance
decay_layer = DecayingAttention(embed_dim=d_model, num_heads=num_heads)

# Create dummy input data (Random noise)
x = torch.randn(batch_size, seq_len, d_model)

# The module handles Q, K, V projection internally
output = decay_layer(x)
# Result: torch.Size([32, 100, 512]) -> Matches input shape


