<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Attention_Mechanisms_in_Deep_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class Attention(nn.Module):
    """
    Scaled Dot-Product Attention Mechanism.
    Inputs:
        query: (batch_size, seq_len, input_dim)
        key: (batch_size, seq_len, input_dim)
        value: (batch_size, seq_len, input_dim)
    Outputs:
        context: (batch_size, seq_len, input_dim)
        attention_weights: (batch_size, seq_len, seq_len)
    """
    def __init__(self, input_dim):
        super(Attention, self).__init__()
        self.query = nn.Linear(input_dim, input_dim)
        self.key = nn.Linear(input_dim, input_dim)
        self.value = nn.Linear(input_dim, input_dim)
        self.softmax = nn.Softmax(dim=-1)

        # Optional: Initialize weights
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)

    def forward(self, query, key, value, mask=None):
        Q = self.query(query)
        K = self.key(key)
        V = self.value(value)

        # Scaled dot-product attention
        scores = torch.matmul(Q, K.transpose(-2, -1)) / torch.sqrt(torch.tensor(Q.size(-1), dtype=Q.dtype, device=Q.device))

        # Apply masking (if provided)
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))

        attention_weights = self.softmax(scores)  # Normalize scores
        context = torch.matmul(attention_weights, V)  # Compute context vectors
        return context, attention_weights

# Example usage
attention_layer = Attention(input_dim=256)
query = torch.randn(32, 10, 256)  # Batch size 32, sequence length 10, feature size 256
key = torch.randn(32, 10, 256)
value = torch.randn(32, 10, 256)
context, attention_weights = attention_layer(query, key, value)
print("Context shape:", context.shape)  # Should be (32, 10, 256)
print("Attention weights shape:", attention_weights.shape)  # Should be (32, 10, 10)