<a href="https://colab.research.google.com/github/AkHiLdEvGoD/DeepLearning-Algorithms/blob/main/Bahdanau_Attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [43]:
import torch
import torch.nn as nn

batch_size = B

src_len = S (length of input sentence)

trg_len = T (length of target sentence)

emb_dim = E

enc_hidden_dim = H

dec_hidden_dim = D

output_dim = V (vocab size of target)

In [44]:
class Encoder(nn.Module):                                                       # Shapes
  def __init__(self,input_dims,embd_dims,hidden_dims,dropout):
    super().__init__()
    self.embedding = nn.Embedding(input_dims,embd_dims)
    self.gru = nn.GRU(embd_dims,hidden_dims,bidirectional=True,batch_first=True)
    self.dropout = nn.Dropout(dropout)

  def forward(self,X):                                                          # [B,S]
    embedded = self.dropout(self.embedding(X))                                  # [B,S,E]
    outputs,hidden = self.gru(embedded)                                         # [B,S,2H], [2,B,H]
    return outputs,hidden

In [45]:
class BahdanauAtttention(nn.Module):
  def __init__(self,enc_hidden_dims,dec_hidden_dims):                           # [B,S,2H], [B,D]
    super().__init__()
    self.attention = nn.Linear(enc_hidden_dims*2 + dec_hidden_dims,dec_hidden_dims)
    self.v = nn.Parameter(torch.rand(dec_hidden_dims))

  def forward(self,hidden_dec,enc_outs):
    seq_len = enc_outs.shape[1]
    hidden_dec = hidden_dec.unsqueeze(1).repeat(1,seq_len,1)                    # [B,D] -> [B,1,D] -> [B,S,D]

    energy = torch.tanh(self.attention(torch.cat(hidden_dec,enc_outs),dims=2))  # [B,S,D+2H] -> [B,S,attn_dims==D]

    v = self.v.unsqueeze(0).unsqueeze(1)                                        # [1,1,D]
    scores = torch.sum(v*energy,dim=2)                                          # [B,S]
    attn_weights = torch.softmax(scores,dim=1)                                  # [B,S]

    return attn_weights

In [47]:
class Decoder(nn.Module):
  def __init__(self,output_dims,embd_dims,dropout,enc_hidden_dims,dec_hidden_dims,attention):
    super().__init__()
    self.output_dims = output_dims
    self.attention = attention
    self.embedding = nn.Embedding(output_dims,embd_dims)
    self.gru = nn.GRU(2*enc_hidden_dims + embd_dims,dec_hidden_dims,batch_first=True)
    self.fc = nn.Linear(2*enc_hidden_dims + embd_dims + dec_hidden_dims,output_dims)
    self.dropouts = nn.Dropout(dropout)

  def forward(self,Y_t,hidden,enc_outs):
    Y_t = Y_t.unsqueeze(1)                                                      # [B,1]
    embedded = self.dropouts(self.embedding(Y_t))                               # [B,1,E]

    alignment_weights = self.attention(enc_outs,hidden)                         # [B,S]
    a = alignment_weights.unsqueeze(1)                                          # [B,1,S]

    context = torch.bmm(a,enc_outs)                                             # [B,1,S] * [B,S,2H] = [B,1,2H]
    gru_input = torch.cat((context,embedded),dims=2)                            # [B,1,E+2H]

    outputs,hidden = self.gru(gru_input,hidden.unsqueeze(1))                    # [B,1,D], [1,B,D]
    embedded = embedded.squeeze(1)                                              # [B,E]
    output = outputs.squeeze(1)                                                 # [B,D]
    context = context.squeeze(1)                                                # [B,2H]

    pred = self.fc(torch.cat((output, context, embedded), dims=1))              # [B,V]
    return pred, hidden.squeeze(0)                                              # [B,V], [B,D]

In [64]:
class Seq2Seq(nn.Module):
  def __init__(self,encoder,decoder):
    super().__init__()
    self.encoder = encoder
    self.decoder = decoder

  def forward(self,src,trg,teacher_forcing_ratio=0.5):                          # src : [B,S] , trg :[B,T]
    batch_size = trg.shape[0]
    seq_len = trg.shape[1]
    trg_vocab_size = self.decoder.output_dims

    outputs = torch.zeros(batch_size,seq_len,trg_vocab_size)                    # [B,T,V]

    enc_outs,hidden = self.encoder(src)                                         # [B,S,2H] , [2,B,H]

    hidden = torch.cat((hidden[-2,:,:], hidden[-1,:,:]), dim=1)                 # [B,2H]
    input = trg[:,0]                                                            # <sos>

    for t in range(1,seq_len):                                                  # 1 -> T
      output,hidden = self.decoder(input,hidden,enc_outs)                       # [B,V], [B,D]
      outputs[:,t,:] = output

      teacher_force = torch.rand(1).item() < teacher_forcing_ratio
      top1 = output.argmax(1)

      input = trg[:,t] if teacher_force else top1

    return outputs                                                              # [B,T,V]