<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Attention_Mechanisms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Attention, self).__init__()
        self.query = nn.Linear(input_dim, output_dim)
        self.key = nn.Linear(input_dim, output_dim)
        self.value = nn.Linear(input_dim, output_dim)

        # Custom initialization
        nn.init.xavier_uniform_(self.query.weight)
        nn.init.xavier_uniform_(self.key.weight)
        nn.init.xavier_uniform_(self.value.weight)

    def forward(self, x, mask=None):
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # Scaled dot-product attention
        scores = torch.matmul(q, k.transpose(-2, -1)) / torch.sqrt(torch.tensor(q.size(-1), device=x.device, dtype=torch.float32))
        if mask is not None:
            scores = scores.masked_fill(mask == 0, float('-inf'))  # Apply mask
        weights = F.softmax(scores, dim=-1)  # Normalize scores
        output = torch.matmul(weights, v)
        return output

# Example usage
input_seq = torch.rand(32, 10, 64)  # Batch size 32, sequence length 10, feature dimension 64
attention = Attention(input_dim=64, output_dim=64)
output_seq = attention(input_seq)

print("Output shape:", output_seq.shape)  # Expected: (32, 10, 64)