前馈神经网络

有两个线性层和一个激活函数组成，增加非线性性，$FFN(x)=Linear(ReLU(Linear(x)))$

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class FFN(nn.Module):
    def __init__(self, embedding_dim, ff_dim, dropout=0.1):
        super().__init__()
        self.linear1 = nn.Linear(embedding_dim, ff_dim)
        self.linear2 = nn.Linear(ff_dim, embedding_dim)
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        # 用于防止过拟合
        x = self.dropout(x)
        x = self.linear2(x)
        x = self.dropout(x)
        return x

dummy_input = torch.randn(1, 196, 768)
embedding_dim = 768
ff_dim = 1024  # NOTE 前馈神经网络的隐藏层维度，一般要比embedding_dim大
ffn = FFN(embedding_dim=embedding_dim, ff_dim=ff_dim)
output = ffn(dummy_input)
print(output.shape)

torch.Size([1, 196, 768])
