In [1]:
import typing
from copy import deepcopy
import math

import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
class Encoder(nn.Module):
    """Encoder architecture of the Transformer that includes N stacked layers."""
    def __init__(self, layer, N=6):
        super(Encoder, self).__init__()
        self.layers = nn.ModuleList([deepcopy(layer) for _ in range(N)])

        def forward(self, x):
            for layer in self.layers:
                x = layer(x)
            return x

In [3]:
class EncoderLayer(nn.Module):
    """Encoder Layer that consists of two sublayers
            1. Multi-head self attention
            2. Feed Forward Neural Network (FFNN)
        There's a residual connection followed by layer normalization
        joining the two layers.
    """
    def __init__(self, self_attn, ffnn):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.ffnn = ffnn
        self.sublayer = nn.ModuleList([SubLayerConnection() for _ in range(2)])  

        def forward(self, x):
            x = self.sublayer[0](x, self.self_attn)
            return self.sublayer[1](x, self.ffnn)

In [4]:
class FeedForwardNetwork(nn.Module):
    """A simple, positionwise fully connected feed-forward network
            FFN(x) = max(0, xW1 + b1)W2 + b2
    """
    
    def __init__(self, d_model: int = 512, d_ff: int = 2048):
        super(FeedForwardNetwork, self).__init__()
        self.W_1 = nn.Linear(d_model, d_ff)
        self.W_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        return self.W_2(self.W_1(x).relu())

In [5]:
class SubLayerConnection(nn.Module):
    """Creates a residual connection and performs Layer Normalization for a
    sublayer.
            LayerNorm(x + Sublayer(x))
    """
    def __init__(self, dropout: float = 0.1, size: int = 512):
        super(SubLayerConnection, self).__init__()
        self.dropout = nn.Dropout(dropout)
        self.norm = nn.LayerNorm(size)
        
    def forward(self, x, sublayer):
        """apply dropout to the output of each sub-layer, 
        before it is added to the sub-layer input and normalized."""
        return self.norm(x+sublayer(self.dropout(x)))

In [6]:
class Decoder(nn.Module):
    """Decoder architecture of the Transformer that includes N stacked layers."""

    def __init__(self, layer, N: int = 6):
        super(Decoder, self).__init__


In [7]:
class DecoderLayer(nn.Module):
    """Encoder Layer that consists of two sublayers
            1. Multi-head self attention
            2. Feed Forward Neural Network (FFNN)
            3. Multi-head attention over the output of the encoder stack
        There's a residual connection followed by layer normalization
        between each layer. 
        The self-attention layer masks subsequent positions.
    """
    def __init__(self, self_attn, cross_attn, ffnn):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.cross_attn = cross_attn
        self.ffnn = ffnn
        self.sublayer = nn.ModuleList([SubLayerConnection() for _ in range(3)]) 

    def forward(self, x, memory, src_mask, tgt_mask):
        x = self.sublayer[0](x, lambda x: self.self_attn(x, tgt_mask)) 
        x = self.sublayer[1](x, lambda x: self.cross_attn(x, memory, src_mask))
        return self.sublayer[2](x, self.ffnn)

In [8]:
def attention(query, key, value, mask=None):
    d_k = query.shape[-1]
    scores = torch.matmul(query,key.transpose(1,2)) / math.sqrt(d_k)
    if mask:
        scores.masked_fill_(mask==0, float("-inf"))
    log_probs = nn.LogSoftmax(dim=-1)(scores)
    return torch.matmul(log_probs, value)

In [9]:
class MultiSelfAttn(nn.Module):
    """Performs multihead self attention"""

    def __init__(self, heads=8, d_model=512):
        super(MultiSelfAttn, self).__init__()
        d_k = d_v = d_model//heads
        self.W_Q = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_K = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_V = nn.ModuleList(
            [nn.Linear(d_model, d_v, bias=False) for _ in range(heads)]
            )
        self.W_O = nn.Linear(heads*d_v, d_model)
        self.heads = heads

    def forward(self, x, mask=None):
        multihead_dotP = []
        for head in range(self.heads):
            Q = self.W_Q[head](x)
            K = self.W_K[head](x)
            V = self.W_V[head](x)
            multihead_dotP.append(attention(Q, K, V, mask))
        multihead_attn = torch.cat(multihead_dotP, dim=-1)
        return self.W_O(multihead_attn)

In [10]:
class MultiCrossAttn(nn.Module):
    """Performs multi-head cross attention between encoder
    and decoder."""

    def __init__(self, heads=8, d_model=512):
        super(MultiCrossAttn, self).__init__()
        d_k = d_v = d_model//heads
        self.W_Q = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_K = nn.ModuleList(
            [nn.Linear(d_model, d_k, bias=False) for _ in range(heads)]
            )
        self.W_V = nn.ModuleList(
            [nn.Linear(d_model, d_v, bias=False) for _ in range(heads)]
            )
        self.W_O = nn.Linear(heads*d_v, d_model)
        self.heads = heads

    def forward(self, x, memory, mask=None):
        multihead_dotP = []
        for head in range(self.heads):
            Q = self.W_Q[head](memory)
            K = self.W_K[head](memory)
            V = self.W_V[head](x)
            multihead_dotP.append(attention(Q, K, V, mask))
        multihead_attn = torch.cat(multihead_dotP, dim=-1)
        return self.W_O(multihead_attn)