In [1]:
# Implementation of Global Attention using only Numpy
# reference : https://arxiv.org/pdf/1508.04025.pdf
import numpy as np 

def softmax( x ):
  e_x = np.exp( x - np.max( x, axis = -1, keepdims = True ) )
  return e_x / np.sum( e_x, axis = -1, keepdims = True )

class GlobalAttentionManual:
  def __init__( self, input_dim, attention_dim ):
    # Initialization of input projection and context vector as Numpy arrays
    # as per the paper, the input projection is a matrix of shape ( input_dim, attention_dim )
    self.input_projection = np.random.randn( input_dim, attention_dim )
    self.context_vector = np.random.randn( attention_dim )

  def forward( self, inputs ):
    # inputs: ( batch_size, sequence_length, input_dim )

    # Projecting the inputs to the attention space
    projected_inputs = np.tanh( np.dot( inputs, self.input_projection ) ) # ( batch_size, sequence_length, attention_dim )

    # Compute the attention scores using dot product with the context vector
    attention_scores = np.dot( projected_inputs, self.context_vector ) # ( batch_size, sequence_length )

    # Apply softmax to get the attention weights
    attention_weights = softmax( attention_scores ).reshape( attention_scores.shape + ( 1, ) ) # ( batch_size, sequence_length, 1 )

    # Compute the context vector as the weighted sum of the inputs
    context_vector = np.sum( attention_weights * inputs, axis = 1 ) # ( batch_size, input_dim )

    return context_vector, attention_weights.squeeze( -1 )

# Example usage
input_dim = 128 # input dimension
attention_dim = 64 # attention space dimension
batch_size = 32 # batch size
sequence_length = 10 # sequence length

# Example input data 
inputs = np.random.rand( batch_size, sequence_length, input_dim )

# Instantiate the GlobalAttentionManual class and apply the forward pass
attention = GlobalAttentionManual( input_dim, attention_dim )
context_vector, attention_weights = attention.forward( inputs )

print( "Context Vector Shape: ")
print( context_vector.shape ) # ( batch_size, input_dim )
print( "" )
print( "Context Vector: ", context_vector )
print( "" )
print( "Attention Weights Shape: ")
print( attention_weights.shape ) # ( batch_size, sequence_length )
print( "" )
print( "Attention Weights: ", attention_weights )

Context Vector Shape: 
(32, 128)

Context Vector:  [[0.46813264 0.78216055 0.46192932 ... 0.39823554 0.78992913 0.76973062]
 [0.43649398 0.48718817 0.5085277  ... 0.77742532 0.48125993 0.52651599]
 [0.47534268 0.60960369 0.46723035 ... 0.28381167 0.6513187  0.39213449]
 ...
 [0.16946894 0.75410793 0.37903254 ... 0.71693853 0.88752933 0.5679254 ]
 [0.51978342 0.0029969  0.62563177 ... 0.09780153 0.59340388 0.14625027]
 [0.47141838 0.22208784 0.58006632 ... 0.54693116 0.42051949 0.54724907]]

Attention Weights Shape: 
(32, 10)

Attention Weights:  [[8.18980558e-06 1.47678585e-01 2.90603281e-05 1.21652008e-02
  9.63945962e-02 3.76444876e-07 7.19366339e-01 6.71324788e-03
  1.76442976e-02 1.07090551e-07]
 [2.31403755e-02 2.19038604e-02 3.14949327e-04 1.86260763e-01
  4.00526551e-01 2.66879992e-01 9.64238359e-02 1.69930892e-03
  2.29105545e-03 5.59309100e-04]
 [3.60357904e-02 2.23946990e-01 5.41057857e-02 1.58290713e-01
  3.22000716e-01 9.66439163e-04 1.22939363e-02 3.87683215e-02
  7.354175