<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Attention_Mechanisms.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)
        self.v = nn.Parameter(torch.rand(hidden_dim))

    def forward(self, hidden, encoder_outputs):
        batch_size = encoder_outputs.shape[0]
        src_len = encoder_outputs.shape[1]

        hidden = hidden.unsqueeze(1).repeat(1, src_len, 1)  # Repeat hidden state for each source length

        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.sum(self.v * energy, dim=2)
        return torch.softmax(attention, dim=1)

class Seq2Seq(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim, attention):
        super(Seq2Seq, self).__init__()
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(hidden_dim + input_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, output_dim)
        self.attention = attention

    def forward(self, src, trg):
        encoder_outputs, (hidden, _) = self.encoder(src)
        hidden = hidden[-1].unsqueeze(0)  # Extract the last hidden state and add batch dimension

        outputs = []
        for t in range(trg.size(1)):
            attention_weights = self.attention(hidden.squeeze(0), encoder_outputs)
            context = torch.sum(attention_weights.unsqueeze(2) * encoder_outputs, dim=1)
            decoder_input = torch.cat((trg[:, t].unsqueeze(1), context.unsqueeze(1)), dim=2)
            output, (hidden, _) = self.decoder(decoder_input, (hidden, torch.zeros_like(hidden)))
            output = self.fc(output.squeeze(1))
            outputs.append(output)

        return torch.stack(outputs, dim=1)

input_dim = 10
output_dim = 10
hidden_dim = 128
attention = Attention(hidden_dim)
model = Seq2Seq(input_dim, output_dim, hidden_dim, attention)

# Example usage
src = torch.randn(32, 15, input_dim)
trg = torch.randn(32, 10, input_dim)
output = model(src, trg)
print(output.shape)  # Output: torch.Size([32, 10, 10])