# Basic Attention Mechanism

The next code implements a simple and basic attention mechanism using  
PyTorch as the main library. The code is based on the following paper: 
[Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/abs/1409.0473) by Bahdanau et al.  
This mechanism can be described as follows:  
  
1. Initialization (`__init__`): Constructor method that initializes the attention  
mechanism using two linear layers.  
nn.Linear the first layer is used to transform the input dimension to the  
intermediate attention dimension.  
The second layer context_vector_layer is used to reduce the intermediate attention  
dimension to 1 to obtain a unique attention score per element in the input without  
a bias (bias=False).  
2. Forward (`forward`): Method that computes the attention scores using the  
input tensor received with shape (batch_size, sequence_len, input_dim)  
- Step 1: 

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class BasicAttention( nn.Module ):
  def __init__( self, input_dim, attention_dim ):
    super( BasicAttention, self ).__init__()
    self.attention_weights_layer = nn.Linear( input_dim, attention_dim )
    self.context_vector_layer = nn.Linear( attention_dim, 1, bias = False )

  def forward( self, inputs ):
    # inputs: ( batch_size, sequence_length, input_dim )

    # Step 1: Calculate attention weights 
    attention_scores = self.attention_weights_layer( inputs ) # Transform inputs to attention space ( batch_size, sequence_length, attention_dim )
    attention_scores = torch.tanh( attention_scores ) # Apply non-linearity function 
    attention_weights = self.context_vector_layer( attention_scores ) # Calculates attention score for each input 
    attention_weights = torch.softmax( attention_weights.squeeze( -1 ), dim = -1 ).unsqueeze( -1 ) # Apply softmax to get normalized attention weights and prepares for multiplication

    # Step 2: Apply attention weights to inputs to get context vector
    context_vector = torch.sum( attention_weights * inputs, dim = 1 ) # Multiply inputs by attention weights and sum over sequence length to get context vector

    return context_vector, attention_weights
  
# Example
batch_size = 1 
sequence_length = 5
input_dim = 10
attention_dim = 20 
  
# input_dim = 128 # Dimension of the input data
# attention_dim = 64 # Dimension of the attention space
# batch_size = 32 # Batch size
# sequence_length = 10 # Sequence length

inputs = torch.randn( batch_size, sequence_length, input_dim )
attention_layer = BasicAttention( input_dim, attention_dim )
context_vector, attention_weights = attention_layer( inputs )

print( "Context Vector: ", context_vector )
print( "" )
print( "Attention Weights: ", attention_weights )

Context Vector:  tensor([[ 0.1510, -0.0671,  0.7450,  ...,  0.2707,  0.2550, -0.1792],
        [ 0.5306,  0.6922, -0.4998,  ...,  0.1630,  0.5041, -0.3247],
        [ 0.0834,  0.1286,  0.3469,  ..., -0.1087,  0.0986,  0.3567],
        ...,
        [ 0.0211,  0.4529,  0.1570,  ..., -0.1851, -0.0764,  0.0087],
        [-0.0842, -0.0134,  0.0250,  ..., -0.0995,  0.0669, -0.3522],
        [ 0.0158, -0.0941,  0.3257,  ...,  0.0744,  0.0198,  0.3272]],
       grad_fn=<SumBackward1>)

Attention Weights:  tensor([[[0.0925],
         [0.1648],
         [0.0810],
         [0.1057],
         [0.1073],
         [0.0779],
         [0.0686],
         [0.0884],
         [0.1309],
         [0.0830]],

        [[0.1150],
         [0.0758],
         [0.1068],
         [0.0942],
         [0.1265],
         [0.0844],
         [0.0900],
         [0.1485],
         [0.0642],
         [0.0945]],

        [[0.0995],
         [0.0634],
         [0.0989],
         [0.1063],
         [0.0832],
         [0.1068],

In [1]:
# Updated Code for Soft Attention (Basic attention)
# The following code implements a simple and basic attention mechanism using PyTorch as the main library.
# The code is based on the following paper: Neural Machine Translation by Jointly Learning to Align and Translate by Bahdanau et al.
# 10 - 04 - 2024

import torch
import torch.nn as nn
import torch.nn.functional as F

class AttentionNetwork(nn.Module):
    def __init__(self, annotation_dim, hidden_dim, attention_dim):
        """
        Initialize the attention network.
        
        Parameters:
        - annotation_dim: Dimension of the input annotations (a)
        - hidden_dim: Dimension of the decoder's hidden state (h_0)
        - attention_dim: Dimension for the intermediate representation in the attention mechanism
        """
        super(AttentionNetwork, self).__init__()
        
        # Define the weight matrices
        self.W1 = nn.Linear(annotation_dim, attention_dim, bias=False)  # Applies W1 to annotation a
        self.W2 = nn.Linear(hidden_dim, attention_dim, bias=False)  # Applies W2 to hidden state h0
        self.V = nn.Linear(attention_dim, 1, bias=False)  # Applies V to the tanh output

    def forward(self, annotations, hidden):
        """
        Forward pass through the attention network.
        
        Parameters:
        - annotations: Tensor containing annotations from the encoder (shape: batch_size x seq_len x annotation_dim)
        - hidden: Tensor containing the current hidden state of the decoder (shape: batch_size x hidden_dim)
        
        Returns:
        - attention_weights: Tensor containing the attention weights (shape: batch_size x seq_len)
        """
        # Expand hidden state to match the dimensions of annotations for element-wise addition
        hidden = hidden.unsqueeze(1).expand_as(annotations)
        
        # Compute the attention scores
        attn_scores = self.V(torch.tanh(self.W1(annotations) + self.W2(hidden)))
        
        # Squeeze the last dimension and apply softmax to get attention weights
        attention_weights = F.softmax(attn_scores.squeeze(-1), dim=-1)
        
        return attention_weights


In [2]:
import torch

# Assuming these dimensions for the sake of example
annotation_dim = 256  # Dimension of the encoder's output annotations
hidden_dim = 256      # Dimension of the decoder's hidden state
attention_dim = 128   # Intermediate attention representation dimension

# Initialize the attention network
attention_network = AttentionNetwork(annotation_dim, hidden_dim, attention_dim)

# Example annotations and hidden state (randomly generated for demonstration)
batch_size = 2
seq_len = 10
annotations = torch.randn(batch_size, seq_len, annotation_dim)
hidden = torch.randn(batch_size, hidden_dim)

# Forward pass through the attention network
attention_weights = attention_network(annotations, hidden)

print("Attention weights shape:", attention_weights.shape)  # Expected shape: (batch_size, seq_len)


Attention weights shape: torch.Size([2, 10])
