In [1]:
import torch
from torch import nn
import numpy as np
import math
import time
import copy


In [2]:
# Arch units
## Self Attention Unit
## Multi Head Attention
## Encode Decode Unit
## Norm + Residual Layer
## Feed Forward
## Input Positional Encoding 


In [3]:
# parameters from paper
N = 2 # 6
d_model = 512
h = 8
d_k = d_v = d_model//h
d_ff = 2048 #128
vocab_size = 11


In [4]:
class MultiHeadAttention(nn.Module):
    def __init__(self):
        super(MultiHeadAttention, self).__init__()
        self.W_Q = nn.Linear(d_model, d_model)
        self.W_K = nn.Linear(d_model, d_model)
        self.W_V = nn.Linear(d_model, d_model)
        self.W_O = nn.Linear(d_model, d_model)
        
    def forward(self, Q, K, V, mask):
        b, q_seq, _ = Q.size()
        b, k_seq, _ = K.size()
        query = self.W_Q(Q).view(b, q_seq, h, d_k) # view (b, q_seq, h, d)
        key = self.W_K(K).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        value  = self.W_V(V).view(b, k_seq, h, d_k) # view (b, k_seq, h, d)
        
        query = query.transpose(1, 2).contiguous() # view (b, h, q_seq, d)
        key = key.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        value = value.transpose(1, 2).contiguous() # view (b, h, k_seq, d)
        
        qk = query.matmul(key.transpose(-2,-1))
        scale_qk = qk/(math.sqrt(d_k)) # shape (b, h, q_seq, k_seq)
        
        if mask is not None: # mask size (b, 1, k_seq)
            mask = mask.unsqueeze(1) # mask size (b, 1, 1, k_seq)
            scale_qk = scale_qk.masked_fill(mask==0, 1e-9)
        
        softmax_qk = nn.functional.softmax(scale_qk, dim=-1) # (b, h, q_seq, k_seq)
        weighted_value = softmax_qk.matmul(value) # (b, h, q_seq, d)
        return self.W_O(weighted_value.transpose(2,1).contiguous().view(b, q_seq, h*d_k)) # (b, h, d_model)


In [5]:
class LayerNorm(nn.Module):
    def __init__(self, d_mod=d_model):
        super(LayerNorm, self).__init__()
        self.d_mod = d_mod
        # https://stackoverflow.com/questions/39095252/fail-to-implement-layer-normalization-with-keras
        # https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
        self.alpha = nn.Parameter(torch.ones(d_mod))
        self.beta = nn.Parameter(torch.zeros(d_mod))
    def forward(self, x, eps=1e-6):
        u = x.mean(-1, keepdim=True)
        sigma = x.std(-1, keepdim=True)
        return self.alpha * (x - u)/(sigma + eps) + self.beta
    

In [6]:
class EncoderCell(nn.Module):
    def __init__(self):
        super(EncoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_2 = LayerNorm()
    
    def forward(self, x, src_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, src_mask))  # Layer 1
        return self.norm_2(x_norm_1 + self.pff(x_norm_1)) # Layer 2

In [7]:
class DecoderCell(nn.Module):
    def __init__(self):
        super(DecoderCell, self).__init__()
        self.attn = MultiHeadAttention()
        self.norm_1 = LayerNorm()
        self.attn = MultiHeadAttention()
        self.norm_2 = LayerNorm()
        self.pff = nn.Sequential(
                    nn.Linear(d_model, d_ff),
                    nn.ReLU(),
                    nn.Linear(d_ff, d_model))
        self.norm_3 = LayerNorm()
        
    def forward(self, x, enc, src_mask=None, trg_mask=None):
        x_norm_1 = self.norm_1(x + self.attn(x, x, x, trg_mask))
        x_norm_2 = self.norm_2(x_norm_1 + self.attn(x_norm_1, enc, enc, src_mask))
        return self.norm_3(x_norm_2 + self.pff(x_norm_2)) # (b, seq, d_model)


In [8]:
class EncoderStack(nn.Module):
    def __init__(self, N):
        super(EncoderStack, self).__init__()
        self.N = N
        
    def forward(self, x, src_mask):
        cell = EncoderCell()
        encoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        
        for enc in encoders:
            x = enc(x, src_mask)
        return x
    

In [9]:
class DecoderStack(nn.Module):
    def __init__(self, N):
        super(DecoderStack, self).__init__()
        self.N = N
    
    def forward(self, x, enc, src_mask, trg_mask):
        cell = DecoderCell()
        decoders = nn.ModuleList([copy.deepcopy(cell) for _ in range(self.N)])
        for decdr in decoders:
            x = decdr(x, enc, src_mask, trg_mask)
        return x
    

In [10]:
class EmbeddingLayer(nn.Module):
    def __init__(self, vocab_size, d_model):
        super(EmbeddingLayer, self).__init__()
        self.d_model = d_model
        self.vocab_size = vocab_size
        self.embedding = nn.Embedding(self.vocab_size, self.d_model)
        
    def forward(self, x):
        return self.embedding(x) * math.sqrt(d_model)


In [11]:
# https://pytorch.org/tutorials/beginner/transformer_tutorial.html
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dpout=0.1, max_seq=50):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dpout)
        
        pe_matx = torch.zeros(max_seq, d_model)
        position = torch.arange(0, max_seq, dtype=torch.float).unsqueeze(-1)
        w_t = torch.exp(torch.arange(0, d_model, 2).float() * -math.log(10000)/d_model)
        val = position * w_t
        pe_matx[:, 0::2] = torch.sin(val)
        pe_matx[:, 1::2] = torch.cos(val)
        pe_matx = pe_matx.unsqueeze(1)
        self.register_buffer("pe_matx", pe_matx)
        
    def forward(self, x):
        # x - (batch, seq, emb), pe_matrix - (max_seq, 1, d_model)
        x += self.pe_matx[:x.size(0), :]
        return(self.dropout(x))
    

In [12]:
class Transformer(nn.Module):
    def __init__(self, embedd = True, log_softmx=True):
        super(Transformer, self).__init__()
#         self.W_in = nn.Linear(word_emb_dim, d_model)
        self.encoder = EncoderStack(N)
        self.decoder = DecoderStack(N)
        # https://stats.stackexchange.com/questions/392213/understand-the-output-layer-of-transformer
        self.W_out = nn.Linear(d_model, vocab_size)
        self.sftmx = log_softmx
        self.embedd = embedd
        if self.embedd:
            embed_x = EmbeddingLayer(vocab_size, d_model)
            embed_y = EmbeddingLayer(vocab_size, d_model)
            pe_x = PositionalEncoding(d_model)
            pe_y = copy.deepcopy(pe_x)
            self.enc_x = nn.Sequential(embed_x, pe_x)
            self.enc_y = nn.Sequential(embed_y, pe_y)
        
    def forward(self, inp_x, inp_y, src_mask, trg_mask):
        if self.embedd:
            inp_x, inp_y = self.enc_x(inp_x), self.enc_y(inp_y)
        enc_x = self.encoder(inp_x, src_mask)
        dec_x = self.decoder(inp_y, enc_x, src_mask, trg_mask)
        if self.sftmx:
            return nn.functional.log_softmax(self.W_out(dec_x), dim=-1)
        return self.W_out(dec_x)
        

In [13]:
# https://www.reddit.com/r/MachineLearning/comments/bjgpt2/d_confused_about_using_masking_in_transformer/

In [14]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html
class Batch:
    def __init__(self, src, trg=None, pad=0): # size src, trg (b, seq)
        self.src = src 
        self.src_mask = (src != pad).unsqueeze(-2)
        if trg is not None:
            self.trg = trg[:,:-1] # size (b,0:seq-1)
            self.trg_y = trg[:,1:] # size (b,1:seq)
            self.trg_mask = self.std_mask(self.trg, pad)
            self.ntokens = (self.trg_y != pad).data.sum() # size (1)
    
    @staticmethod
    def std_mask(tgt, pad):
        tgt_mask = (tgt != pad).unsqueeze(-2) # size (b, 1, seq)
        mask = torch.from_numpy(np.triu(np.ones((1,tgt.shape[-1],tgt.shape[-1])), k=1).astype('uint8')) == 0
        return tgt_mask & mask # size (b, 1, seq) * (1, seq, seq) -> (b, seq, seq) 
    
#     @staticmethod
#     def subsequent_mask(size):
#         return torch.from_numpy(np.triu(np.ones((1,size,size)), k=1).astype('uint8')) == 0 # size (1, seq, seq)


In [24]:
# https://github.com/pytorch/pytorch/issues/7455    
# kldivLoss = nn.KLDivLoss(size_average=False)

def labelSmoothingLoss(x, y, epsilon, padding_value=0, cls=1, d=-1):
    # concat x, y batch as index_fill_ don't support vector dim > 1
#     x = x.view(-1, x.size(-1))    
    x=x.contiguous().view(-1, x.size(-1))
    y=y.contiguous().view(-1)
    
    x_ = x.data.clone()
    x_.fill_(epsilon / (x_.size(-1) - cls))
    x_.scatter_(d, y.data.unsqueeze(-1), (1 - epsilon))
    x_[:, padding_value] = 0
    mask = torch.nonzero(y.data == padding_value)
    if mask.dim() > 0:
        x_.index_fill_(0, mask.squeeze(), 0.0)
    return torch.mean(torch.sum(-x_*x), dim=d) # x_ is true distribution and x is prediction
#     return kldivLoss(x, copy.deepcopy(x_))



In [16]:
model = Transformer()

In [17]:
# init all parameters as we used deepcopy to save computation tym
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
        

In [18]:
optimizer = torch.optim.Adam(model.parameters())


In [19]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def data_generation(V, batch, nbatches):
    for i in range(nbatches):
        data = torch.randint(1, V, size=(batch, 10))
        data[:, 0] = 1
        src = data.clone().detach()
        trg = data.clone().detach()
        yield Batch(src, trg, 0)


In [35]:
# https://nlp.seas.harvard.edu/2018/04/03/attention.html#synthetic-data
def run_epoch(data_itr, model, optimizer):
    start = time.time()
    total_token = 0
    total_loss = 0
    tokens = 0
    
    for i, batch in enumerate(data_itr):
        optimizer.zero_grad()
        
        outp = model(batch.src, batch.trg, batch.src_mask, batch.trg_mask)
        loss = labelSmoothingLoss(outp, batch.trg_y, batch.ntokens)
        print(outp, batch.trg_y)
        print(loss)
        
        loss.backward()
#         torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
        optimizer.step()
        
        total_loss+=loss
        total_token+=batch.ntokens
        tokens+=batch.ntokens
        
        if i%30 == 1:
            elapsed = time.time() - start
            print("Epoch Step: %d Loss: %f Tokens per Sec: %f" %
                    (i, loss / batch.ntokens, tokens / elapsed))
            start = time.time()
            tokens = 0
#     return total_loss/total_token


In [36]:
run_epoch(data_generation(vocab_size, 30, 20), model, optimizer)


tensor([[[-7.4237, -3.1665, -2.4552,  ..., -2.6077, -2.3816, -2.3433],
         [-6.7839, -2.2102, -2.8418,  ..., -2.3309, -3.0578, -2.5213],
         [-7.3292, -3.4833, -2.2683,  ..., -3.5495, -2.4031, -2.3797],
         ...,
         [-7.9818, -3.2699, -2.4587,  ..., -1.8401, -2.7156, -2.7715],
         [-6.5371, -3.1272, -2.6855,  ..., -2.9109, -2.2116, -2.9183],
         [-6.9132, -3.5621, -3.2908,  ..., -3.3836, -2.3981, -2.7933]],

        [[-7.1767, -3.2429, -1.5856,  ..., -2.7775, -2.2431, -1.8857],
         [-6.0338, -2.9686, -1.8795,  ..., -2.6771, -1.9329, -1.6746],
         [-6.2772, -3.1710, -1.4743,  ..., -2.0955, -2.8640, -2.0553],
         ...,
         [-6.0521, -2.5058, -2.6947,  ..., -2.5138, -2.8490, -1.9991],
         [-6.2784, -3.9501, -2.2114,  ..., -3.4146, -2.5325, -1.9739],
         [-6.1703, -2.6596, -2.9692,  ..., -2.5299, -2.2412, -1.7625]],

        [[-8.1208, -2.5282, -1.8604,  ..., -3.5268, -2.6839, -1.9251],
         [-6.7543, -3.3216, -3.1664,  ..., -2

tensor([[[-7.7950, -2.5703, -2.9710,  ..., -2.1655, -2.3503, -2.3699],
         [-7.4835, -3.0805, -0.9570,  ..., -2.8210, -1.9875, -4.0076],
         [-6.4127, -2.9519, -1.3419,  ..., -2.0622, -2.1407, -3.8500],
         ...,
         [-7.4673, -2.6674, -1.4339,  ..., -2.4560, -2.0066, -3.3499],
         [-7.6415, -2.7999, -2.3931,  ..., -2.0003, -1.9770, -3.2338],
         [-6.9788, -3.3380, -0.8468,  ..., -2.9426, -2.3153, -4.9298]],

        [[-8.2784, -2.3322, -1.5016,  ..., -1.7233, -2.5958, -3.1344],
         [-6.5740, -3.2061, -1.9602,  ..., -1.7172, -3.4880, -3.1112],
         [-5.8225, -3.5218, -0.9733,  ..., -2.8371, -2.1779, -2.9079],
         ...,
         [-6.2010, -3.0233, -1.2086,  ..., -1.9640, -2.9815, -4.0497],
         [-6.2395, -3.3574, -0.8612,  ..., -2.0572, -3.3959, -3.9298],
         [-7.9357, -2.8254, -1.8276,  ..., -2.1613, -1.5951, -4.3203]],

        [[-7.8032, -2.4187, -2.5121,  ..., -1.8451, -2.2766, -2.6166],
         [-5.8621, -2.7302, -1.9287,  ..., -3

tensor([[[-9.0640, -1.9357, -2.1684,  ..., -2.5488, -4.0397, -3.4942],
         [-9.9760, -2.7557, -2.7528,  ..., -3.1590, -4.3148, -3.2182],
         [-8.2828, -2.1472, -2.5048,  ..., -2.6651, -4.5133, -3.1216],
         ...,
         [-8.2096, -2.7888, -4.3449,  ..., -3.1330, -3.6702, -2.7656],
         [-8.2211, -3.0613, -1.8829,  ..., -2.8156, -4.4655, -2.0563],
         [-9.4004, -2.8567, -3.2573,  ..., -3.0598, -3.3100, -2.8578]],

        [[-8.5595, -1.6066, -2.0826,  ..., -2.9281, -3.1487, -3.9539],
         [-7.7073, -2.4613, -1.9801,  ..., -2.3729, -4.2752, -3.4054],
         [-7.7665, -2.7319, -3.0793,  ..., -2.4929, -2.5588, -3.5737],
         ...,
         [-9.4756, -1.8293, -3.1520,  ..., -2.3735, -3.5188, -3.5549],
         [-9.4577, -2.7827, -2.4907,  ..., -2.4375, -3.2882, -2.9545],
         [-8.7958, -2.1797, -1.5928,  ..., -3.1574, -4.4527, -2.8628]],

        [[-8.4212, -1.0725, -3.0453,  ..., -2.1283, -3.8493, -3.7909],
         [-8.6560, -1.8614, -2.9651,  ..., -1

tensor([[[ -7.6686,  -2.6692,  -2.2077,  ...,  -3.0155,  -1.4579,  -4.0962],
         [ -8.3326,  -1.9635,  -2.1802,  ...,  -2.8757,  -3.7400,  -4.5951],
         [ -7.2284,  -2.2436,  -2.0601,  ...,  -2.1644,  -2.2278,  -4.6842],
         ...,
         [ -7.8346,  -2.9236,  -1.7084,  ...,  -2.8776,  -2.1188,  -4.2384],
         [ -7.0040,  -2.8205,  -2.5351,  ...,  -3.1265,  -2.2237,  -4.0277],
         [ -6.3820,  -3.1449,  -2.2299,  ...,  -3.4557,  -2.7186,  -4.6734]],

        [[ -8.1564,  -3.0022,  -1.4018,  ...,  -3.0190,  -2.3795,  -4.5552],
         [ -7.9691,  -3.3696,  -1.9625,  ...,  -2.8768,  -2.4125,  -3.9694],
         [ -6.5091,  -3.2751,  -2.9469,  ...,  -2.5798,  -2.2792,  -4.0519],
         ...,
         [ -6.0019,  -3.1605,  -2.6372,  ...,  -3.9005,  -3.9120,  -4.6277],
         [ -7.7918,  -2.4856,  -2.5246,  ...,  -2.0425,  -3.0443,  -3.4174],
         [ -6.9321,  -2.4207,  -2.1239,  ...,  -2.3858,  -2.6688,  -3.7237]],

        [[ -8.5404,  -3.6106,  -3.1058,  ...

tensor([[[-8.9883, -2.5696, -1.4298,  ..., -3.8510, -2.1397, -2.0918],
         [-8.5362, -2.6937, -3.1370,  ..., -1.8854, -2.6877, -2.0839],
         [-8.9282, -2.1887, -2.2971,  ..., -3.8569, -0.9835, -2.2690],
         ...,
         [-7.2700, -1.8914, -3.2496,  ..., -2.5491, -2.3772, -3.2633],
         [-7.2238, -2.2511, -2.3577,  ..., -2.6042, -2.1268, -2.8295],
         [-8.6966, -2.1429, -2.5223,  ..., -3.4619, -1.3798, -1.9817]],

        [[-8.8524, -2.2490, -1.5240,  ..., -4.0613, -2.0026, -2.3724],
         [-8.9724, -2.6095, -1.3805,  ..., -3.7779, -2.7258, -1.8968],
         [-7.4153, -1.3709, -2.1509,  ..., -2.3789, -3.8681, -3.4361],
         ...,
         [-8.0366, -1.3491, -2.1945,  ..., -2.3130, -2.4094, -2.3112],
         [-6.8883, -1.0840, -2.8858,  ..., -2.1855, -3.3781, -3.0761],
         [-7.4506, -1.7626, -3.4523,  ..., -2.0498, -3.3335, -1.9650]],

        [[-7.8095, -2.7104, -1.4204,  ..., -3.8801, -2.1621, -2.0319],
         [-8.9938, -1.9252, -3.0991,  ..., -2

tensor([[[ -8.0309,  -2.0771,  -2.4560,  ...,  -2.1631,  -3.7331,  -3.2290],
         [ -6.8165,  -2.6509,  -1.8513,  ...,  -1.1613,  -3.6555,  -3.9478],
         [ -7.0170,  -2.4088,  -1.4648,  ...,  -2.1647,  -4.4594,  -3.7020],
         ...,
         [ -7.7841,  -2.5428,  -2.2696,  ...,  -1.9898,  -3.9101,  -2.4256],
         [ -7.4984,  -1.6807,  -2.5428,  ...,  -2.3319,  -3.5676,  -2.8616],
         [ -5.7282,  -2.6225,  -2.0321,  ...,  -1.3262,  -4.5095,  -3.6193]],

        [[ -7.3972,  -2.5764,  -1.9637,  ...,  -2.8853,  -3.1784,  -3.4155],
         [ -8.3359,  -2.9705,  -2.7824,  ...,  -1.9058,  -2.9187,  -3.5370],
         [ -5.9456,  -2.1443,  -1.9955,  ...,  -2.2690,  -3.4628,  -3.5615],
         ...,
         [ -6.1778,  -1.4154,  -2.1805,  ...,  -2.4596,  -3.2713,  -3.9701],
         [ -5.9148,  -1.7623,  -2.1031,  ...,  -1.6275,  -3.3959,  -3.8403],
         [ -6.7613,  -3.0064,  -1.7406,  ...,  -1.8410,  -3.7613,  -2.8214]],

        [[ -7.8922,  -1.3879,  -1.5823,  ...

tensor([[[-8.4550, -2.9835, -3.4971,  ..., -1.6640, -2.3634, -3.4419],
         [-6.8296, -2.8738, -3.0319,  ..., -1.4398, -1.8754, -3.4103],
         [-5.5514, -3.9574, -2.1337,  ..., -1.8429, -1.2302, -2.7686],
         ...,
         [-7.3497, -4.2878, -2.3761,  ..., -1.7713, -1.1571, -2.4241],
         [-6.9080, -2.4289, -3.0508,  ..., -1.8231, -1.5831, -3.7087],
         [-5.2932, -2.8746, -2.5184,  ..., -2.5834, -1.8484, -2.4638]],

        [[-7.4699, -1.8588, -2.8951,  ..., -1.9002, -3.7966, -3.6139],
         [-6.2276, -2.1028, -1.8719,  ..., -2.5294, -1.6711, -3.0962],
         [-6.9636, -1.9278, -2.9781,  ..., -2.5948, -1.4958, -3.8909],
         ...,
         [-8.0812, -2.9461, -2.1905,  ..., -1.3258, -1.9487, -3.0307],
         [-6.4973, -2.9118, -3.2967,  ..., -1.8283, -2.0779, -3.3045],
         [-7.6328, -2.1274, -2.9064,  ..., -1.1728, -2.6317, -3.0612]],

        [[-8.6435, -2.4947, -2.5999,  ..., -1.6422, -2.8852, -3.5315],
         [-7.0771, -2.1398, -2.1403,  ..., -1

tensor([[[ -8.5835,  -3.4841,  -3.0031,  ...,  -1.0436,  -3.4863,  -2.4134],
         [ -6.8819,  -3.1626,  -2.7444,  ...,  -1.5908,  -2.7158,  -1.9327],
         [ -7.3672,  -2.6947,  -2.6577,  ...,  -1.8993,  -2.7553,  -1.9024],
         ...,
         [ -7.7022,  -4.2047,  -2.7235,  ...,  -2.4368,  -3.7119,  -3.0456],
         [ -7.1616,  -3.1409,  -1.9512,  ...,  -1.3642,  -4.8656,  -2.6074],
         [ -8.2057,  -3.8415,  -2.5858,  ...,  -1.6895,  -3.1653,  -1.9008]],

        [[ -8.8131,  -3.7407,  -2.7678,  ...,  -2.0190,  -3.5917,  -2.6485],
         [ -9.6056,  -2.5318,  -2.5042,  ...,  -1.9846,  -3.1428,  -2.0980],
         [ -9.5626,  -3.5737,  -2.9097,  ...,  -1.9179,  -3.1275,  -3.0009],
         ...,
         [ -8.5078,  -4.0170,  -2.8542,  ...,  -3.0236,  -2.7475,  -2.1927],
         [ -9.2287,  -3.2094,  -2.1290,  ...,  -2.8714,  -2.8302,  -2.2562],
         [ -7.0902,  -3.6871,  -2.9169,  ...,  -3.1065,  -2.8225,  -1.8487]],

        [[ -8.8308,  -3.3480,  -2.4984,  ...

tensor([[[-12.2439,  -2.0432,  -4.0043,  ...,  -2.9572,  -4.1879,  -4.0169],
         [-10.3995,  -1.7207,  -2.9050,  ...,  -1.6106,  -3.4019,  -3.5648],
         [-11.0546,  -1.9456,  -2.9454,  ...,  -3.0831,  -4.1149,  -3.0242],
         ...,
         [-12.4393,  -1.7196,  -3.7912,  ...,  -3.2925,  -4.6297,  -4.0563],
         [-11.1165,  -1.7953,  -2.3561,  ...,  -1.6216,  -3.4290,  -2.7587],
         [-11.0054,  -1.1023,  -2.6034,  ...,  -1.8923,  -3.5524,  -3.9801]],

        [[-12.5950,  -2.6382,  -3.4111,  ...,  -3.5086,  -4.6873,  -3.3329],
         [ -9.9841,  -2.8181,  -2.7481,  ...,  -1.3355,  -3.3912,  -2.7073],
         [-11.2497,  -3.4175,  -2.8210,  ...,  -1.8986,  -3.4739,  -2.9435],
         ...,
         [-11.4252,  -2.4923,  -2.4568,  ...,  -2.7353,  -3.6650,  -3.5153],
         [-10.2964,  -2.0941,  -2.3796,  ...,  -1.7336,  -2.4531,  -3.0293],
         [-10.5630,  -1.5368,  -2.1035,  ...,  -1.5091,  -2.9948,  -3.8950]],

        [[-11.4207,  -1.6895,  -2.7983,  ...

tensor([[[ -9.1052,  -2.2428,  -2.5777,  ...,  -3.7386,  -1.8165,  -2.2349],
         [ -7.5599,  -2.6670,  -2.4536,  ...,  -1.6416,  -1.4798,  -2.7118],
         [ -7.4797,  -2.5364,  -1.3139,  ...,  -2.0850,  -2.6401,  -2.7074],
         ...,
         [ -6.8819,  -1.9469,  -2.2590,  ...,  -2.3285,  -2.3416,  -2.7335],
         [ -7.0576,  -3.2943,  -2.3534,  ...,  -2.3509,  -1.7811,  -1.8144],
         [ -7.7797,  -2.0951,  -3.4601,  ...,  -3.1771,  -1.8270,  -2.2424]],

        [[ -9.9063,  -2.4166,  -3.5781,  ...,  -2.3321,  -1.2041,  -3.4466],
         [ -7.5711,  -2.1749,  -3.0413,  ...,  -2.6019,  -1.3060,  -2.7487],
         [ -7.2986,  -3.2925,  -2.6669,  ...,  -1.8804,  -1.2586,  -2.4692],
         ...,
         [ -7.9600,  -1.7045,  -3.7339,  ...,  -1.5264,  -2.2428,  -2.4012],
         [ -7.9277,  -2.0072,  -3.3829,  ...,  -1.6722,  -1.5308,  -2.7701],
         [ -7.5210,  -3.5596,  -3.2201,  ...,  -2.0914,  -1.2236,  -2.0737]],

        [[ -7.9487,  -0.7318,  -2.8386,  ...

In [None]:
mah = MultiHeadAttention()

In [None]:
q = torch.rand(5, 10, d_model)
k = torch.rand(5, 10, d_model)
v = torch.rand(5, 10, d_model)

In [None]:
mah(q,k,v,None)