# Bigram Language Model

We will build an autoregressive bigram language model for predicting the next token in a sequence of text. The shakespeare_char dataset will be used for this demonstration, which can be found in the data folder.  

Bigram is a probabilistic model. It uses the previous token in the sequence to determine the probabilities of the next tokens occuring. Then the next token is sampled using the next tokens probabilities.  

The n-gram models are a more general case of the bigram model. They differ from bigram in that they use the last n-1 tokens in the sequence instead of just the last word. This enables them to see further back in the sentence to make their prediction. 

### References:
- [GPT colab notebook](https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing)
- [Video: simplest baseline: bigram language model, loss, generation](https://www.youtube.com/watch?v=kCc8FmEb1nY&t=1331s)



In [42]:
import os
import requests
import numpy as np
import jax
from jax import random
import jax.numpy as jnp
import flax.linen as nn
import optax
from jax import value_and_grad
from functools import partial

In [43]:
# download the tiny shakespeare dataset
input_file_path = os.path.join('./data/shakespeare_char/input.txt')
if not os.path.exists(input_file_path):
    data_url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
    with open(input_file_path, 'w') as f:
        f.write(requests.get(data_url).text)

with open(input_file_path, 'r') as f:
    data = f.read()
print(f"length of dataset in characters: {len(data):,}")

# get all the unique characters that occur in this text
chars = sorted(list(set(data)))
vocab_size = len(chars)
print("all the unique characters:", ''.join(chars))
print(f"vocab size: {vocab_size:,}")

# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
def encode(s):
    return [stoi[c] for c in s] # encoder: take a string, output a list of integers
def decode(l):
    return ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

length of dataset in characters: 1,115,394
all the unique characters: 
 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
vocab size: 65


In [44]:
# create the train and test splits
n = len(data)
train_data = data[:int(n*0.9)]
val_data = data[int(n*0.9):]

# encode both to integers
train_ids = encode(train_data)
val_ids = encode(val_data)
print(f"train has {len(train_ids):,} tokens")
print(f"val has {len(val_ids):,} tokens")

train_ids = np.array(train_ids, dtype=np.uint16)
val_ids = np.array(val_ids, dtype=np.uint16)

train has 1,003,854 tokens
val has 111,540 tokens


In [45]:
print(decode(train_ids[:100]))

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You


In [46]:
rng_key = jax.random.PRNGKey(128)

In [47]:
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(data, rng_key, batch_size, block_size):
    """
    Extracts a random batch of input and target data
    Args:
        data: An array of all the data's token ID's.
        rng_key: Random number generator key.
        batch_size: Number of parallel batches.
        block_size: Maximum time length for the token sequence.
    Returns:
        Input token ID's and target token ID's.
    """
    ix = random.randint(key=rng_key, shape=(batch_size, ), minval=0, maxval=len(data) - block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

In [48]:
xb, yb = get_batch(train_ids, rng_key, batch_size, block_size)

print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

inputs:
(4, 8)
[[ 1 41 53 51 51 39 52 42]
 [47 41 46  1 40 63  1 58]
 [43  1 58 53  1 57 39 60]
 [58 43 56  5 42  1 46 47]]
targets:
(4, 8)
[[41 53 51 51 39 52 42 43]
 [41 46  1 40 63  1 58 46]
 [ 1 58 53  1 57 39 60 43]
 [43 56  5 42  1 46 47 57]]


In [49]:
for b in range(batch_size):
    for t in range(block_size):
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

when input is [1] the target: 41
when input is [1, 41] the target: 53
when input is [1, 41, 53] the target: 51
when input is [1, 41, 53, 51] the target: 51
when input is [1, 41, 53, 51, 51] the target: 39
when input is [1, 41, 53, 51, 51, 39] the target: 52
when input is [1, 41, 53, 51, 51, 39, 52] the target: 42
when input is [1, 41, 53, 51, 51, 39, 52, 42] the target: 43
when input is [47] the target: 41
when input is [47, 41] the target: 46
when input is [47, 41, 46] the target: 1
when input is [47, 41, 46, 1] the target: 40
when input is [47, 41, 46, 1, 40] the target: 63
when input is [47, 41, 46, 1, 40, 63] the target: 1
when input is [47, 41, 46, 1, 40, 63, 1] the target: 58
when input is [47, 41, 46, 1, 40, 63, 1, 58] the target: 46
when input is [43] the target: 1
when input is [43, 1] the target: 58
when input is [43, 1, 58] the target: 53
when input is [43, 1, 58, 53] the target: 1
when input is [43, 1, 58, 53, 1] the target: 57
when input is [43, 1, 58, 53, 1, 57] the targe

In [50]:
class BigramLanguageModel(nn.Module):
    """
    Uses the previous token in the sequence to 
    determine the probabilities of the next token.
    """
    vocab_size: int
    
    @nn.compact
    def __call__(self, x):
        # Each token directly reads off the logits for the next token from a lookup table
        token_embedding_table = nn.Embed(num_embeddings=self.vocab_size, features=self.vocab_size)
        logits = token_embedding_table(x)
        return logits

In [60]:
model = BigramLanguageModel(vocab_size)

variables = model.init(rng_key, xb)

In [52]:
out = model.apply(variables, xb)
print(out.shape)

(4, 8, 65)


In [53]:
@partial(jax.jit, static_argnames=['vocab_size', 'batch_size', 'max_new_tokens'])
def generate(variables, index_seq, rng_key, vocab_size, batch_size, max_new_tokens):
    """
    Generates max_new_tokens number of new tokens, 
    given the starting sequence of tokens index_seq 
    Args:
        variables: Bigram models parameters.
        index_seq: Array of token indices with shape (B, T), 
            where B is the batch size and T is the time steps.
        rng_key: Random number generator key.
        vocab_size: Number of independent tokens in the vocabulary.
        max_new_tokens: Maximum number of new tokens to generate
    Returns:
        An array of generated indices
    """
    # Batched sampling function
    batched_choice = jax.vmap(jax.random.choice)
    
    for _ in range(max_new_tokens):
        logits = model.apply(variables, index_seq)
        # Focus only on the last time step
        # Shape changes from (B, T, C) -> (B, C)
        logits = logits[:, -1, :]
        # Convert to probabilities using softmax
        probs = jax.nn.softmax(logits, axis=-1)
        # Sample a token index using probs
        rng_key, subkey = jax.random.split(rng_key)
        batched_key = subkey.reshape(1, -1)
        batched_key = jnp.repeat(batched_key, batch_size, axis=0)
        a = jnp.arange(vocab_size).reshape(1, -1)
        a = jnp.repeat(a, batch_size, axis=0)
        next_indexes = batched_choice(batched_key, a, p=probs)
        next_indexes = next_indexes.reshape(batch_size, -1)
        # Append the sampled index to the running sequence
        index_seq = jnp.concatenate([index_seq, next_indexes], axis=1)
    return index_seq



In [54]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, index_seq, rng_key, vocab_size, 1, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

FeRkiTvg.,jtMwetQ
x;;zZFeVmFgOtyYaXqu,wzhj Sfh,i3rE.,rrkHm'PDy,sja33d&;K:,EEhIeMCNl zv;wZkPlNl.lqbbL


In [55]:
def loss_fn(variables, index_seq, labels):
    """
    Calculates the cross entropy loss of 
    all batches and time steps, 
    then returns the mean.
    Args:
        variables: Language model parameters.
        index_seq: Array of token indices with shape (B, T), 
            where B is the batch size and T is the time steps.
        labels: Indexes of the next token in the sequence.
    Returns:
        Cross entropy loss
    """
    logits = model.apply(variables, index_seq)

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    # Average loss across all batches and time steps
    loss = loss.mean()
    return loss

In [59]:
optimizer = optax.adamw(learning_rate=1e-2)
opt_state = optimizer.init(variables)

In [61]:
steps = 100
batch_size = 32

for step in range(steps):
    rng_key, subkey = jax.random.split(rng_key)
    xb, yb = get_batch(train_ids, subkey, batch_size, block_size)

    loss, grads = value_and_grad(loss_fn, argnums=(0))(
        variables, 
        xb, 
        yb
    )
    updates, opt_state = optimizer.update(grads, opt_state, variables)
    variables = optax.apply_updates(variables, updates) 

    print(f"Epoch: {step}, Loss: {loss :.4f}")

Epoch: 0, Loss: 4.1947
Epoch: 1, Loss: 4.1744
Epoch: 2, Loss: 4.1583
Epoch: 3, Loss: 4.1578
Epoch: 4, Loss: 4.1329
Epoch: 5, Loss: 4.1435
Epoch: 6, Loss: 4.1164
Epoch: 7, Loss: 4.1246
Epoch: 8, Loss: 4.1056
Epoch: 9, Loss: 4.0930
Epoch: 10, Loss: 4.0616
Epoch: 11, Loss: 4.0610
Epoch: 12, Loss: 4.0631
Epoch: 13, Loss: 4.0392
Epoch: 14, Loss: 4.0296
Epoch: 15, Loss: 4.0367
Epoch: 16, Loss: 4.0135
Epoch: 17, Loss: 3.9982
Epoch: 18, Loss: 3.9788
Epoch: 19, Loss: 3.9806
Epoch: 20, Loss: 3.9684
Epoch: 21, Loss: 3.9566
Epoch: 22, Loss: 3.9539
Epoch: 23, Loss: 3.9224
Epoch: 24, Loss: 3.9471
Epoch: 25, Loss: 3.9120
Epoch: 26, Loss: 3.9014
Epoch: 27, Loss: 3.9146
Epoch: 28, Loss: 3.8944
Epoch: 29, Loss: 3.8898
Epoch: 30, Loss: 3.8620
Epoch: 31, Loss: 3.8377
Epoch: 32, Loss: 3.8293
Epoch: 33, Loss: 3.8346
Epoch: 34, Loss: 3.8517
Epoch: 35, Loss: 3.8439
Epoch: 36, Loss: 3.8121
Epoch: 37, Loss: 3.7939
Epoch: 38, Loss: 3.8129
Epoch: 39, Loss: 3.7650
Epoch: 40, Loss: 3.7597
Epoch: 41, Loss: 3.7425
Ep

In [62]:
index_seq = jnp.zeros(shape=(1,1), dtype=jnp.uint16)
max_new_tokens = 100

generated_indices = generate(variables, index_seq, rng_key, vocab_size, 1, max_new_tokens)
generated_indices = list(np.array(generated_indices[0]))
print("Generated text: ")
print(decode(generated_indices))

Generated text: 

P?NLp,IRUmvt&UnpeAqodwb
ky;JkDLORenmgrkn,Pm, owraSle-nsVit;b3k!haugy wt:!MI',
YLDGLnicbbunbeG'T?UvbL
