In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class LocalAttention( nn.Module ):
  def __init__( self, input_dim, attention_dim, window_size ):
    super( LocalAttention, self ).__init__()
    self.input_projection = nn.Linear( input_dim, attention_dim )
    self.context_vector = nn.Parameter( torch.randn( attention_dim ), requires_grad=True )
    self.window_size = window_size

  def forward( self, inputs, center_position ):
    # inputs: ( batch_size, sequence_length, input_dim )
    # center_position: Central position of the window for local attention
    center_position = torch.tensor( center_position, device = inputs.device, dtype = torch.int64 )

    # Compute the limits of the window for local attention
    start_position = torch.clamp( center_position - self.window_size // 2, 0, inputs.size( 1 ) - 1 )
    end_position = torch.clamp( center_position + self.window_size // 2, 0, inputs.size( 1 ) - 1 )

    indexes = torch.arange( start_position, end_position + 1, device = inputs.device )
    # Extract the subsequence of the local attention window
    #local_inputs = inputs[:, start_position:end_position, :]
    local_inputs = inputs[:, indexes, :]

    # Project the local inputs to a dimension space attention vector
    projected_inputs = torch.tanh( self.input_projection( local_inputs ) )

    # Compute the attention scores
    attention_scores = torch.matmul( projected_inputs, self.context_vector )

    # Apply the softmax function to get the attention weights
    attention_weights = F.softmax( attention_scores, dim = -1 ).unsqueeze( -1 )

    # Compute the local attention vector as the weighted sum of the local inputs
    context_vector = torch.sum( attention_weights * local_inputs, dim = 1 )

    return context_vector, attention_weights
  
# Example
batch_size = 1 
sequence_length = 10
input_dim = 16
attention_dim = 32 
window_size = 5
center_position = 4

inputs = torch.randn( batch_size, sequence_length, input_dim )
attention_layer = LocalAttention( input_dim, attention_dim, window_size )
context_vector, attention_weights = attention_layer( inputs, center_position )

print( "Context Vector: ", context_vector )
print( "" )
print( "Attention Weights: ", attention_weights )

Context Vector:  tensor([[-0.7241, -0.8201,  1.9877,  0.4045, -0.6351, -0.6169,  0.3873, -0.9794,
         -0.2766, -1.3938,  1.6483,  0.0800,  0.4006, -1.9020, -0.2444, -0.3690]],
       grad_fn=<SumBackward1>)

Attention Weights:  tensor([[[3.5325e-04],
         [3.2313e-02],
         [2.3629e-01],
         [1.6559e-02],
         [7.1449e-01]]], grad_fn=<UnsqueezeBackward0>)
