In [47]:
import torch
import math

In [64]:
class Encoder(torch.nn.Module):
    def __init__(self, d, att_dim):
        super().__init__()
        self.d = d
        self.W_q = torch.nn.Linear(d, att_dim)
        self.W_k = torch.nn.Linear(d, att_dim)
        self.W_v = torch.nn.Linear(d, att_dim)
        self.W_tr = torch.nn.Linear(att_dim, d)
        self.feed1 = torch.nn.Linear(d, d*4)
        self.feed2 = torch.nn.Linear(d*4, d)
        self.norm1 = torch.nn.LayerNorm(d)
        self.norm2 = torch.nn.LayerNorm(d)

    def multi_head_att(self, x):
        Q_q = self.W_q(x)
        Q_k = self.W_k(x)
        Q_v = self.W_v(x)
        context = torch.nn.functional.softmax((Q_q @ torch.transpose(Q_k, -2, -1))/math.sqrt(self.d), dim=-1)
        context = context @ Q_v
        Z = self.W_tr(context)
        return Z
    
    def feed_forw(self, x):
        x = self.feed1(x)
        y = torch.nn.functional.relu(x)
        y = self.feed2(y)
        return y

    def forward(self, x):
        y = self.multi_head_att(x)
        y = self.norm1(y + x)
        out = self.feed_forw(y)
        out = self.norm2(out + y)
        return out


In [73]:
class Decoder(torch.nn.Module):
    def __init__(self, d, att_dim):
        super().__init__()
        self.d = d
        self.W_q_self = torch.nn.Linear(d, att_dim)
        self.W_k_self = torch.nn.Linear(d, att_dim)
        self.W_v_self = torch.nn.Linear(d, att_dim)
        self.W_tr_self = torch.nn.Linear(att_dim, d)
        self.W_q_cross = torch.nn.Linear(d, att_dim)
        self.W_k_cross = torch.nn.Linear(d, att_dim)
        self.W_v_cross = torch.nn.Linear(d, att_dim)
        self.W_tr_cross = torch.nn.Linear(att_dim, d)
        self.feed1 = torch.nn.Linear(d, d*4)
        self.feed2 = torch.nn.Linear(d*4, d)
        self.norm1 = torch.nn.LayerNorm(d)
        self.norm2 = torch.nn.LayerNorm(d)
        self.norm3 = torch.nn.LayerNorm(d)
    
    def mask(self, row, col):
        mask = torch.ones((row, col))
        for i in range(mask.shape[-2]):
            for j in range(mask.shape[-1]):
                if i < j:
                    mask[i, j] = 0
        return mask

    def masked_self_att(self, x):
        Q_q = self.W_q_self(x)
        Q_k = self.W_k_self(x)
        Q_v = self.W_v_self(x)
        context = (Q_q @ torch.transpose(Q_k, -2, -1))/math.sqrt(self.d)
        masked_context = context.masked_fill(self.mask(context.shape[-2], context.shape[-1]) == 0, -math.inf)
        masked_context = torch.nn.functional.softmax(masked_context, dim=-1) @ Q_v
        Z = self.W_tr_self(masked_context)
        return Z
    
    def cross_multi_att(self, x, enc_y):
        Q_q = self.W_q_cross(x)
        Q_k = self.W_k_cross(enc_y)
        Q_v = self.W_v_cross(enc_y)
        context = torch.nn.functional.softmax((Q_q @ torch.transpose(Q_k, -2, -1))/math.sqrt(self.d), dim=-1)
        context = context @ Q_v
        Z = self.W_tr_cross(context)
        return Z

    def feed_forw(self, x):
        x = self.feed1(x)
        y = torch.nn.functional.relu(x)
        y = self.feed2(y)
        return y

    def forward(self, x, enc_y):   
        y = self.masked_self_att(x)
        y = self.norm1(x + y)
        z = self.cross_multi_att(y, enc_y)
        z = self.norm2(z + y)
        z_out = self.feed_forw(z)
        z_out = self.norm3(z_out + z)
        return z_out

In [76]:
class Transformer(torch.nn.Module):
    def __init__(self, enc, dec, emb_dim, att_dim, voc_size, enc_num: int, dec_num: int):
        '''
        enc: encoder class
        dec: decoder class
        emb_dim: dimension of embeddings
        att_dim: just parameter of model
        voc_size: number of words in target language
        enc_num: number of encoders in encoder
        dec_num: number of decoders in decoder
        '''
        super().__init__()
        self.encoders = [enc(emb_dim, att_dim) for _ in range(enc_num)]
        self.encoder = torch.nn.Sequential(*self.encoders)
        self.decoders = [dec(emb_dim, att_dim) for _ in range(dec_num)]
        self.lin = torch.nn.Linear(emb_dim, voc_size)

    def forward(self, x, x_tgt):
        enc_out = self.encoder(x)
        out = x_tgt
        for decoder in self.decoders:
            out = decoder(out, enc_out)
        out = self.lin(out)
        return torch.nn.functional.softmax(out, dim=-1)

In [83]:
russian = torch.randn((1, 5, 10))
english = torch.randn((1, 5, 10))
tfr = Transformer(Encoder, Decoder, 10, 30, 100, 3, 3)
tfr(russian, english).shape

torch.Size([1, 5, 100])