# Encoder (BERT type model)

**Source**: [Natural Language Processing with Transformers by Lewis Tunstall, Leandro von Werra, Thomas Wolf](https://transformersbook.com/)

In [1]:
from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer

In [2]:
model_id = "bert-base-uncased"
bert_config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
bert_config

BertConfig {
  "_name_or_path": "bert-base-uncased",
  "architectures": [
    "BertForMaskedLM"
  ],
  "attention_probs_dropout_prob": 0.1,
  "classifier_dropout": null,
  "gradient_checkpointing": false,
  "hidden_act": "gelu",
  "hidden_dropout_prob": 0.1,
  "hidden_size": 768,
  "initializer_range": 0.02,
  "intermediate_size": 3072,
  "layer_norm_eps": 1e-12,
  "max_position_embeddings": 512,
  "model_type": "bert",
  "num_attention_heads": 12,
  "num_hidden_layers": 12,
  "pad_token_id": 0,
  "position_embedding_type": "absolute",
  "transformers_version": "4.38.2",
  "type_vocab_size": 2,
  "use_cache": true,
  "vocab_size": 30522
}

## Tokenization

In [3]:
input_text = "Hello Transformers!"

# Tokenized input
tokenizer.tokenize(input_text, add_special_tokens=True)

['[CLS]', 'hello', 'transformers', '!', '[SEP]']

In [4]:
# Token ids
tokenized_input = tokenizer(input_text, return_tensors="pt").input_ids
tokenized_input

tensor([[  101,  7592, 19081,   999,   102]])

In [5]:
# Corresponding ids to tokens
for i in tokenized_input[0]:
    print(f"{i} -> {tokenizer.decode(i)}")

101 -> [CLS]
7592 -> hello
19081 -> transformers
999 -> !
102 -> [SEP]


## Token embeddings

In [6]:
token_embed = nn.Embedding(bert_config.vocab_size, bert_config.hidden_size)

print(f"Vocabulary size: {bert_config.vocab_size}")
print(f"Hidden/Embedding size: {bert_config.hidden_size}")
print(f"Embedding matrix: {token_embed}")

Vocabulary size: 30522
Hidden/Embedding size: 768
Embedding matrix: Embedding(30522, 768)


In [7]:
# embedd input tokens
token_embeddings = token_embed(tokenized_input)
print(f"Tokenized shape: {tokenized_input.shape}")
print(f"Embedded shape: {token_embeddings.shape}") # batch_size, input_length, embedding_dim

Tokenized shape: torch.Size([1, 5])
Embedded shape: torch.Size([1, 5, 768])


## Scaled dot-product attention (self-attention)

In [8]:
def scaled_dot_product_attention(query, key, value):
    """Function to calculate self attention."""
    # scaling factor
    dim_k = key.size(-1)
    # attention scores: (5x768)*(5x768)^T=(5x5)
    scores = torch.bmm(key, query.transpose(1, 2)) / sqrt(dim_k) # batch_size, input_length, input_length
    # attention weights
    weigths = F.softmax(scores, dim=-1)
    # (5x5)*(5x768)=(5x768)
    return torch.bmm(weigths, value) # batch_size, input_length, hidden_size

In [9]:
self_attention = scaled_dot_product_attention(token_embeddings, token_embeddings, token_embeddings)
self_attention.shape

torch.Size([1, 5, 768])

## Multi-headed attention

In [10]:
class AttentionHead(nn.Module):
    """Single self attention head."""
    def __init__(self, embed_dim, head_dim):
        super(AttentionHead, self).__init__()
        self.query = nn.Linear(embed_dim, head_dim)
        self.key = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)

    def forward(self, inputs):
        return scaled_dot_product_attention(
            self.query(inputs), self.key(inputs), self.value(inputs)
        )

In [11]:
class MultiHeadAttention(nn.Module):
    """Multi-headed attention through multiple AttentionHeads."""
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = config.hidden_size
        self.num_heads = config.num_attention_heads
        
        # split input into smaller pieces to enable different heads attend to different parts
        self.head_dim = self.embed_dim // self.num_heads
        self.heads = nn.ModuleList([AttentionHead(self.embed_dim, self.head_dim) for _ in range(self.num_heads)])
        
        self.output_linear = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, inputs):
        x = torch.cat([h(inputs) for h in self.heads], dim=-1) # call self-attention heads
        return self.output_linear(x)

In [12]:
multi_head_attn = MultiHeadAttention(bert_config)
outputs = multi_head_attn(token_embeddings)
outputs.shape

torch.Size([1, 5, 768])

## Position-wise feed forward-layer

In [13]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, config):
        super(PositionwiseFeedForward, self).__init__()
        self.linear_1 = nn.Linear(config.hidden_size, config.intermediate_size)
        self.linear_2 = nn.Linear(config.intermediate_size, config.hidden_size)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(config.hidden_dropout_prob)

    def forward(self, inputs):
        x = self.linear_1(inputs)
        x = self.gelu(x)
        x = self.linear_2(x)
        return self.dropout(x)

In [14]:
ff = PositionwiseFeedForward(bert_config)
outputs = ff(token_embeddings)
outputs.shape

torch.Size([1, 5, 768])

## Encoder layer

In [15]:
class EncoderLayer(nn.Module):
    def __init__(self, config):
        super(EncoderLayer, self).__init__()
        self.layer_norm_1 = nn.LayerNorm(config.hidden_size)
        self.layer_norm_2 = nn.LayerNorm(config.hidden_size)
        self.attention = MultiHeadAttention(config)
        self.ff = PositionwiseFeedForward(config)

    def forward(self, inputs):
        """ Pre layer normalization:
        
        inputs -+-> norm_1 -> attention -- + -- +-> norm_2 -> feed-forward -- + --> outputs
                +---- skip-connection -----+    +------ skip-connection ------+
        """
        # block 1
        norm_1 = self.layer_norm_1(inputs) # normalization
        x = inputs + self.attention(norm_1) # attention + skip-conncetion
        # block 2
        norm_2 = self.layer_norm_2(x) # normalization
        x = x + self.ff(norm_2) # feed-forward + skip-connection
        return x

In [16]:
encoder_layer = EncoderLayer(bert_config)
outputs = encoder_layer(token_embeddings)
outputs.shape

torch.Size([1, 5, 768])

## Positional encodings

In [17]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super(Embeddings, self).__init__()
        self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size)
        self.position_embed = nn.Embedding(config.max_position_embeddings, config.hidden_size)
        self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
        self.dropout = nn.Dropout()

    def forward(self, inputs):
        # create position ids
        seq_len = inputs.size(1)
        position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0)
        # create embeddings
        token_embeddings = self.token_embed(inputs)
        position_embeddings = self.position_embed(position_ids)
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        return self.dropout(embeddings)

In [18]:
embed = Embeddings(bert_config)
outputs = embed(tokenized_input)
outputs.shape

torch.Size([1, 5, 768])

## Transformer encoder

In [19]:
class TransformerEncoder(nn.Module):
    def __init__(self, config):
        super(TransformerEncoder, self).__init__()
        self.embeddings = Embeddings(config)
        self.encoder_layers = nn.ModuleList([EncoderLayer(config) for _ in range(config.num_hidden_layers)])

    def forward(self, inputs):
        x = self.embeddings(inputs)
        for layer in self.encoder_layers:
            x = layer(x)
        return x

In [20]:
encoder = TransformerEncoder(bert_config)
outputs = encoder(tokenized_input)
outputs.shape

torch.Size([1, 5, 768])