# Transformer

In [10]:
import math 
import pandas as pd
import torch
from torch import nn 
from d2l import torch as d2l
from IPython.display import Image

Below is the transformer architecture from paper "Attention all you need". 

In [11]:
Image(url="Transformer_architecture.png", width=800, height=800)

#### Positionwise Feed-Forward Networks

In [12]:
class PositionWiseFFN(nn.Module):
	"""Positionwise feed-forward network"""

	def __init__(self, ffn_num_input, ffn_num_hiddens, ffn_num_outputs, **kwargs):
		super(PositionWiseFFN, self).__init__(**kwargs)
		self.dense1 = nn.Linear(ffn_num_input, ffn_num_hiddens)
		self.relu = nn.ReLU()
		self.dense2 = nn.Linear(ffn_num_hiddens, ffn_num_outputs)

	def forward(self, X):
		return self.dense2(self.relu(self.dense1(X)))

In [13]:
ffn = PositionWiseFFN(4, 4, 8)
ffn.eval()

PositionWiseFFN(
  (dense1): Linear(in_features=4, out_features=4, bias=True)
  (relu): ReLU()
  (dense2): Linear(in_features=4, out_features=8, bias=True)
)

In [14]:
# test (batch_size, sequence_positions, hidden_dimesions)

# extract one batch before feed-forward
print(ffn(torch.ones((2, 3, 4))[0]).shape)

# extract one batch after feed-forward
print(ffn(torch.ones((2, 3, 4)))[0].shape)

torch.Size([3, 8])
torch.Size([3, 8])
