# Decoder (GPT type model)

In [1]:
from math import sqrt
import torch
from torch import nn
import torch.nn.functional as F
from transformers import AutoConfig, AutoTokenizer

In [2]:
model_id = "openai-community/gpt2"
gpt2_config = AutoConfig.from_pretrained(model_id)
tokenizer = AutoTokenizer.from_pretrained(model_id)
gpt2_config

GPT2Config {
  "_name_or_path": "openai-community/gpt2",
  "activation_function": "gelu_new",
  "architectures": [
    "GPT2LMHeadModel"
  ],
  "attn_pdrop": 0.1,
  "bos_token_id": 50256,
  "embd_pdrop": 0.1,
  "eos_token_id": 50256,
  "initializer_range": 0.02,
  "layer_norm_epsilon": 1e-05,
  "model_type": "gpt2",
  "n_ctx": 1024,
  "n_embd": 768,
  "n_head": 12,
  "n_inner": null,
  "n_layer": 12,
  "n_positions": 1024,
  "reorder_and_upcast_attn": false,
  "resid_pdrop": 0.1,
  "scale_attn_by_inverse_layer_idx": false,
  "scale_attn_weights": true,
  "summary_activation": null,
  "summary_first_dropout": 0.1,
  "summary_proj_to_labels": true,
  "summary_type": "cls_index",
  "summary_use_proj": true,
  "task_specific_params": {
    "text-generation": {
      "do_sample": true,
      "max_length": 50
    }
  },
  "transformers_version": "4.38.2",
  "use_cache": true,
  "vocab_size": 50257
}

## Tokenization

In [3]:
input_text = "Hello Transformers!"

# Tokenized input
tokenizer.tokenize(input_text, add_special_tokens=True)

['Hello', 'ĠTransformers', '!']

In [4]:
# Token ids
tokenized_input = tokenizer(input_text, return_tensors="pt").input_ids
tokenized_input

tensor([[15496, 39185,     0]])

In [5]:
# Corresponding ids to tokens
for i in tokenized_input[0]:
    print(f"{i} -> {tokenizer.decode(i)}")

15496 -> Hello
39185 ->  Transformers
0 -> !


In [6]:
token_embed = nn.Embedding(gpt2_config.vocab_size, gpt2_config.n_embd)

print(f"Vocabulary size: {gpt2_config.vocab_size}")
print(f"Hidden/Embedding size: {gpt2_config.n_embd}")
print(f"Embedding matrix: {token_embed}")

Vocabulary size: 50257
Hidden/Embedding size: 768
Embedding matrix: Embedding(50257, 768)


In [7]:
# embedd input tokens
token_embeddings = token_embed(tokenized_input)
print(f"Tokenized shape: {tokenized_input.shape}")
print(f"Embedded shape: {token_embeddings.shape}") # batch_size, input_length, embedding_dim

Tokenized shape: torch.Size([1, 3])
Embedded shape: torch.Size([1, 3, 768])


## Scaled dot-product attention (self-attention)

In [8]:
def scaled_dot_product_attention(key, query, value, mask):
    # scaling factor
    dim_k = key.size(-1)
    # (seq_len x embed_dim)*(embed_dim x seq_dim)=(seq_len X seq_len)
    scores = torch.bmm(key, query.transpose(1, 2)) / sqrt(dim_k) 
    # masked fill (seq_len x seq_len)
    scores = scores.masked_fill(mask == 0, float("-inf"))
    weigths = F.softmax(scores, dim=-1)
    # (seq_len x seq_len)*(seq_len x embed_dim)=(seq_len x embed_dim)
    return torch.bmm(weigths, value) 

In [9]:
# mask
seq_len = token_embeddings.size(1)
mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0)

outputs = scaled_dot_product_attention(token_embeddings, token_embeddings, token_embeddings, mask)
outputs.shape

torch.Size([1, 3, 768])

## Multi-headed attention

In [10]:
class AttentionHead(nn.Module):
    def __init__(self, embed_dim, head_dim):
        super(AttentionHead, self).__init__()
        self.key = nn.Linear(embed_dim, head_dim)
        self.query = nn.Linear(embed_dim, head_dim)
        self.value = nn.Linear(embed_dim, head_dim)

    def forward(self, inputs):
        seq_len = inputs.size(1)
        mask = torch.tril(torch.ones(seq_len, seq_len)).unsqueeze(0) # (seq_len x seq_len)
        return scaled_dot_product_attention(
            self.key(inputs), self.query(inputs), self.value(inputs), mask
        )

In [11]:
class MultiHeadAttention(nn.Module):
    def __init__(self, config):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = config.n_head
        self.embed_dim = config.n_embd
        self.head_dim = self.embed_dim // self.num_heads
        self.heads = nn.ModuleList([AttentionHead(self.embed_dim, self.head_dim) for _ in range(self.num_heads)])
        self.output_linear = nn.Linear(self.embed_dim, self.embed_dim)

    def forward(self, inputs):
        x = torch.cat([h(inputs) for h in self.heads],dim=-1)
        return self.output_linear(x)

In [12]:
multi_head_attn = MultiHeadAttention(gpt2_config)
outputs = multi_head_attn(token_embeddings)
outputs.shape

torch.Size([1, 3, 768])

## Position-wise feed-froward

In [13]:
class PositionwiseFeedForward(nn.Module):
    def __init__(self, config):
        super(PositionwiseFeedForward, self).__init__()
        self.linear_1 = nn.Linear(config.n_embd, 4*config.n_embd)
        self.linear_2 = nn.Linear(4*config.n_embd, config.n_embd)
        self.gelu = nn.GELU()
        self.dropout = nn.Dropout(0.1)

    def forward(self, inputs):
        x = self.linear_1(inputs)
        x = self.gelu(x)
        x = self.linear_2(x)
        return self.dropout(x)

In [14]:
ff = PositionwiseFeedForward(gpt2_config)
outputs = ff(token_embeddings)
outputs.shape

torch.Size([1, 3, 768])

## Decoder layer

In [15]:
class DecoderLayer(nn.Module):
    def __init__(self, config):
        super(DecoderLayer, self).__init__()
        self.layer_norm_1 = nn.LayerNorm(config.n_embd)
        self.layer_norm_2 = nn.LayerNorm(config.n_embd)
        self.attention = MultiHeadAttention(config)
        self.ff = PositionwiseFeedForward(config)

    def forward(self, inputs):
        """ Pre layer normalization:
        
        inputs -+-> norm_1 -> attention -- + -- +-> norm_2 -> feed-forward -- + --> outputs
                +---- skip-connection -----+    +------ skip-connection ------+
        """
        # block 1
        norm_1 = self.layer_norm_1(inputs)
        x = inputs + self.attention(norm_1)

        # block 2
        norm_2 = self.layer_norm_2(x)
        x = x + self.ff(norm_2)
        return x

In [16]:
decoder_layer = DecoderLayer(gpt2_config)
outputs = decoder_layer(token_embeddings)
outputs.shape

torch.Size([1, 3, 768])

## Positional encodings

In [17]:
class Embeddings(nn.Module):
    def __init__(self, config):
        super(Embeddings, self).__init__()
        self.token_embed = nn.Embedding(config.vocab_size, config.n_embd)
        self.position_embed = nn.Embedding(config.n_positions, config.n_embd)
        self.layer_norm = nn.LayerNorm(config.n_embd, eps=config.layer_norm_epsilon)
        self.dropout = nn.Dropout(0.1)

    def forward(self, inputs):
        # create position ids
        seq_len = inputs.size(1)
        position_ids = torch.arange(seq_len, dtype=torch.long).unsqueeze(0) # seq_len

        # embeddings
        token_embeddings = self.token_embed(inputs) # embed_dim
        position_embeddings = self.position_embed(position_ids) # embed_dim
        embeddings = token_embeddings + position_embeddings
        embeddings = self.layer_norm(embeddings)
        return self.dropout(embeddings)

In [18]:
embed = Embeddings(gpt2_config)
outputs = embed(tokenized_input)
outputs.shape

torch.Size([1, 3, 768])

## Transfomer decoder

In [19]:
class Decoder(nn.Module):
    def __init__(self, config):
        super(Decoder, self).__init__()
        self.embedding = Embeddings(config)
        self.decoder_layers = nn.ModuleList([DecoderLayer(config) for _ in range(config.n_layer)])
        self.output_linear = nn.Linear(config.n_embd, config.vocab_size)

    def forward(self, inputs):
        x = self.embedding(inputs)
        for layer in self.decoder_layers:
            x = layer(x)
        return self.output_linear(x)

In [20]:
decoder = Decoder(gpt2_config)
outputs = decoder(tokenized_input)
outputs.shape

torch.Size([1, 3, 50257])