In [None]:
# in this notebook, we'll build up to a GPT in JAX

In [None]:
import time
from typing import List, Dict, Tuple, Callable, NamedTuple, Any, Optional
from functools import partial

### intro/what is a GPT



In [None]:
# GPT models Generate output, rather than predicting or classifying labels
# typically, we might pass a model a pair (features, target)
# where the model has to get the target 'correct' relative to the features
# like guessing the correct house price given some data about the house
# or guessing the right digit drawn from an image dataset

# on the other hand:
# passing an input to a GPT model like 'the man outside was'
# the GPT model wouldn't produce a class label, or a number prediction
# rather, it would generate the rest of the sentence like
# 'the man outside was mowing his lawn'

# GPT models learn to do this through 'pretraining'
# instead of explicitly giving the model labeled examples like in the tasks above (features, target)
# the model simply supervises itself, comparing an output it made to the actual data
# this is usually called 'pretraining' rather than 'training' as its followed by a more classic training test
# where the model actually has to classify or predict.
# as an example, a GPT model might be pretrained on large chunks of text
# and then asked to predict if a given imdb rating is 'positive' or 'negative'
# the idea is that the model will become too specific to a given task if given labels
# but we want a model that can generalize across tasks
# this is clued in by the title of the GPT-2 paper 'language models are unsupervised task learners'

# finally, the Transformer architecture is where the magic happens
# there's a lot to cover there, so we'll motivate the choices we make step by step in what follows

### why use jax?

In [None]:
# TODO: this motivation for jax could be better

# in order to implement this ourselves
# we will be using JAX
# JAX is a library that is based on doing high performance math with arrays
# arrays are really nice because we can represent very high dimensional data with them
# if you've used numpy, a lot of the syntax will be very familiar
# if a vector is a 1 dimensional object
# a matrix is 2 dimensions
# and everything beyond is a tensor
# as an example, you can imagine a 3d tensor as a bunch of matrices stacked side by side into a cuboid type situation
# (computerphile on tensors: https://www.youtube.com/watch?v=DfK83xEtJ_k)

# jax is an autodifferentiating library
# this means that it automatically can keep track of computations that happen to a given input
# and automatically calculate partial derivatives
# this is very useful for machine learning research
# as backpropagation is how neural networks fit their parameters
# 3blue1brown on backprop (https://www.youtube.com/watch?v=Ilg3gGewQ5U)

# backpropagation is frequently parallelized on 'batches' of input at a time
# jax is nice in that it provides automatic function vectorization
# it makes doing things like [f(i) for i in x] very efficient

# in machine learning research
# we also frequently have functions where the transformations don't really change
# and we might want to compile those functions, rather than re-interpreting it each time
# while python is very convenient for its extensive tooling and ease of writing
# it doesnt natively have a way to compile and efficiently execute chunks of code
# so jax introduces Just In Time compiling
# making it easy to write performant code

# ultimately, jax is just about expressing and composing transformations, or functions
# https://www.youtube.com/watch?v=PAZTIAfaNr8
# jit, vmap, and grad all work very easily with each other due to jax's functional style
# and it makes the compilers life a lot easier too
# code is also more easy to reproduce
# as jax's functional style avoids implicit global states

# let's import it now

import jax
import jax.numpy as jnp

### what goes into a GPT

In [None]:
# back to the GPT, we'll go over some high level structure of what we need to make

# as we mentioned before, a GPT generates some output, in our case text output
# if we're doing a bunch of array based math
# we'll need a way to turn strings into arrays of numbers
# we call this a 'tokenizer'
# it breaks up a string into different pieces in a way that makes sense
# and each piece gets its own id
# ultimately, we want a model that generates text
# machine learning models don't really use words
# so we'll need a way to go from text: str -> tokens: jax.Array, using some learned parameters
# this introduces a model configuration setting: 'vocabulary size'
# which is just the total number of tokens the model knows about
# we'll call it vocab_size from now on

# now, our GPT needs to take in that tokens: jax.Array, and output some generated_tokens: jax.Array
# if we take in the tokens 'hello world', our sequence length is 11
# so our input is of shake (seq_len, vocab_size)

# but it doesn't generate all the tokens at once
# it'll take the input, and predict the next token that should come after that input
# continuing this until it either decides to stop, or until we cut it off
# the next token prediction is usually a distribution of values
# one value for each possible next token
# we call that raw output 'logits'
# since its just for the next token, the logits are of size (vocab_size)

# from this distribution of logits
# we need a way to choose the single token, or sample from the distribution
# the simplest way of doing this is to just pick the highest value
# so we'll take in the logits of shape (vocab_size) and output a single number

# lastly, we'll need to implement the unsupervised training for the model
# we have to give the model a way to understand how good/bad its token prediction was
# also known as a loss function
# and update its parameters based on the the gradient with respect to that loss function

# we'll implement basic versions of all of the above, and build on them over time

# to make this faster, we build it so it can run in parallel
# we can do multiple sequences, of shape (batch_size, seq_len, vocab_size)
# so the logit output becomes (batch_size, vocab_size)
# and the final output is (batch_size,)
# jax's automatic vectorization we discussed earlier makes this very convenient


### basic tokenizer

In [None]:
# TODO: could break this up into initial tokenizer -> with special chars

# for our initial tokenizer, we can use python's ord() and chr()
# to turn chars into integers, and vice versa
# the ascii printable characters are from 32-126
# so we'll define a 'char_to_idx' and an 'idx_to_char' function
# this will make our lives easier by mapping into the range 0-94
# as a default, we can return 0 if ord(x) is outside of the printable character idxs
# and return ('\n') if we get an index we don't know about

"""
def char_to_idx(char: int) -> int:
    return

def idx_to_char(index: int) -> int:
    return

print(char_to_idx(ord('a')))
print(chr(idx_to_char(char_to_idx(ord('a')))))
"""

def char_to_idx(char: int) -> int:
    if 32 <= char <= 126:
        return char - 32
    else:
        return 0

def idx_to_char(index: int) -> int:
    if 0 <= index <= 94:
        return index + 32
    else:
        return ord('\n')


print(char_to_idx(ord('a')))
print(chr(idx_to_char(char_to_idx(ord('a')))))

In [None]:
# now, we make the tokenizer
# we'll create a namedtuple object, to easily organize our tokenizer

"""
class Tokenizer(NamedTuple):
    encode: Callable
    decode: Callable

# use char_to_idx/idx_to_char from above with a list comprehension

def tokenizer_encode(text: str) -> jax.Array:
    return ...

def tokenizer_decode(ids: jax.Array) -> str:
    return ...

tokenizer = Tokenizer(encode=encode, decode=decode)
# Test the functions
print('hello')
print(tokenizer.encode('hello'))
print(tokenizer.decode(tokenizer.encode('hello')))

print('1.\n2.')
print(tokenizer.encode('1.\n2.'))
print(tokenizer.decode(tokenizer.encode('1.\n2.')))

"""

class Tokenizer(NamedTuple):
    encode: Callable
    decode: Callable

def encode(text: str) -> jax.Array:
    return jnp.array([char_to_idx(ord(c)) for c in text])

def decode(ids: jax.Array) -> str:
    return ''.join([chr(idx_to_char(int(id))) for id in ids])

tokenizer = Tokenizer(encode=encode, decode=decode)
# Test the functions
print('hello')
print(tokenizer.encode('hello'))
print(tokenizer.decode(tokenizer.encode('hello')))

print('1.\n2.')
print(tokenizer.encode('1.\n2.'))
print(tokenizer.decode(tokenizer.encode('1.\n2.')))

# notice that our tokenizer does not handle the \n token!

### bigram model

In [None]:
# now we'll implement the most basic possible model
# it should perform a lookup from one vocabulary word, to some random logits, of size (batch_size ,vocab_size)
# suppose we had a vocab size of 3: [a,b,c]
# our matrix basically looks like
#      a     b     c
# a  0.1   0.4   0.5
# b  0.9  0.02  0.08
# c  0.8   0.1   0.1

# so if we have token b, we would usually predict token a, with 90% probability
# the above values are probabilities for simplicity
# but in actuality will be generated from a normal distribution (0,1) * scaling_factor

In [None]:
# first we'll create a struct to hold configuration parameters
# an initialization function to initialize parameters
# and a forward pass of the model, taking in some input and creating outpu

# struct
class ModelConfig(NamedTuple):
    vocab_size: int

In [None]:
# to initialize parameters
# jax wants us to pass in a jax.random.PRNGKey()
# jax uses this sort of like the way a seed works
# but rather than defining a global seed, we get to explicitly pass it to our random functions
# https://jax.readthedocs.io/en/latest/random-numbers.html#explicit-random-state
# we'll store our parameters in a dictionary, and call the matrix 'token_table'
# it should have shape (vocab_size, vocab_size), to have a mapping from each token to each other token in the vocab

"""
def initialize_bigram_params(model_config: ModelConfig, key=jax.random.PRNGKey(0)) -> Dict:
    vocab_size = model_config.vocab_size
    weights = ... # initialize with jax.random.normal
    model_params = {'token_table': }
    return model_params

bigram_config = ModelConfig(vocab_size = 95)
bigram_params = initialize_bigram_params(bigram_config)
bigram_params['token_table'].shape
"""

def initialize_bigram_params(model_config: ModelConfig, key=jax.random.PRNGKey(0)) -> Dict:
    vocab_size = model_config.vocab_size
    weights = jax.random.normal(key, (vocab_size, vocab_size))
    model_params = {'token_table': weights}
    return model_params

bigram_config = ModelConfig(vocab_size = 95)
bigram_params = initialize_bigram_params(bigram_config)
bigram_params['token_table'].shape

In [None]:
# and now we write the forward pass
# as discussed above, we'll use the token table
# to figure out what our prediction for the next token should be
# we'll be taking in a vector of shape (batch_size, sequence_length,)
# take the last token as our key
# keying into the token_table will return a vector of size (batch_size, vocab_size)
# and finally, we should return a vector of size (batch_size, 1, vocab_size)

"""
def bigram_model(model_params: Dict, model_config: ModelConfig, tokens: jax.Array): # (seq_len,) -> (1,vocab_size)
    token_table = ...
    last_token_id = ...
    logits = ...
    reshaped_logits = ...
    return reshaped_logits


logits = bigram_model(bigram_params, bigram_config, tokenizer.encode('hello'))
logits.shape # you can use this to debug!

"""

def bigram_model(model_params: Dict, model_config: ModelConfig, tokens: jax.Array): # (seq_len,) -> (1,vocab_size)
    token_table = model_params['token_table']
    last_token_id = tokens[-1]
    logits = token_table[last_token_id]
    reshaped_logits = logits.reshape((1, model_config.vocab_size))
    return reshaped_logits

logits = bigram_model(bigram_params, bigram_config, tokenizer.encode('hello'))
logits.shape # you can use this to debug!

### sampling

In [None]:
# now we take this raw logit output
# and implement a sample function
# which should take in the model output of shape (batch_size, vocab_size)
# and return values, of shape (batch_size,)

# we can think of the next token prediction as similar to a classification task
# suppose we had to pick between [red, green, blue] in some computer vision model
# the correct 'label' might be [1, 0, 0]
# a classification model usually compares the label to some probability distribution output
# as an example, our model output might be one of [0.9, 0.02, 0.08] or [0.4, 0.3, 0.3]
# while the maximum value for both is correct
# its clear that one model potentially performs better at the task
# as its probability distribution is in some sense 'closer' to the true distribution

# but we need don't have the probabilities yet
# our model, in the above example, might be outputting something like [5, 3, 2]
# how do we turn this into a probability?

In [None]:
# lets call that initial output f(x), with values f(x_i)
# we need to create a probability distribution p(x)
# but we dont want to lose any of the information from f(x)
# by the principle of maximum entropy
# the best representation of some state of knowledge is the probability distribution that maximizes entropy
# in this context, entropy is the expected value of 'surprise'

# the 'surprise' of some outcome can actually be mathematically defined
# to develop intuition about this, suppose you had 100 red balls and 2 blue balls in a box
# drawing a blue ball out is more 'surprising' than a red ball
# so the surprise might be 1/p
# but if p is 0 this breaks, so we just add a log transform
# thus, the rigorous probability definition of surprise is log(1/p)
# and entropy is E = sum (p(x) * log(1/p(x)))
# this can be simplified into E = -sum(p(x) * log(p(x))) (see if you can do it yourself)

# so we can set up a constrained optimization problem
# to figure out a function that can map f(x) into the best p(x)
# our objective function is maximizing E = -sum(p(x) * log(p(x)))
# our constraints are that
# 1. sum(p(x)) = 1
# 2. p_i >= 0

# using lagrange multipliers, we can solve the above to get
# p_i = e^{beta * x_i}/sum(e^{beta * x_j})
# for writing the softmax, we'll remove beta
# but remember that we can scale the logits before doing the softmax
# this beta constant is frequently rewritten as 1/T, where T is temperature
# we call it that because of the boltzmann distribution in statistical mechanics
# i dont know enough about that to explain it but: https://en.wikipedia.org/wiki/Boltzmann_distribution

"""
def softmax(x: jnp.array) -> jnp.array:
    return

x1 = jnp.array([1.2, 2, -4, 0.0])
x2 = jnp.array([1.2, 2000, -4000, 0.0])

print(softmax(x1))
print(softmax(x2))
"""

def softmax(x: jax.Array) -> jax.Array:
    exp_x = jnp.exp(x)
    return exp_x / jnp.sum(exp_x)

x1 = jnp.array([1.2, 2, -4, 0.0])
x2 = jnp.array([1.2, 2000, -4000, 0.0])

print(softmax(x1))
print(softmax(x2))

In [None]:
# you might have noticed we have a bit of instability with this computation
# e^x where x is a really big number can easily overflow
# but e^-x will always be between 0 and 1
# so we want to see if we can shift all of our e^x such that x is, at most, 0
# suppose we subtracted some constant c from all of our x_i (from the initial f(x))
# mathematically, our function is invariant to any shifts (can you prove this)
# so we can just subtract the largest value from every value
# to get a numerically stable softmax

"""
def stable_softmax(x: jnp.array) -> jnp.array:
    return

x1 = jnp.array([1.2, 2, -4, 0.0])
x2 = jnp.array([1.2, 2000, -4000, 0.0])

print(stable_softmax(x1))
print(stable_softmax(x2))

"""

def stable_softmax(x: jax.Array) -> jax.Array:
    exp_x = jnp.exp(x - jnp.max(x))
    return exp_x / jnp.sum(exp_x)

x1 = jnp.array([1.2, 2, -4, 0.0])
x2 = jnp.array([1.2, 2000, -4000, 0.0])

print(stable_softmax(x1))
print(stable_softmax(x2))

In [None]:
# to set up our generation
# we'll need to pass in our model parameters, config, forward pass
# to make selecting the forward pass easier, we'll set up a dictionary MODEL_DICT

MODEL_DICT = {
    'bigram_model': bigram_model,
}

# our generate function should use
# model params/config/fwd pass
# prompt tokens
# max tokens to generate
# temperature (as previously discussed)
# and a jax key
# we'll be using the jax.random.split as discussed in # https://jax.readthedocs.io/en/latest/random-numbers.html#explicit-random-state

# while we haven't finished generating tokens
# get the logits for the current sequence
# get the last token's logits
# scale the logits with temperature
# sample with jax.random.categorical to get the next token (we can use shape (1,), not parallelizing generations)
# add to the generated tokens to return (you can use jnp.concatenate)
# add to the tokens we use for the prompt
# repeat

"""
def generate(params: Dict, model_config: ModelConfig, model_name: str, prompt_tokens: jax.Array, max_new: int, temp=1, key=jax.random.PRNGKey(0)):
    gen_tokens = jnp.array([], dtype=jnp.int32)
    cur_pos = 0

    while cur_pos < max_new:
        # may want to split your key on each generation
        key, subkey = ...
        logits = ...
        last_token_logit = ...
        scaled_logit = ...
        next_token = ...
        gen_tokens = ...
        tokens = ...
        cur_pos += 1

    return gen_tokens

prompt = 'hello?'

model_name = 'bigram_model'
tokenized_prompt = jnp.array(tokenizer.encode(prompt), dtype=jnp.int32)
generated_tokens = generate(bigram_params, bigram_config, model_name, tokens=tokenized_prompt, max_new=10, temp=0.8, key=jax.random.PRNGKey(0))
generated_text = tokenizer.decode(generated_tokens)
print(prompt + generated_text)
"""

def generate(params: Dict, model_config: ModelConfig, model_name: str, tokens: jax.Array, max_new: int, temp=1, key=jax.random.PRNGKey(0)):
    gen_tokens = jnp.array([], dtype=jnp.int32)
    cur_pos = 0

    while cur_pos < max_new:
        key, subkey = jax.random.split(key, 2)
        logits = MODEL_DICT[model_name](params, model_config, tokens)
        last_token_logit = logits[-1:]
        scaled_logit = last_token_logit / temp
        next_token = jax.random.categorical(subkey, scaled_logit, shape=(1,))
        gen_tokens = jnp.concatenate((gen_tokens, next_token))
        tokens = jnp.concatenate((tokens, next_token))
        cur_pos += 1

    return gen_tokens

prompt = 'hello?'

model_name = 'bigram_model'
tokenized_prompt = jnp.array(tokenizer.encode(prompt), dtype=jnp.int32)
generated_tokens = generate(bigram_params, bigram_config, model_name, tokens=tokenized_prompt, max_new=10, temp=0.8, key=jax.random.PRNGKey(0))
generated_text = tokenizer.decode(generated_tokens)
print(prompt + generated_text)

### loss

In [None]:
# turns out, this model is pretty weak
# before we go about improving it, we need to measure how well/poorly its doing
# our model outputs logits, which get softmaxed into some distribution q
# for each training datapoint, we want to know how surprising the real distribution is
# for example, if our guess q has 0.9 for the token !  and 0.002 for ?, but the actual label is ?
# then we were pretty wrong
# turns out we can basically just use the entropy equation from before
# but rather than using
# E = -sum(p(x) * log(p(x)))
# we just replace one of the p(x) distributions with our model output distribution q(x)
# C = -sum(p(x) * log(q(x)))
# this is known as 'cross-entropy'
# note that we choose to put the log transform on our q distribution bc log(0) is not stable

# cross_entropy_loss should take in logits, of shape (batch_size, vocab_size)
# the targets, should be (batch_size,)

# we know that we have vocab_size classes
# so we reshape the targets using one hot encoding, into (batch_size, vocab size) as well

# one hot encoding would take some id of 3, for a vocab size 5 into [0,0,0,1,0]

# after that, take the log softmax because we log transformed p(x) above
# and then the loss is just the equation we derived above: the negated sum of the expected surprise

# we use axis=-1 because we want to sum over the vocab_size in the shape

# finally, we take the mean of this value across the batch size
# to output our final loss value (as a scalar)

"""
def cross_entropy_loss(logits: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:

    num_classes = ...
    onehot_targets = ...
    log_probs = ...
    loss = ...

    mean_loss = ...

    return mean_loss

text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

logits = bigram_model(bigram_params, bigram_config, tokenized_text)
cross_entropy_loss(logits, target)
"""

def cross_entropy_loss(logits: jnp.ndarray, targets: jnp.ndarray) -> jnp.ndarray:

    num_classes = logits.shape[-1]
    onehot_targets = jax.nn.one_hot(targets, num_classes)
    log_probs = jax.nn.log_softmax(logits)
    loss = -jnp.sum(onehot_targets * log_probs, axis=-1)

    mean_loss = jnp.mean(loss)

    return mean_loss

text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

logits = bigram_model(bigram_params, bigram_config, tokenized_text)
cross_entropy_loss(logits, target)

### backpropagation and optimizers (grad/pytree intro)

In [None]:
# now that we can compute how bad our model is
# we should try to update its parameters

# suppose our model had 2 parameters, x and y
# and the loss function was some z axis
# we could plot the surface of the loss as a function of x and y
# and try to find the minimum value of it

# unfortunately, we won't know what the surface looks like at all points
# we can only compute it for some pair (x,y)
# so how can we find the minimum value?

# one way to motivate this is to imagine a blind man on the loss surface
# trying to get to the bottom
# at each point on the loss surface
# the blind man can feel out how steep the ground he stands on is
# and follow the slope

# if we take the gradient of the loss function with respect to our parameters
# we can basically do the same thing with our model
# for each of our parameters
# we update the parameter by taking a step 'downhill'
# typically, we take a step in the direction by multiplying the gradient by some 'learning rate'

parameter = 2
grad = 0.2
learning_rate = 0.01
new_parameter = parameter - learning_rate * grad
new_parameter

In [None]:
# as previously mentioned, jax is an autodiff library
# so we can automatically calculate the gradient with respect to some loss function
# basically jax applies a tracer object to all of its arguments
# the tracer records all the operations that happen to it
# and jax creates a Jaxpr out of it
# as an example, lets look at gelu, the activation function used in GPT-2
from jax import grad

def gelu(x):
    return 0.5 * x * (1 + jnp.tanh(jnp.sqrt(2 / jnp.pi) * (x + 0.044715 * jnp.power(x, 3))))

jax.make_jaxpr(gelu)(3.0)

In [None]:
# when computing the gradient
# jax traverses this intermediate representation to calculate the derivative


slope = 2.0
intercept = -0.5

x = 1.5

pred = slope * x + intercept

true = 4.0

mse = lambda y_hat, y: (y_hat - y)**2

print(jax.make_jaxpr(mse)(pred,true)) # c = a - b, d = c**2

value, grad = jax.value_and_grad(mse)(pred, true) # value and grad gives us both the value and gradient
value, grad

In [None]:
# let's see what this looks like for a training step here

"""
text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

def model_loss(model_params, model_config, model_name, tokens, target):
    logits = ...
    loss = ...
    return loss

# use partial to set up the loss function, without passing in params yet
partial_loss_fn = partial(...)

# view jaxpr
jaxpr = jax.make_jaxpr(jax.value_and_grad(partial_loss_fn))(bigram_params)
print(jaxpr)

loss, grads = ...

print(f'loss: {loss}')
print(f'grad: {grad}')

lr = 1e-4
new_bigram_params = ...
"""

text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

def model_loss(model_params, model_config, model_name, tokens, target):
    logits = MODEL_DICT[model_name](model_params, model_config, tokens)
    loss = cross_entropy_loss(logits, target)
    return loss

# use partial to set up the loss function, without passing in params yet
partial_loss_fn = partial(model_loss, model_config=bigram_config, model_name='bigram_model', tokens=tokenized_text, target=target)

# make jaxpr
jaxpr = jax.make_jaxpr(jax.value_and_grad(partial_loss_fn))(bigram_params)
print(jaxpr)

loss, grads = jax.value_and_grad(partial_loss_fn)(bigram_params)

print(f'loss: {loss}')
print(f'grad: {grad}')

lr = 1e-4
new_bigram_params = bigram_params['token_table'] - grad * lr

In [None]:
# we want a scalable way to apply this new_param = param - lr * grad to all of our parameters
# in a way that maintains the structure that we passed our params in

# jax's pytrees are container-like structures
# that are built out of other python container structures (lists, tuples, dicts, etc)
# its leaves are anything that isn't a pytree, like a jax array

example_trees = [
    [1, {'k1': 2, 'k2': (3, 4)}, jnp.array([1,3])],
    {'a': 2, 'b': (7, 2), 'c': jnp.array([4,5,6])},
    jnp.array([1, 2, 3]),
    [1, 'a', 17.],
    (1, (2, 3),None),
]

# TODO: turn this into a test

for pytree in example_trees:
  leaves = jax.tree.leaves(pytree)
  print(f"{repr(pytree)} has {len(leaves)} leaves: {leaves}")

In [None]:
# we can apply a function to all of a trees leaves with jax.tree.map
# while maintaining the original structure
# jax.tree.map(func, tree)

tree = [-17, 3, 8, [4, -6, 7,[1, 2, -3], 4]]
jax.tree.map(lambda x: x**2, tree)

In [None]:
# use jax.tree.map to implement the update step of stochastic gradient descent
# assume that we already have the gradients

class OptConfig(NamedTuple):
    lr: int
    opt_init: Callable
    opt_update: Callable

# we'll keep track of our optimizer's state using a dictionary, similar to the model params dictionary
"""
class OptConfig(NamedTuple):
    lr: int
    opt_init: Callable
    opt_update: Callable

def init_sgd_state(model_params: Dict, opt_config: OptConfig) -> Dict:
    state = {'step': 0,
             'lr': }
    return state

def sgd_update(model_params: Dict, grads: Dict, opt_state: Dict, opt_config: OptConfig) -> Tuple[Dict, Dict]:
    new_opt_state = {
        'step': ,
    }
    new_params = ... # jax tree map and a lambda will come in handy here
    return new_params, new_opt_state
"""

class OptConfig(NamedTuple):
    lr: int
    opt_init: Callable
    opt_update: Callable

def init_sgd_state(model_params: Dict, opt_config: OptConfig) -> Dict:
    state = {'step': 0, 'lr': opt_config.lr}
    return state

def sgd_update(model_params: Dict, grads: Dict, opt_state: Dict, opt_config: OptConfig) -> Tuple[Dict, Dict]:
    new_opt_state = {
        'step': opt_state['step'] + 1,
    }
    new_params = jax.tree_map(lambda p, g: p - opt_config.lr * g, model_params, grads)
    return new_params, new_opt_state

In [None]:
# now we can write a training step

"""

sgd_config = ...
opt_state = ...

text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

loss, grads = ...

new_params, new_opt_state = ...
new_params

"""

sgd_config = OptConfig(lr=1e-2, opt_init=init_sgd_state, opt_update=sgd_update)
opt_state = sgd_config.opt_init(bigram_params, sgd_config)

text = 'hello worl'
tokenized_text = tokenizer.encode(text)
target = tokenizer.encode('d')

loss, grads = jax.value_and_grad(model_loss)(bigram_params, bigram_config, 'bigram_model', tokenized_text, target)

new_params, new_opt_state = sgd_update(bigram_params, grads, opt_state, sgd_config)
new_params

### optimizations for larger data (intro to jit/vmap)

In [None]:
# now that we've created a model, a loss, and an optimizer
# we'll need data to train the model, and update its parameters
# the tiny shakespeare dataset has all of shakespeare's works
# we'll download it from hugging face datasets library

In [None]:
pip install datasets

In [None]:
from datasets import load_dataset
dataset = load_dataset('tiny_shakespeare')
dataset

In [None]:
# so the dataset has a train, validation, and test set
# there's just one row for each dataset
# fill out the following to load in the different datasets

"""
train_text = dataset['train']['text'][0]
val_text = ...
test_text = ...
len(train_text), len(val_text), len(test_text)
"""

train_text = dataset['train']['text'][0]
val_text = dataset['validation']['text'][0]
test_text = dataset['test']['text'][0]

len(train_text), len(val_text), len(test_text)

In [None]:
# for the amount of data we're about to work with
# we'll need to speed up our tokenizing functions
# if you remember when we discussed jax's just in time compilation, here's a great place to use it

# let's look at a quick example of how that works

In [None]:
# jit, just in time compiling, is one of the key reasons to use jax
# just in time compiling compiles a function during runtime, before it executes
# it's one of the key reasons to use JAX, providing speedups to commonly used functions
# suppose for example, the gelu activation function we defined earlier

x = jnp.arange(1000000)
%timeit gelu(x).block_until_ready()

In [None]:
# this function will execute every time its called with python
# but if we wrap it with jax.jit ...

gelu_jit = jax.jit(gelu)
gelu_jit(x).block_until_ready()
%timeit gelu_jit(x).block_until_ready()

In [None]:
# we can see that its a lot faster!
# if you noticed, the function took longer to start running
# but each run was almost an order of magniture faster
# this is because jax compiles the function the first time we run it
# by using the tracer objects/method we saw in the grad section
# and optimizes it for the hardware it will run on, like a GPU/TPU
# this adds some overhead for the functions we jit compile
# but it allows every subsequent run to be significantly faster
# this is really great for neural networks, which can apply activation functions, or parameter updates thousands of times

In [None]:
# notice that this will skip over some steps with side effects, like print statements
def print_gelu(x):
    print('x: ', x)
    return gelu(x)

# also notice that the x that gets printed is a traced object!
jax.make_jaxpr(print_gelu)(3.0)

# to get the desired printing behavior, use jax.debug.print() instead

In [None]:
# jax's jax arrays are specifically immutable objects, to comply with the tracing we saw above
# so jit and the other functions really want you to use jax arrays
# let's say that you wanted to create a jnp.sum() that allowed any inputs
# by casting it into an array inside the function

def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
jax.make_jaxpr(permissive_sum)(x) # what do you expect this to output?

In [None]:
# since lists are mutable, jax considers every element of a python list to be its own element to get traced
# so make sure to explicitly have your inputs be jax arrays

x = list(range(10))
jax.make_jaxpr(jnp.sum)(jnp.array(x)) # what will this output?

In [None]:
# lets see how JIT interacts with global variables
g = 10
def func(x):
    return x + 5 + g

func_jit = jax.jit(func)
print(func_jit(10))
g = 20
print(func_jit(10)) # what will this output?
jax.make_jaxpr(func_jit)(10)

In [None]:
# because of what we talked about with the tracing
# JIT will compile the function with the operation that it observed the first time the function ran
# with g = 10
# and the next time we call it, it will not know that g = 20 now

In [None]:
# what about working with conditionals?

def relu(x):
  if x > 0:
    return x
  else:
    return 0
try:
    relu_jit = jax.jit(relu)(10)
    result = relu_jit(-1)
    print(result)
except Exception as e:
    print("Error:", str(e))

# since the function's output depends on the input
# its impossible to compile it in the way JAX wants it to behave

In [None]:
# we can get around this in a couple ways
# one is with static arguments
# this tells the compiler that the specified argument is allowed to change
# and will recompile the function anytime it gets a new input
# so this is best used when you don't expect a lot of different values
# we also use jax.lax.cond, a control flow primitive that works with jit compilation

# also note that we can use the @ decorator to jit a function rather than explicitly passing it

@partial(jax.jit, static_argnames=['threshold'])
def conditional_sum(x, y, threshold):
    return jax.lax.cond(
        x > threshold,
        lambda: x + y,
        lambda: x - y
    )

x_val = jnp.array(10.0)
y_val = jnp.array(5.0)
threshold_val = 7.0

result = conditional_sum(x_val, y_val, threshold_val)
print(result)

x_val2 = jnp.array(6.0)
y_val2 = jnp.array(3.0)

result2 = conditional_sum(x_val2, y_val2, threshold_val)
print(result2)

threshold_val2 = 8.0

result3 = conditional_sum(x_val, y_val, threshold_val2) # triggers recompilation
print(result3)

In [None]:
# another performance transformation jax provides is vmap
# vmap automatically vectorizes a given function
# by defining a batch dimension to work across
# it basically does a [f(i) for i in x]

# let's look at a basic example
def add_and_square(x, y):
    return (x + y) ** 2

vectorized_add_and_square = jax.vmap(add_and_square)

x = jnp.array([1, 2, 3])
y = jnp.array([4, 5, 6])

x1 = jnp.arange(10)
y1 = jnp.arange(10)

print("original jaxpr:")
print(jax.make_jaxpr(add_and_square)(x[0], y[0]))
print(add_and_square(x[0],y[0]).shape)

print("\nvmapped jaxpr:")
print(jax.make_jaxpr(vectorized_add_and_square)(x, y))
print(vectorized_add_and_square(x,y).shape)
print(jax.make_jaxpr(vectorized_add_and_square)(x1, y1))
print(vectorized_add_and_square(x1,y1).shape)

# notice what happens in the jaxpr that's created
# vmap figures out the batch dimension to vectorize on
# we can specify this ourselves if needed with in_axes, out_axes
# but we won't go over that
# jit/vmap/grad are all functions
# and all play well with each other
# you can vmap a jitted function, etc

In [None]:
# with jit and vmap, we are now ready to handle the tiny shakespeare dataset efficiently
# as you saw earlier, we should avoid jit compiling functions that don't take in jax arrays
# to turn text into a jnp array that will play well with jax
# we can just define a preprocess function
# make sure to explicitly pass in dtype so that we get integer values for each index

"""
def preprocess_text(text: str) -> jax.Array:
    return
"""

def preprocess_text(text: str):
    return jnp.array([ord(c) for c in text], dtype=jnp.int32)

In [None]:
# now we can start by jit compiling the char_to_idx and idx_to_char functions
# this is pretty much as simple as adding the jax.jit decorator

"""
def _char_to_idx(char: jax.Array):
    return jnp.where((32 <= char) & (char <= 126), char - 32, 0)

def _idx_to_char(index: jax.Array):
    return jnp.where((0 <= index) & (index <= 94), index + 32, ord(' '))
"""

@jax.jit
def _char_to_idx(char: jax.Array):
    return jnp.where((32 <= char) & (char <= 126), char - 32, 0)

@jax.jit
def _idx_to_char(index: jax.Array):
    return jnp.where((0 <= index) & (index <= 94), index + 32, ord(' '))


In [None]:
# now, we can use vmap to wrap those idx/char conversion functions
# we should jit compile these functions as well

"""
def _tokenizer_encode_jax(char_array: jax.Array):
    return # vmap

def _tokenizer_decode_jax(ids: jax.Array):
    return # vmap
"""

@jax.jit
def _tokenizer_encode_jax(char_array: jax.Array):
    return jax.vmap(_char_to_idx)(char_array)

@jax.jit
def _tokenizer_decode_jax(ids: jax.Array):
    return jax.vmap(_idx_to_char)(ids)

In [None]:
# now, we can write our final encode/decode functions

def jit_encode(char_array):
    return _tokenizer_encode_jax(char_array)

def jit_decode(ids):
    char_array = _tokenizer_decode_jax(ids)
    return ''.join([chr(int(c)) for c in char_array])


In [None]:
jit_tokenizer = Tokenizer(encode=jit_encode, decode=jit_decode)

In [None]:
train_array = preprocess_text(train_text)
train_tokens = jit_tokenizer.encode(train_array)
train_text[:20], train_array[:20], train_tokens[:20], jit_tokenizer.decode(train_tokens[:20])

### vmap intro (batched forward)

In [None]:
jit_tokenizer.decode(jit_tokenizer.encode(preprocess_text('hello')))

In [None]:
# another optimization to make sure we can process this dataset in reasonable time
# is using batches of training data at a time
# we'll grab some batch_size amount of sequences at once
# and process them in parallel with jax's vmap
# as a reminder, vmap makes it fast/easy to run
# [f(i) for i in x]
# across some batch dimension (like batch_size!)

# we'll create a fake batch here to illustrate the concept

tokens_1 = jit_tokenizer.encode(preprocess_text('hello'))  # First sequence
tokens_2 = jit_tokenizer.encode(preprocess_text('world'))  # Second sequence
tokens_3 = jit_tokenizer.encode(preprocess_text('test!'))   # Third sequence

# Stack the sequences into a single batch with jnp.stack
tokens_batch = jnp.stack([tokens_1, tokens_2, tokens_3])
print(tokens_batch.shape)

# we'll write a batch forward for our model
# use jax.vmap and lambda x: to forward the specified model over the input batch

"""

def batch_forward(params: Dict, model_config: Dict, model_name: str, input_batch: jnp.ndarray) -> jnp.ndarray:
    return ...

logits = batch_forward(bigram_params, bigram_config, 'bigram_model', tokens_batch)
logits.shape

"""
def batch_forward(params: Dict, model_config: Dict, model_name: str, input_batch: jnp.ndarray) -> jnp.ndarray:
    return jax.vmap(lambda x: MODEL_DICT[model_name](params, model_config, x))(input_batch)

logits = batch_forward(bigram_params, bigram_config, 'bigram_model', tokens_batch)
logits.shape

In [None]:
# now we can evaluate our loss over these batches
# generate a batch_size number of integer start indices
# use those and jax.lax.dynamic_slice to get sequences
# return the input/output batches
# remember that [1, 2, 3, 4, 5]
# would have input batch [1, 2, 3, 4]
# and target batch [2, 3, 4, 5]

"""
def create_random_batches(tokens: jax.Array, batch_size: int, seq_len: int, key) -> Tuple[jax.Array, jax.Array]:
    total_tokens = ...
    max_start_idx = ...
    start_indices = jax.random.randint(key, shape=..., minval = ..., maxval = max_start_idx)

    def get_sequence(start_idx):
        return jax.lax.dynamic_slice(...)

    sequences = ... # use vmap
    input_batches, target_batches = ... # remember targets should be offset by 1 from input but same size
    return input_batches, target_batches


input_batches, target_batches = create_random_batches(tokens=train_tokens, batch_size=4, seq_len=8, key=jax.random.PRNGKey(0))
logit_batch = batch_forward(bigram_params, bigram_config, 'bigram_model', input_batches)
cross_entropy_loss(logit_batch, target_batches)
"""

def create_random_batches(tokens: jax.Array, batch_size: int, seq_len: int, key):
    max_start_idx = tokens.shape[0] - seq_len - 1
    start_indices = jax.random.randint(
        key,
        shape=(batch_size,),
        minval=0,
        maxval=max_start_idx
    )

    def get_sequence(start_idx):
        return jax.lax.dynamic_slice(tokens, (start_idx,), (seq_len + 1,))

    sequences = jax.vmap(get_sequence)(start_indices)

    input_batches, target_batches = sequences[:, :-1], sequences[:, 1:]

    return input_batches, target_batches



input_batches, target_batches = create_random_batches(tokens=train_tokens, batch_size=4, seq_len=8, key=jax.random.PRNGKey(0))
logit_batch = batch_forward(bigram_params, bigram_config, 'bigram_model', input_batches)
cross_entropy_loss(logit_batch, target_batches)

In [None]:
# finally, we can also parallelize our model over an input
# here we'll use jax vmap
# as a reminder, it works similarly to [f(i) for i in x] over some dimension

# as a start
# rather than predicting separately for each position in the seq_len
# we can do our predictions in parallel across the seq_len dimension
# so instead of outputting (batch_size, 1, vocab_size)
# we output (batch_size, seq_len, vocab_size)
# we'll just have to make sure to select the last logit for generation

"""
def parallel_bigram(model_params: Dict, model_config: Dict, tokens: jax.Array) -> jax.Array:
    logits = jax.vmap(lambda x: ...)(tokens)
    return logits

MODEL_DICT = {
    'bigram_model': bigram_model,
    'parallel_bigram': parallel_bigram
}

logits = parallel_bigram(bigram_params, bigram_config, input_batches)
logits.shape

"""

def parallel_bigram(model_params: Dict, model_config: Dict, tokens: jax.Array) -> jax.Array:
    logits = jax.vmap(lambda x: model_params['token_table'][x])(tokens)
    return logits

MODEL_DICT = {
    'bigram_model': bigram_model,
    'parallel_bigram': parallel_bigram
}

logits = parallel_bigram(bigram_params, bigram_config, input_batches)
logits.shape

### training

In [None]:
# we now have everything we need to train our model!
# typically in training deep learning models
# we work through our data in epochs
# traditionally, this is a full pass through of our dataset
# but it can also just mean a fixed amount of batches
# in each batch, we compute the loss over the batch, and do an update step
# finally, we return the trained parameters

# let's see what that looks like for our language model
# we'll start by creating the TrainConfig struct
# we'll define the number of epochs
# the number of batches per epoch
# the amount of sequences in each batch
# the sequence length of each batch
# and a seed for our key

class TrainConfig(NamedTuple):
    num_epochs: int
    batches_per_epoch: int
    batch_size: int
    batch_seq_len: int
    seed: int

In [None]:
# now we can write one training step
# we'll use jax.jit to compile this so training goes faster
# use the batch_forward we defined above to efficiently compute logits for an input batch
# define a batch_loss_function to compute the loss against a target batch (nested inside, so it also gets compiled)
# use create random batches from before as well

# our training step should return the new parameters, optimizer state, and the loss for the step

"""

@partial(jax.jit, static_argnames=[])
def train_step(params: Dict, model_config: ModelConfig, model_name: str, tokens: jax.Array, opt_state: Dict, opt_config: OptConfig, train_config: TrainConfig, key):
    def batch_loss_fn(params, input_batch, target_batch):
        logits = ...
        batch_loss = ...
        return ...

    batch_size = ...
    seq_len = ...
    input_batch, target_batch = ...
    loss, grads = ...
    new_params, new_opt_state = ...
    return new_params, new_opt_state, loss

"""

@partial(jax.jit, static_argnames=['model_config', 'model_name', 'opt_config', 'train_config', ])
def train_step(params: Dict, model_config: ModelConfig, model_name: str, tokens: jax.Array, opt_state: Dict, opt_config: OptConfig, train_config: TrainConfig, key):
    def batch_loss_fn(params, input_batch, target_batch):
        logits = batch_forward(params, model_config, model_name, input_batch)
        batch_loss = cross_entropy_loss(logits, target_batch)
        return batch_loss

    batch_size = train_config.batch_size
    seq_len = train_config.batch_seq_len
    input_batch, target_batch = create_random_batches(tokens, batch_size, seq_len, key)
    loss, grads = jax.value_and_grad(batch_loss_fn)(params, input_batch, target_batch)
    new_params, new_opt_state = opt_config.opt_update(params, grads, opt_state, opt_config)
    return new_params, new_opt_state, loss

In [None]:
# now that we have our training step defined
# we can write our training loop
# this is also the step where we can use an accelerator
# we'll use jax.device_put() to put our parameters, tokens, and optimizer state onto the device
# and we can get our device from jax.devices()

# for each epoch
# we call the train step for each of the batches in the epoch
# add up the losses by epoch
# and then we'll average the loss over each epoch
# finally we will return the trained parameters

In [None]:
"""
def train(params: Dict, model_config: ModelConfig, model_name: str, tokens: jnp.ndarray, opt_config: OptConfig, train_config: TrainConfig, key=jax.random.PRNGKey(0)):
    if model_name not in MODEL_DICT:
        raise ValueError(f"Unknown model: {model_name}")

    device =  # jax.devices() returns list of at least size 1
    print(f"using device: {device}")

    # initialize optimizer state

    init_opt_state = ...

    # use seed from train config to define initial key
    key = ...

    # move data to device
    params = ...
    tokens = ...

    opt_state = ...

    print('training ...')

    # you may need to change these

    for each epoch:
        epoch_loss = 0.0
        for each batch:
            key, subkey = jax.random.split(key, 2)
            new_params, opt_state, loss = train_step()
            # add to epoch loss
            params = new_params
        avg_loss = ...
        print(f"epoch {}/{}, average Loss: {avg_loss}")

    return params


train_config = TrainConfig(num_epochs=200, batches_per_epoch=128, batch_size=8, batch_seq_len=16, seed=0)
opt_config = OptConfig(lr=1e-2, opt_init=init_sgd_state, opt_update=sgd_update)

trained_bigram_params = train(bigram_params, bigram_config, 'bigram_model', train_tokens, sgd_config, train_config)
"""

def train(params: Dict, model_config: ModelConfig, model_name: str, tokens: jnp.ndarray, opt_config: OptConfig, train_config: TrainConfig, key=jax.random.PRNGKey(0)):
    if model_name not in MODEL_DICT:
        raise ValueError(f"Unknown model: {model_name}")

    device = jax.devices()[0]
    print(f"Using device: {device}")

    key = jax.random.PRNGKey(train_config.seed)
    init_opt_state = opt_config.opt_init(params, opt_config)

    # Move data to device
    params = jax.device_put(params, device)
    tokens = jax.device_put(tokens, device)
    opt_state = jax.device_put(init_opt_state, device)


    print('training ...')
    for epoch in range(train_config.num_epochs):
        epoch_loss = 0.0
        for _ in range(train_config.batches_per_epoch):
            key, subkey = jax.random.split(key, 2)
            new_params, opt_state, loss = train_step(
                params, model_config, model_name, tokens, opt_state, opt_config, train_config, key=subkey
            )
            epoch_loss += loss

            params = new_params
        avg_loss = epoch_loss / train_config.batches_per_epoch
        print(f"Epoch {epoch + 1}/{train_config.num_epochs}, Average Loss: {avg_loss}")

    return params

train_config = TrainConfig(num_epochs=200, batches_per_epoch=128, batch_size=8, batch_seq_len=16, seed=0)
opt_config = OptConfig(lr=1e-2, opt_init=init_sgd_state, opt_update=sgd_update)

trained_bigram_params = train(bigram_params, bigram_config, 'bigram_model', train_tokens, sgd_config, train_config)

In [None]:
# lets compare with our initial generation!

"""
def compare_params(prompt: str, model_config: ModelConfig, model_name: str, params1: Dict, params2: Dict, max_new=10, temp=0.8, key=jax.random.PRNGKey(0)):

    tokenized_prompt = ... # make sure to cast to integer
    generated_tokens = # generate with params1
    generated_tokens2 = # params2

    generated_text = ... # decode
    generated_text2 = ...

    print(jit_tokenizer.decode(tokenized_prompt))
    print(prompt + generated_text)
    print(prompt + generated_text2)

    return generated_text, generated_text2

compare_params('shakespeare', bigram_config, 'bigram_model', bigram_params, trained_bigram_params)
"""

def compare_params(prompt: str, model_config: ModelConfig, model_name: str, params1: Dict, params2: Dict):

    tokenized_prompt = jnp.array(jit_tokenizer.encode(preprocess_text(prompt)), dtype=jnp.int32)
    generated_tokens = generate(params1, model_config, model_name, tokens=tokenized_prompt, max_new=10, temp=1, key=jax.random.PRNGKey(0))
    generated_tokens2 = generate(params2, model_config, model_name, tokens=tokenized_prompt, max_new=10, temp=1, key=jax.random.PRNGKey(0))

    generated_text = jit_tokenizer.decode(generated_tokens)
    generated_text2 = jit_tokenizer.decode(generated_tokens2)
    print(jit_tokenizer.decode(tokenized_prompt))
    print(prompt + generated_text)
    print(prompt + generated_text2)

prompt = jit_tokenizer.decode(train_tokens[:21])
compare_params(prompt, bigram_config, 'bigram_model', bigram_params, trained_bigram_params)

### token embeddings

In [None]:
# so it seems like bigram models kind of suck
# our loss is going down, but the generations are horrendous

# by themselves, our characters dont really mean anything
# for example, if e is 5 and j is 10
# e * 2 isn't really something that should make sense to do
# but we've implied that it is the case

# also, one integer is not a lot of information
# instead, we might want to represent each character as a vector
# we can do this by having our model learn embedding vectors
# the size of these vectors is the embedding dimension

# for example, a might start out as [?,?]
# and up as [0,1]
# where the first dimension is about being a consonant
# and the second is about if you can say the letter without moving your lips
# obviously thats a toy example
# but the idea is that the more dimensions we add
# the more information the model can encode about a token as it learns

# we'll need two matrices: one to turn the initial vocab into embedding vectors, and vice versa
# we call these token_embedding and output_projection
# and we'll need to update our model config struct to include an embedding dimension
# we'll also introduce a scaling factor
# to initialize our model with weights drawn from N(0,sigma) instead of N(0,1)

class ModelConfig(NamedTuple):
    vocab_size: int
    embedding_dim: int

"""
def init_bigram_embed_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2 = jax.random.split(key)
    return {
        'token_embedding': ,
        'output_projection': ,
    }


# remember: the lookup table means the ith row vector is the ith token in our vocab

key = jax.random.PRNGKey(0)
b_embed_config = ModelConfig(vocab_size=95, embedding_dim = 5)
b_embed_params = init_bigram_embed_params(key, b_embed_config)

idx = tokenizer.encode('a')
input_embedding = b_embed_params['token_embedding'][idx]
print(input_embedding)

output_projection = b_embed_params['output_projection']
logits = jnp.dot(input_embedding, output_projection)
print(logits.shape)
"""

def init_bigram_embed_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2 = jax.random.split(key)
    return {
        'token_embedding': jax.random.normal(k1, (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(k2, (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor
    }

# remember: the lookup table means the ith row vector is the ith token in our vocab

key = jax.random.PRNGKey(0)
b_embed_config = ModelConfig(vocab_size=95, embedding_dim = 5)
b_embed_params = init_bigram_embed_params(key, b_embed_config)

idx = tokenizer.encode('a')
input_embedding = b_embed_params['token_embedding'][idx]
print(input_embedding)

output_projection = b_embed_params['output_projection']
logits = jnp.dot(input_embedding, output_projection)
print(logits.shape)

In [None]:
# to forward pass through this model
# you need to get the corresponding token embeddings for each token in (seq_len) from your token embedding table
# then use jnp.dot to project the embedded tokens back into the vocab size dimension with the output projection matrix

"""
def bigram_embed(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    token_embedding = model_params['token_embedding']
    output_projection = model_params['output_projection']

    embedded = ... # (seq_len,) -> (seq_len, embedding_dim)

    logits = ... # (seq_len, embedding_dim) @ (embedding_dim, vocab_size) -> (seq_len, vocab_size)

    return logits

bigram_embed(b_embed_params, b_embed_config, tokenizer.encode('abc')).shape

# now we just update model dict and train again
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed
}

train_config = TrainConfig(num_epochs=50, batches_per_epoch=64, batch_size=16, batch_seq_len=32, seed=0)
trained_b_embed_params = train(b_embed_params, b_embed_config, 'bigram_embed', train_tokens, sgd_config, train_config)
compare_params('shakespeare', b_embed_config, 'bigram_embed', b_embed_params, trained_b_embed_params)
"""

def bigram_embed(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    token_embedding = model_params['token_embedding']
    output_projection = model_params['output_projection']

    embedded = token_embedding[tokens] # (seq_len,) -> (seq_len, embedding_dim)

    logits = jnp.dot(embedded, output_projection) # (seq_len, embedding_dim) @ (embedding_dim, vocab_size) -> (seq_len, vocab_size)

    return logits

bigram_embed(b_embed_params, b_embed_config, tokenizer.encode('abc')).shape

# now we just update model dict and train again
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed
}

train_config = TrainConfig(num_epochs=50, batches_per_epoch=64, batch_size=16, batch_seq_len=32, seed=0)
trained_b_embed_params = train(b_embed_params, b_embed_config, 'bigram_embed', train_tokens, sgd_config, train_config)
prompt = jit_tokenizer.decode(train_tokens[:21])
compare_params(prompt, b_embed_config, 'bigram_embed', b_embed_params, trained_b_embed_params)

### new optimization strategy

In [None]:
# you'll notice the loss went down but not by a lot
# and the sequence generated was basically the same
# another possibility is that our optimization algorithm needs work
# stochastic gradient descent is nice, but it depends only on the current gradient of the hill
# you can imagine the parameters walking down the loss surface, specifically taking a step based on the slope
# as you can see, this is kind of slow
# and it would easily get stuck in a small valley
# one way to speed up learning and roll through small valleys is with momentum
# rather than having a person taking steps on the loss surface
# picture a ball that rolls into the loss surface
# the ball carries momentum from its previous actions into the next action it takes
# causing it to quickly go over valleys that are not the true bottom of the loss surface
# this will hopefully speed up the training process

# updated_param = param - (learning_rate * velocity)

# SGD momentum update rule
# velocity = (old_velocity * momentum) + gradient

In [None]:
# while SGD with momentum works well in some problems
# it was noticed empirically that different parameters made a big impact on gradients
# and the fixed, global learning rate meant different parameters couldn't update independently
# RMSProp is an optimizer that looks at the moving average of the squared gradient
# and updates parameters using that value

# squared_gradient_average = 0
# decay_rate = 0.9  # typically 0.9
# epsilon = 1e-8  # avoids div by 0

# # for each parameter update
# squared_gradient_average = (decay_rate * squared_gradient_average) + ((1 - decay_rate) * gradient ** 2)

# # RMSprop update rule
# adjusted_gradient = gradient / (sqrt(squared_gradient_average) + epsilon)
# updated_param = param - (learning_rate * adjusted_gradient)

In [None]:
# the adam optimizer combines ideas from rmsprop and sgd with momntum
# similarly to how RMSProp takes the moving average of the squared gradient to update
# Adam also calculates a moving average of the gradient to adjust parameters
# at the start, it initializes averages to 0
# but this effectively implies every previous measurement was a 0
# leading to very small estimations
# so we introduce a bias correction mechanism that depends on the time step

"""
# initialize
gradient_mean = 0  # average of gradient
gradient_sq_mean = 0  # average of squared gradient
beta_1 = 0.9  # decay rate for first moment
beta_2 = 0.999  # decay rate for second moment
epsilon = 1e-8  # to avoid division by zero
learning_rate = 0.001
weight_decay = 0.01
t = 0  # timestep

# update for each parameter
t += 1

# update biased first moment estimate
gradient_mean = (beta_1 * gradient_mean) + ((1 - beta_1) * gradient)

# update biased second raw moment estimate
gradient_sq_mean = (beta_2 * gradient_sq_mean) + ((1 - beta_2) * gradient ** 2)

# compute bias-corrected first moment estimate
gradient_mean_corrected = gradient_mean / (1 - beta_1 ** t)

# compute bias-corrected second raw moment estimate
gradient_sq_mean_corrected = gradient_sq_mean / (1 - beta_2 ** t)

# compute final AdamW parameter update
adjusted_gradient = gradient_mean_corrected / (sqrt(gradient_sq_mean_corrected) + epsilon)
param_update = learning_rate * (adjusted_gradient + weight_decay * param) # weight decay adjustment

# apply update
updated_param = param - param_update
"""

class OptConfig(NamedTuple):
    lr: float
    beta1: float
    beta2: float
    weight_decay: float
    eps: float
    opt_init: Callable
    opt_update: Callable

class OptimizerState(NamedTuple):
    step: int
    gradient_mean: float
    gradient_squared_mean: float

In [None]:
"""
def init_adam_state(params: Dict, opt_config: OptConfig) -> OptimizerState:
    gradient_mean = ...
    gradient_squared_mean = ...
    return OptimizerState(
        gradient_mean=gradient_mean,
        gradient_squared_mean=gradient_squared_mean,
        step=0
    )

def adamw_update(params: Dict, grads: Dict, opt_state: OptimizerState, opt_config: OptConfig) -> Tuple[Dict, Dict]:

    current_step = opt_state.step + 1

    gradient_mean_biased = ...

    gradient_squared_mean_biased = ...

    gradient_mean_corrected = ...

    gradient_squared_mean_corrected = ...

    param_updates = ...

    updated_params = ...

    updated_opt_state = OptimizerState(
        gradient_mean=gradient_mean_biased,
        gradient_squared_mean=gradient_squared_mean_biased,
        step=current_step
    )

    return updated_params, updated_opt_state

adamw_config = OptConfig(lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, opt_init=init_adam_state, opt_update=adamw_update)
"""

def init_adam_state(params: Dict, opt_config: OptConfig) -> OptimizerState:
    gradient_mean = jax.tree_map(lambda p: jnp.zeros_like(p), params)
    gradient_squared_mean = jax.tree_map(lambda p: jnp.zeros_like(p), params)
    return OptimizerState(
        gradient_mean=gradient_mean,
        gradient_squared_mean=gradient_squared_mean,
        step=0
    )

def adamw_update(params: Dict, grads: Dict, opt_state: OptimizerState, opt_config: OptConfig) -> Tuple[Dict, Dict]:

    current_step = opt_state.step + 1

    gradient_mean_biased = jax.tree.map(
        lambda prev_mean, gradient: opt_config.beta1 * prev_mean + (1 - opt_config.beta1) * gradient,
        opt_state.gradient_mean,
        grads
    )

    gradient_squared_mean_biased = jax.tree.map(
        lambda prev_squared_mean, gradient: opt_config.beta2 * prev_squared_mean + (1 - opt_config.beta2) * gradient**2,
        opt_state.gradient_squared_mean,
        grads
    )

    gradient_mean_corrected = jax.tree.map(
        lambda mean: mean / (1 - opt_config.beta1**current_step),
        gradient_mean_biased
    )

    gradient_squared_mean_corrected = jax.tree.map(
        lambda squared_mean: squared_mean / (1 - opt_config.beta2**current_step),
        gradient_squared_mean_biased
    )

    param_updates = jax.tree.map(
        lambda mean, squared_mean, param: -opt_config.lr * (mean / (jnp.sqrt(squared_mean) + opt_config.eps) + opt_config.weight_decay * param),
        gradient_mean_corrected,
        gradient_squared_mean_corrected,
        params
    )

    updated_params = jax.tree.map(
        lambda param, update: param + update, params, param_updates)

    updated_opt_state = OptimizerState(
        gradient_mean=gradient_mean_biased,
        gradient_squared_mean=gradient_squared_mean_biased,
        step=current_step
    )

    return updated_params, updated_opt_state

adamw_config = OptConfig(lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, opt_init=init_adam_state, opt_update=adamw_update)


In [None]:
# now we just update model dict and train again
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed
}

train_config = TrainConfig(num_epochs=50, batches_per_epoch=32, batch_size=32, batch_seq_len=64, seed=0)
adamw_config = OptConfig(lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, opt_init=init_adam_state, opt_update=adamw_update)
trained_b_embed_params = train(b_embed_params, b_embed_config, 'bigram_embed', train_tokens, adamw_config, train_config)
prompt = jit_tokenizer.decode(train_tokens[:21])
compare_params(prompt, b_embed_config, 'bigram_embed', b_embed_params, trained_b_embed_params)


### causal masking

In [None]:
# while we saw some change in generation
# its still kind of unintelligible
# this is likely due to the fact that bigram models are weak
# even with the embedding vector, we are still only using the previous token to output our prediction
# naturally this is not enough

In [None]:
# so rather than just looking at the previous token
# let's update our model to look at all previous tokens
# remember that we're doing our predictions in parallel
# so we need a way to make sure that we don't look ahead in time to future tokens
# otherwise we would be letting the model peek at the right answer
# how might we do this?

# let's work with a small case, of sequence length 5
# the first token can only pay attention to itself
# the second token can look back by 1, and itself,
# etc
# 1 0 0 0 0
# 1 1 0 0 0
# 1 1 1 0 0
# ...
# this ends up turning into a matrix of (seq_len, seq_len) shape, with all 1s and 0s
# the 1s follow a specific shape known as lower triangular
# and jax has a built in function to create one of these matrices

In [None]:
# we call this a 'causal mask' as it preserves causality,
# it prevents info from previous tokens going back in time
# use jnp tril and jnp ones to create the causal mask for the seq_len

"""
seq_len = 5
causal_mask = ...
causal_mask
"""

seq_len = 5
causal_mask = jnp.tril(jnp.ones((seq_len,seq_len)))
causal_mask

In [None]:
# now, we just need to divide each row by its sum
# to get the equal weighting for each token in the sequence
# you can use jnp sum for this, you may need some of the optional parameters like 'axis' and 'keepdims'
"""

causal_weights = ...
causal_weights
"""

causal_weights = causal_mask / jnp.sum(causal_mask, axis=1, keepdims=True)
causal_weights

In [None]:
# now, averaging the embedding vectors with our causal weights is just a matrix multiplication!
# and our final output sends these contexually weighted vectors back into the vocab_size dimension

# lets update our basic bigram model to equally attend to each of the previous tokens

"""
def causal_embed(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array: # (seq_len,) -> (seq_len, vocab_size)

    token_embedding = model_params['token_embedding']
    output_projection = model_params['output_projection']

    embedded = token_embedding[tokens] # (seq_len,) -> (seq_len, embedding_dim)

    seq_len = tokens.shape[0]

    causal_mask = ...

    causal_weights = ...

    context_vectors = ... # (seq_len, seq_len) @ (seq_len, embedding_dim) -> (seq_len, embedding_dim)

    logits = ... # (seq_len, embedding_dim) @ (embedding_dim, vocab_size) -> (seq_len, vocab_size)

    return logits
"""

def causal_embed(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array: # (seq_len,) -> (seq_len, vocab_size)

    token_embedding = model_params['token_embedding']
    output_projection = model_params['output_projection']

    embedded = token_embedding[tokens] # (seq_len,) -> (seq_len, embedding_dim)

    seq_len = tokens.shape[0]
    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    causal_weights = causal_mask / jnp.sum(causal_mask, axis=1, keepdims=True)

    context_vectors = jnp.dot(causal_weights, embedded) # (seq_len, seq_len) @ (seq_len, embedding_dim) -> (seq_len, embedding_dim)

    logits = jnp.dot(context_vectors, output_projection) # (seq_len, embedding_dim) @ (embedding_dim, vocab_size) -> (seq_len, vocab_size)

    return logits



In [None]:
# now we can just retrain after adding the params to the model dict
# we won't change anything about the training configuration just yet

MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed
}
trained_c_embed_params = train(b_embed_params, b_embed_config, 'causal_embed', train_tokens, adamw_config, train_config)
compare_params(prompt, b_embed_config, 'causal_embed', b_embed_params, trained_c_embed_params)


### positional encoding

In [None]:
# now what if we run two separate inputs
# what do we expect to happen?

"""
shuffled_prompt = ... # use jax.random.permutation and the 0 key, as well as the jit_tokenizer functions and preprocess text from before
"""

shuffled_prompt = jit_tokenizer.decode(jax.random.permutation(jax.random.PRNGKey(0), jit_tokenizer.encode(preprocess_text(prompt))))
print(prompt + '\n' + shuffled_prompt)

compare_params(prompt, b_embed_config, 'causal_embed', b_embed_params, trained_c_embed_params)
compare_params(shuffled_prompt, b_embed_config, 'causal_embed', b_embed_params, trained_c_embed_params)

# if you look at the part where the model starts generating
# you can see we got the same output, even after shuffling

# while we're now looking at all previous tokens
# we haven't really given the model a way to understand position
# we're simply averaging over the previous tokens equally

In [None]:
# similarly to what we did with token embedding
# we'll give the model a matrix it can use to learn about position
# so we want the ith row vector to encode information about position i in a sequence
# we cannot make a matrix of infinite size, so we'll need to limit the seq_len to some maximum
# so we'll create a context_length parameter
# and now our positional embedding matrix is of shape (context_len, embedding_dim)

class ModelConfig(NamedTuple):
    vocab_size: int
    embedding_dim: int
    context_len: int

# define initialization function

"""
def init_causal_pos_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2, k3 = jax.random.split(key, 3)
    return {
        'token_embedding': jax.random.normal(k1, (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': ,
        'output_projection': jax.random.normal(k3, (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor
    }
"""

def init_causal_pos_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2, k3 = jax.random.split(key, 3)
    return {
        'token_embedding': jax.random.normal(k1, (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(k2, (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(k3, (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor
    }

In [None]:
# define forward pass
# add the positional embeddings

"""
def causal_pos(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    token_embedding = model_params['token_embedding']
    output_projection = model_params['output_projection']

    token_embedded = token_embedding[tokens] # (seq_len,) -> (seq_len, embedding_dim)

    # add positional embeddings
    pos_embeds = ... # remember that we only need embeddings up to seq_len
    embedded = token_embeds + pos_embeds

    seq_len = tokens.shape[0]

    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    causal_weights = causal_mask / jnp.sum(causal_mask, axis=1, keepdims=True)

    context_vectors = jnp.dot(causal_weights, embedded) # (seq_len, seq_len) @ (seq_len, embedding_dim) -> (seq_len, embedding_dim)

    logits = jnp.dot(context_vectors, output_projection) # (seq_len, embedding_dim) @ (embedding_dim, vocab_size) -> (seq_len, vocab_size)

    return logits



    return
"""

def causal_pos(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    token_embedding = model_params['token_embedding']
    positional_embedding = model_params['positional_embedding']
    output_projection = model_params['output_projection']

    token_embeds = token_embedding[tokens]

    seq_len = tokens.shape[0]
    pos_embeds = positional_embedding[:seq_len]
    embedded = token_embeds + pos_embeds

    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    causal_weights = causal_mask / jnp.sum(causal_mask, axis=1, keepdims=True)

    context_vectors = jnp.dot(causal_weights, embedded)

    logits = jnp.dot(context_vectors, output_projection)

    return logits

In [None]:
# now we just update model dict and train again
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos
}

c_pos_config = ModelConfig(vocab_size = 95, embedding_dim = 4, context_len = 64)
train_config = TrainConfig(num_epochs=50, batches_per_epoch=32, batch_size=32, batch_seq_len=64, seed=0)
c_pos_params = init_causal_pos_params(key, c_pos_config)
trained_c_pos_params = train(c_pos_params, c_pos_config, 'causal_pos', train_tokens, adamw_config, train_config)
compare_params(prompt, c_pos_config, 'causal_pos', c_pos_params, trained_c_pos_params)
compare_params(shuffled_prompt, c_pos_config, 'causal_pos', c_pos_params, trained_c_pos_params)

### attention

In [None]:
# although we now understand position
# we're still prioritizing each character equally
# but this doesn't really make sense
# when we write or speak
# we use previous values in the sentence to figure out what to say next
# as an example if we were continuing: 'the boy is eating a blue'
# we would probably pay attention to the word 'boy' and 'eating' and 'blue' to figure out the next token
# and we might choose a word that related to something that a boy would eat that's blue
# like candy

In [None]:
# so in this sequence
# the boy is eating a blue
# our bigram model would perform a lookup on 'blue'
# and output the best predicted word vector to follow it, maybe 'car'

# what we want to do is a similar lookup, but across all of the vectors
# we want to submit our query of 'blue' to each of ['the', 'boy', 'is' , ...]
# each of these keys should then return some value its associated with, like in a lookup table
# and we'll use that output score to determine what to pay attention to

# for example
# lookup[boy] = boy_value
# blue * boy_value -> attn_score 0.2
# blue * eating_value -> attn_score 0.75
# blue * the_value -> attn_score 0.05

# this might lead the model to instead try to choose a vector like 'lollipop'

# by letting the model do these lookups, we give it the ability to learn how words relate to each other
# so for each value in our sequence
# we'll want to create a query vector, a key vector, and a value vector
# the query vector is what each token is looking for in other tokens
# the key vector is something for other tokens to match into
# and the value vector is what the token represents, and is used when matched into

In [None]:
# so query is of size (embed_dim)
# and we want to compute the similarity of that query and the ith key
# we'll use a dot product for this # TODO: EXPLAIN WHY
# q @ k_i

# we know that we'll have a maximum amount of keys, (context_len)
# so we can actually make a matrix of keys, where the ith row is k_i
# with (context_len) rows
# this means instead of q @ k_i, we can just do q @ K.T (K being the matrix of k_i rows)

# these dot products can get pretty fucking big
# so we scale them down by multiplying 1/((embed_dim)**0.5)

# we want these attention weights to add to 1, so we do a softmax transform on it
# so now this final weighting allows us to extract the relevant amount of information from each point in the sequence
# so V is of size (context_len, embed_dim)
# and we take the dot product of softmaxed_scores @ v

# in order to do all of these lookups of qs in parallel
# we actually want Q to also be a matrix, of size (context_len, embed_dim) (technically batch_size, context_len, embed_dim)

# putting it all together:

"""
def attention(q, k, v):

    # q is of shape (context_len, embed_dim)

    embed_dim = ... # should come from q.shape ...

    attn_scores = ...
    scaled_scores = ...
    softmaxed_scores = ...

    output = ...

    return output

batch_size = 2
seq_len = 4
embed_dim = 8

q = jax.random.normal(key, (batch_size, seq_len, embed_dim))
k = jax.random.normal(key, (batch_size, seq_len, embed_dim))
v = jax.random.normal(key, (batch_size, seq_len, embed_dim))

batched_attention = jax.vmap(attention)
batched_attention(q, k, v)
"""

def attention(q, k, v):
    embed_dim = q.shape[-1] # (batch_size, context_len, embed_dim)

    attn_scores = q @ k.T # (context_len, embed_dim) @ (context_len, embed_dim).T -> (context_len, context_len)
    scaled_scores = attn_scores * 1/jnp.sqrt(embed_dim)
    softmaxed_scores = jax.nn.softmax(scaled_scores)

    output = softmaxed_scores @ v # (context_len, context_len) @ (context_len, embed_dim).T -> (context_len, embed_dim)

    return output

batch_size = 2
seq_len = 4
embed_dim = 8

q = jax.random.normal(key, (batch_size, seq_len, embed_dim))
k = jax.random.normal(key, (batch_size, seq_len, embed_dim))
v = jax.random.normal(key, (batch_size, seq_len, embed_dim))

batched_attention = jax.vmap(attention)
batched_attention(q, k, v)

### causal attention

In [None]:
# remember that we need to make sure we don't communicate with the future
# so we can modify the attention matrix by adding our mask from before

def create_causal_mask(seq_len: int):
    return jnp.tril(jnp.ones((seq_len, seq_len)))

# since the scores are raw logits, we typically mask with a really large negative value, instead of using 0
# so our mask is a lower triangular of 0s, and the rest of it is -inf

# float('-inf')

"""

def causal_attention(q, k, v, mask):
    embed_dim = q.shape[-1]

    attn_scores = q @ k.T
    scaled_scores = attn_scores * 1/jnp.sqrt(embed_dim)
    masked_scores = ... # add mask, use jnp.where to check for 0s and replace with float ('-inf')
    softmaxed_scores = jax.nn.softmax(masked_scores)

    output = softmaxed_scores @ v

    return output

"""

def causal_attention(q, k, v, mask):
    embed_dim = q.shape[-1]

    attn_scores = q @ k.T
    scaled_scores = attn_scores / jnp.sqrt(embed_dim)
    masked_scores = jnp.where(mask == 0, float('-inf'), scaled_scores)  # Corrected this line
    softmaxed_scores = jax.nn.softmax(masked_scores, axis=-1)

    output = softmaxed_scores @ v

    return output

mask = create_causal_mask(seq_len)
batched_causal_attention = jax.vmap(causal_attention, in_axes=(0, 0, 0, None))
batched_causal_attention(q, k, v, mask)

### self attention

In [None]:
# self attention is named as such because it attends to itself
# so each of our q, k, v values are just some input x value of shape (seq_len, embed_dim)
# we apply weight matrices of size (embed_dim, embed_dim)
# as well as a bias of size (embed_dim)
# so that the model can learn how to assign q, k, v values to some input x
# we'll call these w_q, b_q, w_k, ...

# we also do a final projection at the end to add some more space to learn about the attended vectors

"""
def self_attention(x, w_q, b_q, w_k, b_k, w_v, b_v, w_proj, b_proj, causal_mask): # (seq_len, embed_dim) -> (seq_len, embed_dim)
    seq_len = q.shape[-1]

    # qkv projections
    q = ...
    k = ...
    v = ...

    # create mask
    causal_mask = ...

    # perform self attention
    x = ...

    # out projection
    x = ...

    return x
"""

def self_attention(x, w_q, b_q, w_k, b_k, w_v, b_v, w_proj, b_proj, causal_mask): # [seq_len, embed_dim] -> [seq_len, embed_dim]
    seq_len = x.shape[-1]
    # qkv projections
    q = x @ w_q + b_q # [seq_len, embed_dim] @ [embed_dim, embed_dim] -> (seq_len, embed_dim)
    k = x @ w_k + b_k # [seq_len, embed_dim] @ [embed_dim, embed_dim] -> (seq_len, embed_dim)
    v = x @ w_v + b_v # [seq_len, embed_dim] @ [embed_dim, embed_dim] -> (seq_len, embed_dim)


    # perform self attention
    attn_scores = (q @ k.T) / jnp.sqrt(embed_dim)
    masked_scores = jnp.where(causal_mask == 0, float('-inf'), attn_scores)
    attn_weights = jax.nn.softmax(masked_scores, axis=-1)
    x = attn_weights @ v

    # out projection
    x = (x @ w_proj) + b_proj  # [seq_len, embed_dim] @ [embed_dim, embed_dim] -> [seq_len, embed_dim]

    return x

x = jax.random.normal(key, (batch_size, seq_len, embed_dim))
w_q = jax.random.normal(key, (embed_dim, embed_dim))
b_q = jax.random.normal(key, (embed_dim,))
w_k = jax.random.normal(key, (embed_dim, embed_dim))
b_k = jax.random.normal(key, (embed_dim,))
w_v = jax.random.normal(key, (embed_dim, embed_dim))
b_v = jax.random.normal(key, (embed_dim,))
w_proj = jax.random.normal(key, (embed_dim, embed_dim))
b_proj = jax.random.normal(key, (embed_dim,))

batched_self_attention1 = jax.vmap(self_attention, in_axes=(0, *([None] * 9)))
batched_self_attention1(x, w_q, b_q, w_k, b_k, w_v, b_v, w_proj, b_proj, mask)

In [None]:
# typically, we store the learned weights for qkv all in the same matrix
# and then split them up after multiplying
# this is more efficient because multiplying one large matrix is better than 3 smaller multiplications
# especially if on an accelerator like a GPU/TPU
# so the shape of that attention weight matrix is (seq_len, 3*embed_dim)
# and we split the output of that linear layer along the last axis, the embed_dim
# to become the q/k/v for the self attention

"""
def self_attention(x, w_qkv, b_qkv, w_proj, b_proj, causal_mask): # [seq_len, embed_dim] -> [seq_len, embed_dim]

    # qkv weights in one matrix multiply, add bias
    x = ...

    # split weights into qkv
    q, k, v = ...

    # perform self attention
    x = ...

    # out projection
    x = ...

    return x

"""

def self_attention(x, w_qkv, b_qkv, w_proj, b_proj, causal_mask):
    seq_len, embed_dim = x.shape

    # qkv projections
    x = (x @ w_qkv) + b_qkv  # [seq_len, embed_dim] @ [embed_dim, 3*embed_dim] -> [seq_len, 3*embed_dim]

    # split into qkv
    q, k, v = jnp.split(x, 3, axis=-1)  # [seq_len, 3*embed_dim] -> 3 of [seq_len, embed_dim]

    # perform self attention
    attn_scores = (q @ k.T) / jnp.sqrt(embed_dim)
    masked_scores = jnp.where(causal_mask == 0, float('-inf'), attn_scores)
    attn_weights = jax.nn.softmax(masked_scores, axis=-1)
    x = attn_weights @ v

    # out projection
    x = (x @ w_proj) + b_proj  # [seq_len, embed_dim] @ [embed_dim, embed_dim] -> [seq_len, embed_dim]

    return x

w_qkv = jax.random.normal(key, (embed_dim, 3 * embed_dim))
b_qkv = jax.random.normal(key, (3 * embed_dim,))

batched_self_attention = jax.vmap(self_attention, in_axes=(0, *([None] * 5)))
batched_self_attention(x, w_qkv, b_qkv, w_proj, b_proj, mask)

### multi head attention

In [None]:
# as one last final complication, remember that each part of a sentence carries high dimensional information
# 'the boy is eating a'
# we want to generate the next token based on the semantics of 'eating' and it being a food
# we want to have it be a singular item, because of the 'a'
# so we may want to pay different kinds of attention to different tokens, depending on syntax/grammar/semantics

# this leads us to multi head attention
# to maintain the computational complexity
# we downscale the Q/K/V matrices from embed_dim down to embed_dim//n_heads (head_dim)
#
# and we end up with multiple matrices Q_i, K_i, V_i for each ith head
# note that this is a tradeoff! we sacrifice some dimensionality for each token but gain ensembling power

In [None]:
# we'll start by doing the linear layer as before
# which will give us a shape of (seq_len, 3*embed_dim)
# then, we need to split into q/k/v, as well as split along each head
# we want to end up with shape (seq_len, 3, n_head, head_dim)
# so we can just use a .reshape to get there
# then, we can split this tensor along that second dimension we specified earlier
# now x_q, x_k, x_v are tensors of shape (seq_len, n_head, head_dim)
# we need to perform q @ k.T
# and then @ v
# and we want to do it along the n heads
# so we need to transpose our tensors into (n_head, seq_len, head_dim)
# note that while doing this reshaping, we can actually do the transpose on k
# and transpose k to (n_head, head_dim, seq_len) instead
# then, we do the matrix multiplication, scale by head_dim, mask and softmax
# to get our context weightings
# remember that our context weightings are of shape (n_head, seq_len, head_dim)
# and we want to go back to (seq_len, n_head, head_dim)
# so we do the .transpose again
# finally, we need to put the n_head/head_dim back together before our final linear layer
# so we .reshape into (seq_len, embed_dim)
# which concatenates the heads back together
# and do the final output projection
# [seq_len, n_head, head_dim]

"""
def multi_head_attn(x: jax.Array, w_qkv: jax.Array, b_qkv: jax.Array, w_proj: jax.Array, b_proj: jax.Array, n_head: int, causal_mask: jax.Array):
    seq_len, embed_dim = ...
    head_dim = ...

    # linear projections
    x_qkv = ...
    x_qkv_heads = ... # split into heads
    xq, xk, xv = ... # indices [0, 1, 2] respectively on the dim we just created

    # reshape for attention
    xq = ...
    xkt = ...
    xv = ...

    # create search with q/k
    raw_scores = ...
    scaled_scores = ...
    masked_scores = ...
    attn_weights = ...

    # apply search with v to get final contextual weights
    context_weights = ...

    # transpose and reshape
    transposed_context_weights = ...
    reshaped_context_weights = ...

    # project vectors back to tokens
    token_logits = ...

    return token_logits
"""

def multi_head_attn(x: jax.Array, w_qkv: jax.Array, b_qkv: jax.Array, w_proj: jax.Array, b_proj: jax.Array, n_head: int, causal_mask: jax.Array):
    seq_len, embed_dim = x.shape
    head_dim = embed_dim // n_head

    x_qkv = jnp.dot(x, w_qkv) + b_qkv
    x_qkv_heads = x_qkv.reshape(seq_len, 3, n_head, head_dim)
    xq, xk, xv = x_qkv_heads[:, 0], x_qkv_heads[:, 1], x_qkv_heads[:, 2]

    xq = xq.transpose(1, 0, 2)  # (n_head, seq_len, head_dim)
    xkt = xk.transpose(1, 2, 0)  # (n_head, head_dim, seq_len)
    xv = xv.transpose(1, 0, 2)  # (n_head, seq_len, head_dim)

    raw_scores = jnp.matmul(xq, xkt)/ jnp.sqrt(head_dim)
    scaled_scores = raw_scores
    masked_scores = jnp.where(causal_mask == 0, float('-inf'), scaled_scores)
    attn_weights = jax.nn.softmax(masked_scores, axis=-1)

    context_weights = jnp.matmul(attn_weights, xv)

    transposed_context_weights = context_weights.transpose(1, 0, 2)
    reshaped_context_weights = transposed_context_weights.reshape(seq_len, embed_dim)

    token_logits = jnp.dot(reshaped_context_weights, w_proj) + b_proj

    return token_logits

n_head = 2
batched_multi_head_attn = jax.vmap(multi_head_attn, in_axes=(0, *([None] * 6)))
batched_multi_head_attn(x, w_qkv, b_qkv, w_proj, b_proj, n_head, mask)

In [None]:
# let's make our first model that actually puts this into practice!
# make sure to initialize the parameters with the correct shapes

class ModelConfig(NamedTuple):
    vocab_size: int
    embedding_dim: int
    context_len: int
    n_head: int

"""
def init_attentive_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, 7)
    return {
        'token_embedding': jax.random.normal(k1, (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(k2, (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(k3, (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in_weights': ...,
        'attn_in_bias': ...,
        'attn_out_weights': ...,
        'attn_out_bias': ...,

"""

def init_attentive_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    k1, k2, k3, k4, k5, k6, k7 = jax.random.split(key, 7)
    return {
        'token_embedding': jax.random.normal(k1, (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(k2, (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(k3, (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in_weights': jax.random.normal(k4, (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
        'attn_in_bias': jax.random.normal(k5, (3*model_config.embedding_dim,)) * scaling_factor,
        'attn_out_weights': jax.random.normal(k6, (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
        'attn_out_bias': jax.random.normal(k7, (model_config.embedding_dim,)) * scaling_factor
    }

In [None]:
# for the forward pass
# after embedding the tokens
# create the causal mask
# apply the multi head attention
# then project the vectors back into tokens

"""
def attentive_model(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]

    # embed tokens
    token_embeds = ...
    pos_embeds = ...
    embedded = ...


    # create causal mask
    causal_mask = ...

    # apply multi-head attention
    context = ...

    # project vectors back to tokens
    token_logits = ...

    return token_logits
"""

def attentive_model(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]


    token_embeds = model_params['token_embedding'][tokens]
    pos_embeds = model_params['positional_embedding'][:seq_len]
    embeds = token_embeds + pos_embeds

    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))


    context = multi_head_attn(embeds, model_params['attn_in_weights'], model_params['attn_in_bias'], model_params['attn_out_weights'],
                        model_params['attn_out_bias'], model_config.n_head, causal_mask)

    token_logits = jnp.dot(context, model_params['output_projection'])

    return token_logits

In [None]:
# once again, we train this model

MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos,
    'attentive_model': attentive_model,
}

train_config = TrainConfig(num_epochs=50, batches_per_epoch=32, batch_size=32, batch_seq_len=64, seed=0)
adamw_config = OptConfig(lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, opt_init=init_adam_state, opt_update=adamw_update)
attn_config = ModelConfig(vocab_size=95, embedding_dim=32, context_len=128, n_head=4)
attn_params = init_attentive_params(key, attn_config)
trained_attn_params = train(attn_params, attn_config, 'attentive_model', train_tokens, adamw_config, train_config)
compare_params(prompt, attn_config, 'attentive_model', attn_params, trained_attn_params)
compare_params(shuffled_prompt, attn_config, 'attentive_model', attn_params, trained_attn_params)

### feed forward networks

In [None]:
# while we've given the model a way to understand context
# and it successfully produces different output for a shuffled prompt
# its learning context of what tokens to pay attention to
# but it doesnt really process that information in any way
# we can take the outputs and process them by using a basic feedforward network
# when a feedforward network processes mnist handwritten data
# it goes from understanding the vector representation of the grayscale pixels
# to enough information to classify them into digits

# similarly, taking the output of the attention's contextualized representation of the positionally encoded token vectors
# the feedforward network 'thinks about it' and produces a more useful output

# our network will expand into a 4x dimensional space from its input
# to have more 'space' to work with
# go through an activation function
# to understand nonlinear relationships
# and shrink back down into the original dimension of the token embeddings
# forcing it to compress what its learned into the most useful stuff


In [None]:
# each layer gets a weight matrix and a bias matrix
# the input gets dotted with the weight, then you add the bias
# just like we've seen many times above
# the weight/bias are kind of like the slope/intercept in a regression
# the slope does a lot of heavy lifting in the understanding of relationships
# and the weight means it doesn't have to start from 0, and helps get the vectors into the right place

"""
def linear(x, weight, bias):  # [m, in], [in, out], [out] -> [m, out]
    return
"""

def linear(x, weight, bias):  # [m, in], [in, out], [out] -> [m, out]
    return x @ weight + bias

in_dim = embed_dim
out_dim = 2 * embed_dim
weight = jax.random.normal(key, (in_dim, out_dim))
bias = jax.random.normal(key, (out_dim,))
linear(x[0], weight, bias)

In [None]:
# to add nonlinearity to this network
# we use an activation function that is not linear
# typically something like relu (f(x) = max(0,x)) is used for this
# id love to write an intuitive reason for the choice of activation function for gpt2
# but it was actually just chosen bc it works empirically
# the activation function is applied after each linear layer, except the last one

def gelu(x):
    return 0.5 * x * (1 +jnp.tanh(jnp.sqrt(2 /jnp.pi) * (x + 0.044715 * x**3)))

In [None]:
# so now our network should take in two sets of linear layer parameters
# one for projecting up in to the higher dimensional space
# and one for projecting back down

"""
def ffn(x, c_in_params, c_out_params):  # [n_seq, n_embd] -> [n_seq, n_embd]
    # project up
    projected_up = ...

    # project back down
    output = ...

    return output
"""

def ffn(x, c_in_weight, c_in_bias, c_out_weight, c_out_bias):  # [n_seq, n_embd] -> [n_seq, n_embd]
    # project up
    projected_up = gelu(linear(x, c_in_weight, c_in_bias))  # [n_seq, n_embd] -> [n_seq, 4*n_embd]

    # project back down
    output = linear(projected_up, c_out_weight, c_out_bias)  # [n_seq, 4*n_embd] -> [n_seq, n_embd]

    return output

c_in_weight = jax.random.normal(key, (embed_dim, 4 * embed_dim))
c_in_bias = jax.random.normal(key, (4 * embed_dim,))
c_out_weight = jax.random.normal(key, (4 * embed_dim, embed_dim))
c_out_bias = jax.random.normal(key, (embed_dim,))
ffn(x[0], c_in_weight, c_in_bias, c_out_weight, c_out_bias)

In [None]:
# we'll clean up our parameters a little bit while we're here
# we can use nested dictionaries for each of our linear layers

class ModelConfig(NamedTuple):
    vocab_size: int
    embedding_dim: int
    context_len: int
    n_head: int

"""
def init_attentive_ffn_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    key, *subkeys  = jax.random.split(key, 12)
    return {
        'token_embedding': jax.random.normal(subkeys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(subkeys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(subkeys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in': {
            'weight': jax.random.normal(subkeys[3], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[4], (3*model_config.embedding_dim,)) * scaling_factor,
        },
        'attn_out': {
            'weight': jax.random.normal(subkeys[5], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[6], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_in': {},
        'ffn_out': {},

    }
"""

def init_attentive_ffn_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    key, *subkeys  = jax.random.split(key, 12)
    return {
        'token_embedding': jax.random.normal(subkeys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(subkeys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(subkeys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in': {
            'weight': jax.random.normal(subkeys[3], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[4], (3*model_config.embedding_dim,)) * scaling_factor,
        },
        'attn_out': {
            'weight': jax.random.normal(subkeys[5], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[6], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_in': {
            'weight': jax.random.normal(subkeys[7], (model_config.embedding_dim, 4*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[8], (4*model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_out': {
            'weight': jax.random.normal(subkeys[9], (4*model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[10], (model_config.embedding_dim,)) * scaling_factor,
        },

    }


In [None]:
# for the forward pass
# just add in the ffn function we defined earlier
# right after the multi head attention

"""

def attentive_ffn(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = ...
    causal_mask = ...

    token_embeds = ...
    pos_embeds = ...
    embedded = ...


    context = ...

    enhanced_context = ...

    token_logits = ...

    return token_logits
"""

def attentive_ffn(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]
    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    token_embeds = model_params['token_embedding'][tokens]
    pos_embeds = model_params['positional_embedding'][:seq_len]
    embeds = token_embeds + pos_embeds


    context = multi_head_attn(embeds, model_params['attn_in']['weight'], model_params['attn_in']['bias'], model_params['attn_out']['weight'],
                        model_params['attn_out']['bias'], model_config.n_head, causal_mask)

    enhanced_context = ffn(context, model_params['ffn_in']['weight'], model_params['ffn_in']['bias'], model_params['ffn_out']['weight'], model_params['ffn_out']['bias'])

    token_logits = jnp.dot(enhanced_context, model_params['output_projection'])

    return token_logits

In [None]:
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos,
    'attentive_model': attentive_model,
    'attentive_ffn': attentive_ffn
}

train_config = TrainConfig(num_epochs=50, batches_per_epoch=32, batch_size=32, batch_seq_len=64, seed=0)
adamw_config = OptConfig(lr=0.001, beta1=0.9, beta2=0.999, eps=1e-8, weight_decay=0.01, opt_init=init_adam_state, opt_update=adamw_update)
attn_config = ModelConfig(vocab_size=95, embedding_dim=32, context_len=128, n_head=4)
attn_ffn = init_attentive_params(key, attn_config)
trained_attn_ffn_params = train(attn_ffn, attn_config, 'attentive_model', train_tokens, adamw_config, train_config)
compare_params(prompt, attn_config, 'attentive_model', attn_ffn, trained_attn_ffn_params)


### layer normalization

In [None]:
# while we gave the model room to think more
# we're starting to get a pretty deep network
# as evidenced by our higher ending loss
# the model immediately has to start processing the attention output in the feedforward network
# we can think of it as the model needing to recalibrate before moving onto the next task
# a little more mathematically, we might be getting very skewed or weirdly distributed values from the attention output
# so we want a way to stabilize the input that each of our functions gets
# the most obvious way to do this is to normalize the distribution, to have mean 0 and variance 1
# this may not be the best way to represent the data to the next function
# so we also give the model a learnable scale/shift parameter
# so it can affine transform the distribution as needed

# when writing the layer normalization
# we can use jnp mean, jnp var, etc
# be sure to use the axis and keepdims parameters
# to make sure we take the mean across the correct axis
# we want to take the mean/variance over the last axis (the one where the features are)
# and we want to keep the dimensions after the processing

"""
def layer_norm(x, gamma, beta, eps: float = 1e-5):
    mean = ...
    variance = ...
    normalized_x = ...  # normalize x to have mean=0 and var=1 over last axis
    affine_x = ...
    return affine_x
"""

def layer_norm(x, gamma, beta, eps: float = 1e-5):
    mean = jnp.mean(x, axis=-1, keepdims=True)
    variance = jnp.var(x, axis=-1, keepdims=True)
    normalized_x = (x - mean) / jnp.sqrt(variance + eps)  # normalize x to have mean=0 and var=1 over last axis
    affine_x = gamma * normalized_x + beta
    return affine_x

gamma = jnp.ones((embed_dim,))
beta = jnp.zeros((embed_dim,))
layer_norm_output = layer_norm(x[0], gamma, beta)
jnp.mean(layer_norm_output, axis=-1, keepdims=True), jnp.var(layer_norm_output, axis=-1, keepdims=True)

In [None]:
# in our model
# we'll want to have a layer norm before the attention output
# a layer norm before the feedforward
# and a final layer norm before we project back into token space
# you might not want to initialize the gamma/beta from the normal distribution ... whats a sensible default (jnp.zeros/jnp.ones?)

"""
def init_attentive_ffn_ln_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    key, *subkeys = jax.random.split(key, 16)
    return {
        'token_embedding': jax.random.normal(subkeys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(subkeys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(subkeys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in': {
            'weight': jax.random.normal(subkeys[3], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[4], (3*model_config.embedding_dim,)) * scaling_factor,
        },
        'attn_out': {
            'weight': jax.random.normal(subkeys[5], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[6], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_in': {
            'weight': jax.random.normal(subkeys[7], (model_config.embedding_dim, 4*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[8], (4*model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_out': {
            'weight': jax.random.normal(subkeys[9], (4*model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[10], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ln1': {},
        'ln2': {},
        'lnf': {},
    }
"""
def init_attentive_ffn_ln_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    key, *subkeys = jax.random.split(key, 16)
    return {
        'token_embedding': jax.random.normal(subkeys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(subkeys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(subkeys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'attn_in': {
            'weight': jax.random.normal(subkeys[3], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[4], (3*model_config.embedding_dim,)) * scaling_factor,
        },
        'attn_out': {
            'weight': jax.random.normal(subkeys[5], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[6], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_in': {
            'weight': jax.random.normal(subkeys[7], (model_config.embedding_dim, 4*model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[8], (4*model_config.embedding_dim,)) * scaling_factor,
        },
        'ffn_out': {
            'weight': jax.random.normal(subkeys[9], (4*model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
            'bias': jax.random.normal(subkeys[10], (model_config.embedding_dim,)) * scaling_factor,
        },
        'ln1': {
            'gamma': jnp.ones((model_config.embedding_dim,)),
            'beta': jnp.zeros((model_config.embedding_dim,)),
        },
        'ln2': {
            'gamma': jnp.ones((model_config.embedding_dim,)),
            'beta': jnp.zeros((model_config.embedding_dim,)),
        },
        'lnf': {
            'gamma': jnp.ones((model_config.embedding_dim,)),
            'beta': jnp.zeros((model_config.embedding_dim,)),
        },
    }

In [None]:
# for the forward pass
# add in the layer norms as we discussed above

"""
def attentive_ffn_ln(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = ...
    causal_mask = ...
    token_embeds = ...
    pos_embeds = ...
    embedded = ...
    normed_embedded = ...



    context = ...
    normed_context = ...

    enhanced_context = ...
    normed_enhanced_context = ...

    token_logits = ...

    return token_logits
"""

def attentive_ffn_ln(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]
    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    token_embeds = model_params['token_embedding'][tokens]
    pos_embeds = model_params['positional_embedding'][:seq_len]
    embedded = token_embeds + pos_embeds



    normed_embedded = layer_norm(embedded, model_params['ln1']['gamma'], model_params['ln1']['beta'])

    context = multi_head_attn(normed_embedded, model_params['attn_in']['weight'], model_params['attn_in']['bias'], model_params['attn_out']['weight'],
                        model_params['attn_out']['bias'], model_config.n_head, causal_mask)

    normed_context = layer_norm(context, model_params['ln2']['gamma'], model_params['ln2']['beta'])

    enhanced_context = ffn(normed_context, model_params['ffn_in']['weight'], model_params['ffn_in']['bias'], model_params['ffn_out']['weight'], model_params['ffn_out']['bias'])

    normed_enhanced_context = layer_norm(enhanced_context, model_params['lnf']['gamma'], model_params['lnf']['beta'])

    token_logits = jnp.dot(normed_enhanced_context, model_params['output_projection'])

    return token_logits

In [None]:
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos,
    'attentive_model': attentive_model,
    'attentive_ffn': attentive_ffn,
    'attentive_ffn_ln': attentive_ffn_ln,
}

attn_ffn_ln_params = init_attentive_ffn_ln_params(key, attn_config)
trained_attn_ffn_ln_params = train(attn_ffn_ln_params, attn_config, 'attentive_ffn_ln', train_tokens, adamw_config, train_config)
compare_params(prompt, attn_config, 'attentive_ffn_ln', attn_ffn_ln_params, trained_attn_ffn_ln_params)

### residual connections

In [None]:
# this is great! our outputs are getting more and more sensible
# but we're still not quite there yet
# as our model gets larger and deeper
# inputs have to travel further and further
# we've alleviated this a bit with the layer normalizations
# but its not enough
# computer vision researchers found that after a certain point, adding more depth doesn't really work
# suppose we have a model that is some n layers deep
# we would assume that a model n+1 layers deep should always be at least as good as the n layer model
# after all, it could just replicate the n layer deep function
# and use the identity function in the excess layer
# but empirically, training deeper networks actually had higher training loss

# a rough intuition for this is that as we go through these complex transformations
# our initial weights can't learn easily from the gradients we got from the input we trained on
# its hard for the model to both preserve complex transformations, as well as the original data
# imagine if you played a game of telephone, where the first person asks a question about the group of people
# and the last person has to answer it, with everyone
# not only will the initial question get lost
# but everyone along the way has to add their information to that question, further obscuring the information
# so we need a way to pass information through the network that isn't strictly sequential

# this is because learning the identity function is actually not that easy
# we're trying to estimate some H(x)
# by doing H(x) = a(b(c(d(x))))
# since each of these functions is nonlinear
# its pretty hard to preserve x, if H(x) or one of these subfunctions needs to be f(x)=x
# so we can reformulate the problem as trying to predict
# H(x) = F(x) + x
# now each step can just predict the difference that needs to be applied to x
# and it becomes easy to just do nothing to the input

# we basically just save the input before we do our layernorm/attention/ffn
# and add it to whatever the function's output was
# like we saw in the above example

In [None]:
"""

def attentive_ffn_ln_res(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = ...
    causal_mask = ...

    token_embeds = ...
    pos_embeds = ...

    embedded = ...
    embedded_residual = ...
    normed_embedded = ...

    context = ...

    context = ... # add residual

    context_residual = ...

    normed_context = ...

    enhanced_context = ...

    enhanced_context = ... # add residual

    normed_enhanced_context = ...

    token_logits = ...

    return token_logit

"""
def attentive_ffn_ln_res(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]
    causal_mask = jnp.tril(jnp.ones((seq_len, seq_len)))

    token_embeds = model_params['token_embedding'][tokens]
    pos_embeds = model_params['positional_embedding'][:seq_len]

    embedded = token_embeds + pos_embeds
    embedded_residual = embedded
    normed_embedded = layer_norm(embedded, model_params['ln1']['gamma'], model_params['ln1']['beta'])

    context = multi_head_attn(normed_embedded, model_params['attn_in']['weight'], model_params['attn_in']['bias'], model_params['attn_out']['weight'],
                        model_params['attn_out']['bias'], model_config.n_head, causal_mask)
    context = context + embedded_residual

    context_residual = context

    normed_context = layer_norm(context, model_params['ln2']['gamma'], model_params['ln2']['beta'])

    enhanced_context = ffn(normed_context, model_params['ffn_in']['weight'], model_params['ffn_in']['bias'], model_params['ffn_out']['weight'], model_params['ffn_out']['bias'])

    enhanced_context = enhanced_context + context_residual

    normed_enhanced_context = layer_norm(enhanced_context, model_params['lnf']['gamma'], model_params['lnf']['beta'])

    token_logits = jnp.dot(normed_enhanced_context, model_params['output_projection'])

    return token_logits



In [None]:
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos,
    'attentive_model': attentive_model,
    'attentive_ffn': attentive_ffn,
    'attentive_ffn_ln': attentive_ffn_ln,
    'attentive_ffn_ln_res': attentive_ffn_ln_res,

}

attn_ffn_ln_params = init_attentive_ffn_ln_params(key, attn_config)
trained_attn_ffn_ln_params = train(attn_ffn_ln_params, attn_config, 'attentive_ffn_ln_res', train_tokens, adamw_config, train_config)
compare_params(prompt, attn_config, 'attentive_ffn_ln_res', attn_ffn_ln_params, trained_attn_ffn_ln_params)

### transformer blocks

In [None]:
# the model still isn't fully there yet
# while we now have a powerful structure
# that can understand context and process on that context
# we want to give it even more space to learn
# so we can repeat the attention and the feedforward network
# what this does is it allows the model to process the input in multiple passes
# kind of like reading a textbook in multiple passes
# the first pass might be like skimming over the material, looking at headings and subheadings
# the second pass might be a closer look, reading more deeply into the important parts
# and the final pass could be a deep reading of the book, made better by the abstract understanding from previous passes

# so we'll define a block of a transformer: the residuals and layer norms around the attention/feedforward
# and pass our input through multiple of these blocks

In [None]:
class ModelConfig(NamedTuple):
    vocab_size: int
    embedding_dim: int
    context_len: int
    n_head: int
    n_blocks: int

"""

def init_multilayer_transformer_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    keys = jax.random.split(key, ) # use n blocks to figure out how many keys you need

    params = {
        'token_embedding': jax.random.normal(keys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(keys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(keys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'lnf': {
            'weight': jnp.ones((model_config.embedding_dim,)),
            'bias': jnp.zeros((model_config.embedding_dim,)),
        },
    }

    for i in range(model_config.n_blocks):
        block_key_start = ... # figure out the correct indexing strategy here
        params[f'block_{i}'] = {
            'attn_in': {
                'weight': jax.random.normal(keys[block_key_start], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+1], (3*model_config.embedding_dim,)) * scaling_factor,
            },
            'attn_out': {
                'weight': jax.random.normal(keys[block_key_start+2], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+3], (model_config.embedding_dim,)) * scaling_factor,
            },
            'ln1': {
                'weight': jnp.ones((model_config.embedding_dim,)),
                'bias': jnp.zeros((model_config.embedding_dim,)),
            },
            'ln2': {
                'weight': jnp.ones((model_config.embedding_dim,)),
                'bias': jnp.zeros((model_config.embedding_dim,)),
            },
            'ffn_in': {
                'weight': jax.random.normal(keys[block_key_start+4], (model_config.embedding_dim, 4*model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+5], (4*model_config.embedding_dim,)) * scaling_factor,
            },
            'ffn_out': {
                'weight': jax.random.normal(keys[block_key_start+6], (4*model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+7], (model_config.embedding_dim,)) * scaling_factor,
            },
        }

    return params
"""
def init_multilayer_transformer_params(key: jax.random.PRNGKey, model_config: ModelConfig, scaling_factor = 0.02) -> Dict:
    keys = jax.random.split(key, 7 + 10 * model_config.n_blocks)

    params = {
        'token_embedding': jax.random.normal(keys[0], (model_config.vocab_size, model_config.embedding_dim)) * scaling_factor,
        'positional_embedding': jax.random.normal(keys[1], (model_config.context_len, model_config.embedding_dim)) * scaling_factor,
        'output_projection': jax.random.normal(keys[2], (model_config.embedding_dim, model_config.vocab_size)) * scaling_factor,
        'lnf': {
            'gamma': jnp.ones((model_config.embedding_dim,)),
            'beta': jnp.zeros((model_config.embedding_dim,)),
        },
    }

    for i in range(model_config.n_blocks):
        block_key_start = 3 + i * 10
        params[f'block_{i}'] = {
            'attn_in': {
                'weight': jax.random.normal(keys[block_key_start], (model_config.embedding_dim, 3*model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+1], (3*model_config.embedding_dim,)) * scaling_factor,
            },
            'attn_out': {
                'weight': jax.random.normal(keys[block_key_start+2], (model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+3], (model_config.embedding_dim,)) * scaling_factor,
            },
            'ln1': {
                'gamma': jnp.ones((model_config.embedding_dim,)),
                'beta': jnp.zeros((model_config.embedding_dim,)),
            },
            'ln2': {
                'gamma': jnp.ones((model_config.embedding_dim,)),
                'beta': jnp.zeros((model_config.embedding_dim,)),
            },
            'ffn_in': {
                'weight': jax.random.normal(keys[block_key_start+4], (model_config.embedding_dim, 4*model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+5], (4*model_config.embedding_dim,)) * scaling_factor,
            },
            'ffn_out': {
                'weight': jax.random.normal(keys[block_key_start+6], (4*model_config.embedding_dim, model_config.embedding_dim)) * scaling_factor,
                'bias': jax.random.normal(keys[block_key_start+7], (model_config.embedding_dim,)) * scaling_factor,
            },
        }

    return params

In [None]:
# lets first define the structure of what we're creating
# assume the transformer_block function is already implemented
# none of this should be new, we're just adding in the logic to pass the input through multiple blocks

"""
def multilayer_transformer(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = ...

    token_embeds = ...
    pos_embeds = ...
    x = ...

    causal_mask = ...

    for i in range(model_config.n_blocks):
        x = ...

    x = ...

    token_logits = ...

    return token_logits
"""

def multilayer_transformer(model_params: Dict, model_config: ModelConfig, tokens: jax.Array) -> jax.Array:
    seq_len = tokens.shape[0]

    token_embeds = model_params['token_embedding'][tokens]
    pos_embeds = model_params['positional_embedding'][:seq_len]
    x = token_embeds + pos_embeds

    causal_mask = create_causal_mask(seq_len)

    for i in range(model_config.n_blocks):
        x = transformer_block(x, model_params[f'block_{i}'], model_config, causal_mask)

    x = layer_norm(x, model_params['lnf']['gamma'], model_params['lnf']['beta'])

    token_logits = jnp.dot(x, model_params['output_projection'])

    return token_logits

In [None]:
# here we implemnt the actual block
# again, nothing you haven't seen before

"""
def transformer_block(x: jax.Array, block_params: Dict, model_config: ModelConfig, causal_mask: jax.Array):

    residual = ...

    normed_x = ...

    context = ...

    context = ...

    context_residual = ...
    normed_context = ...

    enhanced_context = ...

    enhanced_context = ...

    return enhanced_context
"""

def transformer_block(x: jax.Array, block_params: Dict, model_config: ModelConfig, causal_mask: jax.Array):

    residual = x

    normed_x = layer_norm(x, block_params['ln1']['gamma'], block_params['ln1']['beta'])

    context = multi_head_attn(normed_x, block_params['attn_in']['weight'], block_params['attn_in']['bias'],
                              block_params['attn_out']['weight'], block_params['attn_out']['bias'],
                              model_config.n_head, causal_mask)

    context = context + residual

    context_residual = context
    normed_context = layer_norm(context, block_params['ln2']['gamma'], block_params['ln2']['beta'])

    enhanced_context = ffn(normed_context, block_params['ffn_in']['weight'], block_params['ffn_in']['bias'],
                           block_params['ffn_out']['weight'], block_params['ffn_out']['bias'])

    enhanced_context = enhanced_context + context_residual

    return enhanced_context

In [None]:
MODEL_DICT = {
    'bigram_model': bigram_model,
    'bigram_embed': bigram_embed,
    'causal_embed': causal_embed,
    'causal_pos': causal_pos,
    'attentive_model': attentive_model,
    'attentive_ffn': attentive_ffn,
    'attentive_ffn_ln': attentive_ffn_ln,
    'attentive_ffn_ln_res': attentive_ffn_ln_res,
    'multilayer_transformer': multilayer_transformer,

}

transformer_config = ModelConfig(vocab_size=95, embedding_dim=32, context_len=128, n_head=4, n_blocks=12)
multilayer_transformer_params = init_multilayer_transformer_params(key, transformer_config)
trained_multilayer_transformer_params = train(multilayer_transformer_params, transformer_config, 'multilayer_transformer', train_tokens, adamw_config, train_config)
compare_params(prompt, transformer_config, 'multilayer_transformer', multilayer_transformer_params, trained_multilayer_transformer_params)

### sampling hyper parameters

In [None]:
# thats all for the model architecture part of gpt2
# but theres a couple more things we can do with generation and tokenization

# right now we sample just using temperature and softmax
# even with temperature control
# we still consider literally every token
# this is kind of a pain, because we might generate a super unlikely token
# even with low temperature
# so we want a way to only select some 'reasonable' tokens

# the way that gpt-2 did this was a method called top k
# exactly as it sounds, it restricts the sample space to only the top k tokens

In [None]:
# top k is nice, but it doesnt really account for differently shaped distributions
# at some generations, we might have a lot of different highly probable choices
# that are greater than k
# instead of the discrete cutoff of k
# we could instead use a cutoff of some amount of cumulative probability
# and only sample tokens within that space
# this is known as top p, or nucleus sampling
# this is because it samples from the nucleus or core of the probability distribution

In [None]:
# another sampling filter we can add is min_p
# it makes sure we only select logits above some minimum threshold
# this is especially helpful when temperatures are high, like above 1
# usually this can result in creative but unintelligible output
# but min_p helps control it

In [None]:
"""
def generate(params: Dict, model_config: ModelConfig, model_name: str, tokens: jax.Array, max_new: int, key: jax.random.PRNGKey, temp=1.0, top_k=None, top_p=None, min_p=None):

    gen_tokens = jnp.array([], dtype=jnp.int32)
    cur_pos = 0

    # you'll probably need to use jnp where, jnp take along axis, jnp maximum, etc
    # one neat way to do this is to update a logit cutoff value with each of the filtering steps

    def filter_logits(logits, top_k=None, top_p=None, min_p=0.1):
        ...

        # Initialize cutoff with the minimum logit value
        cutoff = ...

        if min_p is not None:
            min_p_mask = ...
            cutoff = ...

        if top_k is not None:
            top_k_cutoff = ...
            cutoff = ...

        if top_p is not None:
            top_p_mask = ...
            cutoff = ...

        return jnp.where(logits < cutoff, -jnp.inf, logits)

    while cur_pos < max_new:
        key, subkey = jax.random.split(key, 2)
        logits = MODEL_DICT[model_name](params, model_config, tokens)
        last_token_logit = logits[-1:]
        scaled_logits = last_token_logit / temp
        filtered_logits = filter_logits(scaled_logits, top_k, top_p)

        next_token = jax.random.categorical(subkey, filtered_logits, shape=(1,))
        gen_tokens = jnp.concatenate((gen_tokens, next_token))
        tokens = jnp.concatenate((tokens, next_token))
        cur_pos += 1

    return gen_tokens
"""

def generate(params: Dict, model_config: ModelConfig, model_name: str, tokens: jax.Array, max_new: int, key: jax.random.PRNGKey, temp=1.0, top_k=None, top_p=None, min_p=None):

    gen_tokens = jnp.array([], dtype=jnp.int32)
    cur_pos = 0

    def filter_logits(logits, top_k=None, top_p=None, min_p=0.1):
        probs = jax.nn.softmax(logits, axis=-1)
        sorted_indices = jnp.argsort(logits, axis=-1)[:, ::-1]
        sorted_logits = jnp.take_along_axis(logits, sorted_indices, axis=-1)
        cumulative_probs = jnp.cumsum(jnp.take_along_axis(probs, sorted_indices, axis=-1), axis=-1)

        cutoff = jnp.min(logits, axis=-1, keepdims=True)

        if min_p is not None:
            min_p_mask = probs >= min_p
            cutoff = jnp.where(min_p_mask, cutoff, -jnp.inf)

        if top_k is not None:
            top_k_cutoff = sorted_logits[:, top_k-1:top_k]
            cutoff = jnp.maximum(cutoff, top_k_cutoff)

        if top_p is not None:
            top_p_mask = cumulative_probs <= top_p
            dynamic_cutoff = jnp.where(top_p_mask, sorted_logits, -jnp.inf)
            cutoff = jnp.maximum(cutoff, jnp.max(dynamic_cutoff, axis=-1, keepdims=True))

        return jnp.where(logits < cutoff, -jnp.inf, logits)

    while cur_pos < max_new:
        key, subkey = jax.random.split(key, 2)
        logits = MODEL_DICT[model_name](params, model_config, tokens)
        last_token_logit = logits[-1:]
        scaled_logits = last_token_logit / temp
        filtered_logits = filter_logits(scaled_logits, top_k, top_p)

        next_token = jax.random.categorical(subkey, filtered_logits, shape=(1,))
        gen_tokens = jnp.concatenate((gen_tokens, next_token))
        tokens = jnp.concatenate((tokens, next_token))
        cur_pos += 1

    return gen_tokens

In [None]:
# use this to play around with the parameters, notice how they affect the generations

tokenized_prompt = jnp.array(jit_tokenizer.encode(preprocess_text(prompt)), dtype=jnp.int32)
generated_tokens = generate(trained_multilayer_transformer_params, transformer_config, 'multilayer_transformer', tokens=tokenized_prompt, max_new=10, temp=1.5, min_p=0.05, top_k=40, top_p=0.5, key=jax.random.PRNGKey(0))
generated_text = jit_tokenizer.decode(generated_tokens)
print(jit_tokenizer.decode(tokenized_prompt))
print(prompt + generated_text)

### byte pair tokenization

In [None]:
# right now we tokenize with every single character
# but this becomes quite inefficient quickly
# one of the easiest ways to understand this is the context length
# since our model has a fixed context length,
# if each of our tokens is a character, we won't get long enough generations

# making each token a word would be pretty tough too
# for one, our model would never be able to say a word it hasnt seen before
# this limits the models creativity and forces us to have very very strong training data
# so we want some way to get between a full word and a character

# our problem is basically a data compression problem
# words might be overly compressive, and each character is under compressive
# lets think of each character as carrying some amount of information, on its own
# suppose each character is a byte, from 0-255
# if we were to inspect our training data
# we might find that some bytes tend to be paired with others
# rather than having to predict both bytes
# why not develop a way to predict both of those bytes at the same time?
# if we see a very common pair of bytes, we can replace it with a new byte that hasnt been used, say 256 in our example
# we can do this as many times as we want on our training data, until we think we've achieved sufficient compression


In [None]:
# rather than using ASCII as we have above
# we'll use unicode, so we can support emojis and such
text = "This is an example text for BPE tokenization. It includes some unicode characters like 你好 and émoji 😊."
list(text.encode('utf-8'))[:20]

In [None]:
# so to compress our data, we'll want to find frequent token pairs
# and greedily merge them
# and maintain a mapping from the original bytes to our merged indexes
# first lets write a function to count token pairs

"""
def count_token_pairs(tokens):
    token_pairs = {}
    ...
    return token_pairs
"""

def count_token_pairs(tokens):
    token_pairs = {}
    for i in range(len(tokens)-1):
        pair = (tokens[i], tokens[i+1])
        token_pairs[pair] = token_pairs.get(pair, 0) + 1
    return token_pairs

tokens = list(text.encode('utf-8'))
token_pairs = count_token_pairs(tokens)
sorted_pairs = sorted(((v,k) for k,v in token_pairs.items()), reverse=True)
[(chr(x[1][0]), chr(x[1][1])) for x in sorted_pairs[:5]] # print out top 5

In [None]:
# now anytime we see the top pair count, swap it out for new max value
# [66, 121, 116, 101, 32, 112, 97, 105, 114, 32]
# [66, 121, 116, NEW, 112, 97, 105, 114, 32]
"""
def merge_pair(tokens, top_pair, new_id):
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == top_pair:
            new_tokens.append(new_id)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens
"""

def merge_pair(tokens, top_pair, new_id):
    new_tokens = []
    i = 0
    while i < len(tokens):
        if i < len(tokens) - 1 and (tokens[i], tokens[i+1]) == top_pair:
            new_tokens.append(new_id)
            i += 2
        else:
            new_tokens.append(tokens[i])
            i += 1
    return new_tokens

top_pair = sorted_pairs[0][1]
print(top_pair)
new_tokens = merge_pair(tokens, top_pair, new_id = 256)
print(new_tokens[:20])

In [None]:
# in order to encode new text
# we need to know the merges we performed, in order
# so that we can perform them correctly when encoding new text

# to decode text
# we'll need a dictionary that maps integer ids back into bytes objects to decode with utf8
# we'll call this id_to_token


# let's write the encode/decode functions now

"""

def encode(text, merges):
    tokens = list(text.encode('utf-8'))
    # apply merges
    return tokens

def decode(encoded_tokens, id_to_token):
    # use id_to_token and convert to a bytes object with b''.join()
    return

"""

def encode(text, merges):
    tokens = list(text.encode('utf-8'))
    for pair, new_id in merges:
        tokens = merge_pair(tokens, pair, new_id)
    return tokens

def decode(encoded_tokens, id_to_token):
    return b''.join(id_to_token[token] for token in encoded_tokens).decode('utf-8', errors='replace')


In [None]:
# now to learn the byte pair encoding, and create the id_to_token and merges
# we should

In [None]:
# in order to learn the byte pair encoding
# we'll define a target vocab size
# the amount of integer ids we'll end up with that our model can use
# first we list/encode the text
# initialize the mappings from bytes objects to tokens
# and initialize the next_id parameter
# which we'll have to increment
# while our vocabulary size isn't at the target
# we count pairs
# exiting early if there are no more pairs to merge
# then we take the top pair
# make the new token
# create the mappings between that new token and the ids
# add this information to merges
# merge the tokens
# and continue
# there are some cool opportunities to optimize this
# but the solutions only implement the brute force method for now

"""
def learn_bpe(text, target_vocab_size):
    tokens = list(text.encode('utf-8'))
    token_to_id = {i: bytes([i]) for i in range(256)}  # Initialize with byte tokens
    id_to_token = {i: bytes([i]) for i in range(256)}  # Reverse mapping
    next_id = 256
    merges = []

    while len(token_to_id) < target_vocab_size:
        ...

    return token_to_id, id_to_token, merges
"""

def learn_bpe(text, target_vocab_size):
    tokens = list(text.encode('utf-8'))
    token_to_id = {i: bytes([i]) for i in range(256)}  # Initialize with byte tokens
    id_to_token = {i: bytes([i]) for i in range(256)}  # Reverse mapping
    next_id = 256
    merges = []

    while len(token_to_id) < target_vocab_size:
        if len(token_to_id) % 100 == 0:
            print('vocab at: ', len(token_to_id))

        token_pairs = count_token_pairs(tokens)
        if not token_pairs:
            print("no more pairs to merge. current vocab size:", len(token_to_id))
            break

        top_pair = max(token_pairs, key=token_pairs.get)
        new_token = id_to_token[top_pair[0]] + id_to_token[top_pair[1]]
        token_to_id[new_token] = next_id
        id_to_token[next_id] = new_token
        merges.append((top_pair, next_id))

        tokens = merge_pair(tokens, top_pair, next_id)
        next_id += 1

        if len(token_to_id) >= target_vocab_size:
            break

    return token_to_id, id_to_token, merges



target_vocab_size = 1256
token_to_id, id_to_token, merges = learn_bpe(train_text[:10000], target_vocab_size) # running the full tokenization takes a long time

In [None]:
encoded = encode(train_text[:1000], merges)

compression_ratio = len(encoded) / len(train_text[:1000])
print(f"compression: ({len(train_text[:1000])}) -> ({len(encoded)}): ({compression_ratio:.2f})")

### full gpt-2 weights

In [None]:
# now we will load the actual gpt2 tokenizer and weights
# we've already written everything above - so now we can just skip to the finished model

# we'll import from hugging face
from transformers import GPT2LMHeadModel, GPT2Tokenizer

In [None]:
# the tokenizer is as simple as
gpt2_tokenizer = GPT2Tokenizer.from_pretrained('gpt2') # we call this similarly to how we made our other tokenizers

# the model itself is a bit more complicated
gpt2 = GPT2LMHeadModel.from_pretrained('gpt2')
gpt2

In [None]:
# as you can see all of the stuff in the model is the stuff that we've already made!
# now is the part where you would write the unpacking of the model parameters into the structure we've defined
# this is honestly kind of tedious so here's the code

In [None]:
import numpy as np

def load_gpt2_params(model_name="gpt2"):
    model = GPT2LMHeadModel.from_pretrained(model_name)

    config = model.config
    model_config = ModelConfig(
        vocab_size=config.vocab_size,
        embedding_dim=config.n_embd,
        context_len=config.n_positions,
        n_head=config.n_head,
        n_blocks=config.n_layer
    )

    converted_params = {}

    converted_params['token_embedding'] = jnp.array(model.transformer.wte.weight.detach().numpy())

    converted_params['positional_embedding'] = jnp.array(model.transformer.wpe.weight.detach().numpy())

    converted_params['output_projection'] = jnp.array(model.lm_head.weight.detach().numpy().T)

    converted_params['lnf'] = {
        'gamma': jnp.array(model.transformer.ln_f.weight.detach().numpy()),
        'beta': jnp.array(model.transformer.ln_f.bias.detach().numpy())
    }

    # Convert transformer blocks
    for i in range(model_config.n_blocks):
        block = model.transformer.h[i]
        converted_block = {
            'attn_in': {
                'weight': jnp.array(block.attn.c_attn.weight.detach().numpy()),
                'bias': jnp.array(block.attn.c_attn.bias.detach().numpy())
            },
            'attn_out': {
                'weight': jnp.array(block.attn.c_proj.weight.detach().numpy()),
                'bias': jnp.array(block.attn.c_proj.bias.detach().numpy())
            },
            'ln1': {
                'gamma': jnp.array(block.ln_1.weight.detach().numpy()),
                'beta': jnp.array(block.ln_1.bias.detach().numpy())
            },
            'ln2': {
                'gamma': jnp.array(block.ln_2.weight.detach().numpy()),
                'beta': jnp.array(block.ln_2.bias.detach().numpy())
            },
            'ffn_in': {
                'weight': jnp.array(block.mlp.c_fc.weight.detach().numpy()),
                'bias': jnp.array(block.mlp.c_fc.bias.detach().numpy())
            },
            'ffn_out': {
                'weight': jnp.array(block.mlp.c_proj.weight.detach().numpy()),
                'bias': jnp.array(block.mlp.c_proj.bias.detach().numpy())
            },
        }

        converted_params[f'block_{i}'] = converted_block

    return converted_params, model_config

# usage example
gpt2_params, gpt2_config = load_gpt2_params()

In [None]:
# and finally: gpt2

prompt = 'the short, hilarious tweet from gpt2 read as follows:'

tokenized_prompt = jnp.array(gpt2_tokenizer.encode(prompt), dtype=jnp.int32)
generated_tokens = generate(gpt2_params, transformer_config, 'multilayer_transformer', tokens=tokenized_prompt, max_new=20, temp=1.5, min_p=0.05, top_k=40, top_p=0.5, key=jax.random.PRNGKey(0))
generated_text = gpt2_tokenizer.decode(generated_tokens)
print(gpt2_tokenizer.decode(tokenized_prompt))
print(prompt + generated_text)

In [None]:
# hopefully this was instructive! the author would love your feedback as a github issue or twitter dm at @arb8020