In [1]:
import torch
from torch import nn
import torch.nn.functional as F

### Method 1

In [2]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size, debugging=False):
        super(BahdanauAttention, self).__init__()
        self.debugging = debugging
        self.hidden_size = hidden_size
        self.W = nn.Linear(hidden_size, hidden_size)
        self.U = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)

    def forward(self, decoder_hidden, encoder_output):
        # decoder hidden shape = (batch_size, hidden_size)
        # encoder output shape = (seq_len, batch_size, hidden_size)
       
        #add an extra dimension to match the query
        decoder_hidden = decoder_hidden.unsqueeze(0)
        #compute attention/alignment score
        score = self.V(F.tanh(self.W(decoder_hidden) + self.U(encoder_output)))
        #compute attention weights
        weights = F.softmax(score)
        #compute the context vector
        context = torch.sum(weights * encoder_output, dim=0)

        if self.debugging:
            print("Encoder Output Shape", encoder_output.shape)
            print("Decoder Hidden Shape", decoder_hidden.shape)
            print("Attention Weights Shape", weights.shape)
            print("Attention Context Shape", context.shape)

        return context, weights
    
hidden_size = 512
attention = BahdanauAttention(hidden_size)
decoder_hidden = torch.zeros([32, hidden_size])
encoder_output = torch.zeros([21, 32, hidden_size])
context, weight = attention(decoder_hidden, encoder_output)

  weights = F.softmax(score)


### Method 2

In [6]:
class BahdanauAttention(nn.Module):
    def __init__(self, hidden_size, debugging=False):
        super(BahdanauAttention, self).__init__()
        self.debugging = debugging
        self.hidden_size = hidden_size
        self.W = nn.Linear(hidden_size, hidden_size)
        self.U = nn.Linear(hidden_size, hidden_size)
        self.V = nn.Linear(hidden_size, 1)

    def forward(self, decoder_hidden, encoder_output):
        # decoder_hidden shape = (batch_size, hidden_size)
        # encoder_output shape = (seq_len, batch_size, hidden_size)
        print("Encoder Output Shape (Attention)", encoder_output.shape)
        # Expand decoder hidden state to match the shape of encoder output
        decoder_hidden = decoder_hidden.unsqueeze(1).repeat(1, encoder_output.shape[0], 1).permute(1, 0, 2)
        print("Decoder Hidden Shape (Attention)", decoder_hidden.shape)
        # Compute attention scores
        scores = self.V(torch.tanh(self.W(decoder_hidden) + self.U(encoder_output)))
        print("Attention Score Shape", scores.shape)
        # Compute attention weights
        weights = F.softmax(scores)
        print("Attention Weights Shape", weights.shape)
        # Compute the context vector
        context = torch.sum(weights * encoder_output, dim=0)
        print("Attention Context Shape", context.shape)

        if self.debugging:
            print("Encoder Output Shape:", encoder_output.shape)
            print("Decoder Hidden Shape:", decoder_hidden.shape)
            print("Attention Context Shape:", context.shape)
            print("Attention Weights Shape:", weights.shape)

        return context, weights
    
hidden_size = 512
attention = BahdanauAttention(hidden_size)
decoder_hidden = torch.zeros([32, hidden_size])
encoder_output = torch.zeros([21, 32, hidden_size])
context, weight = attention(decoder_hidden, encoder_output)

Encoder Output Shape (Attention) torch.Size([21, 32, 512])
Decoder Hidden Shape (Attention) torch.Size([21, 32, 512])
Attention Score Shape torch.Size([21, 32, 1])
Attention Weights Shape torch.Size([21, 32, 1])
Attention Context Shape torch.Size([32, 512])


  weights = F.softmax(scores)
