In [1]:
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
import math
import numpy as np

print(torch.__version__)

2.3.0


In [2]:
class Transformer(nn.Module):
    
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def encode(self, src, src_mask):
        out = self.encoder(src, src_mask)
        return out
    
    def decode(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = self.decoder(tgt, encoder_out, tgt_mask, src_tgt_mask)
        return out
    
    def forward(self, src, tgt):
        src_mask = self.make_src_mask(src)
        tgt_mask = self.make_tgt_mask(tgt)
        encoder_out = self.encode(src, src_mask)
        y = self.decode(tgt, encoder_out, tgt_mask)
        return y
    
    def make_subsequent_mask(self, query, key):
        """make subsequent masking data

        Args:
            query (torch.Tensor): (n_batch, query_seq_len)
            key (torch.Tensor): (n_batch, key_seq_len)
        """
        
        query_seq_len, key_seq_len = query.size(1), key.size(1)
        
        tril = np.tril(np.ones((query_seq_len, key_seq_len)), k=0).astype('uint8')
        mask = torch.tensor(tril, dtype=torch.bool, requires_grad=False, device=query.device)
        return mask

    def make_pad_mask(self, query, key, pad_idx = 1):
        # query : (n_batch, query_seq_len)
        # key   : (n_batch, key_seq_len)
        query_seq_len, key_seq_len = query.size(1), key.size(1)
        
        key_mask = key.ne(pad_idx).unsqueeze(1).unsqueeze(2) # (n_batch, 1, 1, key_seq_len)
        key_mask = key_mask.repeat(1, 1, query_seq_len, 1)   # (n_batch, 1, query_seq_len, key_seq_len)
        
        query_mask = query.ne(pad_idx).unsqueeze(1).unsqueeze(3) # (n_batch, 1, query_seq_len, 1)
        query_mask = query_mask.repeat(1, 1, 1, key_seq_len)   # (n_batch, 1, query_seq_len, key_seq_len)
        
        mask = key_mask & query_mask
        mask.requires_grad = False
        return mask
    
    def make_src_mask(self, src):
        pad_mask = self.make_pad_mask(src,src)
        return pad_mask
    
    def make_tgt_mask(self, tgt):
        pad_mask = self.make_pad_mask(tgt,tgt)
        seq_mask = self.make_subsequent_mask(tgt,tgt)
        mask = pad_mask & seq_mask
        return mask
        
    def make_src_tgt_mask(self, src, tgt):
        pad_mask = self.make_pad_mask(tgt,src)
        return pad_mask



In [3]:
class MultiHeadAttentionLayer(nn.Module):
    
    def __init__(self, d_model, h, qkv_fc, out_fc):
        super(MultiHeadAttentionLayer, self).__init__()
        self.d_model = d_model
        self.h = h
        self.q_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.k_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.v_fc = copy.deepcopy(qkv_fc) # (d_embed, d_model)
        self.out_fc = out_fc              # (d_model, d_embed)
        
    def calculate_attention(self, query, key, value, mask):
        """
        query, key, value : (n_batch, h, seq_len, d_k)
        mask : (n_batch, 1, seq_len, seq_len)
        """
        d_k = key.shape[-1]
        attention_score = torch.matmul(query, key.transpose(-2,-1))  # Q x K^T  -> (n_batch, h, seq_len, seq_len)
        attention_score = attention_score/math.sqrt(d_k)
        
        if mask is not None:
            attention_score = attention_score.masked_fill(mask==0, -1e9)
        
        attention_prob = F.softmax(attention_score, dim = -1) # (n_batch, h, seq_len, seq_len)
        out = torch.matmul(attention_prob, value) # (n_batch, h, seq_len, d_k)
        
        return out

    def forward(self, *args, query, key, value, mask=None):
        """
        Args:
            query : (n_batch, seq_len, d_embed)
            key : (n_batch, seq_len, d_embed)
            value : (n_batch, seq_len, d_embed)
            mask : (n_batch, h, seq_len, seq_len)
        Return:
            value : (n_batch, h, seq_len, d_k)
        """
        
        n_batch = query.size(0)
        
        def transform(x, fc):                                           # (n_batch, seq_len, d_embed)
            out = fc(x)                                                 # (n_batch, seq_len, d_embed)
            out = out.view(n_batch, -1, self.h, self.d_model//self.h)   # (n_batch, seq_len, h, d_k)
            out = out.transpose(1, 2)                                   # (n_batch, h, seq_len, d_k)
            return out
        
        query = transform(query, self.q_fc) # (n_batch, h, seq_len, d_k)
        key = transform(key, self.k_fc)     # (n_batch, h, seq_len, d_k)
        value = transform(value, self.v_fc) # (n_batch, h, seq_len, d_k)
        
        out = self.calculate_attention(query, key, value, mask) # (n_batch, h, seq_len, d_k)
        out = out.transpose(1, 2)                               # (n_batch, seq_len, h, d_k)
        out = out.contiguous().view(n_batch, -1, self.d_model)  # (n_batch, seq_len, d_model)
        out = self.out_fc(out) # (n_batch, seq_len, d_embed)
        
        return out

In [4]:
class PositionWiseFeedForwardLayer(nn.Module):
    
    def __init__(self, fc1, fc2):
        super(PositionWiseFeedForwardLayer, self).__init__()
        self.fc1 = fc1
        self.relu = nn.ReLU()
        self.fc2 = fc2
        
    def forward(self, x):
        out = x
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        return out
    
class ResidualConnectionLayer(nn.Module):
    
    def __init__(self):
        super(ResidualConnectionLayer, self).__init__()
        
    def forward(self, x, sub_layer):
        out = x
        out = sub_layer(out)
        out = out + x
        return out
        

In [5]:

class Encoder(nn.Module):
    
    def __init__(self, encoder_layer, n_layer):  # n_layer: Encoder Block의 개수
        super(Encoder, self).__init__()
        self.layers = []
        for i in range(n_layer):
            self.layers.append(copy.deepcopy(encoder_layer))


    def forward(self, src, src_mask):
        out = src
        for layer in self.layers:
            out = layer(out, src_mask)
        return out
    
class EncoderBlock(nn.Module):
    
    def __init__(self, self_attention, position_ff):
        super(EncoderBlock, self).__init__()
        self.self_attention = self_attention
        self.position_ff = position_ff
        self.residuals = [ResidualConnectionLayer() for _ in range(2)]
        
    def forward(self, src, src_mask):
        out = src
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=src_mask))
        out = self.residuals[1](out, self.position_ff)
        return out

In [None]:
class Decoder(nn.Module):
    def __init__(self, decoder_block, n_layer):
        super(Decoder, self).__init__()
        self.n_layer = n_layer
        self.layers = nn.ModuleList([copy.deepcopy(decoder_block) for _ in range(self.n_layer)])
    
    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        for layer in self.layers:
            out = layer(out, encoder_out, tgt_mask, src_tgt_mask)
        return out
    
class DecoderBlock(nn.Module):
    
    def __init__(self, self_attention, cross_attention, position_ff):
        super(DecoderBlock, self).__init__()
        self.self_attention = self_attention
        self.cross_attention = cross_attention
        self.position_ff = position_ff
        self.residuals = [ResidualConnectionLayer() for _ in range(3)]
        
    def forward(self, tgt, encoder_out, tgt_mask, src_tgt_mask):
        out = tgt
        out = self.residuals[0](out, lambda out: self.self_attention(query=out, key=out, value=out, mask=tgt_mask))
        out = self.residuals[1](out, lambda out: self.self_attention(query=out, key=encoder_out, value=encoder_out, mask=src_tgt_mask))
        out = self.residuals[2](out, self.position_ff)
        return out
    
