# Building attention mechanism in Encoder-decoder from Scratch

In [1]:
import torch
from torch import nn
import pandas as pd

In [None]:
class Encoder(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, dropout):
        super().__init__()
        self.lstm_layers = nn.LSTM(input_size=input_size, hidden_size=hidden_size, num_layers=num_layers, batch_first=True, dropout=dropout if num_layers>1 else 0, bidirectional=True)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        output, (h_n, c_n) = self.lstm_layers(x)
        return output, h_n, c_n
    
class Decoder(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self):
        pass

class EncoderDecoderAttention(nn.Module):
    def __init__(self, vocab_size, d_hidden, hidden_size, num_layers, dropout):
        super().__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(num_embeddings=vocab_size, embedding_dim=d_hidden, padding_idx=0, max_norm=1)
        self.encoder = Encoder(input_size=512, hidden_size=hidden_size, num_layers=num_layers, dropout=0.2)
        self.dropout = nn.Dropout(dropout)
        self.linear = nn.ModuleList([
            nn.Linear(in_features=self.hidden_size*2, out_features=self.hidden_size) for i in range(num_layers)
        ])

    def context(self, matrix, linear):
        m_new = []

        for i in range(len(matrix)//2):
            m_first = matrix[2*i]
            m_second = matrix[2*i+1]
            m_n_new = torch.cat([m_first, m_second], dim=-1)
            m_n_new = torch.tanh(linear[i](m_n_new))
            m_new.append(m_n_new)
        m_new = torch.stack(m_new)

        return m_new

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        embeddings = self.embedding(x)
        embeddings = self.dropout(embeddings)
        print(embeddings.shape)
        output, h_n, c_n = self.encoder(embeddings)
        print(output.shape, h_n.shape, c_n.shape)

        h_new = self.context(matrix=h_n, linear=self.linear)
        c_new = self.context(matrix=c_n, linear=self.linear)

        print(h_new.shape, c_new.shape)
        
        

In [53]:
model = EncoderDecoderAttention(vocab_size=10000, d_hidden=512, hidden_size=256, num_layers=2, dropout=0.25)

In [54]:
from torchinfo import summary
summary(model=model, input_size=(64, 1024), dtypes=[torch.long], col_names=["input_size", "output_size", "num_params"])

torch.Size([64, 1024, 512])
torch.Size([64, 1024, 512]) torch.Size([4, 64, 256]) torch.Size([4, 64, 256])
torch.Size([2, 64, 256]) torch.Size([2, 64, 256])


Layer (type:depth-idx)                   Input Shape               Output Shape              Param #
EncoderDecoderAttention                  [64, 1024]                --                        --
├─Embedding: 1-1                         [64, 1024]                [64, 1024, 512]           5,120,000
├─Dropout: 1-2                           [64, 1024, 512]           [64, 1024, 512]           --
├─Encoder: 1-3                           [64, 1024, 512]           [64, 1024, 512]           --
│    └─LSTM: 2-1                         [64, 1024, 512]           [64, 1024, 512]           3,153,920
├─ModuleList: 1-4                        --                        --                        --
│    └─Linear: 2-2                       [64, 512]                 [64, 256]                 131,328
│    └─Linear: 2-3                       [64, 512]                 [64, 256]                 131,328
│    └─Linear: 2-4                       [64, 512]                 [64, 256]                 (recursive)
│ 