In [4]:
import torch
import torch.nn as nn

Attention class (https://www.geeksforgeeks.org/nlp/adding-attention-layer-to-a-bi-lstm)

In [None]:
class AttentionLayer(nn.Module):
    def __init__(self, input_dim, attention_dim):
        super(AttentionLayer, self).__init__()
        self.attention =  nn.Sequential(
            nn.Linear(input_dim, attention_dim),
            nn.Tanh(),
            nn.Linear(attention_dim, 1)
        )

    def forward(self, inputs):
        # inputs: (batch_size, seq_len, input_dim)
        scores = self.attention(inputs)  # (batch_size, seq_len, 1)
        weights = torch.sigmoid(scores)  # (batch_size, seq_len, 1) apply sigmoid to get weights between 0 and 1
        
        # return the weighted sequence
        weighted_seq = inputs * weights # (batch_size, seq_len, input_dim), if we need a single value per sequence, we can sum over seq_len dimension
        return weighted_seq, weights

Bidirectional LSTM class

In [None]:
class BLSTMWithAttention(nn.Module):
    def __init__(self, 
                 input_dim,                         # bert embedding dim
                 hidden_dim,                        # lstm hidden dim
                 output_dim,                        # number of classes
                 num_layers=1,                      # lstm layers
                 bidirectional=True,                # bidirectional lstm
                 dropout=0.5,                       # dropout rate
                 use_attention=True,                # use attention mechanism
                 attention_dim=128,                  # attention layer dimension
                 autoregress = False                  # autoregressive flag  
                 ):
        super(BLSTMWithAttention, self).__init__()
        self.use_attention = use_attention
        self.autoregress = autoregress

        input_size = input_dim + 2 if autoregress else input_dim
        self.blstm = nn.LSTM(input_size=input_size, 
                             hidden_size=hidden_dim, 
                             num_layers=num_layers, 
                             batch_first=True,
                             bidirectional=bidirectional)
        
        if self.use_attention:
            final_input_dim = hidden_dim * 2 if bidirectional else hidden_dim
            self.attention_layer = AttentionLayer(final_input_dim, attention_dim)
        
        self.fc = nn.Linear(hidden_dim * 2 if bidirectional else hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)  
    
    def forward(self, x, past_status=None, hidden=None):
        if self.autoregress :
            if past_status is None:
                raise ValueError("past_status must be provided for autoregressive mode")
            
            # concat the past status with current input
            x = torch.cat((past_status, x), dim=2)  # assuming past_status shape is (batch_size, past_seq_len, 2)
            out, next_state = self.blstm(x, hidden)
        else:
            out, weight = self.blstm(x)
        
        if self.use_attention:
            out, _ = self.attention_layer(out)
        out = self.dropout(out)
        prediction = self.fc(out)
        
        return prediction, next_state, weight if self.use_attention else None
        
        
        