In [1]:
import math
import torch
from torch import nn
import torch.nn.functional as F

### Positional Encoding

In [7]:
class PositionalEncoding(nn.Module):
    def __init__(self, num_hiddens, dropout, max_len=1000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(dropout)        
        self.P = self.get_positional_matrix(max_len, num_hiddens)
        
    def forward(self, X):
        assert X.shape[-1] == self.P.shape[-1]
        X = X + self.P[:X.shape[-2]]
        X = self.dropout(X)
        return X
    
    def get_positional_matrix(self, max_len, d):
        # J = torch.tensor(range(0, d, 2))
        # I = torch.tensor(range(max_len)).reshape(-1, 1).expand(max_len, J.shape[0])
        # matrix = torch.true_divide(I, 10000**(torch.true_divide(J, d)))
        matrix = [[i / 10000**(2*j/d) for j in range(0, d, 2)] for i in range(max_len)]
        matrix = torch.tensor(matrix)
        P = torch.empty([max_len, d])
        P[:, 0::2] = torch.sin(matrix)
        P[:, 1::2] = torch.cos(matrix)
        return P

### PFFN

In [39]:
class PositionWiseFFN(nn.Module):
    def __init__(self, ffn_num_hiddens, pw_num_outputs, **kwargs):
        super(PositionWiseFFN, self).__init__(**kwargs)
        self.dense1 = nn.Linear(pw_num_outputs, ffn_num_hiddens)
        self.activation = nn.ReLU()
        self.dense2 = nn.Linear(ffn_num_hiddens, pw_num_outputs)

    def forward(self, X):
        X = self.dense1(X)
        X = self.activation(X)
        X = self.dense2(X) 
        return X

### Add&Norm

In [6]:
class AddNorm(nn.Module):
    def __init__(self, num_hiddens, dropout, **kwargs):
        super(AddNorm, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        self.ln = nn.LayerNorm(num_hiddens)

    def forward(self, X, Y):
        X = self.dropout(Y) + X
        X = self.ln(X)
        return X

### Multi-Head Attention

In [33]:
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # `query`: (`batch_size`, #queries, `d`)
    # `key`: (`batch_size`, #kv_pairs, `d`)
    # `value`: (`batch_size`, #kv_pairs, `dim_v`)
    def forward(self, query, key, value):
        assert query.shape[-1] == key.shape[-1]
        
        scores = torch.bmm(query, key.transpose(2, 1)) / math.sqrt(query.shape[-1])
        attention_weights = self.dropout(self.mask_softmax(scores, dim=-1))
        return torch.bmm(attention_weights, value)
    
    def mask_softmax(self, X, dim=-1):
        mask = X.eq(0).int()
        X = -1e6 * mask + X * (1 - mask)
        return F.softmax(X, dim=dim)
            

class MultiHeadAttention(nn.Module):
    def __init__(self, num_hiddens, num_heads, dropout, **kwargs):
        super(MultiHeadAttention, self).__init__(**kwargs)
        self.num_heads = num_heads
        self.attention = DotProductAttention(dropout)
        self.W_q = nn.Linear(num_hiddens, num_hiddens, bias=False)  # qW_q^T
        self.W_k = nn.Linear(num_hiddens, num_hiddens, bias=False)  # kW_k^T
        self.W_v = nn.Linear(num_hiddens, num_hiddens, bias=False)  # vW_v^T
        self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=False)

    def forward(self, query, key, value):
        """ For self-attention, `query`, `key`, and `value` shape: (`batch_size`, `seq_len`, `dim`),
            where `seq_len` is the length of input sequence.
            `valid_len` shape is either (`batch_size`, ) or (`batch_size`, `seq_len`)."""

        key = transpose_qkv(self.W_k(key), self.num_heads)
        value = transpose_qkv(self.W_v(value), self.num_heads)
        query = transpose_qkv(self.W_q(query), self.num_heads)
        
        output = self.attention(query, key, value)
        output = transpose_output(output, self.num_heads)
        return self.W_o(output)


def transpose_qkv(X, num_heads):
    # `X_in.size() == (batch_size, seq_len, num_hiddens)`
    # `X_out.size() == (batch_size * num_heads, seq_len, num_hiddens / num_heads)`"""
    num_heads = int(num_heads)
    (batch_size, seq_len, num_hidden) = tuple(X.size())
    X = X.view(batch_size, seq_len, num_heads, -1)
    X = X.view(batch_size, num_heads, seq_len, -1)
    X = X.view(batch_size * num_heads, seq_len, -1)
    return X


def transpose_output(X, num_heads):
    num_heads = int(num_heads)
    (bs_mul_nh, seq_len, nhid_div_nheads) = tuple(X.size())
    X = X.view(-1, seq_len, nhid_div_nheads * num_heads)
    return X

### Encoder

In [8]:
class EncoderBlock(nn.Module):
    def __init__(self, num_hiddens, ffn_num_hiddens, num_heads, dropout, **kwargs):
        super(EncoderBlock, self).__init__(**kwargs)
        self.attention = MultiHeadAttention(num_hiddens, num_heads, dropout)
        self.addnorm1 = AddNorm(num_hiddens, dropout)
        self.ffn = PositionWiseFFN(ffn_num_hiddens, num_hiddens)
        self.addnorm2 = AddNorm(num_hiddens, dropout)

    def forward(self, X):
        Y = self.attention(X, X, X)
        Y = self.addnorm1(X, Y)
        Z = self.ffn(Y)
        Z = self.addnorm2(Y, Z)
        return Z

### Transformer

In [9]:
class TransformerEncoder(nn.Module):
    def __init__(self, vocab_size, num_hiddens, ffn_num_hiddens,
                 num_heads, num_layers, dropout, use_bias=False, **kwargs):
        super(TransformerEncoder, self).__init__(**kwargs)
        self.num_hiddens = num_hiddens
        self.embedding = nn.Embedding(vocab_size, num_hiddens)
        self.pos_encoding = PositionalEncoding(num_hiddens, dropout)
        self.blocks = nn.Sequential()
        for _ in range(num_layers):
            self.blocks.add_module("Encoder",
                                   EncoderBlock(num_hiddens,
                                                ffn_num_hiddens,
                                                num_heads,
                                                dropout))

    def forward(self, X, *args):
        X = self.embedding(X) * math.sqrt(self.num_hiddens)
        X = self.pos_encoding(X)  # X.shape == (2, 100, 24)
        for block in self.blocks:
            X = block(X)
        return X

In [40]:
batch_size = 2
seq_len = 100
vocab_size = 200
dropout = 0.5
num_heads = 8
num_layers = 2
num_hiddens = 24
ffn_num_hiddens = 48
encoder = TransformerEncoder(vocab_size, num_hiddens, ffn_num_hiddens, num_heads, num_layers, dropout)
X = encoder(torch.ones([batch_size, seq_len], dtype=torch.int64))
print(X.shape)

torch.Size([2, 100, 24])
