# Exploring JAX as a model backend.

In [None]:
import pandas as pd
import ast

import torch
import numpy as np
from torch.utils import data

import numpy as np
from jax.tree_util import tree_map
from torch.utils import data
import jax.numpy as jnp
import haiku as hk
import optax
import jax
import time

from pylibrarian.recommender.models import AttentionModel
from pylibrarian.dataset.package_dataset import PackageDataset
from pylibrarian.dataset.numpy_loader import NumpyLoader

## Dataset creation

In [None]:
df = pd.read_csv('../data/pypi_packages.csv', delimiter=";")
dataset = PackageDataset(df)
training_generator = NumpyLoader(dataset, batch_size=4, num_workers=0)

## Model

In [None]:
def _custom_forward_fn(x, y):
  module = AttentionModel(vocab_size=len(dataset.tokenizer) + 1)
  return module(x, y)

In [None]:
custom_forward_without_rng = hk.without_apply_rng(hk.transform(_custom_forward_fn))
params = custom_forward_without_rng.init(rng=jax.random.PRNGKey(0), x=dataset[0]['x'][None,:], y=dataset[0]['y'][None,:])

In [None]:
def loss(params, x, y, label):
  logits = custom_forward_without_rng.apply(params, y=y, x=x)
  return jnp.mean(optax.sigmoid_binary_cross_entropy(logits, label))

## Training loop

In [None]:
def fit(params, optimizer, num_epochs: int = 10):
  opt_state = optimizer.init(params)

  @jax.jit
  def step(params, opt_state, batch):
    loss_value, grads = jax.value_and_grad(loss)(params, **batch)
    updates, opt_state = optimizer.update(grads, opt_state, params)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss_value

  for epoch in range(num_epochs):
    start_time = time.time()
    for i, batch in enumerate(training_generator):
      params, opt_state, loss_value = step(params, opt_state, batch)
      if i % 100 == 0:
        print(f'step {i}, loss: {loss_value}')
    epoch_time = time.time() - start_time


optimizer = optax.adam(learning_rate=1e-2)
params = fit(params, optimizer)