# Building attention mechanism in Encoder-decoder from Scratch

In [None]:
import torch
from torch import nn

In [None]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, embd_dropout, input_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=input_size, padding_idx=0, max_norm=1)
        self.lstm_layers = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0, bidirectional=True)
        self.embd_dropout = nn.Dropout(embd_dropout)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embeddings = self.embedding(x)
        embeddings = self.embd_dropout(embeddings)
        output, (h_n, c_n) = self.lstm_layers(embeddings)
        return output, h_n, c_n
    
class Decoder(nn.Module):
    def __init__(self, vocab_size, input_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=input_size, padding_idx=0, max_norm=1)
        self.lstm_layers = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers > 1 else 0, bidirectional=False)
        self.out = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, tgt, hidden):
        x = self.embedding(tgt)

        output, (h, c) = self.lstm_layers(x, hidden)
        output = self.out(output)
        return output, h, c

class EncoderDecoderAttention(nn.Module):
    def __init__(self, vocab_size, embd_dropout, input_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.encoder = Encoder(vocab_size=vocab_size, embd_dropout=embd_dropout, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)
        self.linear = nn.ModuleList([
            nn.Linear(in_features=hidden_size*2, out_features=hidden_size) for i in range(num_layers)
        ])

        self.decoder = Decoder(vocab_size=vocab_size, input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, dropout=dropout)

    def combine_bidirectional_outputs(self, matrix, linear):
        m_new = []

        for i in range(len(matrix)//2):
            m_first = matrix[2*i]
            m_second = matrix[2*i+1]
            m_n_new = torch.cat([m_first, m_second], dim=-1)
            m_n_new = torch.tanh(linear[i](m_n_new))
            m_new.append(m_n_new)
        m_new = torch.stack(m_new)

        return m_new

    def forward(self, x: torch.Tensor, tgt: torch.Tensor) -> torch.Tensor:
        output, h_n, c_n = self.encoder(x)

        h_new = self.combine_bidirectional_outputs(matrix=h_n, linear=self.linear)
        c_new = self.combine_bidirectional_outputs(matrix=c_n, linear=self.linear)

        output, h, c = self.decoder(tgt, (h_new, c_new))
        return output
        

In [None]:
model = EncoderDecoderAttention(vocab_size=10000, embd_dropout=0.2, input_size=512, hidden_size=256, num_layers=2, dropout=0.25)

In [65]:
from torchinfo import summary
summary(model=model, input_size=(64, 1024), dtypes=[torch.long], col_names=["input_size", "output_size", "num_params"])

torch.Size([2, 64, 256]) torch.Size([2, 64, 256])


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
EncoderDecoderAttention                  [64, 1024]                --                        --
├─Encoder: 1-1                           [64, 1024]                [64, 1024, 512]           --
│    └─Embedding: 2-1                    [64, 1024]                [64, 1024, 512]           5,120,000
│    └─Dropout: 2-2                      [64, 1024, 512]           [64, 1024, 512]           --
│    └─LSTM: 2-3                         [64, 1024, 512]           [64, 1024, 512]           3,153,920
├─ModuleList: 1-2                        --                        --                        --
│    └─Linear: 2-4                       [64, 512]                 [64, 256]                 131,328
│    └─Linear: 2-5                       [64, 512]                 [64, 256]                 131,328
│    └─Linear: 2-6                       [64, 512]                 [64, 256]                 (recursive)
│ 