In [50]:
import numpy as np
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

def get_sinusoid_encoding_table(n_seq, emb_d):
    """
    n_seq: seq_len
    emb_d: dim of sinusoide table
        - equal to the dim of word embedded weight mat 
    """
    def _cal_angle(position, ith_emb_fature):
        return position / np.power(10000, 2 * (ith_emb_fature // 2) / emb_d)
    
    def _get_position_angle_vec(position):
        return [_cal_angle(position, ith_emb_feature) for ith_emb_feature in range(emb_d)]
    
    sinusoid_table = np.array([_get_position_angle_vec(i_seq) for i_seq in range(n_seq)])
    sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2])
    sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2])
    
    return sinusoid_table

def get_attn_pad_mask(seq_q, seq_k, i_pad): # i_pad=0
    """
    seq_q: query sequence(not embbed)
    seq_k: key sequence
    i_pad: padding vocab_idx 
        - eg. 0
    """
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    pad_attn_mask = seq_k.data.eq(i_pad).unsqueeze(1).expand(batch_size, len_q, len_k)
    return pad_attn_mask

def get_attn_decoder_mask(seq):
    """
    mask upper triangular part
    seq: decoder sequence
        - not word embedding
    """
    triangular_mask = torch.ones_like(seq).unsqueeze(-1).expand((seq.size(0), seq.size(1), seq.size(1)))

class ScaledDotProductAttention(nn.Module):
    def __init__(self, config):
        """
        config: use types.SimpleNamespace
        """
        super().__init__()
        self.config = config
        self.dropout = nn.Dropout(config.dropout)
        self.scale = 1 / (self.config.k_dim**0.5)
    
    def forward(self, Q, K, V, attn_mask):
        """
        attn_mask: encoder part and decoder part has different one
        """
        scores = torch.matmul(Q, K.transpose(-1, -2)).mul_(self.scale)
        scores.masked_fill_(attn_mask, -1e+9)
        attn_prob = nn.Softmax(dim=-1)(scores)
        attn_prob = self.dropout(attn_prob)
        context = torch.matmul(attn_prob, V)
        return context, attn_prob
    

In [None]:
class GPT(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.decoder = Decoder(self.config)
        
    def forward(self, dec_inputs):
        dec_outputs, dec_self_attn_probs = self.decoder(dec_inputs)
        return dec_outputs, dec_self_attn_probs
    
    def save(self, epoch, loss, path):
        torch.save({
            'epoch': epoch,
            'loss': loss,
            'state_dict': self.state_dict()
        }, path)
    
    def load(self, path):
        save = torch.load(path)
        self.load_state_dict(save['state_dict'])
        return save['epoch'], save['loss']
    
class GPTpretrain(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.config = config
        self.gpt = GPT(self.config)
        
        self.projection_lm = nn.Linear(self.config.d_hidn, self.config.n_dec_vocab, bias=False)
        self.projection_lm.weight = self.gpt.decoder.dec_emb.weight
    
    def forward(self, dec_inputs):
        dec_outputs, dec_self_attn_probs = self.gpt(dec_inputs)
        logits_lm = self.projection_lm(dec_outputs)
        return logits_lm[:, :-1, :].contiguous(), dec_self_attn_probs
    