# Bahdanau Attention (Additive Attention)

One of the motivations behind Bahdanau Attention approach was the use of a fixed-length context vector in the basic encoderâ€“decoder approach. This limitation makes the basic encoder-decoder approach to underperform with long sentences. In basic encoder-decoder approach, the last element of a sequence contains the memory of all the previous elements and thus form a fixed-dimension context vector. But in case of Bahdanau attention approach:

- First, we initialize the Decoder states by using the last states of the Encoder as usual
- Then at each decoding time step:
    - We use Encoder's all hidden states and the previous Decoder's output to calculate a Context Vector by applying an Attention Mechanism
    - Lastly, we concatenate the Context Vector with the previous Decoder's output to create the input to the decoder.

In [1]:
import torch.nn as nn

In [2]:
class Encoder(nn.Module):
    
    def __init__(self, n_vocab, n_speaker, n_tags, n_embed_text, n_embed_speaker, n_embed_tags, n_hidden_enc, n_layers, n_hidden_dec, dropout):
        super().__init__()
        self.n_layers = n_layers
        self.n_hidden_enc = n_hidden_enc
        self.text_embedding = nn.Embedding(n_vocab, n_embed_text)
        self.tag_embedding = nn.Embedding(n_tags, n_embed_tags)
        self.speaker_embedding = nn.Embedding(n_speaker, n_embed_speaker)
        
        self.bdGRU = nn.GRU(n_embed_text + n_embed_speaker + n_embed_tags, 
                            n_hidden_enc, 
                            n_layers,
                            bidirectional=True, batch_first=True,
                            dropout=dropout)
        
        self.fc = nn.Linear(n_hidden_enc, n_hidden_dec)
        
        self.dropout = nn.Dropout(dropout)
        
        
    def forward(self, x, s, t, h):
        ''' x, s, t: [b, seq_len] '''
        
        embedded = torch.cat((self.text_embedding(x), self.speaker_embedding(s)), dim=2)  
        embedded = torch.cat((embedded, self.tag_embedding(t)), dim=2) #[b, seq_len, n_embed_text + _tags + _speaker]
        
        last_layer_enc, last_h_enc = self.bdGRU(embedded, h)
        
        last_layer_enc = self.dropout(last_layer_enc) #[b, seq_len, 2*n_hidden_enc]
        last_h_enc = self.dropout(last_h_enc.permute(1, 0, 2))         #[b, 2*n_layer, n_hidden_enc]

        # this is to be decoder's 1st hidden state
        for i in range(0, -self.n_layers, -2):
            if i==0:
                last_h_enc_sum = last_h_enc[:, i-1, :] + last_h_enc[:, i-2, :] #[b, n_hidden_enc]
                last_h_enc_sum = last_h_enc_sum.unsqueeze(1)
            else:
                last_h_enc_sum = torch.cat((last_h_enc_sum, 
                                            (last_h_enc[:, i-1, :] + last_h_enc[:, i-2, :]).unsqueeze(1)), dim=1)
        

        last_h_enc_sum = self.fc(last_h_enc_sum) #[b, n_layers, n_hidden_dec]

        return last_layer_enc, last_h_enc_sum   #[b, seq_len, 2*n_hidden_enc], [b, n_layer, n_hidden_enc]
        
    
    def init_hidden(self, batch_size):
        device = "cuda" if torch.cuda.is_available() else "cpu"
        weight = next(self.parameters()).data
    
        h = weight.new(2*self.n_layers, batch_size, self.n_hidden_enc).zero_().to(device)
        
        return h

In [3]:
class Attention(nn.Module):
    
    def __init__(self, n_hidden_enc, n_hidden_dec):
        
        super().__init__()
        
        self.h_hidden_enc = n_hidden_enc
        self.h_hidden_dec = n_hidden_dec
        
        self.W = nn.Linear(2*n_hidden_enc + n_hidden_dec, n_hidden_dec, bias=False) 
        self.V = nn.Parameter(torch.rand(n_hidden_dec))
        
    
    def forward(self, hidden_dec, last_layer_enc):
        ''' 
            PARAMS:           
                hidden_dec:     [b, n_layers, n_hidden_dec]    (1st hidden_dec = encoder's last_h's last layer)                 
                last_layer_enc: [b, seq_len, n_hidden_enc * 2] 
            
            RETURN:
                att_weights:    [b, src_seq_len] 
        '''
        
        batch_size = last_layer_enc.size(0)
        src_seq_len = last_layer_enc.size(1)

        hidden_dec = hidden_dec[:, -1, :].unsqueeze(1).repeat(1, src_seq_len, 1)         #[b, src_seq_len, n_hidden_dec]

        tanh_W_s_h = torch.tanh(self.W(torch.cat((hidden_dec, last_layer_enc), dim=2)))  #[b, src_seq_len, n_hidden_dec]
        tanh_W_s_h = tanh_W_s_h.permute(0, 2, 1)       #[b, n_hidde_dec, seq_len]
        
        V = self.V.repeat(batch_size, 1).unsqueeze(1)  #[b, 1, n_hidden_dec]
        e = torch.bmm(V, tanh_W_s_h).squeeze(1)        #[b, seq_len]
        
        att_weights = F.softmax(e, dim=1)              #[b, src_seq_len]
        
        return 

In [4]:
class Decoder(nn.Module):
    
    def __init__(self, n_output, n_embed, n_hidden_enc, n_hidden_dec, n_layers, dropout):
        super().__init__()

        self.n_output = n_output
        self.embedding = nn.Embedding(n_output, n_embed)
        
        self.GRU = nn.GRU(n_embed + n_hidden_enc*2, n_hidden_dec, n_layers,
                          bidirectional=False, batch_first=True,
                          dropout=dropout) 
        
        self.attention = Attention(n_hidden_enc, n_hidden_dec)
        
        self.fc_final = nn.Linear(n_embed + n_hidden_enc*2 + n_hidden_dec, n_output)
        
        self.dropout = nn.Dropout(dropout)
        
    
    def forward(self, target, hidden_dec, last_layer_enc):
        ''' 
            target:         [b] 
            hidden_dec:     [b, n_layers, n_hidden_dec]      (1st hidden_dec = encoder's last_h's last layer)
            last_layer_enc: [b, seq_len, n_hidden_enc * 2]   (* 2 since bi-directional)  
            --> not to be confused with encoder's last hidden state. last_layer_enc is ALL hidden states (of timesteps 0,1,...,t-1) of the last LAYER.
        '''
        
        ######################## 1. TARGET EMBEDDINGS ######################### 
        target = target.unsqueeze(1) #[b, 1] : since n_output = 1
        embedded_trg = self.embedding(target)  #[b, 1, n_embed]


        ################## 2. CALCULATE ATTENTION WEIGHTS #####################      
        att_weights = self.attention(hidden_dec, last_layer_enc)  #[b, input_seq_len]
        att_weights = att_weights.unsqueeze(1)   #[b, 1, input_seq_len]


        ###################### 3. CALCULATE WEIGHTED SUM ######################
        weighted_sum = torch.bmm(att_weights, last_layer_enc) #[b, 1, n_hidden_enc*2]
        
        
        ############################# 4. GRU LAYER ############################
        gru_input = torch.cat((embedded_trg, weighted_sum), dim=2) #[b, 1, n_embed + n_hidden_enc*2]

        last_layer_dec, last_h_dec = self.GRU(gru_input, hidden_dec.permute(1, 0, 2))
        # last_layer_dec: [b, trg_seq_len, n_hidden_dec]
        last_h_dec = last_h_dec.permute(1, 0, 2)  #[b, n_layers, n_hidden_dec]
        
        
        ########################### 5. FINAL FC LAYER #########################
        fc_in = torch.cat((embedded_trg.squeeze(1),           #[b, n_embed]
                           weighted_sum.squeeze(1),           #[b, n_hidden_enc*2]
                           last_layer_dec.squeeze(1)), dim=1) #[b, n_hidden_dec]                           
                           
       
        output = self.fc_final(fc_in) #[b, n_output]
        
        return output, last_h_dec, att_weights

In [5]:
class Seq2Seq(nn.Module):
    
    def __init__(self, n_vocab, n_speaker, n_tags, 
                       n_embed_text, n_embed_speaker, n_embed_tags, n_embed_dec, 
                       n_hidden_enc, n_hidden_dec, n_layers, 
                       n_output, dropout):
        
        super().__init__()
        
        self.encoder = Encoder(n_vocab=n_vocab, n_speaker=n_speaker, n_tags=n_tags,
                               n_embed_text=n_embed_text, n_embed_speaker=n_embed_speaker, n_embed_tags=n_embed_tags,
                               n_hidden_enc=n_hidden_enc, n_layers=n_layers, n_hidden_dec=n_hidden_dec, 
                               dropout=dropout)
        
        self.decoder = Decoder(n_output=n_output, 
                               n_embed=n_embed_dec, 
                               n_hidden_enc=n_hidden_enc, n_hidden_dec=n_hidden_dec, n_layers=n_layers, 
                               dropout=dropout)
        
        
    def forward(self, inputs, speakers, tags, targets, tf_ratio=0.5):
        ''' inputs:  [b, input_seq_len(200)]
            targets: [b, input_seq_len(200)]'''
            
        ###########################  1. ENCODER  ##############################
        h = self.encoder.init_hidden(inputs.size(0))
        
        last_layer_enc, last_h_enc = self.encoder(inputs, speakers, tags, h)              
        
            
        ###########################  2. DECODER  ##############################
        hidden_dec = last_h_enc       #[b, n_layers, n_hidden_dec]
        
        trg_seq_len = targets.size(1)
        
        b = inputs.size(0)
        n_output = self.decoder.n_output
        output = targets[:, 0]
        
        outputs = torch.zeros(b, n_output, trg_seq_len).cuda()
        
        for t in range(1, trg_seq_len, 1):
            output, hidden_dec = self.decoder(output, hidden_dec, last_layer_enc)
            outputs[:, :, t] = output #output: [b, n_output]

            if random.random() < tf_ratio:
                output = targets[:, t]
                
            else:
                output = output.max(dim=1)[1]
        
        return outputs  #[b, n_output, trg_seq_len]

## References

- [The Power of Attention in Deep Learning](https://www.youtube.com/watch?v=Qu81irGlR-0)
- [Neural Machine Translation by Jointly Learning to Align and Translate](https://arxiv.org/pdf/1409.0473.pdf)
- [The Bahdanau Attention Mechanism](https://machinelearningmastery.com/the-bahdanau-attention-mechanism/#:~:text=The%20Bahdanau%20attention%20was%20proposed,mechanism%20for%20neural%20machine%20translation.)