In [3]:
import math
import numpy as np

In [4]:
import torch
import torch.nn as nn

In [5]:
class InputEmbedding(nn.Module):

    def __init__(self, d_model, vocab_size):
        super().__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(d_model, vocab_size)
        # shape = [num of words * dimension of embedding layer]

    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)
        # dimension same

In [6]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, seq_length, dropout = 0):
        super().__init__()
        self.d_model = d_model
        self.seq_length = seq_length
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(self.seq_length, self.d_model)  # To get the matrix of dimension as of embedding layer
        positions = torch.arange(0, self.seq_length, dtype = torch.float32).unsqueeze(1)  # matrix of [seq_length x 1]
        div_term = (positions /(torch.pow(10000, 2 * torch.arange(0, d_model, 2).float() /self.d_model))) #to calculate say (angle)  pos/(10000^(2i/dmodel))
        pe[:, 0::2] = torch.sin(div_term)   #Apply sine formula in even positions
        pe[:, 1::2] = torch.cos(div_term)   # Appply cosine formula in odd positions
        
        self.pe = pe.unsqueeze(0)  # for batches dimension [1 x seq_length x d_model]

        # self.register_buffer('pe', self.pe) # By adding this in register buffer this stores pe too while saving the model without considering it as a learning parameter
                

    def forward(self, x):
        x = x + self.pe.required_grad(False)  #To make it not to learn
        return self.dropout(x)
    # def forward(self, ..):
        # pe = torch.zeros()

In [7]:
class LayerNormalization(nn.Module):
    def __init__(self, epsilon=1e-5):
        super().__init__()
        self.epsilon = epsilon
        self.gamma = nn.Parameter(torch.ones(1))  # Scale
        self.beta = nn.Parameter(torch.zeros(1))  # Shift

    def forward(self, x):
        mean = x.mean(dim=1, keepdim=True)
        var = x.var(dim=1, keepdim=True) 
        return self.gamma * (x - mean) / torch.sqrt(var + self.epsilon) + self.beta


In [8]:
class FeedForward(nn.Module):

    def __init__(self, d_model, dff):
        super().__init__()
        self.forward1 = nn.Linear(d_model, dff)
        self.dropout = nn.Dropout(dropout)
        self.forward2 = nn.Linear(dff, d_model)

    def forward(self, x):
        return self.forward2(self.dropout(torch.relu(self.forward1(x))))

In [85]:
# HERE I USED ALL EMBEDDING FOR EACH HEAD AND CONCATENATE THEM AND USE LINEAR TRANSFORMATION TO GET THE OUTPUT SAME DIMENSION AS INPUT
# class MultiHeadAttention(nn.Module):

#     def __init__(self, d_model, heads, dropout = 0.5):
#         super(MultiHeadAttention, self).__init__()
#         self.d_model = d_model
#         self.heads = heads
#         self.dropout = dropout

#         self.w_q = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_k = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))
#         self.w_v = nn.ModuleList(nn.Linear(d_model, d_model) for _ in range(heads))

#         self.w_o = nn.Linear(d_model * heads, d_model)

#         self.softmax = nn.Softmax(dim = -1)
#         self.dropout = nn.Dropout(dropout)

#     def forward(self, embeded_layer):

#         attention_outputs = []

#         for head in range(self.heads):
        
#             query = self.w_q[head](embeded_layer)
#             key = self.w_k[head](embeded_layer)
#             value = self.w_v[head](embeded_layer)

#             similarity = torch.matmul(query, torch.transpose(key, -2, -1))  / math.sqrt(self.d_model)

#             sim = self.softmax(similarity)
#             sim = self.dropout(sim)

#             final = torch.matmul(sim, value)

#             attention_outputs.append(final)
            
#         concat_matrix = torch.cat(attention_outputs, -1)
#         print(concat_matrix.shape)
#         print(self.w_o.weight.shape)
#         return self.w_o(concat_matrix)
        

        
        

In [302]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_model, heads, dropout = 0.5):
        super(MultiHeadAttention, self).__init__()
        self.d_model = d_model
        self.heads = heads
        self.d_heads = d_model//heads

        self.w_q = nn.Linear(d_model, d_model, bias = False)
        self.w_v = nn.Linear(d_model, d_model, bias = False)
        self.w_k = nn.Linear(d_model, d_model, bias = False)

        self.w_o = nn.Linear(d_model, d_model)

        self.softmax = nn.Softmax(dim = -1)
        self.dropout = nn.Dropout(dropout)

    
    # def splitweights(self, x):
    #     batch_size, seq_len, d_model = x.shape
    #     x = x.view(batch_size, seq_len, self.heads, -1)
    #     return x.permute(0, 2, 1, 3)
        

    def forward(self, x, mask = None):

        batch_size, seq_len, d_model = x.shape
        print(x.shape)
        print(self.w_q.weight.shape)

        query = self.w_q(x).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        key = self.w_k(x).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)
        value = self.w_v(x).view(batch_size, seq_len, self.heads, -1).permute(0, 2, 1, 3)

        # query = self.splitweights(self.w_q(x))
        # key = self.splitweights(self.w_k(x))
        # value = self.splitweights(self.w_v(x))
        # print(query.shape)

        similarity = torch.matmul(query, key.transpose(-2, -1))/math.sqrt(self.d_heads)

        # print(similarity.shape)

        if mask is not None:
            # print(mask)
            mask = mask.unsqueeze(0).unsqueeze(0)
            print(mask)
            # print(similarity)
            similarity = similarity.masked_fill(mask == 0, float('-inf'))
        print(similarity)

        sim = self.softmax(similarity)
        sim = self.dropout(sim)

        print(sim)

        final = torch.matmul(sim, value)

        final = final.permute(0, 2, 1, 3).contiguous()
        final = final.view(batch_size, seq_len, self.d_model)
        

        # print(final.shape)
        return self.w_o(final)
        

In [304]:
m = torch.tril(torch.ones(4, 4))
x = torch.rand(2, 4 ,15)
a = MultiHeadAttention(15, 3)
a(x, m)


torch.Size([2, 4, 15])
torch.Size([15, 15])
tensor([[[[1., 0., 0., 0.],
          [1., 1., 0., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]]]])
tensor([[[[ 0.0630,    -inf,    -inf,    -inf],
          [ 0.0257,  0.0704,    -inf,    -inf],
          [-0.0041,  0.0296, -0.0377,    -inf],
          [-0.0146,  0.0632, -0.0005,  0.0133]],

         [[-0.0426,    -inf,    -inf,    -inf],
          [-0.0475, -0.1261,    -inf,    -inf],
          [-0.0737, -0.2088, -0.0939,    -inf],
          [-0.0003, -0.0865,  0.0612, -0.0616]],

         [[-0.0050,    -inf,    -inf,    -inf],
          [-0.0262, -0.0471,    -inf,    -inf],
          [-0.0636, -0.0541, -0.0202,    -inf],
          [ 0.0243, -0.0100, -0.0345,  0.0510]]],


        [[[ 0.0138,    -inf,    -inf,    -inf],
          [-0.0100, -0.0418,    -inf,    -inf],
          [-0.0077, -0.0439, -0.0243,    -inf],
          [ 0.0799,  0.0544,  0.1289,  0.0327]],

         [[ 0.0221,    -inf,    -inf,    -inf],
          [ 0.0

tensor([[[ 0.1922, -0.5326, -0.0211,  0.1084,  0.6304,  0.4637,  0.2269,
          -0.3404,  0.2666,  0.3058,  0.1101, -0.3443,  0.3794, -0.3468,
           0.1906],
         [ 0.2257, -0.2249, -0.2407,  0.4867,  0.4513,  0.1386, -0.3822,
           0.2340,  0.1526, -0.2040,  0.1067,  0.1499,  0.3272, -0.3012,
          -0.0111],
         [ 0.2209, -0.2316, -0.1655,  0.3477,  0.3848,  0.1129, -0.3217,
           0.1897,  0.0846, -0.0457,  0.0901,  0.1435,  0.2708, -0.2666,
           0.0532],
         [ 0.2882, -0.3533, -0.0039,  0.1248,  0.5535,  0.3160,  0.1243,
          -0.1467,  0.1550,  0.3483, -0.0814, -0.1228,  0.4367, -0.2792,
           0.1446]],

        [[ 0.2381, -0.4796, -0.0370, -0.0113,  0.7348,  0.3329,  0.2110,
          -0.3423,  0.4280,  0.3690, -0.0656, -0.1981,  0.6929, -0.3684,
           0.3548],
         [ 0.4724, -0.2926, -0.0177,  0.2570,  0.5469,  0.4887, -0.4414,
           0.0290,  0.5697, -0.1539, -0.0862, -0.0877,  0.7005, -0.4761,
          -0.1113],
  

In [258]:
class ResidualConnection(nn.Module):

    def __init__(self, d_model ,dropout):

        super(ResidualConnection, self).__init__()
        self.ln = LayerNormalization()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x1, x2):

        return self.ln(x1 + self.dropout(x2))

In [12]:
class EncoderBlock(nn.Module):

    def __init__(self, attention, rc, ff):

        super(EncoderBlock, self).__init__()
        self.attention = attention
        self.rc = rc
        self.ff = ff

    def forward(self, x):

        x1 = self.attention(x)
        x2 = self.rc(x, x1)

        x3 = self.ff(x2)
        x4 = self.rc(x2, x3)
        return x4
        
        

In [13]:
class Encoder(nn.Module):

    def __init__(self, embedding, pos_encoding, attention, rc, ff, n= 6):

        super(Encoder, self).__init__()
        self.embedding = embedding
        self.pos_encoding = pos_encoding
        
        self.encoder_blocks = nn.ModuleList(EncoderBlock(attention, rc, ff) for _ in range(n))
        print(type(self.encoder_blocks))

    def forward(self,x):

        x = self.pos_encoding(self.embedding(x))

        for block in self.encoder_blocks:
            x = block(x)
        return x

In [14]:
Encoder(0, 0, 0, 0, 0)

<class 'torch.nn.modules.container.ModuleList'>


Encoder(
  (encoder_blocks): ModuleList(
    (0-5): 6 x EncoderBlock()
  )
)

In [15]:
# torch.manual_seed(44)

In [16]:
# a = torch.rand((4, 3, 3))
# a

In [17]:
# sm = nn.Softmax(dim = -1)

In [18]:
# sm(a)

In [76]:
x = torch.rand(5, 4)

In [78]:
fc = nn.Linear(4, 5)

In [80]:
fc.weight.shape

torch.Size([5, 4])

In [83]:
x

tensor([[0.8063, 0.8474, 0.3038, 0.2912],
        [0.7710, 0.5077, 0.0820, 0.4452],
        [0.9906, 0.8621, 0.5196, 0.7618],
        [0.9640, 0.7600, 0.2415, 0.9532],
        [0.7602, 0.2326, 0.6311, 0.9994]])

In [90]:
10//5

2

In [94]:
16//3

5

In [96]:
x = torch.rand(4, 5, 6)

In [98]:
x

tensor([[[0.4523, 0.7673, 0.4615, 0.7689, 0.4221, 0.0651],
         [0.0223, 0.9863, 0.1898, 0.7654, 0.9906, 0.9972],
         [0.2631, 0.7573, 0.2952, 0.9816, 0.9212, 0.7088],
         [0.0012, 0.7520, 0.7898, 0.9398, 0.4272, 0.2188],
         [0.2945, 0.0889, 0.1906, 0.3348, 0.8903, 0.8236]],

        [[0.6094, 0.0213, 0.1959, 0.9292, 0.7505, 0.4193],
         [0.3287, 0.3430, 0.8965, 0.3555, 0.3171, 0.1416],
         [0.6776, 0.4158, 0.4936, 0.3396, 0.0263, 0.8206],
         [0.6837, 0.3355, 0.0200, 0.8911, 0.7474, 0.6377],
         [0.5047, 0.9601, 0.4932, 0.4784, 0.2009, 0.4425]],

        [[0.1479, 0.2026, 0.9293, 0.8562, 0.6848, 0.2807],
         [0.1583, 0.9729, 0.0499, 0.6895, 0.3252, 0.5378],
         [0.3896, 0.3278, 0.4169, 0.1010, 0.4727, 0.4761],
         [0.2407, 0.8861, 0.8218, 0.4522, 0.6229, 0.4474],
         [0.5529, 0.6962, 0.3056, 0.0646, 0.2397, 0.5979]],

        [[0.3858, 0.7512, 0.5848, 0.4403, 0.4271, 0.7991],
         [0.9456, 0.8152, 0.4946, 0.4556, 0.3767, 

In [106]:
y = x.view(4, 5, 2, -1)

In [108]:
y.permute(0, 2, 1, 3)

tensor([[[[0.4523, 0.7673, 0.4615],
          [0.0223, 0.9863, 0.1898],
          [0.2631, 0.7573, 0.2952],
          [0.0012, 0.7520, 0.7898],
          [0.2945, 0.0889, 0.1906]],

         [[0.7689, 0.4221, 0.0651],
          [0.7654, 0.9906, 0.9972],
          [0.9816, 0.9212, 0.7088],
          [0.9398, 0.4272, 0.2188],
          [0.3348, 0.8903, 0.8236]]],


        [[[0.6094, 0.0213, 0.1959],
          [0.3287, 0.3430, 0.8965],
          [0.6776, 0.4158, 0.4936],
          [0.6837, 0.3355, 0.0200],
          [0.5047, 0.9601, 0.4932]],

         [[0.9292, 0.7505, 0.4193],
          [0.3555, 0.3171, 0.1416],
          [0.3396, 0.0263, 0.8206],
          [0.8911, 0.7474, 0.6377],
          [0.4784, 0.2009, 0.4425]]],


        [[[0.1479, 0.2026, 0.9293],
          [0.1583, 0.9729, 0.0499],
          [0.3896, 0.3278, 0.4169],
          [0.2407, 0.8861, 0.8218],
          [0.5529, 0.6962, 0.3056]],

         [[0.8562, 0.6848, 0.2807],
          [0.6895, 0.3252, 0.5378],
          [0.1