In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class GlobalAttention( nn.Module ):
  def __init__( self, input_dim, attention_dim ):
    super( GlobalAttention, self ).__init__()
    self.input_projection = nn.Linear( input_dim, attention_dim )
    self.context_vector = nn.Parameter( torch.randn( attention_dim ), requires_grad=True )

  def forward( self, inputs ):
    # inputs: ( batch_size, sequence_len, input_dim )

    # project inputs to attention space
    projected_inputs = torch.tanh( self.input_projection( inputs ) ) # ( batch_size, sequence_len, attention_dim )

    # Compute attention scores using the dot product between the projected inputs and the learnable context vector
    attention_scores = torch.matmul( projected_inputs, self.context_vector ) # ( batch_size, sequence_len )

    # Apply softmax to obtain attention weights
    attention_weights = F.softmax( attention_scores, dim = -1 ).unsqueeze( - 1 ) # ( batch_size, sequence_len, 1 )

    # Compute the context vector as the weighted sum of the inputs
    context_vector = torch.sum( attention_weights * inputs, dim = 1 ) # ( batch_size, input_dim )

    return context_vector, attention_weights

# Example
batch_size = 1 
sequence_length = 5
input_dim = 10
attention_dim = 20 

inputs = torch.randn( batch_size, sequence_length, input_dim )
attention_layer = GlobalAttention( input_dim, attention_dim )
context_vector, attention_weights = attention_layer( inputs )

print( "Context Vector: ", context_vector )
print( "" )
print( "Attention Weights: ", attention_weights )

Context Vector:  tensor([[ 0.2647, -0.0531, -0.8628,  0.9418, -0.3954,  0.4831,  0.0982,  0.5392,
          0.4690,  0.3095]], grad_fn=<SumBackward1>)

Attention Weights:  tensor([[[0.4770],
         [0.3783],
         [0.0115],
         [0.1059],
         [0.0273]]], grad_fn=<UnsqueezeBackward0>)
