# Study: Transformer

## Reference
* https://cpm0722.github.io/pytorch-implementation/transformer

In [45]:
import torch
from torch import nn
import copy
from torch.nn import functional as F

## Transformer

In [46]:
# Sentence: src
# Shifted sentence: tgt (target)

class Transformer(nn.Module):

    def __init__(self, src_embed, tgt_embed, encoder, decoder, generator):
        super().__init__()
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.encoder = encoder
        self.decoder = decoder
        self.generator = generator


    def encode(self, src, src_mask):
        return self.encoder(self.src_embed(src), src_mask)


    def decode(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        return self.decoder(self.tgt_embed(tgt), encoder_out, tgt_mask, src_tgt_mask)


    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        src_tgt_mask = self.make_src_tgt_mask(src, tgt)
        encoder_out = self.encode(src, src_mask)
        decoder_out = self.decode(tgt, encoder_out, tgt_mask, src_tgt_mask)
        out = self.generator(decoder_out)
        out = F.log_softmax(out, dim=-1)
        return out, decoder_out
    
    @staticmethod
    def make_pad_mask(query, key, pad_value=1):
        # query: (n_batch, query_seq_len)
        # key: (n_batch, key_seq_len)
        # pad_value is special value of padding token
        #   if we want to mask padding token, pad_value should be same as padding token

        query_seq_len, key_seq_len = query.size(1), key.size(1)

        key_mask = key.ne(pad_value).unsqueeze(1).unsqueeze(2)  # (n_batch, 1, 1, key_seq_len)
        key_mask = key_mask.repeat(1, 1, query_seq_len, 1)    # (n_batch, 1, query_seq_len, key_seq_len)

        query_mask = query.ne(pad_value).unsqueeze(1).unsqueeze(3)  # (n_batch, 1, query_seq_len, 1)
        query_mask = query_mask.repeat(1, 1, 1, key_seq_len)  # (n_batch, 1, query_seq_len, key_seq_len)

        mask = key_mask & query_mask
        mask.requires_grad = False
        return mask
    
    @staticmethod
    def make_src_mask(src, pad_value=1):
        src_mask = Transformer.make_pad_mask(src, src, pad_value)
        return src_mask
    
    @staticmethod
    def make_subsequent_mask(tgt):
        # tgt: (n_batch, tgt_seq_len)
        tgt_seq_len = tgt.size(1)
        tgt_mask = torch.tril(torch.ones((tgt_seq_len, tgt_seq_len), device=tgt.device)).bool()
        tgt_mask.requires_grad = False
        return tgt_mask
    
    @staticmethod
    def make_tgt_mask(tgt, pad_value=1):
        pad_mask = Transformer.make_pad_mask(tgt, tgt, pad_value)
        seq_mask = Transformer.make_subsequent_mask(tgt)
        return pad_mask & seq_mask
    
    @staticmethod
    def make_src_tgt_mask(src, tgt, pad_value=1):
        pad_mask = Transformer.make_pad_mask(tgt, src, pad_value)
        return pad_mask

In [47]:
class Encoder(nn.Module):

    def __init__(self, encoder_block, n_layer, norm):  # n_layer: Encoder Block의 개수
        super().__init__()
        self.layers = [copy.deepcopy(encoder_block) for _ in range(n_layer)]
        self.norm = norm

    def forward(self, src, mask):
        x = src
        for layer in self.layers:
            x = layer(x, mask)
            
        x = self.norm(x)
        return x

# class EncoderBlock(nn.Module):

#     def __init__(self, self_attention, position_ff):
#         super().__init__()
#         self.self_attention = self_attention 
#         self.position_ff = position_ff


#     def forward(self, src, src_mask):
#         query = src
#         key = src
#         value = src
#         out = self.self_attention(query, key, value, src_mask)
#         context = self.position_ff(out)
#         return context

In [48]:
class MultiHeadAttentionLayer(nn.Module):

    def __init__(self, d_model, h, qkv_fc, out_fc, dropout):
        super().__init__()
        self.d_model = d_model
        self.h = h                        # number of heads for multi-head attention
        self.q_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.k_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.v_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.out_fc = out_fc              # (d_model, d_embed)
        self.dropout = dropout

    def forward(self, *args, query, key, value, mask=None):
        # query, key, value: (n_batch, seq_len, d_embed)
        # mask: (n_batch, seq_len, seq_len)
        # return value: (n_batch, h, seq_len, d_k)
        n_batch = query.size(0)

        def transform(x, fc):                                         # (n_batch, seq_len, d_embed) input
            out = fc(x)                                               # (n_batch, seq_len, d_model) FC
            out = out.view(n_batch, -1, self.h, self.d_model//self.h) # (n_batch, seq_len, h, d_k) reshape
            out = out.transpose(1, 2)                                 # (n_batch, h, seq_len, d_k) transpose
            return out

        query = transform(query, self.q_fc) # (n_batch, h, seq_len, d_k)
        key = transform(key, self.k_fc)     # (n_batch, h, seq_len, d_k)
        value = transform(value, self.v_fc) # (n_batch, h, seq_len, d_k)

        out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len, d_k) attention
        out = out.transpose(1, 2)                               # (n_batch, seq_len, h, d_k) transpose
        out = out.contiguous().view(n_batch, -1, self.d_model)  # (n_batch, seq_len, d_model) reshape
        out = self.out_fc(out)                                  # (n_batch, seq_len, d_embed) FC
        return out
    
    def calculate_attention(self, query, key, value, mask):
        # query, key, value: (n_batch, seq_len, d_k)                    -> multi-head attention (n_batch, h, seq_len, d_k)
        # mask: (n_batch, seq_len, seq_len)                             -> multi-head attention (n_batch, 1, seq_len, seq_len)
        
        # step1. Q x K^T = Attention Score
        # (n_batch, seq_len, seq_len)                                   -> multi-head attention (n_batch, h, seq_len, seq_len)
        attention_score = torch.matmul(query, key.transpose(-2, -1)) 

        # step2. Normalization with sqrt(d_k) which is the dimension of key
        d_k = key.shape[-1]
        attention_score = attention_score / torch.sqrt(d_k) 

        # step3. Masking
        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)

        # step4. Softmax
        # (n_batch, seq_len, seq_len)                                   -> multi-head attention (n_batch, h, seq_len, seq_len)
        attention_prob = F.softmax(attention_score, dim=-1) 
        attention_prob = self.dropout(attention_prob)

        # step5. Attention Score x V = Attention Value
        # (n_batch, seq_len, d_k)                                       -> multi-head attention (n_batch, h, seq_len, d_k)
        out = torch.matmul(attention_prob, value) 
        return out

In [49]:
class PositionWiseFeedForwardLayer(nn.Module):

    def __init__(self, fc1, fc2, dropout):
        super().__init__()
        self.fc1 = fc1   # (d_embed, d_ff)
        self.relu = nn.ReLU()
        self.fc2 = fc2 # (d_ff, d_embed)

    def forward(self, x):
        out = x
        out = self.fc1(out)
        out = self.relu(out)
        out = self.dropout(out)
        out = self.fc2(out)
        return out

In [50]:
class ResidualConnectionLayer(nn.Module):

    def __init__(self, norm, dropout):
        super().__init__()
        self.norm = norm
        self.dropout = dropout


    def forward(self, x, sub_layer):
        out = x
        
        out = self.norm(out)
        out = sub_layer(out)
        out = self.dropout(out)

        out = out + x
        return out

In [51]:
class EncoderBlock(nn.Module):

    def __init__(self, self_attention, position_ff, norm, dropout):
        super().__init__()
        self.self_attention = self_attention 
        self.position_ff = position_ff
        self.residual_connection = [ResidualConnectionLayer(copy.deepcopy(norm), copy.deepcopy(dropout)) for _ in range(2)]


    def forward(self, src, src_mask):
        out = src
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask))
        context = self.residuals[1](out, self.position_ff)
        return context

In [52]:
class Decoder(nn.Module):

    def __init__(self, decoder_block, n_layer, norm):
        super().__init__()
        self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(n_layer)])
        self.norm = norm

    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        for layer in self.layers:
            out = layer(out, encoder_out, tgt_mask, src_tgt_mask)
        
        out = self.norm(out)
        return out

class DecoderBlock(nn.Module):

    def __init__(self, self_attention, cross_attention, position_ff, norm, dropout):
        super().__init__()
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.position_ff = position_ff
        self.residuals = [ResidualConnectionLayer(copy.deepcopy(norm), copy.deepcopy(dropout)) for _ in range(3)]


    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask))
        out = self.residuals[1](out, lambda out: self.cross_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask))
        out = self.residuals[2](out, self.position_ff)
        return out

In [53]:
class TransformerEmbedding(nn.Module):

    def __init__(self, token_embed, pos_embed, dropout):
        super().__init__()
        self.embedding = nn.Sequential(token_embed, pos_embed)
        self.dropout = dropout


    def forward(self, x):
        out = self.embedding(x)
        out = self.dropout(out)
        return out
    
class TokenEmbedding(nn.Module):

    def __init__(self, d_embed, vocab_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, d_embed)
        self.d_embed = d_embed


    def forward(self, x):
        out = self.embedding(x) * torch.sqrt(self.d_embed)
        return out
    
class PositionalEncoding(nn.Module):

    def __init__(self, d_embed, max_len=256, device=torch.device("cpu")):
        super(PositionalEncoding, self).__init__()
        encoding = torch.zeros(max_len, d_embed)
        encoding.requires_grad = False
        position = torch.arange(0, max_len).float().unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_embed, 2) * -(torch.tensor(10000.0, device=device) / d_embed))
        encoding[:, 0::2] = torch.sin(position * div_term)
        encoding[:, 1::2] = torch.cos(position * div_term)
        self.encoding = encoding.unsqueeze(0).to(device)


    def forward(self, x):
        _, seq_len, _ = x.size()
        pos_embed = self.encoding[:, :seq_len, :]
        out = x + pos_embed
        return out

In [56]:
def build_model(
    src_vocab_size,
    tgt_vocab_size,
    device=torch.device("cpu"),
    max_len=256,
    d_embed=512,
    n_layer=6,
    d_model=512,
    h=8,
    d_ff=2048,
    dropout_rate=0.1,
    norm_eps=1e-6,
):
    import copy

    copy = copy.deepcopy

    src_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=src_vocab_size)
    tgt_token_embed = TokenEmbedding(d_embed=d_embed, vocab_size=tgt_vocab_size)
    pos_embed = PositionalEncoding(d_embed=d_embed, max_len=max_len, device=device)

    src_embed = TransformerEmbedding(
        token_embed=src_token_embed,
        pos_embed=copy(pos_embed),
        dropout=nn.Dropout(dropout_rate),
    )
    tgt_embed = TransformerEmbedding(
        token_embed=tgt_token_embed,
        pos_embed=copy(pos_embed),
        dropout=nn.Dropout(dropout_rate),
    )

    attention = MultiHeadAttentionLayer(
        d_model=d_model,
        h=h,
        qkv_fc=nn.Linear(d_embed, d_model),
        out_fc=nn.Linear(d_model, d_embed),
        dropout=nn.Dropout(dropout_rate),
    )
    position_ff = PositionWiseFeedForwardLayer(
        fc1=nn.Linear(d_embed, d_ff),
        fc2=nn.Linear(d_ff, d_embed),
        dropout=nn.Dropout(dropout_rate),
    )
    norm = nn.LayerNorm(d_embed, eps=norm_eps)

    encoder_block = EncoderBlock(
        self_attention=copy(attention),
        position_ff=copy(position_ff),
        norm=copy(norm),
        dropout=nn.Dropout(dropout_rate),
    )
    decoder_block = DecoderBlock(
        self_attention=copy(attention),
        cross_attention=copy(attention),
        position_ff=copy(position_ff),
        norm=copy(norm),
        dropout=nn.Dropout(dropout_rate),
    )

    encoder = Encoder(
        encoder_block=encoder_block, n_layer=n_layer, norm=copy(norm)
    )
    decoder = Decoder(
        decoder_block=decoder_block, n_layer=n_layer, norm=copy(norm)
    )
    generator = nn.Linear(d_model, tgt_vocab_size)

    model = Transformer(
        src_embed=src_embed,
        tgt_embed=tgt_embed,
        encoder=encoder,
        decoder=decoder,
        generator=generator,
    ).to(device)
    model.device = device

    return model

In [57]:
model = build_model(1024, 1024)
print(model)

Transformer(
  (src_embed): TransformerEmbedding(
    (embedding): Sequential(
      (0): TokenEmbedding(
        (embedding): Embedding(1024, 512)
      )
      (1): PositionalEncoding()
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (tgt_embed): TransformerEmbedding(
    (embedding): Sequential(
      (0): TokenEmbedding(
        (embedding): Embedding(1024, 512)
      )
      (1): PositionalEncoding()
    )
    (dropout): Dropout(p=0.1, inplace=False)
  )
  (encoder): Encoder(
    (norm): LayerNorm((512,), eps=1e-06, elementwise_affine=True)
  )
  (decoder): Decoder(
    (layers): ModuleList(
      (0-5): 6 x DecoderBlock(
        (self_attention): MultiHeadAttentionLayer(
          (q_fc): Linear(in_features=512, out_features=512, bias=True)
          (k_fc): Linear(in_features=512, out_features=512, bias=True)
          (v_fc): Linear(in_features=512, out_features=512, bias=True)
          (out_fc): Linear(in_features=512, out_features=512, bias=True)
          (dropout)