<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Attention_Mechanism.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class Attention(nn.Module):
    def __init__(self, encoder_hidden_dim, decoder_hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(encoder_hidden_dim + decoder_hidden_dim, decoder_hidden_dim)
        self.v = nn.Parameter(torch.rand(decoder_hidden_dim))

    def forward(self, hidden, encoder_outputs):
        # hidden: [batch_size, decoder_hidden_dim]
        # encoder_outputs: [batch_size, seq_len, encoder_hidden_dim]
        seq_len = encoder_outputs.size(1)
        hidden = hidden.unsqueeze(1).repeat(1, seq_len, 1)
        energy = torch.tanh(self.attn(torch.cat((hidden, encoder_outputs), dim=2)))
        attention = torch.einsum("bij,j->bi", energy, self.v)
        return F.softmax(attention, dim=1)

class Seq2Seq(nn.Module):
    def __init__(self, input_dim, output_dim, encoder_hidden_dim, decoder_hidden_dim):
        super(Seq2Seq, self).__init__()
        self.encoder = nn.LSTM(input_dim, encoder_hidden_dim, batch_first=True)
        self.decoder = nn.LSTM(decoder_hidden_dim + encoder_hidden_dim, decoder_hidden_dim, batch_first=True)
        self.fc_out = nn.Linear(decoder_hidden_dim, output_dim)
        self.attention = Attention(encoder_hidden_dim, decoder_hidden_dim)

    def forward(self, src, trg):
        encoder_outputs, (hidden, cell) = self.encoder(src)
        trg_len = trg.size(1)
        outputs = torch.zeros(src.size(0), trg_len, trg.size(2)).to(src.device)

        hidden = hidden[-1].unsqueeze(0)  # Reshape hidden state

        for t in range(trg_len):
            context = self.attention(hidden.squeeze(0), encoder_outputs).unsqueeze(1)
            rnn_input = torch.cat((trg[:, t:t+1, :], context), dim=2)
            output, (hidden, cell) = self.decoder(rnn_input, (hidden, cell))
            outputs[:, t:t+1, :] = self.fc_out(output)

        return outputs

# Example usage
input_dim = 10
output_dim = 10
encoder_hidden_dim = 64
decoder_hidden_dim = 64

model = Seq2Seq(input_dim, output_dim, encoder_hidden_dim, decoder_hidden_dim)
src = torch.rand(32, 20, input_dim)  # [batch_size, seq_len, input_dim]
trg = torch.rand(32, 20, output_dim)  # [batch_size, seq_len, output_dim]
output = model(src, trg)
print(output.shape)  # [batch_size, seq_len, output_dim]