<a href="https://colab.research.google.com/github/apeforest/nanoGPT/blob/master/gpt_dev_jax0.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

*Building* a NanoGPT example using JAX (based on the PyTorch version from https://github.com/karpathy/nanoGPT)

In [1]:
# We always start with a dataset to train on. Let's download the tiny shakespeare dataset
!wget https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt

--2023-05-26 21:51:40--  https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 1115394 (1.1M) [text/plain]
Saving to: ‘input.txt’


2023-05-26 21:51:40 (39.8 MB/s) - ‘input.txt’ saved [1115394/1115394]



In [None]:
# read it in to inspect it
with open('input.txt', 'r', encoding='utf-8') as f:
    text = f.read()

In [None]:
print("length of dataset in characters: ", len(text))

length of dataset in characters:  1115394


In [None]:
# let's look at the first 1000 characters
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [None]:
chars = sorted(list(set(text)))
vocab_size = len(chars)
print(''.join(chars))
print(vocab_size)


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
65


In [None]:
# create a mapping from characters to integers
stoi = { ch:i for i,ch in enumerate(chars) }
itos = { i:ch for i,ch in enumerate(chars) }
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string

print(encode("hii there"))
print(decode(encode("hii there")))

[46, 47, 47, 1, 58, 46, 43, 56, 43]
hii there


In [None]:
# let's now encode the entire text dataset and store it into a jnp array
import jax.numpy as jnp
import jax
import numpy as np

jax.config.update("jax_enable_x64", True)

data = jnp.array(encode(text), dtype=jnp.int64)
print(data.shape, data.dtype)
print(data[:1000]) # the 1000 characters we looked at earier will to the GPT look like this



(1115394,) int64
[18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 14 43 44 53 56 43  1 61 43
  1 54 56 53 41 43 43 42  1 39 52 63  1 44 59 56 58 46 43 56  6  1 46 43
 39 56  1 51 43  1 57 54 43 39 49  8  0  0 13 50 50 10  0 31 54 43 39 49
  6  1 57 54 43 39 49  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10
  0 37 53 59  1 39 56 43  1 39 50 50  1 56 43 57 53 50 60 43 42  1 56 39
 58 46 43 56  1 58 53  1 42 47 43  1 58 46 39 52  1 58 53  1 44 39 51 47
 57 46 12  0  0 13 50 50 10  0 30 43 57 53 50 60 43 42  8  1 56 43 57 53
 50 60 43 42  8  0  0 18 47 56 57 58  1 15 47 58 47 64 43 52 10  0 18 47
 56 57 58  6  1 63 53 59  1 49 52 53 61  1 15 39 47 59 57  1 25 39 56 41
 47 59 57  1 47 57  1 41 46 47 43 44  1 43 52 43 51 63  1 58 53  1 58 46
 43  1 54 43 53 54 50 43  8  0  0 13 50 50 10  0 35 43  1 49 52 53 61  5
 58  6  1 61 43  1 49 52 53 61  5 58  8  0  0 18 47 56 57 58  1 15 47 58
 47 64 43 52 10  0 24 43 58  1 59 57  1 49 47 50 50  1 46 47 51  6  1 39
 52 42  1 61 43  5 50 50  1 46 39 

In [None]:
# Let's now split up the data into train and validation sets
n = int(0.9*len(data)) # first 90% will be train, rest val
train_data = data[:n]
val_data = data[n:]

In [None]:
block_size = 8
train_data[:block_size+1]

Array([18, 47, 56, 57, 58,  1, 15, 47, 58], dtype=int64)

In [None]:
x = train_data[:block_size]
y = train_data[1:block_size+1]
for t in range(block_size):
    context = x[:t+1]
    target = y[t]
    print(f"when input is {context} the target: {target}")

when input is [18] the target: 47
when input is [18 47] the target: 56
when input is [18 47 56] the target: 57
when input is [18 47 56 57] the target: 58
when input is [18 47 56 57 58] the target: 1
when input is [18 47 56 57 58  1] the target: 15
when input is [18 47 56 57 58  1 15] the target: 47
when input is [18 47 56 57 58  1 15 47] the target: 58


In [None]:
from jax import random
key = random.PRNGKey(42)
batch_size = 4 # how many independent sequences will we process in parallel?
block_size = 8 # what is the maximum context length for predictions?

def get_batch(split):
    # generate a small batch of data of inputs x and targets y
    data = train_data if split == 'train' else val_data
    ix = random.randint(key, (batch_size, ), 0, len(data) - block_size)
    x = jnp.stack([data[i:i+block_size] for i in ix])
    y = jnp.stack([data[i+1:i+block_size+1] for i in ix])
    return x, y

xb, yb = get_batch('train')
print('inputs:')
print(xb.shape)
print(xb)
print('targets:')
print(yb.shape)
print(yb)

print('----')

for b in range(batch_size): # batch dimension
    for t in range(block_size): # time dimension
        context = xb[b, :t+1]
        target = yb[b,t]
        print(f"when input is {context.tolist()} the target: {target}")

inputs:
(4, 8)
[[ 0  5 18 53 56 43  1 63]
 [ 1 46 47 51  0 32 53  1]
 [ 1 46 43 56  0 33 52 58]
 [36 17 32 17 30 10  0 20]]
targets:
(4, 8)
[[ 5 18 53 56 43  1 63 53]
 [46 47 51  0 32 53  1 46]
 [46 43 56  0 33 52 58 47]
 [17 32 17 30 10  0 20 43]]
----
when input is [0] the target: 5
when input is [0, 5] the target: 18
when input is [0, 5, 18] the target: 53
when input is [0, 5, 18, 53] the target: 56
when input is [0, 5, 18, 53, 56] the target: 43
when input is [0, 5, 18, 53, 56, 43] the target: 1
when input is [0, 5, 18, 53, 56, 43, 1] the target: 63
when input is [0, 5, 18, 53, 56, 43, 1, 63] the target: 53
when input is [1] the target: 46
when input is [1, 46] the target: 47
when input is [1, 46, 47] the target: 51
when input is [1, 46, 47, 51] the target: 0
when input is [1, 46, 47, 51, 0] the target: 32
when input is [1, 46, 47, 51, 0, 32] the target: 53
when input is [1, 46, 47, 51, 0, 32, 53] the target: 1
when input is [1, 46, 47, 51, 0, 32, 53, 1] the target: 46
when input i

In [None]:
print(xb) # our input to the transformer

[[ 0  5 18 53 56 43  1 63]
 [ 1 46 47 51  0 32 53  1]
 [ 1 46 43 56  0 33 52 58]
 [36 17 32 17 30 10  0 20]]


In [None]:
from flax.core.frozen_dict import V
import flax.linen as nn
import optax

def cross_entropy(y, label):
  log_p = jax.nn.log_softmax(y)

  m = label.shape[0]
  # This is equivalent to converting label into one-hot and sum
  log_likelihood = -log_p[jnp.arange(m), label]
  return jnp.mean(log_likelihood)


class BigramLanguageModel(nn.Module):
    vocab_size: int

    def setup(self):
        # each token directly reads off the logits for the next token from a lookup table
        self.token_embedding_table = nn.Embed(vocab_size, vocab_size)

    def __call__(self, idx):

        # idx are (B, T) tensor of integers, C is vocab_size here.
        logits = self.token_embedding_table(idx) # (B,T,C)

        return logits


def loss_fn(logits, label):
    B, T, C = logits.shape
    y = jnp.reshape(logits, ((B * T), C)) # (B, T, C) ->
    loss = cross_entropy(y, jnp.reshape(label, (y.shape[0],)))
    return loss

def multinomial(key, input, num_samples, replacement=False):
    # input: Tensor of shape (batch_size, num_categories) containing the probabilities for each category
    # num_samples: The number of samples to draw for each batch
    # replacement: Whether to sample with replacement or not

    batch_size, num_categories = input.shape

    # Transform input to log probabilities
    log_probs = jnp.log(input)

    # Generate random samples
    samples = jax.random.categorical(key, log_probs, axis=-1)

    # If replacement is False, we need to ensure unique samples within each batch
    if not replacement:
        unique_samples = jnp.unique(samples, axis=-1, return_counts=True)
        repeated_indices = unique_samples[1] > 1
        while jnp.any(repeated_indices):
            # Replace repeated samples with new samples
            new_samples = jax.random.categorical(key, log_probs, jnp.sum(repeated_indices), axis=-1)
            samples = jnp.where(repeated_indices[:, jnp.newaxis], new_samples, samples)

            unique_samples = jnp.unique(samples, axis=-1, return_counts=True)
            repeated_indices = unique_samples[1] > 1

    return jnp.expand_dims(samples, 1)

def generate(model, params, idx, max_new_tokens):
    # idx is (B, T) array of indices in the current context
    key = random.PRNGKey(42)
    for _ in range(max_new_tokens):
        # get the predictions
        logits = model.apply(params, idx)
        # focus only on the last time step
        logits = logits[:, -1, :] # becomes (B, C)

        # apply softmax to get probabilities
        probs = jax.nn.softmax(logits, axis=-1) # (B, C)
        # sample from the distribution
        key, subkey = random.split(key)

        idx_next = multinomial(key, probs, num_samples=1) # (B, 1)

        # append sampled index to the running sequence
        idx = jnp.concatenate((idx, idx_next), axis=1) # (B, T+1)
    return idx


m = BigramLanguageModel(vocab_size)
params = m.init(key, jnp.zeros((vocab_size, vocab_size), dtype=jnp.int64)) # Initialization call

# Check the parameters
jax.tree_map(lambda x: x.shape, params)

logits = m.apply(params, xb)

print(f'{logits.shape=}')

loss = loss_fn(logits, yb)
print(f'{loss=}')


idx = generate(m, params, jnp.zeros((1, 1), dtype=jnp.int64), max_new_tokens=100)
print(idx)
print(decode(generate(m, params, idx = jnp.zeros((1, 1), dtype=jnp.int64), max_new_tokens=100)[0].tolist()))


logits.shape=(4, 8, 65)
loss=Array(4.1609755, dtype=float32)
[[ 0 34  3  8 27 25 12 44 54 13 45 40 20  3 37 46 38 34 63  9 37 34 16 35
  34 32 54 23  5 57 48 21 15 11 63 26 48 35  0 62 13 11 60  8 43 29 57 28
  24 33  7 43  4 23 50 16 42 44  9 17 51  7 27 48 23 20 32 62  5 18 38 54
  17 11 61 40 34 21 52 22 14 59 38 47  8 31 41 33 31 36  5 22 48 34 39 10
  29 13 36 35 63]]

V$.OM?fpAgbH$YhZVy3YVDWVTpK'sjIC;yNjW
xA;v.eQsPLU-e&KlDdf3Em-OjKHTx'FZpE;wbVInJBuZi.ScUSX'JjVa:QAXWy
