<a href="https://colab.research.google.com/github/Kushagra481/Attention_in_Transformers/blob/main/Addition_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class AdditiveAttention(nn.Module):
    def __init__(self, embed_dim, hidden_dim):
        super().__init__()
        self.q_proj = nn.Linear(embed_dim, hidden_dim)
        self.k_proj = nn.Linear(embed_dim, hidden_dim)
        self.v_proj = nn.Linear(embed_dim, hidden_dim)
        self.add_mlp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.Tanh(),
            nn.Linear(hidden_dim, 1)
        )

    def forward(self, x, mask=None):
        # x: (B, T, C)
        Q = self.q_proj(x)  # (B, T, H)
        K = self.k_proj(x)  # (B, T, H)
        V = self.v_proj(x)  # (B, T, H)

        B, T, H = Q.shape

        # Expand to compute Q_i + K_j
        Q_exp = Q.unsqueeze(2)  # (B, T, 1, H)
        K_exp = K.unsqueeze(1)  # (B, 1, T, H)
        add_input = Q_exp + K_exp  # (B, T, T, H)

        # Pass through additive kernel MLP
        e = self.add_mlp(add_input).squeeze(-1)  # (B, T, T)

        if mask is not None:
            e = e.masked_fill(mask == 0, -1e9)

        att = F.softmax(e, dim=-1)  # (B, T, T)
        out = torch.matmul(att, V)  # (B, T, H)

        return out


In [2]:
dummy = torch.randn(1, 10, 64)  # (B, T, C)
out = AdditiveAttention(embed_dim=64, hidden_dim=64)(dummy)
print(out.shape)  # (1, 10, 64)


torch.Size([1, 10, 64])
