In [210]:
from chex import assert_rank
import haiku as hk
import jax.numpy as jnp
from jax import vmap, jit, grad
import jax
import optax
from torch.utils.data import DataLoader

In [192]:
from tinyshakespeareloader.hamlet import get_data

In [193]:
class SimpleBigram(hk.Module):
    def __init__(self, vocab_size) -> None:
        super(SimpleBigram, self).__init__()
        self.vocab_size = vocab_size

    def __call__(self, x):
        assert_rank(x, 2)
        logits = hk.Embed(vocab_size=self.vocab_size, embed_dim=self.vocab_size)(x)
        assert_rank(logits, 3)
        return logits

In [232]:
data = get_data(batch_size=32)
dummy_x, dummy_y = next(iter(data["train_dataloader"]))
vocab_size = data["vocabulary_size"]
def bigram_forward(x):
    return SimpleBigram(vocab_size)(x)
model = hk.transform(bigram_forward)
params = model.init(rng=next(hk.PRNGSequence(42)), x=dummy_x.numpy())
dummy_x.shape

torch.Size([32, 8])

In [239]:
net = lambda params, x, rng: model.apply(params=params, x=x, rng=next(rng))

In [254]:
def loss_fn(params: optax.Params, batch: jnp.ndarray, labels: jnp.ndarray) -> jnp.ndarray:
    rng_key = hk.PRNGSequence(42)
    output = net(params, batch, rng_key)
    # output = model.apply(params=params, x=batch, rng=next(rng_key))
    loss_value = optax.softmax_cross_entropy_with_integer_labels(output, labels)
    return loss_value.mean()

@jit
def step(params, opt_state, batch, labels):
    loss_value, grads = jax.value_and_grad(loss_fn)(params, batch, labels)
    updates, opt_state = optim.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value


learning_rate = 1e-3
optim = optax.adamw(learning_rate=learning_rate)
def fit(params: optax.Params, optim: optax.GradientTransformation, train_dataloader: DataLoader):
    opt_state = optim.init(params)
    
    for i, (x, y) in enumerate(train_dataloader):
        with jax.checking_leaks():
            params, opt_state, loss_value = step(params, opt_state, x.numpy(), y.numpy())
        if i % (len(train_dataloader) // 10) == 0:
            print(f"step={(i / len(train_dataloader)) * 100}%, {loss_value=}")
    return params

params = fit(params, optim, train_dataloader=data["train_dataloader"])

step=0.0%, loss_value=Array(2.6456897, dtype=float32)
step=9.999681234260942%, loss_value=Array(2.0665283, dtype=float32)
step=19.999362468521884%, loss_value=Array(2.2822242, dtype=float32)
step=29.999043702782824%, loss_value=Array(2.519982, dtype=float32)
step=39.99872493704377%, loss_value=Array(2.3446436, dtype=float32)
step=49.99840617130471%, loss_value=Array(2.5622246, dtype=float32)
step=59.99808740556565%, loss_value=Array(2.11761, dtype=float32)
step=69.9977686398266%, loss_value=Array(2.3336637, dtype=float32)
step=79.99744987408754%, loss_value=Array(2.4252636, dtype=float32)
step=89.99713110834847%, loss_value=Array(2.1375194, dtype=float32)
step=99.99681234260942%, loss_value=Array(3.441904, dtype=float32)


In [246]:
def blabber(params, max_new_tokens=100):
    rng = hk.PRNGSequence(42)
    idx = jnp.zeros(shape=(1, 1), dtype=jnp.int32)
    
    for _ in range(max_new_tokens):
        output = net(params=params, x=idx, rng=n))
        output = output[:, -1, :]
        next_idx = jax.random.categorical(next(rng), output).reshape(-1,1)        
        idx = jnp.concatenate((idx, next_idx), axis=1)

    return idx
idx = blabber(params)

decode = data["decode"]
decode(idx[0].tolist())

"\n\nWetis'shyst ase say\nWhr tow.\nANe,\nEreroneisa, shisthadove, a banod\nINartilos han'shios w; tin, 't s"