In [512]:
import torch
import torch.nn as nn
import math
import copy

In [513]:
class LayerNorm(nn.Module):
    def __init__(self, emb_dim, eps=1e-7):
        super().__init__()
        self.alfa = nn.Parameter(torch.ones(emb_dim))
        self.beta = nn.Parameter(torch.zeros(emb_dim))
        self.eps = eps
    def forward(self, input):
        mean = input.mean(dim=-1, keepdim=True)
        std = input.std(dim=-1, keepdim=True)
        return self.alfa * (input - mean) / (std + self.eps) + self.beta

In [514]:
z = torch.rand((2, 3, 6))


test_ln = LayerNorm(6)
torch_ln = nn.LayerNorm(6)

test_ln(z), torch_ln(z)

(tensor([[[ 1.0167, -0.6778,  0.5319, -0.9807, -1.0072,  1.1170],
          [-1.6710, -0.1143, -0.1853,  0.0722,  1.3696,  0.5288],
          [-1.3628,  0.9835,  0.3506, -0.9035,  1.0992, -0.1670]],
 
         [[-0.0858, -0.5054,  0.7010,  1.5698, -1.2704, -0.4093],
          [-1.0077,  1.0716, -0.0588, -1.1348,  1.2382, -0.1085],
          [ 0.2882, -0.9200, -0.7648, -0.6624,  1.7103,  0.3487]]],
        grad_fn=<AddBackward0>),
 tensor([[[ 1.1137, -0.7424,  0.5826, -1.0742, -1.1033,  1.2236],
          [-1.8303, -0.1252, -0.2030,  0.0791,  1.5002,  0.5792],
          [-1.4927,  1.0772,  0.3840, -0.9896,  1.2039, -0.1829]],
 
         [[-0.0940, -0.5536,  0.7678,  1.7194, -1.3914, -0.4483],
          [-1.1039,  1.1738, -0.0644, -1.2431,  1.3563, -0.1188],
          [ 0.3157, -1.0077, -0.8378, -0.7256,  1.8734,  0.3820]]],
        grad_fn=<NativeLayerNormBackward0>))

In [515]:
def attention(q, k, v, mask=None):
    hid_dim = q.size(-1)
    scores = (q @ k.transpose(-2, -1)) / math.sqrt(hid_dim)
    if mask is not None:
        scores = scores.masked_fill(mask == 0, float('-inf'))
    return torch.matmul(scores.softmax(dim=-1), v)

class MultiHeadAttention(nn.Module):
    def __init__(self, model_dim, heads, dropout=0):
        super().__init__()
        """
        The code string bellow is equivalent to this:
        self.w_q = nn.Linear(model_dim, model_dim)
        self.w_k = nn.Linear(model_dim, model_dim)
        self.w_v = nn.Linear(model_dim, model_dim)
        self.w_out = nn.Linear(model_dim, model_dim)
        """
        self.linears = nn.ModuleList([nn.Linear(model_dim, model_dim) for _ in range(4)])

        assert model_dim % heads == 0, ('In this simple realisation of transformer '
                                        'you should observe equation: model_dim % heads == 0')
        self.hid_dim = model_dim // heads
        self.heads = heads
        self.model_dim = model_dim
        self.dropout = nn.Dropout(p=dropout)
        

    def forward(self, query, key, value, mask=None):
        bs = query.size(0)
        q, k, v = [lin(input).reshape(bs, -1, self.heads, self.hid_dim).transpose(1,2) \
                   for lin, input in zip(self.linears[:3], (query, key, value))]
        att = attention(q, k, v, mask)
        att = att.transpose(1, 2).reshape(bs, -1, self.hid_dim * self.heads)

        return self.dropout(self.linears[-1](att))

In [516]:
test_mha = MultiHeadAttention(512, 8)
torch_mha = nn.MultiheadAttention(512, 8, batch_first=True)
x = torch.rand((2, 5, 512))

In [517]:
for k, v in torch_mha.state_dict().items():
    print(f'{k:20} {tuple(v.shape)}')

in_proj_weight       (1536, 512)
in_proj_bias         (1536,)
out_proj.weight      (512, 512)
out_proj.bias        (512,)


In [518]:
q, k, v = torch_mha.in_proj_weight[:512], torch_mha.in_proj_weight[512:512+512], torch_mha.in_proj_weight[512+512:]
q_b, k_b, v_b = torch_mha.in_proj_bias[:512], torch_mha.in_proj_bias[512:512+512], torch_mha.in_proj_bias[512+512:]
out, out_b = torch_mha.out_proj.weight, torch_mha.out_proj.bias

for param, weight in zip(test_mha.linears, (q, k, v, out)):
    param.weight.data = weight

for param, bias in zip(test_mha.linears, (q_b, k_b, v_b, out_b)):
    param.bias.data = bias

In [519]:
a = torch_mha(x, x, x, need_weights=False)[0].data
b = test_mha(x, x, x).data
torch.allclose(a, b, rtol=1e-04, atol=1e-07)

True

In [520]:
class FeedForward(nn.Module):
    def __init__(self, model_dim, hid_dim=None, dropout=0):
        super().__init__()
        hid_dim = model_dim if hid_dim == None else hid_dim

        self.ff = nn.Sequential(
            nn.Linear(model_dim, hid_dim),
            nn.ReLU(),
            nn.Linear(hid_dim, model_dim),
            nn.Dropout(p=dropout)
        )
        
    def forward(self, x):
        return self.ff(x)


In [521]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout, max_len=5000):
        super().__init__()
        assert d_model % 2 == 0, 'd_model must be even'

        self.dropout = nn.Dropout(p=dropout)
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer("pe", pe)

    def forward(self, x):
        x = x + self.pe[:, : x.size(1)].requires_grad_(False)
        return self.dropout(x)

In [522]:
class Embeddings(nn.Module):
    def __init__(self, vocab_size, model_dim, padding_idx=0):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, model_dim, padding_idx)
        self.model_dim = model_dim

    def forward(self, x):
        return self.emb(x) * math.sqrt(self.model_dim)

class Encoder(nn.Module):
    def __init__(self, model_dim, heads, ff_hid_dim=None, dropout=0):
        super().__init__()
        self.MHA = MultiHeadAttention(model_dim, heads, dropout)
        self.LN_1 = LayerNorm(model_dim)
        self.FF = FeedForward(model_dim, ff_hid_dim, dropout)
        self.LN_2 = LayerNorm(model_dim)
        
    def forward(self, x, mask=None):
        out_1 = x + self.MHA(x, x, x)
        out_1 = self.LN_1(out_1)
        
        out_2 = out_1 + self.FF(out_1)
        out_2 = self.LN_2(out_2)
        return out_2

class Decoder(nn.Module):
    def __init__(self, model_dim, heads, ff_hid_dim=None, dropout=0):
        super().__init__()
        self.MHA = MultiHeadAttention(model_dim, heads, dropout)
        self.LN_1 = LayerNorm(model_dim)
        self.EDA = MultiHeadAttention(model_dim, heads, dropout)
        self.LN_2 = LayerNorm(model_dim)
        self.FF = FeedForward(model_dim, ff_hid_dim, dropout)
        self.LN_3 = LayerNorm(model_dim)
    
        
    def forward(self, x, enc_out, mask=None):
        out_1 = x + self.MHA(x, x, x)
        out_1 = self.LN_1(input=out_1)

        out_2 = out_1 + self.EDA(out_1, enc_out, enc_out)
        out_2 = self.LN_2(input=out_2)

        out_3 = out_2 + self.FF(out_2)
        out_3 = self.LN_3(input=out_3)
        return out_3

class EncoderDecoder(nn.Module):
    def __init__(self, num_enc, model_dim, heads, src_vocab_size, tg_vocab_size, 
                 max_len=5000, ff_hid_dim=None, dropout=0, num_dec=None):
        super().__init__()

        self.enc_emb = nn.Sequential(
            Embeddings(src_vocab_size, model_dim, padding_idx=0),
            PositionalEncoding(model_dim, dropout, max_len)
        )
        self.dec_emb = nn.Sequential(
            Embeddings(tg_vocab_size, model_dim, padding_idx=0),
            PositionalEncoding(model_dim, dropout, max_len)
        )
        num_dec = num_enc if num_dec == None else num_dec
        self.encoders = nn.ModuleList(
            [Encoder(model_dim, heads, ff_hid_dim, dropout) for _ in range(num_enc)]
            )
        self.decoders = nn.ModuleList(
            [Decoder(model_dim, heads, ff_hid_dim, dropout) for _ in range(num_dec)]
            )

    def forward(self, enc_input, dec_input, src_mask=None, tg_mask=None):
        memory = self.encode(enc_input, src_mask=None)
        out = self.decode(dec_input, memory, tg_mask)
        return out

    def encode(self, input, src_mask=None):
        x_enc = self.enc_emb(input)
        for enc in self.encoders:
            x_enc = enc(x_enc, src_mask)
        return x_enc

    def decode(self, input, memory, tg_mask=None):
        x_dec = self.dec_emb(input)
        for dec in self.decoders:
            x_dec = dec(x_dec, memory, tg_mask)
        return x_dec


In [523]:
def subsequent_mask(size):
    "Mask out subsequent positions."
    attn_shape = (1, size, size)
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1).type(
        torch.uint8
    )
    return subsequent_mask == 0

subsequent_mask(6)

tensor([[[ True, False, False, False, False, False],
         [ True,  True, False, False, False, False],
         [ True,  True,  True, False, False, False],
         [ True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True]]])

In [524]:
enc = Encoder(512, 8, dropout=0.1)
x = torch.rand((3, 4, 512))
enc(x).shape

torch.Size([3, 4, 512])

In [538]:
abv = EncoderDecoder(num_enc=3, model_dim=512, heads=8, dropout=0.2, src_vocab_size=20, tg_vocab_size=22)
x = torch.randint(0, 20, (1, 6))
y = torch.randint(0, 22, (1, 2))
abv(x, y, tg_mask=subsequent_mask(2)).shape

torch.Size([1, 2, 512])