In [1]:
%%capture
# Colab-only
!pip install -U einops transformers datasets torch

# Workflow:
$$
\text{Text} \xrightarrow{\text{Tokenize}} \text{Token IDs} \xrightarrow{\text{Embed}} \text{Embeddings} \xrightarrow{\text{Multi-Head Attention}} \text{Attention} \xrightarrow{\text{Feed Forward}} \text{Output}
$$

First, I need to get a tokenizer to tokenize the text.

The researchers used byte-pair encoding to tokenize the inputs that seems to come from this repo: https://github.com/google/seq2seq

However, I will cheat a bit by using a more recent one: BERT `AutoTokenizer`, offered by HuggingFace.

# Encoder Attention

In [2]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from transformers import AutoTokenizer, AutoConfig

In [3]:
config = AutoConfig.from_pretrained("bert-base-uncased")

In [4]:
class Embeddings(nn.Module):
    def __init__(self, config) -> None:
        super().__init__()
        self.tok_embedder = nn.Embedding(config.vocab_size, config.hidden_size)
        self.pos_embedder = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout()

    def forward(self, input_ids):
        seq_length = input_ids.size(1)
        position_ids = torch.arange(seq_length, dtype=torch.long).unsqueeze(0)

        tok_emb = self.tok_embedder(input_ids)
        pos_emb = self.pos_embedder(position_ids)
        return self.dropout(self.ln(tok_emb + pos_emb))

In [5]:
input_embedding = Embeddings(config)
text = "Time flies like an arrow."
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [6]:
tokens = tokenizer(text, return_tensors='pt', add_special_tokens=False)
tokens.input_ids.size()

torch.Size([1, 6])

In [7]:
tokens.input_ids

tensor([[ 2051, 10029,  2066,  2019,  8612,  1012]])

In [8]:
tok_embs = input_embedding(tokens.input_ids)
tok_embs.size(), tok_embs # (batch_size, seq_length, embedding)

(torch.Size([1, 6, 768]),
 tensor([[[-0.0000,  0.0469,  0.0000,  ...,  2.5158, -0.0000, -0.0224],
          [-0.0000, -0.0000,  3.3932,  ...,  0.0000,  0.0000, -1.9809],
          [ 0.8667, -5.4253,  2.4648,  ...,  1.2762, -2.3111,  0.0000],
          [-0.0000, -0.0000, -0.0000,  ..., -3.1015,  0.0000,  0.0000],
          [ 0.0000, -0.0000, -0.0000,  ..., -0.1973, -0.0000,  0.0000],
          [-2.2891,  0.0000, -0.0000,  ...,  0.0000, -0.0000,  1.6231]]],
        grad_fn=<MulBackward0>))

In [9]:
config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.30.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

In [10]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.mha = nn.MultiheadAttention(embed_dim = config.hidden_size,
                                            num_heads = config.num_attention_heads,
                                            dropout = config.attention_probs_dropout_prob,
                                            batch_first = True)
        self.mlp = nn.Sequential(
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.intermediate_size, config.hidden_size),
            nn.Dropout(config.hidden_dropout_prob),
        )

    def forward(self, x):
        x = self.ln(x)
        x = x + self.mha(x, x, x, need_weights=False)[0]
        x = x + self.mlp(x)
        return x

In [11]:
torch.manual_seed(13)
embedder = Embeddings(config)
encoder_layer = EncoderLayer(config)
input_embs = embedder(tokens.input_ids)
attn_output = encoder_layer(input_embs)

In [12]:
attn_output.size(), attn_output

(torch.Size([1, 6, 768]),
 tensor([[[-0.3095, -0.1367,  0.8761,  ...,  0.6114, -0.1188,  0.5869],
          [-2.4964,  0.2702,  0.2411,  ...,  0.2163, -2.1636,  0.3740],
          [-1.8449, -0.1501,  0.1115,  ..., -0.1459, -1.0278,  0.1750],
          [-0.4407, -0.2148, -1.1614,  ..., -0.2032,  2.0405, -0.0879],
          [ 0.3524, -0.2698,  0.1365,  ..., -1.8543,  2.8713,  0.3415],
          [ 0.0638,  0.1646,  0.3128,  ..., -0.5256, -0.1556,  0.1668]]],
        grad_fn=<AddBackward0>))

In [13]:
# torch built-in
torch.manual_seed(13)
embedder = Embeddings(config)
encoder_layer = nn.TransformerEncoderLayer(d_model=768, # Hidden size - d_model = 512 in the original paper
                                            nhead=12, # Number of heads
                                            dim_feedforward=3072, # MLP size - 2048 in the original paper
                                            dropout=0.1, # Amount of dropout for dense layers
                                            activation="gelu", # GELU non-linear activation - ReLU in the original paper
                                            batch_first=True, # Do our batches come first?
                                            norm_first=True) # Normalize first or after? False in the original paper
input_embs = embedder(tokens.input_ids)
attn_output = encoder_layer(input_embs)
attn_output.size(), attn_output

(torch.Size([1, 6, 768]),
 tensor([[[-0.2543, -0.0695,  1.3249,  ...,  1.0119, -0.2667,  0.6509],
          [-3.6748,  0.2192,  0.2183,  ...,  0.3020, -3.0234,  0.3456],
          [-2.6925,  0.2037, -0.0856,  ..., -0.3380, -1.5309,  0.2545],
          [-0.9036, -0.4099, -1.5102,  ..., -0.1269,  2.5892, -0.1954],
          [ 0.0365, -0.2197, -0.0721,  ..., -2.3647,  4.1379,  0.4813],
          [ 0.0783,  0.1790,  0.3027,  ..., -0.3547, -0.3315,  0.2470]]],
        grad_fn=<AddBackward0>))

In [14]:
# torch built-in
encoder_block = nn.TransformerEncoder(encoder_layer, num_layers=12, norm=nn.GELU())
attn_out = encoder_block(input_embs)
attn_out.size(), attn_out

(torch.Size([1, 6, 768]),
 tensor([[[ 1.1389e+00, -7.1901e-02,  1.4478e+00,  ..., -3.5197e-02,
           -1.0566e-06,  2.2843e+00],
          [ 0.0000e+00, -8.8637e-02,  2.4241e+00,  ..., -1.4967e-01,
            0.0000e+00, -1.6726e-01],
          [ 0.0000e+00, -3.0400e-02,  2.2148e+00,  ..., -6.8196e-02,
            0.0000e+00, -5.5695e-02],
          [-4.2814e-04,  9.0434e-01, -1.4771e-01,  ..., -1.6984e-01,
            4.8173e-01, -1.4140e-01],
          [ 7.4693e-01, -5.3238e-04, -1.1459e-01,  ...,  0.0000e+00,
            1.9923e-01, -1.6861e-01],
          [ 7.2464e-01,  4.9042e-01,  3.3780e+00,  ..., -9.0422e-02,
           -3.4072e-02,  4.0730e-01]]], grad_fn=<GeluBackward0>))

In [15]:
attn_out.mean(dim=-1), attn_out.std(dim=-1)

(tensor([[1.5086, 1.5549, 1.4868, 1.4596, 1.4525, 1.4986]],
        grad_fn=<MeanBackward1>),
 tensor([[2.3516, 2.4070, 2.3345, 2.2248, 2.2195, 2.3881]],
        grad_fn=<StdBackward0>))

# Cross-attention

In [16]:
transformer = nn.Transformer()
transformer

Transformer(
  (encoder): TransformerEncoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True)
        )
        (linear1): Linear(in_features=512, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=512, bias=True)
        (norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
    )
    (norm): LayerNorm((512,), eps=1e-05, elementwise_affine=True)
  )
  (decoder): TransformerDecoder(
    (layers): ModuleList(
      (0-5): 6 x TransformerDecoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=512, o

The output of the whole encoder stack is fed into the whole decoder stack. It is not the output of each encoder layer is fed into each decoder layer as I have worried about.

In [17]:
out_embedder = Embeddings(config)

In [18]:
out_text = "Die Zeit fliegt wie ein Pfeil."
out_tokens = tokenizer(out_text, return_tensors='pt', add_special_tokens=False)
out_tokens.input_ids.size()

torch.Size([1, 13])

In [21]:
out_embs = out_embedder(out_tokens.input_ids)
out_embs.size(), out_embs

(torch.Size([1, 13, 768]),
 tensor([[[-1.0874,  0.0000, -0.0000,  ..., -0.0000, -0.0000,  2.4156],
          [-0.4436, -0.0000,  0.0000,  ..., -0.8739, -1.7109, -3.8949],
          [-1.0612,  0.0000, -0.2693,  ...,  1.2509, -0.0000,  0.0000],
          ...,
          [ 0.9750,  2.0410,  0.0000,  ..., -0.0000,  0.0000, -0.2905],
          [ 0.0000,  2.2961,  1.8994,  ..., -0.0000, -0.0000,  0.0000],
          [ 0.0000,  2.9155, -3.7012,  ...,  0.0000, -0.4442, -0.0000]]],
        grad_fn=<MulBackward0>))

In [33]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.ln_1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.ln_2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.masked_mha = nn.MultiheadAttention(embed_dim = config.hidden_size,
                                                num_heads = config.num_attention_heads,
                                                dropout = config.attention_probs_dropout_prob,
                                                batch_first = True)
        self.cross_mha = nn.MultiheadAttention(embed_dim = config.hidden_size,
                                                num_heads = config.num_attention_heads,
                                                dropout = config.attention_probs_dropout_prob,
                                                batch_first = True)
        self.mlp = nn.Sequential(
            nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps),
            nn.Linear(config.hidden_size, config.intermediate_size),
            nn.GELU(),
            nn.Dropout(config.hidden_dropout_prob),
            nn.Linear(config.intermediate_size, config.hidden_size),
            nn.Dropout(config.hidden_dropout_prob),
        )

    def forward(self, x, targets):
        seq_len = x.size(1)
        tril = torch.tril(torch.ones(seq_len, seq_len))
        x = self.ln_1(x)
        x = x + self.masked_mha(x, x, x, need_weights=False, attn_mask=tril)[0]
        x = self.ln_2(x)
        x = x + self.cross_mha(x, targets, targets, need_weights=False)[0]
        x = x + self.mlp(x)
        return x

In [34]:
decoder_layer = DecoderLayer(config)

In [36]:
fin_attn = decoder_layer(out_embs, attn_out)
fin_attn.shape, fin_attn

(torch.Size([1, 13, 768]),
 tensor([[[ 0.4823,  0.9333,  2.7316,  ..., -0.1462,  1.2293,  3.0945],
          [ 1.4167,  0.3474,  3.0290,  ..., -0.6089,  0.2863, -1.6745],
          [ 0.9513,  0.3664,  2.4151,  ...,  1.0125,  1.3159,  1.2196],
          ...,
          [ 1.9623,  1.9298,  2.6363,  ..., -0.0759,  0.8840,  0.6493],
          [ 1.7121,  2.1650,  3.5570,  ...,  0.0386,  1.5015,  1.3978],
          [ 2.0135,  2.2830, -0.3611,  ..., -0.3804,  1.2136,  0.7575]]],
        grad_fn=<AddBackward0>))

In [37]:
decoder_layer = nn.TransformerDecoderLayer(d_model=768,
                                            nhead=12,
                                            dim_feedforward=3072,
                                            dropout=0.1,
                                            activation="gelu",
                                            batch_first=True,
                                            norm_first=True)

In [39]:
nn_attn = decoder_layer(out_embs, attn_out)
nn_attn.size(), nn_attn

(torch.Size([1, 13, 768]),
 tensor([[[-1.6848,  1.3700, -1.1266,  ..., -0.4563, -0.4142,  2.1117],
          [-1.0528,  0.2012, -1.2836,  ..., -0.8361, -1.6717, -3.9338],
          [-1.5154,  0.5556, -1.9454,  ...,  1.2982,  0.2885, -0.6141],
          ...,
          [ 0.6797,  3.1068, -0.0083,  ..., -0.4355,  0.0760, -0.2905],
          [ 0.2414,  3.1351,  0.5836,  ...,  0.0694,  0.9064, -0.2163],
          [-0.5599,  4.0275, -5.1363,  ..., -0.6990, -0.0588, -0.0551]]],
        grad_fn=<AddBackward0>))

In [40]:
transformer = nn.Transformer(d_model=768,
                                nhead=12,
                                dim_feedforward=3072,
                                dropout=0.1,
                                activation="gelu",
                                batch_first=True,
                                norm_first=True)
trans_attn = transformer(input_embs, out_embs)
trans_attn.size(), trans_attn

(torch.Size([1, 13, 768]),
 tensor([[[-1.1917,  1.1229, -0.5712,  ...,  0.1839, -0.6941,  1.1376],
          [-0.7294,  0.2228,  0.0463,  ...,  0.1341, -1.1323, -1.5835],
          [-1.9350,  1.0680, -0.7251,  ...,  0.5380, -0.2834, -0.1388],
          ...,
          [-1.0327,  0.7142,  0.0236,  ..., -0.1222,  0.0857, -0.2588],
          [-1.8987,  1.2154,  1.3066,  ...,  0.3801, -0.1634,  0.2587],
          [-0.1824,  0.6913, -1.1397,  ...,  0.3787, -0.6034, -0.9291]]],
        grad_fn=<NativeLayerNormBackward0>))

Great! Now let's train that model from scratch