In [2]:
import torch
from torch import nn
from torch.nn import functional as F

import sys
sys.path.append("..")

from math import sqrt

<h2 align="center">Position-wise Feed-Forward Networks</h2>

>In addition to attention sub-layers, each of the layers in our encoder and decoder contains a fully connected feed-forward network, which is applied to each position separately and identically. This consists of two linear transformations with a ReLu activation in between.

<div align="center">
    <h3>
        FFN(<i>x</i>) = max(0, <i>xW<sub>1</sub> + b<sub>1</sub></i>)<i>W<sub>2</sub> + b<sub>2</sub></i>
    </h3>
</div>

* While the linear transformations are the same across different positions, the use different parameters layer to layer.
    * input-size and output-units have dimensionality d_model = 512.
    * first feed forward layer output-units d_ff = 2048.
* Another way of looking at this is as 2 convolutions of kernel size 1.

In [3]:
class PositionWiseFeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout_prob: int = 0.1):
        super(PositionWiseFeedForward, self).__init__()
        self.w_1 = nn.Linear(in_features=d_model, out_features=d_ff)
        self.w_2 = nn.Linear(in_features=d_ff, out_features=d_model)
        self.dropout = nn.Dropout(p=dropout_prob)

    def forward(self, x):
        return self.w_2(self.dropout(F.relu(self.w_1(x))))

<h2 align="center">Embeddings and Softmax</h2>

* >We use learned embeddings to convert the input tokens and output tokens to vectors of dimension d_model. We also use the usual learned linear transformation and softmax function to convert the decoder output to predict next-token probabilities. 
* >In our model, we share the same weight matrix between the two embedding layers and the pre-softmax linear transformation. 
* >In the embedding layers, we multiply those weights by sqrt(d_model)

In [4]:
class Embeddings(nn.Module):
    def __init__(self, d_model: int, vocab_size: int):
        super(Embeddings, self).__init__()
        self.look_up = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=d_model
        )
        self.d_model = d_model
    
    def forward(self, x):
        return self.look_up(x) / sqrt(self.d_model)