In [None]:
import numpy as np
import torch
import torch.nn as nn
#from utils import split_last, merge_last

In [None]:
class LayerNorm(nn.Module):
    "A layernorm module in the TF style (epsilon inside the square root)."
    def __init__(self, cfg, variance_epsilon=1e-12):
        super().__init__()
        self.gamma = nn.Parameter(torch.ones(cfg.dim))
        self.beta  = nn.Parameter(torch.zeros(cfg.dim))
        self.variance_epsilon = variance_epsilon

    def forward(self, x):
        u = x.mean(-1, keepdim=True)
        s = (x - u).pow(2).mean(-1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.variance_epsilon)
        return self.gamma * x + self.beta

In [None]:
class Embeddings(nn.Module):
    "The embedding module from word, position and token_type embeddings."
    def __init__(self, cfg):
        super().__init__()
        self.tok_embed = nn.Embedding(cfg.vocab_size, cfg.dim)  # token embedding
        self.pos_embed = nn.Embedding(cfg.max_len, cfg.dim)     # position embedding
        self.seg_embed = nn.Embedding(cfg.n_segments, cfg.dim)  # segment(token type) embedding

        self.norm = LayerNorm(cfg)
        self.drop = nn.Dropout(cfg.p_drop_hidden)

    def forward(self, x, seg):
        seq_len = x.size(1)
        pos = torch.arange(seq_len, dtype=torch.long, device=x.device)
        pos = pos.unsqueeze(0).expand_as(x) # (S,) -> (B, S)

        e = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)
        return self.drop(self.norm(e))

In [None]:
class MultiHeadAttentionLayer( nn.Module ) :
    def __init__( self, config ) :
        super().__init__()
        self.valueLinearProjection = nn.Linear( config.hidden_dim, config.hidden_dim )
        self.queryLinearProjection = nn.Linear( config.hidden_dim, config.hidden_dim )
        self.keyLinearProjection = nn.Linear( config.hidden_dim, config.hidden_dim )
        self.dropout = nn.Dropout( config.drop_attn )
        self.n_heads = config.n_heads

    def forward( self, x, mask ) :
        """
        x, query(q), key(k), value(v) : B(batch_size), S(seq_len), D(dim)
        mask : B X S
        """
        # (B, S, H, W)
        q, k, v = self.valueLinearProjection( x ), self.queryLinearProjection( x ), self.keyLinearProjection( x )
        # (B, H, S, W)
        q, k, v = ( x.view( *x.size()[:-1], n_heads, x.shape(-1) / n_head ).transpose( 1, 2 )
                    for x in [q, k, v] )
        #q, k, v = ( split_last( x, ( self.n_heads, -1 ) ).transpose( 1, 2 )
        #            for x in [q, k, v] )
        # (B, H, S, W) @ (B, H, W, S) -> (B, H, S, S)
        scores = q @ k.transpose(-2, -1) / np.sqrt( k.size(-1) )
        if mask is not None :
            mask = mask[:, None, None, :].float()
            scores -= 10000.0 * (1.0 - mask) 
            # Make Masked Area in a very small value
            # So we can ignore the masked Value
        scores = self.dropout( torch.F.softmax( scores, dim=-1 ) )
        # (B, H, S, S) @ (B, H, S, W) -> (B, H, S, W) -transpose-> (B, S, H, W)
        h = ( scores @ v ).transpose( 1, 2 ).contiguous()
        h = h.view( *h.size()[:-2], -1 ) # (B, S, D)

In [None]:
class PositionWiseFeedForward( nn.Module ) :
    def __init__( self, config ) :
        self.relu = nn.ReLU()
        self.forward1 = nn.Linear( config.dim, config.dimff )
        self.forward2 = nn.Linear( config.dimff, config.dim )

    def forward( self, x ) :
        return self.forward2( self.relu( self.forward1( x ) ) )

In [None]:
class Block( nn.Module ) :
    def __init__( slef, config ) :
        super().__init__()
        self.attn = MultiHeadAttentionLayer( config )
        self.proj = nn.Linear( config.dim, config.dim )
        self.norm1 = LayerNorm( config )
        self.pwff = PositionWiseFeedForward( config )
        self.norm2 = LayerNorm( config )
        self.drop = nn.Dropout( config.dropout )

    def forward( self, x, mask ) :
        a = self.attn( x, mask )
        h = self.norm1( x + self.drop( self.proj( a ) ) )
        h = self.norm2( h + self.drop( self.pwff( h ) ) )
        return h

In [None]:
class Transformer( nn.Module ) :
    def __init__( self, config ) :
        self.embedding = Embeddings( config )
        self.blocks = [ Block( config ) for _ in range( config.num_layers ) ]

    def forward( self, x, seg, mask ) :
        embedding = self.embedding( x, seg )
        for block in self.blocks :
            embedding = block( embedding, mask )

In [None]:
# 1. Input x = [x1, x2, ..., xn]
# 2. MLM select a random set of positions to mask out m = [m1, ..., mk]^3
# 3. the token in the selected positions are replaced with a [MASK] token -> xm = replace( x, m, [MASK] )
# 4. generator learns to predict the original identities of the masked-out tokens

In [None]:
from random import randint, shuffle
from random import random as rand

class PreprocessGenerator( Pipeline ):
    
    def __init__( self, vacab_words, indexer ) :
        super().__init__()
        self.max_pred = 20        # max tokens of prediction
        self.mask_prob = 0.15     # mask coverage (following mlm)
        self.vocab_words = vocab_words
        self.indexer = indexer    # function from token to token index
        self.max_len = 512

    def __call__( self, instance ) :
        is_next, tokens_a, tokens_b = instance

        # special tokens [CLS], [SEP], [SEP]
        truncate_tokens_pair( tokens_a, tokens_b, self.max_len - 3 )

        # Add special tokens
        tokens = ['[CLS]'] + tokens_a + ['[SEP]'] + tokens_b + ['[SEP]']
        segment_idx = [0] * ( len(tokens_a) + 2 ) 
                    + [1] * ( len(tokens_b) + 1 )
        input_mask = [1] * len( tokens )

        # For masked language model (MLM)
        masked_tokens, maksed_pos = [], []
        n_pred = min( self.max_pred, max( 1, int( round( len( tokens ) * self.mask_prob ) ) ) )
        candidate_pos = [ i for i, token in enumerate( tokens ) 
                            if token != '[CLS]' and token != '[SEP]' ]
        shuffle( candidate_pos )

        for pos in cadidate_pos[ :n_pred ] :
            masked_tokens.append( tokens[ pos ] )
            masked_pos.append( pos )
            if rand() < 0.8 : #80%
                tokens[ pos ] = '[MASK]'
            elif rand() < 0.5 : #10%
                tokens[ pos ] = get_random_word( self.vocab_words )
        masked_weights = [1] * len( masked_tokens )

        # Token Indexing
        input_idx = self.indexer( tokens )
        maksed_idx = self.indexer( masked_tokens )
        
        # Zero Padding
        n_pad = self.max_len - len( input_idx )
        input_idx.extend( [0] * n_pad )
        segment_idx.extend( [0] * n_pad )
        input_mask.extend( [0] * n_pad )

        # Zero Padding for Masked Target
        if self.max_pred > n_pred :
            n_pad = self.max_pred - n_pred
            masked_idx.extend( [0] * n_pad )
            masked_pos.extend( [0] * n_pad )
            masked_weights.extend( [0] * n_pad )
        
        return ( input_idx, segment_idx, input_mask, masked_idx, masked_pos, masked_weights, is_next )

    def truncate_tokens_pair( tokens_a, tokens_b, max_len ):
        while True :
            if len( tokens_a ) + len( tokens_b ) <= max_len :
                break
            if len( tokens_a ) > len( tokens_b ) :
                tokens_a.pop()
            else :
                tokens_b.pop()

    def get_random_word( vocab_words ) :
        i = randint( 0, len( vocab_words ) - 1 )
        return vocab_words[ i ]
        

In [None]:
class Generator( nn.Module ) :

    def __init__( self, config ) :
        super().__init__()
        self.transformer    = Transformer( config )
        self.fc             = nn.Linear( config.dim, config.dim )
        self.activ1         = nn.Tanh()
        self.linear         = nn.Linear( config.dim, config.dim )
        self.activ2         = nn.ReLU()
        self.norm           = LayerNorm( config )
        self.classifier     = nn.Linear( config.dim, 2 )

        # Decoder
        embed_weight        = self.transformer
                                  .embedding
                                  .tok_embed
                                  .weight
        n_vocab, n_dim      = embed_weight.size()
        self.decoder        = nn.Linear( n_dim, n_vocab, bias=False )
        self.decoder.weight = embed_weight
        self.decoder_bias   = nn.Parameter( torch.zeros( n_vocab) )

    def forward( self, input_idx, segment_idx, input_mask, masked_pos ) :
        h = self.transformer( input_idx, segment_id, input_mask, masked_pos )
        pooled_h = self.activ1( self.fc( h[:,0] ) )
        masked_pos = masked_pos[:, :, None].expand( -1, -1, h.size( -1 ) )
        h_masked = torch.gather( h, 1, masked_pos )
        h_masked = self.norm( self.active2( self.linear( h_masked ) ) )
        logits_lm = self.decoder( h_masked ) + self.decoder_bias
        logits_clsf = self.classifier( pooled_h )

        return logits_lm, logits_clsf