<a href="https://colab.research.google.com/github/KhawajaAbaid/flax-deep-learning/blob/main/makemore_mlp.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [56]:
from flax import linen as nn
from flax.training import train_state
import jax
from jax import numpy as jnp, random, tree
import optax

In [3]:
# Load dataset
import os

filepath = "./names.txt"
if (not os.path.exists(filepath)):
    !wget https://raw.githubusercontent.com/karpathy/makemore/master/names.txt

dataset = open(filepath, 'r').read().splitlines()
print(dataset[:10])

['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [43]:
# Create vocab and mapping
vocab = sorted(list(set(''.join(dataset))))

stoi = {c: i for i, c in enumerate(vocab, start=1)}
itos = {i: c for i, c in enumerate(vocab, start=1)}

vocab.append('.')
stoi['.'] = 0
itos[0] = '.'
vocab_size = len(vocab)

In [53]:
context_length = 3

def build_dataset(data):
  X, Y = [], []
  for word in data:
    context = [0] * context_length
    for c in word + '.':
      X.append(context)
      ix = stoi[c]
      Y.append(ix)
      context = context[1:] + [ix]
  return jnp.asarray(X), jnp.asarray(Y)

In [62]:
# Create dataset splits
def shuffle_randomly(dataset):
  import random
  random.shuffle(dataset)
  return dataset

dataset = shuffle_randomly(dataset)

n_train = int(0.9 * len(dataset))
print(n_train)
x_train, y_train = build_dataset(dataset[:n_train])
x_test, y_test = build_dataset(dataset[n_train:])

print(f"Train Samples: {len(x_train)} | Test Samples {len(x_test)}")

28829
Train Samples: 205240 | Test Samples 22906


In [64]:
# Create / Initialize model
embedding_dim = 10
hidden_dim = 200
input_dim = context_length * embedding_dim

key, subkey = random.split(random.PRNGKey(1337))
embeddings = random.normal(key, (vocab_size, embedding_dim))


class MLP(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Dense(features=hidden_dim)(x)
    x = nn.tanh(x)
    return nn.Dense(features=vocab_size)(x)


key, subkey = random.split(subkey)
model = MLP()
params = model.init(key, jnp.zeros((1, input_dim)))

tx = optax.sgd(learning_rate=0.1)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx)

In [67]:
def train_step(state, x, y):
  def forward_and_loss(params, x, y):
    x = embeddings[x]
    x = x.reshape((x.shape[0], -1))
    logits = state.apply_fn(params, x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
    return loss
  grad_fn = jax.value_and_grad(forward_and_loss)
  loss, grads = grad_fn(state.params, x, y)
  state = state.apply_gradients(grads=grads)
  return state, loss

In [111]:
# Training loop
batch_size = 128
for _ in range(10000):
  # create random batches
  key, subkey = random.split(key)
  idx = random.randint(key, (batch_size,), 0, len(x_train))
  x_batch = x_train[idx]
  y_batch = y_train[idx]

  # train step
  state, loss = train_step(state, x_batch, y_batch)
print(loss)

2.1480448


In [112]:
# Generate new names
def generate(state, n_names=10, seed=2005):
  key, subkey = random.split(random.key(seed))
  for i in range(n_names):
    context = [0] * context_length
    name = ''
    while True:
      e = embeddings[jnp.asarray(context)].reshape((1, -1))
      logits = state.apply_fn(state.params, e)
      probs = jax.nn.softmax(logits).reshape(-1)
      key, subkey = random.split(subkey)
      ix = random.choice(key, jnp.arange(vocab_size), p=probs)
      if ix == 0:
        break
      context = context[1:] + [ix]
      name += itos[ix.tolist()]
    print(name)


In [113]:
generate(state)

sibh
nyia
khluns
lailtilysiami
erily
kir
ankaham
kimlyien
mibraquidaniy
marian


In [114]:
# The end.