In [None]:
"""
FeedForward
sequence_length x d_model (4 x 512)

1. Linear  -> sequence_length x d_hidden (4 x 1024)

2. ReLU    -> sequence_length x d_hidden (4 x 1024)

3. Dropout -> sequence_length x d_hidden (4 x 1024)

4. Linear  -> sequence_length x d_hidden (4 x 512)
"""

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FeedForward(nn.Module):
    def __init__(self, d_model=512, d_hidden=1024, max_sequence_length=1024, dropout=0.1):
        super(FeedForward, self).__init__()
        self.d_model = d_model
        self.d_hidden = d_hidden
        self.max_sequence_length = max_sequence_length # This parameter is not used
        self.linear1 = nn.Linear(d_model, d_hidden)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(d_hidden, d_model)

    def forward(self, x):
        x = self.linear1(x)
        x = F.relu(x)
        x = self.dropout(x)
        x = self.linear2(x)
        return x

