In [None]:
import torch
import math
from torch import nn
import torch.nn.functional as F

In [None]:
def scaled_dot_product_attention(q,k,v,mask=None):
  d_k = q.size(-1)
  scaled = torch.matmul(q,k.transpose(-1,-2))/math.sqrt(d_k)
  attention = F.softmax(scaled)
  attention = torch.matmul(attention,v)
  if mask is not None:
    attention += mask
  return attention

In [None]:
class FeedForward(nn.Module):
  def __init__(self,d_model,hidden,drop_prob):
    self.linear1 = nn.Linear(d_model,hidden)
    self.linear2 = nn.Linear(hidden,d_model)
    self.relu = nn.ReLU()
    self.dropout = nn.Dropout(p=drop_prob)
  def forward(self,x):
    x = self.linear1(x)
    x = self.relu(x)
    x = self.dropout(x)
    x = self.linear2(x)
    return x


In [None]:
class LayerNormalization(nn.Module):
  def __init__(self,parameter_shape,epsilon=1e-5):
    self.parameter_shape = parameter_shape
    self.epsilon = epsilon
    self.gamma = nn.Parameter(torch.ones(parameter_shape))
    self.beta = nn.Paramreter(torch.ones(parameter_shape))
  def forward(self,inputs):
    mean = inputs.mean()
    std = ((inputs-mean)**2).mean()
    std = (std + self.epsilon).sqrt()
    y = (inputs-mean)/std
    out = y*self.gamma + self.beta
    return out

In [None]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_model,num_heads):
    self.d_model = d_model # 512
    self.num_heads = num_heads# 8
    self.qkv_layer = nn.Linear(d_model,3*d_model) # 1536
    self.head_dim = d_model // num_heads
    self.linear_layer = nn.Linear(d_model,d_model) # 512

  def forward(self,x,mask=None):
    sequence_length, dim = x.size() # 200 x 512
    qkv = self.qkv_layer(x)
    qkv = qkv.reshape(sequence_length,self.num_heads,3*self.head_dim) # 200 x 8 x 192
    qkv = qkv.permute(1, 0, 2) # 8 x 200 x 192
    q,k,v = qkv.chunk(3)
    attention = scaled_dot_product_attention(q,k,v,mask)
    attention = attention.reshape(sequence_length,self.num_heads*self.head_dim) # 200 x 512
    out = self.linear_layer(attention)

In [None]:
class MultiHeadCrossAttention(nn.Module):

    def __init__(self, d_model, num_heads):
        super().__init__()
        self.d_model = d_model
        self.num_heads = num_heads
        self.head_dim = d_model // num_heads
        self.kv_layer = nn.Linear(d_model , 2 * d_model) # 1024
        self.q_layer = nn.Linear(d_model , d_model)
        self.linear_layer = nn.Linear(d_model, d_model)

    def forward(self, x, y, mask=None):
        sequence_length, d_model = x.size() # 200 x 512
        kv = self.kv_layer(x) # 200 x 1024
        q = self.q_layer(y) # 200 x 512
        kv = kv.reshape(sequence_length, self.num_heads, 2 * self.head_dim)  #  200 x 8 x 128
        q = q.reshape(sequence_length, self.num_heads, self.head_dim)  # 200 x 8 x 64
        kv = kv.permute(1,0,2) #  8 x 200 x 128
        q = q.permute(1,0,2) #  8 x 200 x 64
        k, v = kv.chunk(2, dim=-1) # K:  8 x 200 x 64, v: 8 x 200 x 64
        values, attention = scaled_dot_product_attention(q, k, v, mask) #  8 x 200 x 64
        values = values.reshape(sequence_length, d_model) #   200 x 512
        out = self.linear_layer(values)  #  200 x 512
        return out  #  200 x 512

In [None]:
class DecoderLayer(nn.Module):

    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob):
        super(DecoderLayer, self).__init__()
        self.self_attention = MultiHeadAttention(d_model=d_model, num_heads=num_heads)
        self.norm1 = LayerNormalization(parameters_shape=[d_model])
        self.dropout1 = nn.Dropout(p=drop_prob)
        self.encoder_decoder_attention = MultiHeadCrossAttention(d_model=d_model, num_heads=num_heads)
        self.norm2 = LayerNormalization(parameters_shape=[d_model])
        self.dropout2 = nn.Dropout(p=drop_prob)
        self.ffn = FeedForward(d_model=d_model, hidden=ffn_hidden, drop_prob=drop_prob)
        self.norm3 = LayerNormalization(parameters_shape=[d_model])
        self.dropout3 = nn.Dropout(p=drop_prob)

    def forward(self, x, y, decoder_mask):
        resy = y
        y = self.self_attention(y, mask=decoder_mask)
        y = self.dropout1(y)
        y = self.norm1(y + resy)

        resy = y
        y = self.encoder_decoder_attention(x, y, mask=None)
        y = self.dropout2(y)
        y = self.norm2(y + resy)

        resy = y

        y = self.ffn(y)
        y = self.dropout3(y)
        y = self.norm3(y + resy)
        return y

class SequentialDecoder(nn.Sequential):
    def forward(self, *inputs):
        x, y, mask = inputs
        for module in self._modules.values():
            y = module(x, y, mask)
        return y

class Decoder(nn.Module):
    def __init__(self, d_model, ffn_hidden, num_heads, drop_prob, num_layers=1):
        super().__init__()
        self.layers = SequentialDecoder(*[DecoderLayer(d_model, ffn_hidden, num_heads, drop_prob)
                                          for _ in range(num_layers)])

    def forward(self, x, y, mask):
        y = self.layers(x, y, mask)
        return y