In [4]:
import torch
from torch import nn

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, dim_model, num_heads):
        super(SelfAttention, self).__init__()
        self.dim_model = dim_model
        self.num_heads = num_heads
        self.depth = dim_model // num_heads

        assert self.depth * num_heads == self.dim_model

        self.values = nn.Linear(self.add_modules, self.depth, bias=False)
        self.keys = nn.Linear(self.add_modules, self.depth, bias=False)
        self.queries = nn.Linear(self.add_modules, self.depth, bias=False)
        
        self.feed_forward = nn.Linear(self.num_heads * self.depth, dim_model)
    
    def forward(self, query, key, value, mask):
        seq_length = query.shape[0]
        query_length = query.shape[1]

        query = query.reshape(seq_length, query.shape[1], self.num_heads, self.depth)
        key = key.reshape(seq_length, key.shape[1], self.num_heads, self.depth)
        value = value.reshape(seq_length, value.shape[1], self.num_heads, self.depth)

        similarity = torch.einsum("nqhd,nkhd->nhqk", [query, key])
        #query shape : (seq_length, query.shape[1], heads, depth)
        #key shape : (seq_length, key.shape[1], heads, depth)
        
        if mask is not None:
            similarity = similarity.masked_fill(mask == 0, float(-1e28))
        
        attention = torch.softmax(similarity/ (self.dim_model)**0.5, dim=3)

        out = torch.einsum("nhqk,nkhd->nqhd", [attention, value]).reshape(
            seq_length, query_length, self.num_heads, self.dim_model
        )
        #attention shape : (seq_length, heads, query_length, key_length)
        #values shape : (seq_length, values, heads, depth)
        #(seq_length, query_length, heads, depth)

        return self.feed_forward(out)


In [6]:
class TransformerBlock(nn.Module):
    def __init__(self, dim_model, num_heads, dropout, forward_expansion):
        super(TransformerBlock, self).__init__()
        self.attention = SelfAttention(dim_model, num_heads)
        self.normalization1 = nn.LayerNorm(dim_model)
        self.normalization2 = nn.LayerNorm(dim_model)
        self.feed_forward = nn.Sequential(
            nn.Linear(dim_model, forward_expansion*dim_model),
            nn.ReLU(),
            nn.Linear(forward_expansion*dim_model, dim_model),
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, query, key, value, mask):
        attention = self.attention(query, key, value, mask)
        
        pre_feedforward = self.dropout(self.normalization1(attention + query))
        out = self.feed_forward(pre_feedforward)
        return self.normalization2(out + pre_feedforward)

In [7]:
class Encoder(nn.Module):
    def __init__(self,
                 src_vocab_size,
                 dim_model,
                 units,
                 heads,
                 forward_expansion, 
                 dropout,
                 max_length,
                 device):
        super(Encoder, self).__init__()
        self.dim_model = dim_model
        self.device = device
        self.embedding = nn.Embedding(src_vocab_size, dim_model)
        self.pos_embedding = nn.Embedding(max_length, dim_model)

        self.layers = nn.ModuleList(
            [
                TransformerBlock(dim_model, heads, dropout, forward_expansion)
                for _ in range(units)
            ]
        )
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        out = self.dropout(self.embedding(x) + self.pos_embedding(positions))

        for layer in self.layers:
            out = layer(out, out, out, mask)
        
        return out

In [8]:
class DecoderBlock(nn.Module):
    def __init__(self, dim_model, heads, forward_expansion, dropout, device):
        super(DecoderBlock,self).__init__()
        self.attention = SelfAttention(dim_model, heads)
        self.normalize = nn.LayerNorm(dim_model)
        self.transform_block = TransformerBlock(
            dim_model, heads, dropout, forward_expansion
        )
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, key, value, src_mask, tar_mask):
        attn = self.attention(x, x, x, tar_mask)
        attn = self.dropout(self.normalize(attn + x))

        out = self.transform_block(attn, key, value, src_mask)
        return out

class Decoder(nn.Module):
    def __init__(self,
                 dim_model,
                 heads, 
                 tar_vocab_length, 
                 units, 
                 forward_expansion, 
                 device, 
                 dropout,
                 max_length):
        super(Decoder, self).__init__()
        self.embedding = nn.Embedding(tar_vocab_length, dim_model)
        self.pos_embedding = nn.Embedding(max_length, dim_model)
        self.device = device
        self.layers = nn.ModuleList(
            [
                DecoderBlock(dim_model,
                             heads,
                             forward_expansion,
                             dropout,
                             device)
                for _ in range(units)
            ]
        )

        self.fc_out = nn.Linear(dim_model, tar_vocab_length)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x, enc_out, src_mask, tar_mask):
        N, seq_length = x.shape
        positions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)
        x = self.dropout(self.embedding(x) + self.pos_embedding(positions))
        for layer in self.layers:
            x = layer(x, enc_out, enc_out, src_mask, tar_mask)
        
        return self.fc_out(x)

In [9]:
class Transformer(nn.Module):
    def __init__(self, 
                 src_vocab_size, 
                 tar_vocab_size, 
                 src_pad_idx, 
                 tar_pad_idx, 
                 dim_model=256, 
                 units=6,
                 forward_expansion=4,
                 heads=8,
                 dropout=0,
                 device='cuda',
                 max_length=100):
        super(Transformer, self).__init__()
        self.encoder = Encoder(src_vocab_size,
                               dim_model,
                               units,
                               heads,
                               forward_expansion,
                               dropout,
                               max_length,
                               device)
        self.decoder = Decoder(dim_model,
                               heads,
                               tar_vocab_size,
                                units,
                               forward_expansion,
                               device,
                               dropout,
                               max_length)
        
        self.device = device
        self.src_pad_idx = src_pad_idx
        self.tar_pad_idx = tar_pad_idx
    
    def create_src_mask(self, src):
        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)
        #(N, 1, 1, src_len)
        return src_mask.to(self.device)
    def create_tar_mask(self,tar):
        N, tar_len = tar.shape
        tar_mask = torch.tril(torch.ones((tar_len, tar.len))).expand(
            N, 1, tar_len, tar_len
        )
        return tar_mask.to(self.device)
    
    def forward(self, src, tar):
        src_mask = self.create_src_mask(src)
        tar_mask = self.create_tar_mask(tar)

        enc_out = self.encoder(src, src_mask)
        out = self.decoder(tar, enc_out, src_mask, tar_mask)
        return out
