# GPT-2

We will build a GPT-2 language model for predicting the next token in a sequence of text. The shakespeare_char dataset will be used for this demonstration, which can be found in the data folder.  

GPT-2 [6] is a scaled up version of the transformer decoder. It uses multiple decoder block layers, larger embedding size, larger block size, more heads in the multi-headed attention block, and a larger vocabulary size. The vocabulary size is 50257 tokens, but for our example we will stick with the simple 65 token vocabulary. It comes in multiple sized versions, small, medium, large and XL, to accomadate for the varying compute requirements of the user. We will be creating and unofficial extra-small variant for this example due to compute constraints. If compute is not a limitation, any one of the GPT-2 variants can be created just by using the parameters shown in the **Paramaters Selection** section below. One major change from the transformer decoder notebook is the introduction of dropout layers [7]. Dropout is a regularization technique mainly used to help prevent overfitting. Seeing as we are scaling up the model, this becomes a valuable addition to the model.


### References:
- [1] [GPT colab notebook](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [2] [Video: scaling up the model!](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=5869s)
- [3] [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
- [4] [Deep Residual Learning for Image Recognition](https://arxiv.org/abs/1512.03385)
- [5] [Layer Normalization](https://arxiv.org/abs/1607.06450)
- [6] [Language Models are Unsupervised Multitask Learners](https://d4mucfpksywv.cloudfront.net/better-language-models/language-models.pdf)
- [7] [Improving neural networks by preventing co-adaptation of feature detectors](https://arxiv.org/abs/1207.0580)

In [1]:
import os
import requests
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import value_and_grad

from helper_funcs import get_batch, generate, masked_fill, loss_fn

## Parameter Selection

The Parameters used below are a scaled down version of GPT-2. GPT-2 has 4 different sizes, small, medium, large and xl. This GPT-2 could be considered an extra-small version. Note that these models may not be able to fit into RAM on your device. The exact specifications of the different sized models are shown below:

### GPT-2 Small
- n_embed: 768
- block_size: 1024
- num_heads: 12
- num_layers: 12
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 Medium
- n_embed: 1024
- block_size: 1024
- num_heads: 16
- num_layers: 24
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 Large
- n_embed: 1280
- block_size: 1024
- num_heads: 20
- num_layers: 36
- vocab_size: 50257 (uses Tiktoken vocab)

### GPT-2 XL
- n_embed: 1600
- block_size: 1024
- num_heads: 25
- num_layers: 48
- vocab_size: 50257 (uses Tiktoken vocab)


In [40]:
n_embed = 32 # Number of embedding dimensions
batch_size = 16 # How many independent sequences will we process in parallel?
block_size = 32 # What is the maximum context length for predictions?
num_heads = 4 # Number of heads in the multi-headed block
num_layers = 6 # Number of transformer decoder blocks
drop_rate = 0.1 # Dropout rate for regularization

rng_key = jax.random.PRNGKey(128)

## Data Preparation

In [33]:
# download the tiny shakespeare dataset
input_file_path = os.path.join('./data/shakespeare_char/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    l = np.array(l)
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = jnp.array(train_ids, dtype=jnp.uint16)
val_ids = jnp.array(val_ids, dtype=jnp.uint16)

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65
train has 1,003,854 tokens
val has 111,540 tokens


In [34]:
print(decode(train_ids[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


## Build the Attention Model

In [14]:
class FeedForward(nn.Module):
    """
    A feed forward multi-layer perceptron network.
    """
    n_embed: int
    drop_rate: float

    @nn.compact
    def __call__(self, x):
        net = nn.Sequential([
            nn.Dense(4 * self.n_embed),
            jax.nn.relu,
            nn.Dense(self.n_embed),
            nn.Dropout(rate=self.drop_rate, deterministic=True)
        ])
        x = net(x)

        return x

In [15]:
class Head(nn.Module):
    """
    A single-headed self-attention decoder block.
    Takes the combined token and position embedding as input,
    then calculates the key and query values.
    The key and query are multiplied to calculate the 
    attention scores/affinities. The future weights are
    then altered to have zero affinity, this ensures the 
    model can't "cheat". The input is then used to calculate
    the values, which are then aggregated by multiplying 
    them with the weights.
    """
    head_size: int
    drop_rate: float

    @nn.compact
    def __call__(self, x):
        B,T,C = x.shape
        key = nn.Dense(self.head_size, use_bias=False)
        k = key(x) # (B,T,C)
        query = nn.Dense(self.head_size, use_bias=False)
        q = query(x) # (B,T,C)
        # compute attention scores ("affinities")
        weights =  q @ k.transpose((0, -1, -2)) * self.head_size**-0.5 # (B, T, C) @ (B, C, T) ---> (B, T, T)
        tril = jnp.tril(jnp.ones(shape=(T, T), dtype=bool))
        tril = jnp.repeat(tril[None, ...], repeats=B, axis=0)
        weights = masked_fill(tril, weights, -jnp.inf)
        weights = jax.nn.softmax(weights, axis=-1)
        drop = nn.Dropout(rate=self.drop_rate, deterministic=True)
        weights = drop(weights)
        # perform the weighted aggregation of the values
        value = nn.Dense(self.head_size, use_bias=False)
        v = value(x)
        out = weights @ v
        return out

In [16]:
class MultiHeadedAttention(nn.Module):
    """
    Combines multiple heads of scaled self-attention 
    in parallel, then concatenates the heads outputs.
    """
    num_heads: int
    head_size: int
    n_embed: int
    drop_rate: float

    @nn.compact
    def __call__(self, x):
        # Create a list of num_heads heads
        heads = [Head(self.head_size, self.drop_rate) for _ in range(self.num_heads)]
        # Provide the same input for each head
        heads_out = [h(x) for h in heads]
        combined_logits = jnp.concatenate(heads_out, axis=-1)
        # Perform a linear projection of the self-attention
        proj = nn.Dense(self.n_embed)
        logits = proj(combined_logits)
        drop = nn.Dropout(rate=self.drop_rate, deterministic=True)
        logits = drop(logits)
        return logits

In [17]:
class Block(nn.Module):
    """
    Transformer decoder block.
    It combines communication and computation.
    The communication is performed by the 
    multi-headed attention layer.
    Then the computation is performed by 
    the feed forward block.
    Skip connections are used to make the block scalable 
    and layer norm is used to speed up training.
    """
    n_embed: int
    num_heads: int
    drop_rate: float

    @nn.compact
    def __call__(self, x):
        head_size = self.n_embed // self.num_heads
        sa_heads = MultiHeadedAttention(self.num_heads, head_size, self.n_embed, self.drop_rate)
        # Using skip connections with x + heads
        x = x + sa_heads(nn.LayerNorm()(x)) # apply one head of self-attention (B, T, C)
        ffwd = FeedForward(self.n_embed, self.drop_rate)
        x = x + ffwd(nn.LayerNorm()(x))
        return x

In [42]:
class GPT2(nn.Module):
    """
    GPT-2 language model.
    Uses the previous tokens in the sequence to 
    determine the probabilities of the next token.
    Processes the combined position and token embedding
    through multiple transformer decoder blocks, 
    which is then processed through a dense layer to 
    aquire the token logits.
    The logits can then be processed through a softmax
    function to calculate the token probabilities.
    """
    vocab_size: int
    n_embed: int
    block_size: int
    num_heads: int
    num_layers: int
    drop_rate: float
    
    @nn.compact
    def __call__(self, index_seq):
        B, T = index_seq.shape

        # Each token directly reads off the logits for the next token from a lookup table
        token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.n_embed) 
        token_emb = token_embedding_table(index_seq) # (B, T, C)

        position_embedding_table = nn.Embed(num_embeddings=self.block_size, features=self.n_embed) 
        pos_emb = position_embedding_table(jnp.arange(T)) # (T, C)

        x = token_emb + pos_emb # (B, T, C)

        decoder_blocks = [Block(self.n_embed, num_heads=self.num_heads, drop_rate=self.drop_rate) for _ in range(self.num_layers)]
        decoder_blocks.append(nn.LayerNorm())
        blocks = nn.Sequential(
            decoder_blocks
        )
        x = blocks(x)

        lm_head = nn.Dense(self.vocab_size)
        logits = lm_head(x) # (B, T, vocab_size)

        return logits

In [43]:
model = GPT2(vocab_size, n_embed, block_size, num_heads, num_layers, drop_rate)
dummy_x = jnp.zeros(shape=(batch_size, block_size), dtype=jnp.uint16)
variables = model.init(rng_key, dummy_x)

In [36]:
out = model.apply(variables, dummy_x)
print(out.shape)

(1, 32, 65)


## Text Generation Pre-Training

In [37]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 20

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

EiQnlVvj''mtKvgrI w:


## Train the Model

In [44]:
optimizer = optax.adamw(learning_rate=1e-2)
opt_state = optimizer.init(variables)

In [45]:
steps = 100

for step in range(steps):
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        model.apply,
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates) 

    print(f"Epoch: {step}, Loss: {loss :.4f}")

Epoch: 0, Loss: 4.4984
Epoch: 1, Loss: 4.0247
Epoch: 2, Loss: 3.7785
Epoch: 3, Loss: 3.4819
Epoch: 4, Loss: 3.5813
Epoch: 5, Loss: 3.4390
Epoch: 6, Loss: 3.3522
Epoch: 7, Loss: 3.3741
Epoch: 8, Loss: 3.2018
Epoch: 9, Loss: 3.5234
Epoch: 10, Loss: 3.2206
Epoch: 11, Loss: 3.1798
Epoch: 12, Loss: 3.2482
Epoch: 13, Loss: 3.1972
Epoch: 14, Loss: 3.2870
Epoch: 15, Loss: 3.2179
Epoch: 16, Loss: 3.2862
Epoch: 17, Loss: 3.3190
Epoch: 18, Loss: 3.3815
Epoch: 19, Loss: 3.3236
Epoch: 20, Loss: 3.2920
Epoch: 21, Loss: 3.1990
Epoch: 22, Loss: 3.3100
Epoch: 23, Loss: 3.2309
Epoch: 24, Loss: 3.3584
Epoch: 25, Loss: 3.2649
Epoch: 26, Loss: 3.1731
Epoch: 27, Loss: 3.1962
Epoch: 28, Loss: 3.4236
Epoch: 29, Loss: 3.1557
Epoch: 30, Loss: 3.2859
Epoch: 31, Loss: 3.1309
Epoch: 32, Loss: 3.1754
Epoch: 33, Loss: 3.1635
Epoch: 34, Loss: 3.2867
Epoch: 35, Loss: 3.1182
Epoch: 36, Loss: 2.9672
Epoch: 37, Loss: 3.1443
Epoch: 38, Loss: 3.1414
Epoch: 39, Loss: 3.1504
Epoch: 40, Loss: 3.0108
Epoch: 41, Loss: 3.0403
Ep

## Text Generation Post-Training

In [46]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 20

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

Wos at arXOTerof t t
