In [None]:
# !pip install torch==2.3.0 torchtext==0.18.0
import torch
import torch.nn as nn

In [None]:
from dataclasses import dataclass


@dataclass
class ModelArgs:
    device = 'cuda'
    no_of_neurons = 128
    block_size = 32
    batch_size = 32
    en_vocab_size = None
    de_vocab_size = None
    dropout = 0.1
    epoch = 50
    max_lr = 1e-4
    embedding_dims = 1024
    num_layers = 4
    hidden_dim = 4*embedding_dims

In [None]:
class BandhanauAttention(nn.Module):
    def __init__(self, hidden_size):
        super(BandhanauAttention, self).__init__()
        self.linear_layer_1 = nn.Linear(2 * ModelArgs.no_of_neurons, ModelArgs.hidden_dim, device=ModelArgs.device)
        self.linear_layer_2 = nn.Linear(ModelArgs.hidden_dim, ModelArgs.embedding_dims, device=ModelArgs.device)

    def forward(self, st_1, ht):
        # print("inside att st: ", st_1.shape)
        # print("inside att ht: ", ht.shape)
        st_1 = st_1.expand(-1, ht.shape[1], -1)
        combined = torch.cat([st_1, ht], dim=-1)
        out = self.linear_layer_1(combined)
        out = torch.nn.functional.tanh(out)
        out = self.linear_layer_2(out)
        attention_weights = torch.nn.functional.softmax(out, dim=1)
        # print("attn: ", attention_weights.shape)
        # print((attention_weights * ModelArgs.block_size).shape)
        # print(torch.sum((attention_weights * ModelArgs.block_size), dim=1).shape)
        context_vector = torch.sum(attention_weights * ModelArgs.block_size, dim=1)
        return context_vector, attention_weights