# Following along https://colab.research.google.com/drive/1JMLa53HDuA-i7ZBmqV7ZnA3c_fvtXnx-?usp=sharing but with JAX/Flax

In [348]:
!mkdir data
!curl https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt -o data/tinyshakespeare

  pid, fd = os.forkpty()


mkdir: data: File exists
  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 1089k  100 1089k    0     0  2185k      0 --:--:-- --:--:-- --:--:-- 2187k


In [349]:
with open('data/tinyshakespeare') as f:
    text = f.read()

print('Corpus size: ' + str(len(text)))
print(text[:1000])

vocab = list(set(text))
vocab_size = len(vocab)
print('Vocabulary size: ' + str(len(vocab)))
print(''.join(sorted(vocab)))

Corpus size: 1115394
First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in 

In [350]:
itos = {i:s for i,s in enumerate(vocab)}
stoi = {s:i for i,s in enumerate(vocab)}

encode = lambda x: [stoi[s] for s in x]
decode = lambda x: ''.join([itos[i] for i in x])

print(decode(encode('hello world')))

hello world


In [351]:
import jax.numpy as jnp

data = jnp.array(encode(text), dtype=jnp.int32)
print(data.dtype)
print(data.shape)
print(data[:100])

int32
(1115394,)
[48 34 40 59 24 18 23 34 24 34 32 55 26  7  1 25 55 44 50 40 55 18 56 55
 18 49 40 50  5 55 55  9 18 42 26  0 18 44 57 40 24 12 55 40 13 18 12 55
 42 40 18  4 55 18 59 49 55 42  2 54  1  1 60 11 11  7  1 37 49 55 42  2
 13 18 59 49 55 42  2 54  1  1 48 34 40 59 24 18 23 34 24 34 32 55 26  7
  1 30 50 57]


In [352]:
train_data = data[: int(.9 * len(data))]
val_data = data[int(.9 * len(data)):]

In [353]:
import jax

batch_size = 4
block_size = 8

dynamic_slice_vmap = jax.vmap(jax.lax.dynamic_slice, in_axes=(None, 0, None))


@jax.jit
def get_batch(data, key):
    ix = jax.random.randint(key, shape=(batch_size, 1), minval=0, maxval=len(data) - block_size)
    x = dynamic_slice_vmap(data, ix, (block_size,))
    y = dynamic_slice_vmap(data, ix + 1, (block_size,))
    return x, y


key = jax.random.key(1337)
print(get_batch(train_data, key))
xb, yb = get_batch(train_data, key)

(Array([[18,  4, 42, 26, 24, 11, 55,  1],
       [27, 14, 36,  7,  1, 36, 13, 18],
       [34,  4, 54,  1, 25, 57, 24, 18],
       [18, 59, 57,  5, 12, 18, 11, 55]], dtype=int32), Array([[ 4, 42, 26, 24, 11, 55,  1, 50],
       [14, 36,  7,  1, 36, 13, 18, 11],
       [ 4, 54,  1, 25, 57, 24, 18, 12],
       [59, 57,  5, 12, 18, 11, 55, 26]], dtype=int32))


In [354]:
from flax import nnx
import optax

class BigramLanguageModel(nnx.Module):
    rngs: nnx.Rngs

    def __init__(self, vocab_size, n_embed, rngs: nnx.Rngs):
        self.rngs = rngs
        # TODO: Something is off since certain values of n_embed is causing loss to be NaN.
        self.token_embedding_table = nnx.Embed(num_embeddings=vocab_size, features=n_embed, rngs=rngs)

    def __call__(self, x):
        logits = self.token_embedding_table(x)
        return logits
    
    def generate(self, x, length): # x has the shape (batch_size, block_size)
        for i in range(length):
            logits = self(x)
            next_token = jax.random.categorical(self.rngs.next(), logits[:, -1])
            x = jnp.concatenate([x, next_token[:, None]], axis=1)
        return x
    

key = jax.random.key(1337)
rngs = nnx.Rngs(key)
model = BigramLanguageModel(vocab_size, 32, rngs)

In [355]:
def loss(model, x, targets):
        logits = model(x)
        print(jnp.min(logits))
        print(jnp.max(logits))
        print(targets)
        return optax.softmax_cross_entropy_with_integer_labels(logits, targets).mean()

print(loss(model, xb, yb))

-0.4952408
0.47264346
[[ 4 42 26 24 11 55  1 50]
 [14 36  7  1 36 13 18 11]
 [ 4 54  1 25 57 24 18 12]
 [59 57  5 12 18 11 55 26]]
nan


In [356]:
optimizer = nnx.Optimizer(model, optax.adam(1e-3))
batch_size = 32

@nnx.jit
def train_step(model, optimizer, xb, yb):
    grads = (nnx.grad(loss))(model, xb, yb)
    optimizer.update(grads)

for i in tqdm.trange(10000):
    key, subkey = jax.random.split(key)
    xb, yb = get_batch(train_data, subkey)
    train_step(model, optimizer, xb, yb)
    if i % 1000 == 0:
        print(loss(model, xb, yb))
print(loss(model, xb, yb))

val_xb, val_yb = get_batch(val_data, key)
print(loss(model, val_xb, val_yb))

  0%|          | 0/10000 [00:00<?, ?it/s]

Traced<ShapedArray(float32[])>with<JVPTrace(level=3/0)> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace(level=1/0)>
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace(level=2/0)> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x14f7215a0>, in_tracers=(Traced<ShapedArray(float32[4,8,32]):JaxprTrace(level=2/0)>, Traced<ShapedArray(float32[4,8,32]):JaxprTrace(level=2/0)>, Traced<ShapedArray(float32[]):JaxprTrace(level=2/0)>), out_tracer_refs=[<weakref at 0x14f85ea20; to 'JaxprTracer' at 0x14f85e940>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[4,8,32] b:f32[4,8,32] c:f32[]. let
    d:f32[4,8,32] = mul a b
    e:f32[] = reduce_sum[axes=(0, 1, 2)] d
    f:f32[] = div e c
  in (f,) }, 'in_shardings': (UnspecifiedValue, UnspecifiedValue, UnspecifiedValue), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None, None, None), 'out_layouts': (None,), 'resource_env': None

  0%|          | 1/10000 [00:00<46:40,  3.57it/s]

-0.4962408
0.47164348
[[52 50 59 24 13 18 11 55]
 [42 26  9 18  9 55 42 24]
 [34 26 62 18 12 55 40 18]
 [18 38 40 50 24 12 55 40]]
nan


 11%|█         | 1092/10000 [00:01<00:11, 767.05it/s]

-1.2612628
0.854791
[[26 18 50 57 24 47  1 28]
 [34  4 55 13  1 60 26  9]
 [57 40 55 47 18  4  0 18]
 [42 11 11 18 24 12 55  4]]
nan


 21%|██        | 2113/10000 [00:03<00:09, 864.78it/s]

-1.9625386
1.2824993
[[26 18 42 40  4 59 54  1]
 [42 38 11 55 18  5 50 57]
 [ 9 13 18 24 12 50 57 62]
 [40 55 18  0 50 57 40 18]]
nan


 31%|███▏      | 3146/10000 [00:04<00:09, 755.02it/s]

-2.641821
1.4377373
[[15 60 61 18 43 36 64 51]
 [40 42  4 55 18 24 50 18]
 [18 23 11 42 57  9 34 50]
 [ 1 39 12 42 24 18 59 12]]
nan


 41%|████▏     | 4130/10000 [00:05<00:07, 774.64it/s]

-3.2882185
1.4454585
[[42  9 55 18 44 42 34 40]
 [57 18  4  0 18 12 55 11]
 [55  0 18  3 42 40  1 39]
 [18  5 50  4 55 18 42 62]]
nan


 51%|█████     | 5108/10000 [00:07<00:06, 779.60it/s]

-3.9214413
1.877178
[[51 17 18 64 14 51 61 30]
 [40  2  8  1  1 27 14 51]
 [12 55 26 13 18 50 57 40]
 [50  4 18  0 50 57 18 12]]
nan


 61%|██████▏   | 6149/10000 [00:08<00:05, 765.45it/s]

-4.5589967
2.0489893
[[ 9 18 42 38 57 59 55 18]
 [55 13  1 64 42 40 49 18]
 [11  9 18 34 26  5 40 55]
 [18 38 55 18  9 34 59 12]]
nan


 71%|███████   | 7081/10000 [00:09<00:05, 547.31it/s]

-5.2123175
1.4436244
[[50 57 47 18  4 42  0 18]
 [ 0  1 59 56 55 55 24 18]
 [18 55 26 59 57 55 18 12]
 [42 26  9 18 12 34 59 18]]
nan


 79%|███████▉  | 7894/10000 [00:11<00:03, 693.79it/s]


KeyboardInterrupt: 

In [None]:
print([decode(row.tolist()) for row in model.generate(jnp.zeros((1, 1), dtype=jnp.int32), 100)][0])

y
IWhym m,

CL:
Bl my m:
'tll m! th yEYy, ty ck CLE:

Cm thl ck:

CKI hth dm,
DKILY!

Wh, th
I cQYhy,
