In [3]:
import math
import numpy as np

In [4]:
import torch
import torch.nn as nn

In [5]:
class InputEmbedding(nn.Module):

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(d_model, vocab_size)
        # shape = [num of words * dimension of embedding layer]

    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)
        # dimension same

In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, seq_length, dropout = 0):
        super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(self.seq_length, self.d_model)  # To get the matrix of dimension as of embedding layer
        positions = torch.arange(0, self.seq_length, dtype = torch.float32).unsqueeze(1)  # matrix of [seq_length x 1]
        div_term = (positions /(torch.pow(10000, 2 * torch.arange(0, d_model, 2).float() /self.d_model))) #to calculate say (angle)  pos/(10000^(2i/dmodel))
        pe[:, 0::2] = torch.sin(div_term)   #Apply sine formula in even positions
        pe[:, 1::2] = torch.cos(div_term)   # Appply cosine formula in odd positions
        
        self.pe = pe.unsqueeze(0)  # for batches dimension [1 x seq_length x d_model]

        # self.register_buffer('pe', self.pe) # By adding this in register buffer this stores pe too while saving the model without considering it as a learning parameter
                

    def forward(self, x):
        x = x + self.pe.required_grad(False)  #To make it not to learn
        return self.dropout(x)
    # def forward(self, ..):
        # pe = torch.zeros()

In [7]:
class LayerNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(1))  # Scale
        self.beta = nn.Parameter(torch.zeros(1))  # Shift

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True) 
        return self.gamma * (x - mean) / torch.sqrt(var + self.epsilon) + self.beta


In [8]:
class FeedForward(nn.Module):

    def __init__(self, d_model, dff):
        super().__init__()
        self.forward1 = nn.Linear(d_model, dff)
        self.dropout = nn.Dropout(dropout)
        self.forward2 = nn.Linear(dff, d_model)

    def forward(self, x):
        return self.forward2(self.dropout(torch.relu(self.forward1(x))))

In [9]:
# HERE I USED ALL EMBEDDING FOR EACH HEAD AND CONCATENATE THEM AND USE LINEAR TRANSFORMATION TO GET THE OUTPUT SAME DIMENSION AS INPUT
# class MultiHeadAttention(nn.Module):

#     def __init__(self, d_model, heads, dropout = 0.5):
#         super(MultiHeadAttention, self).__init__()
#         self.d_model = d_model
#         self.heads = heads
#         self.dropout = dropout

#         self.w_q = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_k = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_v = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))

#         self.w_o = nn.Linear(d_model * heads, d_model)

#         self.softmax = nn.Softmax(dim = -1)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, embeded_layer):

#         attention_outputs = []

#         for head in range(self.heads):
        
#             query = self.w_q[head](embeded_layer)
#             key = self.w_k[head](embeded_layer)
#             value = self.w_v[head](embeded_layer)

#             similarity = torch.matmul(query, torch.transpose(key, -2, -1))  / math.sqrt(self.d_model)

#             sim = self.softmax(similarity)
#             sim = self.dropout(sim)

#             final = torch.matmul(sim, value)

#             attention_outputs.append(final)
            
#         concat_matrix = torch.cat(attention_outputs, -1)
#         print(concat_matrix.shape)
#         print(self.w_o.weight.shape)
#         return self.w_o(concat_matrix)
        

        
        

In [70]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, heads, dropout = 0.5):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_heads = d_model//heads

        self.w_q = nn.Linear(d_model, d_model)
        self.w_v = nn.Linear(d_model, d_model)
        self.w_k = nn.Linear(d_model, d_model)

        self.w_o = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

    
    # def splitweights(self, x):
    #     batch_size, seq_len, d_model = x.shape
    #     x = x.view(batch_size, seq_len, self.heads, -1)
    #     return x.permute(0, 2, 1, 3)
        

    def forward(self, x_q, x_k, x_v, mask = None):

        batch_size, seq_len, d_model = x.shape
        print(x.shape)
        print(self.w_q.weight.shape)

        query = self.w_q(x_q).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        key = self.w_k(x_k).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        value = self.w_v(x_v).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)

        # query = self.splitweights(self.w_q(x))
        # key = self.splitweights(self.w_k(x))
        # value = self.splitweights(self.w_v(x))
        print(query.shape)

        similarity = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_heads)

        # print(similarity.shape)

        if mask is not None:
            # print(mask)
            mask = mask.unsqueeze(0).unsqueeze(0)
            print(mask)
            # print(similarity)
            similarity = similarity.masked_fill(mask == 0, float('-inf'))
        print(similarity)

        sim = self.softmax(similarity)
        print(sim)
        sim = self.dropout(sim)

        # print(sim)

        final = torch.matmul(sim, value)

        final = final.permute(0, 2, 1, 3).contiguous()
        final = final.view(batch_size, seq_len, self.d_model)
        

        # print(final.shape)
        return self.w_o(final), key, value
        

In [72]:
m = torch.tril(torch.ones(4, 4))
x = torch.rand(2, 4 ,15)
a = MultiHeadAttention(15, 3)
a(x,x,x, m)


torch.Size([2, 4, 15])
torch.Size([15, 15])
torch.Size([2, 3, 4, 5])
tensor([[[[1., 0., 0., 0.],
          [1., 1., 0., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]]]])
tensor([[[[-0.0834,    -inf,    -inf,    -inf],
          [-0.0686, -0.0825,    -inf,    -inf],
          [-0.0881, -0.0991, -0.1975,    -inf],
          [-0.0862, -0.1245, -0.2875, -0.2431]],

         [[ 0.2568,    -inf,    -inf,    -inf],
          [ 0.2862,  0.2572,    -inf,    -inf],
          [ 0.4133,  0.3764,  0.2460,    -inf],
          [ 0.1937,  0.1318, -0.0408,  0.1696]],

         [[ 0.0255,    -inf,    -inf,    -inf],
          [-0.0236,  0.0309,    -inf,    -inf],
          [-0.0040,  0.0678,  0.1092,    -inf],
          [ 0.0426,  0.0868,  0.1742, -0.0161]]],


        [[[-0.0673,    -inf,    -inf,    -inf],
          [-0.1612, -0.0941,    -inf,    -inf],
          [-0.1624, -0.1452, -0.2296,    -inf],
          [-0.1657, -0.0959, -0.1652, -0.2055]],

         [[ 0.1737,    -inf,    -inf, 

tensor([[[-0.0998,  0.0088, -0.2239,  0.0868, -0.4007, -0.0979,  0.5699,
          -0.2157,  0.0259,  0.2374, -0.0623, -0.3742,  0.1215, -0.3384,
          -0.3979],
         [-0.0518,  0.0212, -0.1555,  0.0111, -0.3704, -0.1655,  0.5500,
          -0.2036, -0.0059,  0.1552, -0.1609, -0.3379,  0.1518, -0.2935,
          -0.3778],
         [ 0.1753, -0.2545, -0.0684,  0.0714, -0.2857,  0.2374,  0.0751,
          -0.1218,  0.2345,  0.0125, -0.2632, -0.1939, -0.0456, -0.3247,
          -0.3117],
         [ 0.2450, -0.2002, -0.0748, -0.0008, -0.2626,  0.1804, -0.0265,
          -0.0452,  0.2256, -0.0400, -0.3293, -0.2018, -0.0606, -0.2342,
          -0.1655]],

        [[ 0.3797, -0.6368, -0.1588,  0.0027, -0.3756,  0.4985, -0.2552,
          -0.0759,  0.2512, -0.3902, -0.2752,  0.1731,  0.0134, -0.3697,
          -0.3246],
         [ 0.2548, -0.0706, -0.1386, -0.1202, -0.2569,  0.0901,  0.0898,
           0.0363,  0.1339,  0.0278, -0.2490, -0.2191,  0.0837, -0.2117,
          -0.2002],
  

In [12]:
class ResidualConnection(nn.Module):

    def __init__(self, d_model ,dropout):

        super(ResidualConnection, self).__init__()
        self.ln = LayerNormalization()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x1, x2):

        return self.ln(x1 + self.dropout(x2))

In [64]:
class EncoderBlock(nn.Module):

    def __init__(self,  dff, d_model, heads, dropout):

        super(EncoderBlock, self).__init__()

        self.multi_attention = MutltiHeadAttention(d_model, heads, dropout)
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(2)])

        self.feed_forward = FeedForward(d_model, dff)

    def forward(self, x):

        x1, key, value = self.multi_attention(x, x, x)
        x2 = self.residual_connections[0](x, x1)
        x3 = self.feed_forward(x2)
        x4 = self.residual_connections[1](x2, x3)
        return x4, key, value
        
        

In [66]:
class Encoder(nn.Module):

    def __init__(self, vocab_size, dff, seq_length, d_model, heads,dropout, n = 6):

        super(Encoder, self).__init__()
        self.embedding_layer = InputEmbedding(d_model, vocab_size)
        self.positional_embedding = PositionalEmbedding(d_model, seq_length, dropout)
        
        self.encoder_blocks = nn.ModuleList([EncoderBlock(dff, d_model, heads,dropout) for _ in range(n)])
        print(type(self.encoder_blocks))

    def forward(self,x):

        x = self.positional_embedding(self.embedding_layer(x))

        for block in self.encoder_blocks:
            x, key, value = block(x)
        return x, key, value

In [15]:
Encoder(0, 0, 0, 0, 0)

<class 'torch.nn.modules.container.ModuleList'>


Encoder(
  (encoder_blocks): ModuleList(
    (0-5): 6 x EncoderBlock()
  )
)

In [74]:
class DecoderBlock(nn.Module):

    def __init__(self, dff, d_model, heads, dropout):

        super(DecoderBlock, self).__init__()
        # self.masked_attention = masked_attention
        # self.residual_connections = residual_connections
        # self.cross_attention = cross_attention
        # self.feed_forward = feed_forward
        self.masked_attention = MultiHeadAttention(d_model, heads)
        self.residual_connections = nn.ModuleList([ResidualConnection(d_model, dropout) for _ in range(3)])
        self.cross_attention = MultiHeadAttention(d_model, heads)
        self.feed_forward = FeedForward(d_model, dff)


    def forward(self, x, key, value, mask):

        x1, _, _ = self.masked_attention(x, x, x, mask)
        x2 = self.residual_connections[0](x, x1)

        x3, _, _ = self.cross_attention(x2, key, value)
        x4 = self.residual_connections[1](x2, x3)

        x5 = self.feed_forward(x4)
        x6 = self.residual_connections[2](x4, x5)

        return x6

In [78]:
class Decoder(nn.Module):

    def __init__(self, vocab_size, dff, seq_length, d_model, heads,dropout, n = 6):

        super(Decoder, self).__init__()
        
        self.embedding_layer = InputEmbedding(d_model, vocab_size)
        self.positional_embedding = PositionalEmbedding(d_model, seq_length, dropout)

        self.decoder_blocks = nn.ModuleList([DecoderBlock(dff, d_model, heads,dropout) for _ in range(n)])

        self.mask = m = torch.tril(torch.ones(seq_length, seq_length))


    def forward(self, x, key, value):

        x = self.positional_embedding(self.embedding_layer(x))
        for block in self.decoder_blocks:
            x = block(x, key, value, self.mask)
        return x            
        
        

In [16]:
# torch.manual_seed(44)

In [17]:
# a = torch.rand((4, 3, 3))
# a

In [18]:
# sm = nn.Softmax(dim = -1)

In [19]:
# sm(a)

In [20]:
x = torch.rand(5, 4)

In [21]:
fc = nn.Linear(4, 5)

In [22]:
fc.weight.shape

torch.Size([5, 4])

In [23]:
x

tensor([[0.2745, 0.2622, 0.7498, 0.2155],
        [0.2069, 0.7540, 0.0982, 0.3528],
        [0.9714, 0.7038, 0.2820, 0.3485],
        [0.0767, 0.5571, 0.3614, 0.1596],
        [0.7256, 0.7579, 0.5823, 0.4380]])

In [24]:
10//5

2

In [25]:
16//3

5

In [26]:
x = torch.rand(4, 5, 6)

In [27]:
x

tensor([[[0.9769, 0.1334, 0.0875, 0.5955, 0.2906, 0.4639],
         [0.8240, 0.0231, 0.4150, 0.5795, 0.7420, 0.0552],
         [0.1538, 0.9790, 0.6506, 0.7592, 0.7137, 0.3978],
         [0.6863, 0.1378, 0.6559, 0.1667, 0.4780, 0.3286],
         [0.6324, 0.0663, 0.9344, 0.9531, 0.6710, 0.2322]],

        [[0.9834, 0.8209, 0.3834, 0.4204, 0.7666, 0.2699],
         [0.8639, 0.9251, 0.1020, 0.7359, 0.0199, 0.7547],
         [0.9273, 0.0284, 0.5178, 0.3728, 0.7439, 0.7334],
         [0.8518, 0.1130, 0.4126, 0.9243, 0.6957, 0.1537],
         [0.1205, 0.7821, 0.1062, 0.3711, 0.2527, 0.8687]],

        [[0.9578, 0.2201, 0.6164, 0.9399, 0.6242, 0.7545],
         [0.2806, 0.1959, 0.0043, 0.7410, 0.7151, 0.9750],
         [0.9442, 0.6250, 0.2752, 0.1552, 0.1160, 0.8005],
         [0.6515, 0.2737, 0.6172, 0.1177, 0.4433, 0.2350],
         [0.9038, 0.8049, 0.6357, 0.6561, 0.6560, 0.4586]],

        [[0.9834, 0.7357, 0.0055, 0.8204, 0.5147, 0.9410],
         [0.4765, 0.5496, 0.3550, 0.5732, 0.2556, 

In [28]:
y = x.view(4, 5, 2, -1)

In [29]:
y.permute(0, 2, 1, 3)

tensor([[[[0.9769, 0.1334, 0.0875],
          [0.8240, 0.0231, 0.4150],
          [0.1538, 0.9790, 0.6506],
          [0.6863, 0.1378, 0.6559],
          [0.6324, 0.0663, 0.9344]],

         [[0.5955, 0.2906, 0.4639],
          [0.5795, 0.7420, 0.0552],
          [0.7592, 0.7137, 0.3978],
          [0.1667, 0.4780, 0.3286],
          [0.9531, 0.6710, 0.2322]]],


        [[[0.9834, 0.8209, 0.3834],
          [0.8639, 0.9251, 0.1020],
          [0.9273, 0.0284, 0.5178],
          [0.8518, 0.1130, 0.4126],
          [0.1205, 0.7821, 0.1062]],

         [[0.4204, 0.7666, 0.2699],
          [0.7359, 0.0199, 0.7547],
          [0.3728, 0.7439, 0.7334],
          [0.9243, 0.6957, 0.1537],
          [0.3711, 0.2527, 0.8687]]],


        [[[0.9578, 0.2201, 0.6164],
          [0.2806, 0.1959, 0.0043],
          [0.9442, 0.6250, 0.2752],
          [0.6515, 0.2737, 0.6172],
          [0.9038, 0.8049, 0.6357]],

         [[0.9399, 0.6242, 0.7545],
          [0.7410, 0.7151, 0.9750],
          [0.1