[Notes/Attention Mechanism notes.pdf](<Notes/Attention Mechanism notes.pdf>)

In [1]:
import torch
import torch.nn as nn

In [None]:
# Attention mechanism class definition
class Attention(nn.Module):
    def __init__(self, hidden_dim):
        super(Attention, self).__init__()
        
        # Linear layer for calculating attention scores
        self.attn = nn.Linear(hidden_dim, hidden_dim)

    def forward(self, hidden_states):
        # Calculate attention scores (compatibility between hidden states)
        score = torch.softmax(torch.bmm(hidden_states, hidden_states.transpose(1, 2)), dim=-1)
        
        # Compute attention output as weighted sum of hidden states
        attention_output = torch.bmm(score, hidden_states)
        
        return attention_output

In [None]:
# BiLSTM with Attention mechanism class definition
class BiLSTMWithAttention(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim):
        super(BiLSTMWithAttention, self).__init__()
        
        # Bi-directional LSTM layer
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        
        # Attention mechanism to weight the LSTM outputs
        self.attention = Attention(hidden_dim * 2)  # Bi-LSTM doubles hidden_dim

        # Fully connected layer for final output
        self.fc = nn.Linear(hidden_dim * 2, output_dim)

    def forward(self, x):
        # Pass input through LSTM layer
        lstm_out, _ = self.lstm(x)
        
        # Apply attention on LSTM outputs
        attn_out = self.attention(lstm_out)
        
        # Return the output of the last timestep after applying attention
        return self.fc(attn_out[:, -1, :])  # Take the last timestep's output

In [4]:
# Define Model
model = BiLSTMWithAttention(input_dim=128, hidden_dim=64, output_dim=1)
print(model)

BiLSTMWithAttention(
  (lstm): LSTM(128, 64, batch_first=True, bidirectional=True)
  (attention): Attention(
    (attn): Linear(in_features=128, out_features=128, bias=True)
  )
  (fc): Linear(in_features=128, out_features=1, bias=True)
)
