<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Attention_Mechanisms_in_Deep_Learning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        self.attn = nn.Linear(hidden_dim * 2, hidden_dim)  # Combine hidden and encoder outputs
        self.score = nn.Linear(hidden_dim, 1, bias=False)  # Compute attention scores

    def forward(self, hidden, encoder_outputs):
        # Expand hidden to match encoder outputs
        hidden = hidden.unsqueeze(1).repeat(1, encoder_outputs.size(1), 1)  # (batch, seq_len, hidden_dim)
        combined = torch.cat((hidden, encoder_outputs), dim=2)  # (batch, seq_len, hidden_dim*2)

        # Compute attention scores
        energy = torch.tanh(self.attn(combined))  # (batch, seq_len, hidden_dim)
        scores = self.score(energy).squeeze(-1)  # (batch, seq_len)

        # Convert scores to probabilities
        attention_weights = torch.softmax(scores, dim=1)  # (batch, seq_len)

        # Weighted sum of encoder outputs
        context = torch.bmm(attention_weights.unsqueeze(1), encoder_outputs)  # (batch, 1, hidden_dim)
        return context.squeeze(1), attention_weights  # (batch, hidden_dim), (batch, seq_len)

class Seq2SeqWithAttention(nn.Module):
    def __init__(self, input_dim, output_dim, hidden_dim):
        super(Seq2SeqWithAttention, self).__init__()
        self.hidden_dim = hidden_dim
        self.encoder = nn.LSTM(input_dim, hidden_dim, batch_first=True)
        self.decoder = nn.LSTMCell(hidden_dim + output_dim, hidden_dim)
        self.attn = Attention(hidden_dim)
        self.fc_out = nn.Linear(hidden_dim, output_dim)

    def forward(self, input_seq, target_seq):
        # Encode the input sequence
        encoder_outputs, (hidden, cell) = self.encoder(input_seq)  # (batch, seq_len, hidden_dim)

        # Initialize decoder inputs
        batch_size = input_seq.size(0)
        target_len = target_seq.size(1)
        outputs = torch.zeros(batch_size, target_len, self.fc_out.out_features).to(input_seq.device)
        decoder_input = target_seq[:, 0, :]  # Assume the first token as <SOS>

        # Decode one step at a time
        for t in range(1, target_len):
            context, _ = self.attn(hidden[-1], encoder_outputs)  # (batch, hidden_dim)
            decoder_input_combined = torch.cat((decoder_input, context), dim=1)  # (batch, hidden_dim + output_dim)
            hidden, cell = self.decoder(decoder_input_combined, (hidden[-1], cell[-1]))
            outputs[:, t, :] = self.fc_out(hidden)  # Store output for this timestep
            decoder_input = target_seq[:, t, :]  # Teacher forcing: Use the next target token as input

        return outputs