In [3]:
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
from torch.autograd import Variable
import math, copy, time
%matplotlib inline

In [2]:
class EncoderDecoder(nn.Module):
    def __init__(self, encoder, decoder, source_embedding, target_embedding, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.source_embedding = source_embedding
        self.target_embedding = target_embedding
        self.generator = generator
    def forward(self, x, source_mask, target_mask, source, target):
        return self.decode(self.encode(source, source_mask), source_mask, target, target_mask)
    def encode(self,source, source_mask):
        return self.encoder(self.source_embedding(source), source_mask)
    def decode(self, encoder_output, source_mask, target, target_mask):
        return self.decoder(self.target_embedding(target), encoder_output, source_mask, target_mask)

In [4]:
class Generator(nn.Module):
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.projection = nn.Linear(d_model, vocab)
    def forward(self, x):
        return F.log_softmax(self.projection(x), dim = -1)

In [5]:
def clones(module, N):
    return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])

In [6]:
class Encoder(nn.Module):
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)
    def forward(self, x, mask):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

In [7]:
class LayerNorm(nn.Module):
    def __init__(self, features, eps = 1e-6):
        super(LayerNorm, self).__init__()
        self.a2 = nn.Parameter(torch.ones(features))
        self.b2 = nn.Parameter(torch.zeros(features))
        self.eps = eps
    def forward(self, x):
        mean = x.mean(-1, keepdim = True)
        std = x.std(-1, keepdim = True)
        return self.a2 * (x - mean) / (self.eps + std) + self.b2