In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import numpy as np
from typing import Optional, List

In [3]:
class linear_transformation(nn.Module):
    def __init__(self, d_model, heads, d_k, bias=False):
        super().__init__()
        # Transform the dimensionality to divisible by number of heads
        assert(heads * d_k == d_model), "d_model should be divisible by heads"
        self.linear = nn.Linear(d_model, d_model, bias=bias)
        self.heads = heads
        self.d_k = d_k
    
    def forward(self, x: torch.Tensor):
        """
        Parameters:
        -----------
        x   : torch.Tensor
            Input with shape (seq_len, batch_size, d_model)
        
        Return:
        x   : torch.Tensor
            Output with shape (seq_len, batch_size, heads, d_k)
        """

        head_shape = x.shape[:-1]
        x = self.linear(x)
        # Split the x into heads number of matrices, same as reshape
        x = x.view(*head_shape, self.heads, self.d_k)
        return x

In [4]:
class multi_headed_attention(nn.Module):
    def __init__(self, heads: int, d_model: int, dropout_prob: float = 0.1):
        """
        Parameters:
        ---------------
        heads       : int
                    Number of heads.
        d_model     : int
                    Number of features in the query, key, value vectors.
        dropout_prob: float
                    Drop out rate.
        bias        : bool
                    Whether we have the bias terms in linear transformation.
        """

        # Calling the parent class constructor, inheriting from it
        super(multi_headed_attention, self).__init__()

        # number of features per head
        self.d_k = d_model // heads
        self.heads = heads

        # Linear Transformation of the query, key, value matrices
        self.query = linear_transformation(d_model, heads, self.d_k)
        self.key = linear_transformation(d_model, heads, self.d_k)
        self.value = linear_transformation(d_model, heads, self.d_k)

        self.softmax = nn.Softmax(dim=1)

        self.output = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout_prob)
        self.scale = 1 / np.sqrt(self.d_k)
        self.attn = None
    
    def get_scores(self, query: torch.Tensor, key: torch.Tensor):
        # i = the word in query, j = the word in key.
        # Iterate over all j's to get alignment score for the ith word in query over dimension d, which is self.d_k, for each head and batch.
        # the result i, j corresponds to the alignment score of ith word in query and jth word in key.

        # i, j = seq_len
        # b = batch_size
        # h = # of heads
        # d = d_k
        return torch.einsum("ibhd, jbhd->ijbh", query, key)

    def prepare_mask(self, mask: torch.Tensor, query_shape: List[int], key_shape: List[int]):
        assert mask.shape[0] == 1 or mask.shape[0] == query_shape[0]
        assert mask.shape[1] == key_shape[0]
        assert mask.shape[2] == 1 or mask.shape[2] == query_shape[1]

        mask = mask.unsqueeze(-1)

        return mask
    
    def forward(self, *, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, mask: Optional[torch.Tensor] = None):
        seq_len, batch_size, _ = query.shape

        if mask is not None:
            mask = self.prepare_mask(mask, query.shape, key.shape)

        query = self.query(query)
        key = self.key(key)
        value = self.value(value)

        scores = self.get_scores(query, key)
        scores *= self.scale

        if mask is not None:
            scores = scores.masked_fill(mask == 0, float("-inf"))
        
        attn = self.softmax(scores)
        
        attn = self.dropout(attn)

        # i, j = seq_len
        # b = batch
        # h = heads
        # d = d_k

        # for the wetght of jth word to i th word, we get the corresponding dth feature.
        # append each dth feature of each word together according to the weights.
        x = torch.einsum("ijbh, jbhd->ibhd", attn, value)

        x = x.view(seq_len, batch_size, -1) # concatenation of all heads

        return self.output(x) # final linear transformation

p is the position and i is the dimension

$\begin{align}
PE_{p,2i} &= sin\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg) \\
PE_{p,2i + 1} &= cos\Bigg(\frac{p}{10000^{\frac{2i}{d_{model}}}}\Bigg)
\end{align}$


In [5]:
def positional_encoding(d_model, max_len):
    """
    Parameters:
    -------------------
    d_model:    # of features of a word vector
    max_len:    max length of a sentence
    -------------------
    """
    encodings = torch.zeros(max_len, d_model)
    positions = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)
    two_i = torch.arange(0, d_model, 2, dtype=torch.float32)

    div_term = torch.exp((-1 * two_i / d_model) * torch.log(10000))
    encodings[:, 0::2] = torch.sin(positions * div_term)
    encodings[:, 1::2] = torch.cos(positions * div_term)

    encodings = encodings.unsqueeze(1).requires_grad(False) # input x shape should be (seq_len, batch, embed_size)

    return encodings


In [6]:
class word_embedding_with_positional_encoding(nn.Module):
    def __init__(self, d_model, n_vocab, max_len=5000):
        super(word_embedding_with_positional_encoding, self).__init__()
        self.embed = nn.Embedding(n_vocab, d_model)
        self.d_model = d_model
        self.register_buffer("positional_encoding", positional_encoding(d_model, max_len))
    
    def forward(self, x):
        pe = self.positional_encoding[:x.shape[0]].requires_grad(False)
        embeddings = self.embed(x) * np.sqrt(self.d_model) + pe

        return embeddings

In [7]:
class feed_forward(nn.Module):
    def __init__(self, d_model, d_hidden, bias, drop_out_rate):
        super(feed_forward, self).__init__()
        self.layer1 = nn.Linear(d_model, d_hidden, bias=bias)
        self.layer2 = nn.Linear(d_hidden, d_model, bias=bias)
        self.activation = nn.ReLU()
        self.drop_out = nn.Dropout(drop_out_rate)
    
    def forward(self, x):
        y = self.layer1(x)
        y = self.activation(y)
        y = self.drop_out(y)
        y = self.layer2(y)

        return y

In [9]:
class transformer_layer(nn.Module):
    def __init__(self, d_model, self_attn, feed_forward, drop_out_rate, src_attn = None):
        super(transformer_layer, self).__init__()
        self.d_model = d_model
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.drop_out = nn.Dropout(drop_out_rate)
        self.src_attn = src_attn
        self.norm_self_attn = nn.LayerNorm([d_model])

        if src_attn is not None:
            self.norm_src_attn = nn.LayerNorm([d_model])
        self.norm_feed_forward = nn.LayerNorm([d_model])
    
    def forward(self, x, mask, src = None, src_mask = None):
        z = self.norm_self_attn(x)
        self_attn = self.self_attn(query=z, key=z, value=z, mask=mask)
        x = x + self.drop_out(self_attn)

        if src is not None:
            z = self.norm_src_attn(x)
            src_attn = self.src_attn(query=z, key=src, value=src, mask=src_mask)
            x = x + self.drop_out(src_attn)
        
        z = self.norm_feed_forward(x)
        ff = self.feed_forward(z)

        x = x + self.drop_out(ff)

        return x

In [11]:
class encoder(nn.Module):
    def __init__(self, d_model, n_layers, self_attn, feed_forward, drop_out_rate):
        super(encoder, self).__init__()
        self.layers = nn.ModuleList([transformer_layer(d_model, self_attn, feed_forward, drop_out_rate) for _ in range(n_layers)])

        self.norm = nn.LayerNorm([d_model])
        
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask=mask)
        
        return self.norm(x)

In [12]:
class decoder(nn.Module):
    def __init__(self, d_model, n_layers, self_attn, feed_forward, drop_out_rate, src_attn):
        super(decoder, self).__init__()
        self.layers = nn.ModuleList([transformer_layer(d_model, self_attn, feed_forward, drop_out_rate, src_attn) for _ in n_layers])

        self.norm = nn.LayerNorm([d_model])
    
    def forward(self, x, mask, src, src_mask):
        for layer in self.layers:
            x = layer(x, mask, src, src_mask)
        
        return self.norm(x)

In [13]:
class transformer(nn.Module):
    def __init__(self, d_model, heads, dropout_prob, ff_hidden_size, n_layers, src_nvocab, tgt_nvocab):
        encoder_attention = multi_headed_attention(heads=heads, d_model=d_model, dropout_prob=dropout_prob)
        decoder_attention = multi_headed_attention(heads=heads, d_model=d_model, dropout_prob=dropout_prob)
        encoder_decoder_attention = multi_headed_attention(heads=heads, d_model=d_model, dropout_prob=dropout_prob)

        encoder_ffn = feed_forward(d_model=d_model, d_hidden=ff_hidden_size, bias=True, drop_out_rate=dropout_prob)
        decoder_ffn = feed_forward(d_model=d_model, d_hidden=ff_hidden_size, bias=True, drop_out_rate=dropout_prob)

        self.encoder = encoder(d_model=d_model, n_layers=n_layers, self_attn=encoder_attention, feed_forward=encoder_ffn, drop_out_rate=dropout_prob)
        self.decoder = decoder(d_model=d_model, n_layers=n_layers, self_attn=decoder_attention, feed_forward=decoder_ffn, drop_out_rate=dropout_prob, src_attn=encoder_decoder_attention)

        self.encode_embedding = word_embedding_with_positional_encoding(src_nvocab, d_model)
        self.decode_embedding = word_embedding_with_positional_encoding(tgt_nvocab, d_model)

        self.linear = nn.Linear(d_model, tgt_nvocab)
        self.softmax = nn.Softmax(dim=-1)

        for p in self.parameters():
            if p.dim() > 1:
                nn.init.xavier_uniform_(p)
        
    def src_embed(self, x):
        return self.encode_embedding(x)
    
    def tgt_embed(self, x):
        return self.decode_embedding(x)
    
    def generate_target_mask(self, tgt):
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        tri_mask = (1 - torch.triu(torch.ones(1, tgt.shape[0], tgt.shape[0]), diagonal=1)).bool()
        return tgt_mask & tri_mask
    
    def forward(self, src, tgt, src_mask=None):
        encoded = self.encoder(x=src, mask=src_mask)
        tgt_mask = self.generate_target_mask(tgt)
        decoded = self.decoder(x=tgt, mask=tgt_mask, src=encoded, src_mask=src_mask)

        result = self.linear(decoded)
        result = self.softmax(result)

        return result
        