In [12]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

In [None]:
class TokenEmbedding(nn.Embedding):
    def __init__(self, vocab_size, d_model):
        super().__init__(self, vocab_size, d_model, padding_idx=1)
        

In [None]:
class PositionEmbedding(nn.Module):
    def __init__(self, d_model, max_len, device):
        super().__init__()
        self.encoding = torch.zeros(max_len, d_model, device)
        self.encoding.requires_grad_(False)
        
        pos = torch.arange(0, max_len, device)
        pos = pos.float().unsqueeze(1)
        _2i = torch.arange(0, d_model, device)
        
        self.encoding[:, 0::2] = torch.sin(pos / (1000 ** (_2i / d_model)))
        self.encoding[:, 1::2] = torch.cos(pos / (1000 ** (_2i / d_model)))
        
    def forward(self, x):
        seq_len = x.shape(0)
        return self.encoding[:seq_len, :]


In [None]:
class LayerNorm(nn.Module):
    def __init__(self, d_model, eps=1e-10):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(d_model))
        self.beta = nn.Parameter(torch.zeros(d_model))
        self.eps = eps
        
    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        var = x.std(-1, unbiased=False, keepdim=True)
        out = (x - mean) / torch.sqrt(var + self.eps)
        return out * self.gamma +self.beta

In [17]:
class multi_head_attention (nn.Module):
    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads

        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)
        self.w_combine = nn.Linear(d_model, d_model)
        self.softmax = nn.Softmax(dim=-1)

    def forward(self, q, k, v, mask=None):
        batch, time, dimension = q.shape
        n_d = self.d_model // self.num_heads
        
        # perform linear operation and split into N heads
        q, k, v = self.q_linear(q), self.k_linear(k), self.v_linear(v)

        q = q.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        k = k.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        v = v.view(batch, time, self.num_heads, n_d).permute(0, 2, 1, 3)
        
        # score = torch.matmul(q, k.permute(0, 1, 3, 2)) / (self.d_model ** 0.5)
        score = q @ k.transpose(2, 3) / math.sqrt(n_d)
        mask = torch.tril(torch.ones(time, time, dtype=bool))
        score = score.masked_fill(mask == 0 , float('-inf'))
        score = self.softmax(score) @ v
        
        score = score.permute(0, 2, 1, 3).contiguous().view(batch, time, dimension)
        
        output = self.w_combine(score)
        return output
        

In [18]:
d_model = 512
n_heads = 8
X = torch.rand(128, 64, d_model)
attention = multi_head_attention(d_model, n_heads)
out = attention(X, X, X)
print(out, out.shape)


tensor([[[-0.2621, -0.0387,  0.0226,  ..., -0.0882,  0.1103,  0.2545],
         [-0.1952, -0.1403,  0.0070,  ..., -0.0733,  0.1259,  0.3139],
         [-0.2382, -0.0973,  0.0177,  ..., -0.0724,  0.0650,  0.3530],
         ...,
         [-0.1946, -0.0888,  0.0662,  ..., -0.0739,  0.0375,  0.3863],
         [-0.1958, -0.0904,  0.0670,  ..., -0.0757,  0.0388,  0.3863],
         [-0.1953, -0.0912,  0.0685,  ..., -0.0747,  0.0381,  0.3840]],

        [[-0.1965, -0.1198,  0.0130,  ..., -0.1202,  0.0647,  0.5903],
         [-0.2062, -0.1379,  0.1116,  ..., -0.1313,  0.0786,  0.4496],
         [-0.2256, -0.1305,  0.0948,  ..., -0.1183,  0.1172,  0.4691],
         ...,
         [-0.2050, -0.0943,  0.0607,  ..., -0.0677,  0.0321,  0.4063],
         [-0.2061, -0.0949,  0.0614,  ..., -0.0655,  0.0304,  0.4094],
         [-0.2026, -0.0953,  0.0622,  ..., -0.0655,  0.0307,  0.4061]],

        [[-0.3103, -0.0166,  0.1777,  ..., -0.1316, -0.0570,  0.4131],
         [-0.3553, -0.1048,  0.0973,  ..., -0