<a href="https://colab.research.google.com/github/OneFineStarstuff/State-of-the-Art/blob/main/Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
import torch.nn as nn

# Define the Attention mechanism class
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super(Attention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)  # Linear layer for encoder output
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)  # Linear layer for decoder hidden state
        self.full_att = nn.Linear(attention_dim, 1)  # Linear layer for attention score

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)  # Apply linear layer to encoder output
        att2 = self.decoder_att(decoder_hidden)  # Apply linear layer to decoder hidden state
        att = self.full_att(torch.tanh(att1 + att2.unsqueeze(1)))  # Compute attention score
        alpha = torch.softmax(att, dim=1)  # Compute attention weights
        context = (encoder_out * alpha).sum(dim=1)  # Compute context vector
        return context, alpha  # Return context vector and attention weights

# Example instantiation and forward pass
attention = Attention(encoder_dim=512, decoder_dim=256, attention_dim=128)
encoder_out = torch.randn(32, 196, 512)  # Example encoder output (batch_size=32, seq_len=196, encoder_dim=512)
decoder_hidden = torch.randn(32, 256)  # Example decoder hidden state (batch_size=32, decoder_dim=256)
context, alpha = attention(encoder_out, decoder_hidden)

# Print shapes of context vector and attention weights
print(context.shape)  # Should be [32, 512]
print(alpha.shape)  # Should be [32, 196, 1]