<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Transformers_and_Attention_Mechanisms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

# Positional Encoding
class PositionalEncoding(nn.Module):
    def __init__(self, embed_size, max_len=5000):
        super(PositionalEncoding, self).__init__()
        pe = torch.zeros(max_len, embed_size)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, embed_size, 2) * (-torch.log(torch.tensor(10000.0)) / embed_size))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        return x + self.pe[:x.size(1), :]

# Self-Attention Mechanism
class SelfAttention(nn.Module):
    def __init__(self, embed_size, heads):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size
        self.heads = heads
        self.head_dim = embed_size // heads

        assert self.head_dim * heads == embed_size, "Embed size needs to be divisible by heads"

        self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)
        self.fc_out = nn.Linear(heads * self.head_dim, embed_size)

    def forward(self, values, keys, query, mask):
        N = query.shape[0]
        value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]

        # Split embedding into heads
        values = values.view(N, value_len, self.heads, self.head_dim)
        keys = keys.view(N, key_len, self.heads, self.head_dim)
        queries = query.view(N, query_len, self.heads, self.head_dim)

        # Apply linear transformations
        values = self.values(values)
        keys = self.keys(keys)
        queries = self.queries(queries)

        # Energy computation
        energy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])  # Dot product attention
        if mask is not None:
            mask = mask.unsqueeze(1).unsqueeze(2)  # Adjust mask for broadcasting
            energy = energy.masked_fill(mask == 0, float("-1e20"))

        # Apply softmax
        attention = F.softmax(energy / (self.head_dim ** 0.5), dim=3)

        # Weighted sum of values
        out = torch.einsum("nhqk,nvhd->nqhd", [attention, values]).reshape(
            N, query_len, self.heads * self.head_dim
        )

        # Final linear transformation
        out = self.fc_out(out)
        return out

# Example of usage
if __name__ == "__main__":
    embed_size = 64
    heads = 8
    attention = SelfAttention(embed_size, heads)

    batch_size = 32
    seq_len = 50
    values = torch.rand(batch_size, seq_len, embed_size)
    keys = torch.rand(batch_size, seq_len, embed_size)
    queries = torch.rand(batch_size, seq_len, embed_size)
    mask = torch.ones(batch_size, seq_len)  # Example mask (no padding)

    output = attention(values, keys, queries, mask)
    print("Output shape:", output.shape)  # Expected: [batch_size, seq_len, embed_size]