# Bigram Language Model

We will build an autoregressive bigram language model for predicting the next token in a sequence of text. The shakespeare_char dataset will be used for this demonstration, which can be found in the data folder.  

Bigram is a probabilistic model. It uses the previous token in the sequence to determine the probabilities of the next tokens occuring. Then the next token is sampled using the next tokens probabilities. For further explanation of the bigram model, see Andrej Karpathy's video [2].

The n-gram models are a more general case of the bigram model. They differ from bigram in that they use the last n-1 tokens in the sequence instead of just the last word. This enables them to see further back in the sentence to make their prediction. 

### References:
- [1] [GPT colab notebook](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [2] [Video: bigram language model, loss, generation](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1331s)



In [None]:
import os
import requests
import numpy as np
import jax
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import value_and_grad

from helper_funcs import get_batch, generate, loss_fn

## Data Preparation

In [2]:
# download the tiny shakespeare dataset
input_file_path = os.path.join('./data/shakespeare_char/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [3]:
# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)

train has 1,003,854 tokens
val has 111,540 tokens


In [4]:
print(decode(train_ids[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [5]:
rng_key = jax.random.PRNGKey(128)

batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

In [6]:
xb, yb = get_batch(train_ids, rng_key, batch_size, block_size)

print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
(4, 8)
[[ 1 41 53 51 51 39 52 42]
 [47 41 46  1 40 63  1 58]
 [43  1 58 53  1 57 39 60]
 [58 43 56  5 42  1 46 47]]
targets:
(4, 8)
[[41 53 51 51 39 52 42 43]
 [41 46  1 40 63  1 58 46]
 [ 1 58 53  1 57 39 60 43]
 [43 56  5 42  1 46 47 57]]


In [7]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

when input is [1] the target: 41
when input is [1, 41] the target: 53
when input is [1, 41, 53] the target: 51
when input is [1, 41, 53, 51] the target: 51
when input is [1, 41, 53, 51, 51] the target: 39
when input is [1, 41, 53, 51, 51, 39] the target: 52
when input is [1, 41, 53, 51, 51, 39, 52] the target: 42
when input is [1, 41, 53, 51, 51, 39, 52, 42] the target: 43
when input is [47] the target: 41
when input is [47, 41] the target: 46
when input is [47, 41, 46] the target: 1
when input is [47, 41, 46, 1] the target: 40
when input is [47, 41, 46, 1, 40] the target: 63
when input is [47, 41, 46, 1, 40, 63] the target: 1
when input is [47, 41, 46, 1, 40, 63, 1] the target: 58
when input is [47, 41, 46, 1, 40, 63, 1, 58] the target: 46
when input is [43] the target: 1
when input is [43, 1] the target: 58
when input is [43, 1, 58] the target: 53
when input is [43, 1, 58, 53] the target: 1
when input is [43, 1, 58, 53, 1] the target: 57
when input is [43, 1, 58, 53, 1, 57] the targe

In [8]:
class BigramLanguageModel(nn.Module):
    """
    Uses the previous token in the sequence to 
    determine the probabilities of the next token.
    """
    vocab_size: int
    
    @nn.compact
    def __call__(self, x):
        # Each token directly reads off the logits for the next token from a lookup table
        token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)
        logits = token_embedding_table(x)
        return logits

In [9]:
model = BigramLanguageModel(vocab_size)

variables = model.init(rng_key, xb)

In [10]:
out = model.apply(variables, xb)
print(out.shape)

(4, 8, 65)


## Text Generation Pre-Training

In [11]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

FeRkiTvg.,jtMwetQ
x;;zZFeVmFgOtyYaXqu,wzhj Sfh,i3rE.,rrkHm'PDy,sja33d&;K:,EEhIeMCNl zv;wZkPlNl.lqbbL


## Train the Model

In [12]:
optimizer = optax.adamw(learning_rate=1e-2)
opt_state = optimizer.init(variables)

In [13]:
steps = 100
batch_size = 32

for step in range(steps):
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        model.apply,
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates) 

    print(f"Epoch: {step}, Loss: {loss :.4f}")

Epoch: 0, Loss: 4.1996
Epoch: 1, Loss: 4.1725
Epoch: 2, Loss: 4.1578
Epoch: 3, Loss: 4.1523
Epoch: 4, Loss: 4.1439
Epoch: 5, Loss: 4.1430
Epoch: 6, Loss: 4.1282
Epoch: 7, Loss: 4.1220
Epoch: 8, Loss: 4.1064
Epoch: 9, Loss: 4.1081
Epoch: 10, Loss: 4.0991
Epoch: 11, Loss: 4.0791
Epoch: 12, Loss: 4.0501
Epoch: 13, Loss: 4.0383
Epoch: 14, Loss: 4.0466
Epoch: 15, Loss: 4.0380
Epoch: 16, Loss: 4.0254
Epoch: 17, Loss: 4.0144
Epoch: 18, Loss: 3.9974
Epoch: 19, Loss: 3.9868
Epoch: 20, Loss: 3.9660
Epoch: 21, Loss: 3.9618
Epoch: 22, Loss: 3.9420
Epoch: 23, Loss: 3.9424
Epoch: 24, Loss: 3.9269
Epoch: 25, Loss: 3.9313
Epoch: 26, Loss: 3.9343
Epoch: 27, Loss: 3.9111
Epoch: 28, Loss: 3.8979
Epoch: 29, Loss: 3.8893
Epoch: 30, Loss: 3.8613
Epoch: 31, Loss: 3.8779
Epoch: 32, Loss: 3.8776
Epoch: 33, Loss: 3.8579
Epoch: 34, Loss: 3.8241
Epoch: 35, Loss: 3.8421
Epoch: 36, Loss: 3.7989
Epoch: 37, Loss: 3.7984
Epoch: 38, Loss: 3.7727
Epoch: 39, Loss: 3.8095
Epoch: 40, Loss: 3.7699
Epoch: 41, Loss: 3.7758
Ep

## Text Generation Post-Training

In [14]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, model.apply, index_seq, rng_key, vocab_size, 1, block_size, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

frmIIt;DeLINCZlGHk.lxir,Uqh-3QhhSRK$'eI uxnxcUKWD3nCyVWe,cCy
PAgOLCwH;&B!;phtwftheD aqEtoKA$VvO:IM:s
