In [1]:
import torch.nn as nn
import torch 
import torch.nn.functional as F 
import numpy as np

In [2]:
class PositionalEncoding(nn.Module):
    def __init__(self,
                 d_model: int,
                 max_seq_len: int):
        super().__init__()
        self.d_model = d_model
        self.max_seq_len = max_seq_len

    def forward(self) -> torch.tensor:
        pos = torch.arange(0, self.max_seq_len)
        denominator = torch.arange(0, self.d_model, 2)
        denominator = torch.pow(10_000, denominator/self.d_model)

        pos = pos.reshape(-1, 1)
        denominator = denominator.reshape(1, -1)
        even_pos = torch.sin(pos / denominator)
        odd_pos = torch.cos(pos / denominator)

        PE = torch.stack([even_pos, odd_pos], dim=2)
        PE = torch.flatten(PE, start_dim=1, end_dim=2)
        return PE


class MultiHeadAttention(nn.Module):
    def __init__(self,
                 input_dim: int,
                 d_model: int,
                 n_head: int):
        super().__init__()
        self.input_dim = input_dim
        self.d_model = d_model
        self.n_head = n_head
        self.h_dim = d_model // n_head
        self.qkv_layer = nn.Linear(input_dim, 3 * d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self,
                x: torch.tensor,
                mask: torch.tensor = None):
        B, sen_len, input_dim = x.size()
        qkv = self.qkv_layer(x)  # B, sen_len, 3 * d_model
        qkv = qkv.reshape(B, sen_len, self.n_head, self.h_dim * 3)
        qkv = qkv.permute(0, 2, 1, 3)
        q, k, v = qkv.chunk(3, dim=-1)
        d_k = q.size()[-1]
        att = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k))
        if mask is not None:
            att += mask
        att = F.softmax(att, dim=-1)
        new_emb = att @ v
        new_emb = new_emb.reshape(B, sen_len, self.n_head * self.h_dim)
        new_emb = self.linear_layer(new_emb)
        return att, new_emb


class LayerNormalization(nn.Module):
    def __init__(self, parameters_shape, eps=1e-5):
        super().__init__()
        self.parameters_shape = parameters_shape
        self.eps = eps
        self.gamma = nn.Parameter(torch.ones(parameters_shape))
        self.beta = nn.Parameter(torch.zeros(parameters_shape))

    def forward(self, x):
        dims = [-(i + 1) for i in range(len(self.parameters_shape))]
        mean = x.mean(dim=dims, keepdim=True)
        var = ((x - mean) ** 2).mean(dim=dims, keepdim=True)
        std = (var + self.eps).sqrt()
        y = (x - mean) / std
        out = self.gamma * y + self.beta
        return out


# what is being passed to encoder?
# well as per the original diagram we will be giving out vectors
# after adding pos encoding :)

class FeedForwardNetwork(nn.Module):
    def __init__(self,
                 d_model: int,
                 ffn_hidden: int,
                 drop_prob: float):
        super().__init__()
        self.l = nn.Sequential(
            nn.Linear(d_model, ffn_hidden),
            nn.ReLU(),
            nn.Dropout(drop_prob),
            nn.Linear(ffn_hidden, d_model),
            nn.ReLU(),
            nn.Dropout(drop_prob),
        )

    def forward(self,
                x):
        out = self.l(x)
        return out



class MultiHeadCrossAttention(nn.Module):
    def __init__(self,
                 d_model: int,
                 n_head: int):
        
        super().__init__()
        self.q_layer = nn.Linear(d_model, d_model)
        self.k_layer = nn.Linear(d_model, d_model)
        self.v_layer = nn.Linear(d_model, d_model)
        self.n_head = n_head

    def forward(self,
                enc_out: torch.tensor,
                dec_out: torch.tensor):
        
        B, max_sen_len, d_model = enc_out.size()

        q = self.q_layer(dec_out)
        k = self.k_layer(enc_out)
        v = self.v_layer(enc_out)

        q = q.reshape(B, max_sen_len, self.n_head, d_model // self.n_head)
        k = k.reshape(B, max_sen_len, self.n_head, d_model // self.n_head)
        v = v.reshape(B, max_sen_len, self.n_head, d_model // self.n_head)

        q = q.permute(0, 2, 1, 3)
        k = k.permute(0, 2, 1, 3)
        v = v.permute(0, 2, 1, 3)

        att = (q @ k.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_model))
        att = F.softmax(att, dim=-1)

        new_emb = att @ v
        new_emb = new_emb.reshape(B, max_sen_len, self.n_head * (d_model // self.n_head))

        return att, new_emb
    

class DecoderLayer(nn.Module):
    def __init__(self,
                 d_model: int,
                 ffn_hidden: int,
                 n_head: int,
                 drop_prob: float):
        super().__init__()
        # self.m_att = MultiHeadAttention(input_dim=d_model,
        #                                 d_model=d_model,
        #                                 n_head=n_head)
        self.masked_att = MultiHeadAttention(input_dim=d_model,
                                        d_model=d_model,
                                        n_head=n_head)
        self.ffn = FeedForwardNetwork(d_model=d_model,
                                      ffn_hidden=ffn_hidden,
                                      drop_prob=drop_prob)
        self.l_norm1 = LayerNormalization(parameters_shape=[d_model])
        self.l_norm2 = LayerNormalization(parameters_shape=[d_model])
        self.l_norm3 = LayerNormalization(parameters_shape=[d_model])


        self.mcross_att = MultiHeadCrossAttention(d_model=d_model,
                                                  n_head=n_head)



    def forward(self,
                x: torch.tensor,
                mask: torch.tensor,
                enc_out: torch.tensor):
        
        _, att = self.masked_att(x, mask)
        att = self.l_norm1(att + x)

        _, out = self.mcross_att(enc_out, att)
        out = self.l_norm2(out + att)
        
        f_out = self.ffn(out)
        out = self.l_norm2(f_out + out)

        return out


class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, mask, y = inputs
        for module in self._modules.values():
            y = module(x, mask, y) #30 x 200 x 512
        return y

class Decoder(nn.Module):
    def __init__(self,
                 d_model: int,
                 ffn_hidden: int,
                 n_head: int,
                 drop_prob: float,
                 n_layers: int):
        super().__init__()
        self.l = SequentialDecoder(*[DecoderLayer(d_model=d_model,
                                              ffn_hidden=ffn_hidden,
                                              n_head=n_head,
                                              drop_prob=drop_prob) for _ in range(n_layers)])

    def forward(self,
                x: torch.tensor,
                mask: torch.tensor,
                enc_out: torch.tensor) -> torch.tensor:
        out = self.l(x, mask, enc_out)
        return out

In [3]:
d_model = 512
n_heads = 8
drop_prob = 0.1
batch_size = 32
max_seq_len = 200
ffn_hidden = 2048
n_layers = 5
x = torch.randn( (batch_size, max_seq_len, d_model)) 

dec = DecoderLayer(d_model,
             ffn_hidden,
             n_heads, 
             drop_prob)
mask = torch.full([max_seq_len, max_seq_len] , float('-inf'))
mask = torch.triu(mask, diagonal=1)
dec(x=x, enc_out=x, mask=mask).shape



# dec = Decoder(d_model,
#              ffn_hidden,
#              n_heads, 
#              drop_prob, 
#              2)
# dec(x, None, x).shape

torch.Size([32, 200, 512])

In [4]:
mask

tensor([[0., -inf, -inf,  ..., -inf, -inf, -inf],
        [0., 0., -inf,  ..., -inf, -inf, -inf],
        [0., 0., 0.,  ..., -inf, -inf, -inf],
        ...,
        [0., 0., 0.,  ..., 0., -inf, -inf],
        [0., 0., 0.,  ..., 0., 0., -inf],
        [0., 0., 0.,  ..., 0., 0., 0.]])

In [6]:
encoder_padding_mask = torch.full([1, max_seq_len, max_seq_len] , False)
encoder_padding_mask

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])

In [18]:
max_sequence_length = 5
idx = 0
num_sentences = 2
decoder_padding_mask_self_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)
decoder_padding_mask_cross_attention = torch.full([num_sentences, max_sequence_length, max_sequence_length] , False)

eng_sentence_length, kn_sentence_length = 3, 2
eng_chars_to_padding_mask = np.arange(eng_sentence_length + 1, max_sequence_length)
kn_chars_to_padding_mask = np.arange(kn_sentence_length + 1, max_sequence_length)
encoder_padding_mask[idx, :, eng_chars_to_padding_mask] = True
encoder_padding_mask[idx, eng_chars_to_padding_mask, :] = True
decoder_padding_mask_self_attention[idx, :, kn_chars_to_padding_mask] = True
decoder_padding_mask_self_attention[idx, kn_chars_to_padding_mask, :] = True
decoder_padding_mask_cross_attention[idx, :, eng_chars_to_padding_mask] = True
decoder_padding_mask_cross_attention[idx, kn_chars_to_padding_mask, :] = True

In [19]:
eng_chars_to_padding_mask, kn_chars_to_padding_mask
encoder_padding_mask

tensor([[[False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         ...,
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False],
         [False, False, False,  ..., False, False, False]]])