In [5]:
# prelims: torchtext -c pytorch; seaborn spacy
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import math, copy, time
from torch.autograd import Variable
import matplotlib.pyplot as plt
import seaborn
seaborn.set_context(context="talk")
%matplotlib inline

In [6]:
class EncoderDecoder(nn.Module):
    "A standard Encoder-Decoder architecture."
    def __init__(self, encoder, decoder, src_embed, tgt_embed, generator):
        super(EncoderDecoder, self).__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.src_embed = src_embed
        self.tgt_embed = tgt_embed
        self.generator = generator

    def forward(self, src, tgt, src_mask, tgt_mask):
        "Take in and process masked src and target sequences."
        encoded = self.encode(src, src_mask) # n_input_features, n_hidden_features_1 ?
        decoded = self.decode(encoded, src_mask, tgt, tgt_mask)
        return decoded
    
    def encode(self, src, src_mask):
        embeded = self.src_embed(src)
        encoded = self.encoder(embeded, src_mask)
        return encoded
    
    def decode(self, memory, src_mask, tgt, tgt_mask):
        embeded = self.tgt_embed(tgt)
        decoded = self.decoder(embeded, memory, src_mask, tgt_mask)    


In [7]:
class Generator(nn.Module):
    "Define standard linear + softmax generation step."
    def __init__(self, d_model, vocab):
        super(Generator, self).__init__()
        self.proj = nn.Linear(d_model, vocab)

    def forward(self, x):
        out = self.proj(x)
        out = F.log_softmax(out, dim=-1)
        return out

In [8]:
def clones(module, N):
    "Produce N identical layers."
    modules = [copy.deepcopy(module) for _ in range(N)]
    return nn.ModuleList(modules)

In [9]:
class LayerNorm(nn.Module):
    "Construct a layernorm module."
    def __init__(self, features, eps=1e-6):
        super(LayerNorm, self).__init__()
        self.a_2 = nn.Parameter(torch.ones(features)) # what is a_2? Is features the n_input_features?
        self.b_2 = nn.Parameter(torch.zeros(features)) # what is b_2?
        self.eps = eps

    def forward(self, x):
        mean = x.mean(-1, keepdim=True)
        std  = x.std (-1, keepdim=True)
        return self.a_2 * (x - mean) / (std + self.eps) + self.b_2

In [10]:
class Encoder(nn.Module):
    "Core encoder is a stack of N layers"
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask):
        "Pass the input (and mask) through each layer in turn."
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)