# Ninjax: General Modules for JAX

Ninjax is a general module system for JAX. It gives the user complete and transparent control over updating the state of each module, bringing the flexibility of PyTorch and TensorFlow to JAX. Moreover, Ninjax makes it easy to mix and match modules from different libraries, such as Flax and Haiku.

# Motivation

Existing deep learning libraries for JAX provide modules, but those modules only specify neural networks and cannot easily implement training logic. Orchestrating training all in one place, outside of the modules, is fine for simple code bases. But it becomes a problem when there are many modules with their own training logic and optimizers.

Ninjax solves this problem by giving each nj.Module full read and write access to its state. This means modules can have train functions to implement custom training logic, and call each other's train functions. Ninjax is intended to be used with one or more neural network libraries, such as Haiku and Flax.

The main differences to existing deep learning libraries are:

* Ninjax does not need separate apply()/init() functions. Instead, the first function call creates variables automatically.
* Ninjax lets you access and update model parameters inside of impure functions, so modules can handle their own optimizers and update logic.
* Natural support for modules with multiple functions without need for Flax's setup() function or Haiku's hk.multi_transform().
* Ninjax' flexible state handling makes it trivial to mix and match modules from other deep learning libraries in your models.


# Installation

Ninjax is a single file, so you can just copy it to your project directory. Or you can install the package:

```sh
pip install ninjax
```

# Quickstart

In [5]:
import functools

import haiku as hk
import jax
import jax.numpy as jnp
import ninjax as nj
import optax
import flax.linen as nn

# Ninjax supports all Haiku and Flax modules and new libraries are easy to add.

In [2]:
class MyModel(nj.Module):

  def __init__(self, size, lr=0.01, act=jax.nn.relu):
    self.size = size
    self.lr = lr
    self.act = act
    # Define submodules upfront.
    self.h1 = Linear(128, name='h1')
    self.h2 = Linear(128, name='h2')

  def __call__(self, x):
    x = self.act(self.h1(x))
    x = self.act(self.h2(x))
    # Define submodules inline.
    x = self.get('h3', Linear, self.size, with_bias=False)(x)
    # Create state entries of array values.
    x += self.get('bias', jnp.array, 0.0)
    return x

  def train(self, x, y):
    # Compute gradient with respect to all parameters in this module.
    loss, params, grad = nj.grad(self.loss, self)(x, y)
    # Update the parameters with gradient descent.
    state = jax.tree_util.tree_map(lambda p, g: p - self.lr * g, params, grad)
    # Update multiple state entries of this module.
    self.putm(state)
    return loss

  def loss(self, x, y):
    return ((self(x) - y) ** 2).mean()


# The complete state is stored in a flat dictionary. Ninjax automatically
# applies scopes to the string keys based on the module names.
state = {}
model = MyModel(8, name='MyModel')
train = nj.pure(model.train)  # nj.jit(...), nj.pmap(...)
main = jax.random.PRNGKey(0)

# Let's train on some example data.
dataset = [(jnp.ones((64, 32)), jnp.ones((64, 8)))] * 10
for x, y in dataset:
  rng, main = jax.random.split(main)
  # Variables are automatically initialized on the first call. This adds them
  # to the state dictionary.
  loss, state = train(state, rng, x, y)
  # To look at parameters, simply use the state dictionary.
  assert state['MyModel/bias'].shape == ()
  print('Loss:', float(loss))

Loss: 0.7322518825531006
Loss: 0.597238302230835
Loss: 0.48707011342048645
Loss: 0.39765995740890503
Loss: 0.3283187747001648
Loss: 0.27209892868995667
Loss: 0.2248704433441162
Loss: 0.18528717756271362
Loss: 0.15232132375240326
Loss: 0.12485767900943756


# Tutorial

## How can I create state entries?
Ninjax gives modules full control over reading and updating their state entries. Use `self.get(name, ctor, *args, **kwargs)` to define state entries. The first call creates the entry as `ctor(*args, **kwargs)`. Later calls return the current value:

In [3]:
class ModuleToCreateStateEntries(nj.Module):

  def compute(self, x):
    init = jax.nn.initializers.variance_scaling(1, 'fan_avg', 'uniform')
    weights = self.get('weights', init, nj.rng(), (64, 32))
    bias = self.get('bias', jnp.zeros, (32,), jnp.float32)
    print(self.getm())  # {'/path/to/module/weights': ..., '/path/to/module/bias': ...}
    return x @ weights + bias
  

state = {}
model = ModuleToCreateStateEntries(name='ModuleToCreateStateEntries')
compute = nj.pure(model.compute)  # nj.jit(...), nj.pmap(...)
main = jax.random.PRNGKey(0)

# Let's train on some example data.
dataset = [(jnp.ones((64, 64)), jnp.ones((64, 64)))] * 10
for x, y in dataset:
  rng, main = jax.random.split(main)
  # Variables are automatically initialized on the first call. This adds them
  # to the state dictionary.
  loss, state = compute(state, rng, x)
  # To look at parameters, simply use the state dictionary.
  # print('Loss:', float(loss))

{'ModuleToCreateStateEntries/weights': Array([[-0.04147053, -0.1314227 , -0.02465397, ...,  0.02183795,
         0.01697004,  0.04713267],
       [ 0.19299734, -0.11297554, -0.07765299, ...,  0.09174806,
        -0.14098048,  0.21593362],
       [ 0.14658523, -0.24184895, -0.14916015, ...,  0.10904205,
         0.0512017 , -0.05426091],
       ...,
       [ 0.17859662,  0.0211283 , -0.00192875, ...,  0.16585249,
         0.22263348,  0.0516718 ],
       [ 0.10639775, -0.20550174,  0.15000093, ..., -0.00155634,
        -0.2283287 , -0.19280887],
       [ 0.04920799, -0.03702986,  0.17916483, ..., -0.21073604,
         0.05745828,  0.0972172 ]], dtype=float32), 'ModuleToCreateStateEntries/bias': Array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],      dtype=float32)}
{'ModuleToCreateStateEntries/weights': Array([[-0.04147053, -0.1314227 , -0.02465397, ...,  0.02183795,
         0.01697004,  0.04713

## How can I update state entries?
To update the state entries of a module, use `self.put(name, value)` for individual entries of `self.putm(mapping)` to update multiple values:

In [4]:
class ModuleToUpdateState(nj.Module):

  def counting(self):
    counter = self.get('counter', jnp.zeros, (), jnp.int32)
    self.put('counter', counter + 1)
    print(self.get('counter'))  # 1
    state = self.getm() # `state` is a dictionary of all state entries, note that the key of every entry is the **full path** of the entry, not just the entry name.
    print(state)
    print(state['ModuleToUpdateState/counter']) # 1. It's not `state['counter']`.
    counter = self.get('counter')
    self.put('counter', counter + 1)
    print(self.get('counter'))  # 2
    
    
state = {}
model = ModuleToUpdateState(name="ModuleToUpdateState")
counting = nj.pure(model.counting)  # nj.jit(...), nj.pmap(...)
main = jax.random.PRNGKey(0)

state, state = counting(state, main)


1
{'ModuleToUpdateState/counter': Array(1, dtype=int32)}
1
2


## How can I use JIT compilation?
The `nj.pure()` function makes the state your JAX code uses explicit, so it can be transformed freely:

In [5]:
state = {}
counting = nj.jit(nj.pure(model.counting))  # nj.jit(...), nj.pmap(...)
main = jax.random.PRNGKey(0)

state, state = counting(state, main)

Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
{'ModuleToUpdateState/counter': Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>}
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
{'ModuleToUpdateState/counter': Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>}
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>
Traced<ShapedArray(int32[])>with<DynamicJaxprTrace(level=1/0)>


In [9]:
from jax import jit
import jax.numpy as jnp

def print_and_return(x):
    print("Value:", x)
    return x

@jit
def my_function(x):
    result_shape = jax.ShapeDtypeStruct(x.shape, x.dtype)
    x = jax.pure_callback(print_and_return, result_shape, x)
    # Continue computation with x...
    return x ** 2

# Now calling my_function will print the value of x before squaring
result = my_function(jnp.array([1, 2, 3]))


Value: [1 2 3]


## How can I compute gradients?

You can use `jax.grad` as normal for computing gradients with respect to explicit inputs of your function. To compute gradients with respect to Ninjax state, use `nj.grad(fn, keys)`:

In [None]:
class ModuleToComputeGradients(nj.Module):

  def train(self, x, y):
    params = self.getm('.*')
    loss, grads = nj.grad(self.loss, params.keys())(x, y)
    params = jax.tree_map(lambda p, g: p - 0.01 * g, params, grads)
    self.putm(params)
    
state = {}
model = ModuleToComputeGradients(8, name='ModuleToComputeGradients')
train = nj.pure(model.train)  # nj.jit(...), nj.pmap(...)
main = jax.random.PRNGKey(0)

# Let's train on some example data.
dataset = [(jnp.ones((64, 32)), jnp.ones((64, 8)))] * 10
for x, y in dataset:
  rng, main = jax.random.split(main)
  # Variables are automatically initialized on the first call. This adds them
  # to the state dictionary.
  loss, state = train(state, rng, x, y)
  # To look at parameters, simply use the state dictionary.
  assert state['MyModel/bias'].shape == ()
  print('Loss:', float(loss))

The `self.getm(filter='.*')` method optionally accepts a regex pattern to select only a subset of the state dictionary. It also returns only state entries of the current module. To access the global state, use `nj.state()`.


## How can I define modules compactly?
You can use `self.get(name, ctor, *args, **kwargs)` inside methods of your modules. When called for the first time, it creates a new state entry from the constructor `ctor(*args, **kwargs)`. Later calls return the existing entry:

In [10]:
class SimpleMLP(nj.Module):

  def __call__(self, x):
    x = jax.nn.relu(self.get('h1', Linear, 128)(x))
    x = jax.nn.relu(self.get('h2', Linear, 128)(x))
    x = self.get('h3', Linear, 32)(x)
    return x

## How can I use Haiku modules?

There is nothing special about using external libraries with Ninjax. Haiku requires its modules to be passed through `hk.transform` and the initialized via `transformed.init(rng, batch)`. For convenience, Ninjax provides `nj.HaikuModule` to do this for you:

In [None]:
class ModuleHaiku(nj.Module):

  def __init__(self):
    self.mlp = nj.HaikuModule(hk.nets.MLP, [128, 128, 32])

  def __call__(self, x):
    return self.mlp(x)

You can also predefine a list of aliases for Haiku modules that you want to use frequently:

In [11]:
Linear = functools.partial(nj.HaikuModule, hk.Linear)
Conv2D = functools.partial(nj.HaikuModule, hk.Conv2D)
MLP = functools.partial(nj.HaikuModule, hk.nets.MLP)
# ...

## How can I use Flax modules?
There is nothing special about using external libraries with Ninjax. Flax requires its modules to be initialized via `params = model.init(rng, batch)` and used via `model.apply(params, data)`. For convenience, Ninjax provides nj.FlaxModule to do this for you:

In [None]:
class ModuleFlax(nj.Module):

  def __init__(self):
    self.linear = nj.FlaxModule(nn.Dense, 128)

  def __call__(self, x):
    return self.linear(x)

You can also predefine a list of aliases for Flax modules that you want to use frequently:

In [None]:
Dense = functools.partial(nj.FlaxModule, nn.Dense)
Conv = functools.partial(nj.FlaxModule, nn.Conv)
# ...

## How can I use Optax optimizers?
There is nothing special about using external libraries like Optax with Ninjax. Optax requires its optimizers to be initialized, their state to be passed through the optimizer call, and the resulting updates to be applied. For convenience, Ninjax provides `nj.OptaxModule` to do this for you:

In [None]:
class ModuleOptax(nj.Module):

  def __init__(self):
    self.mlp = MLP()
    self.opt = nj.OptaxModule(optax.adam, 1e-3)

  def train(self, x, y):
    self.mlp(x)  # Ensure paramters are created.
    metrics = self.opt(self.mlp.getm('.*'), self.loss, x, y)
    return metrics  # {'loss': ..., 'grad_norm': ...}

  def loss(self, x, y):
    return ((self.mlp(x) - y) ** 2).mean()

In [None]:
class Module(nj.Module):

  def __init__(self):
    self.linear = nj.FlaxModule(nn.Dense, 128)

  def __call__(self, x):
    return self.linear(x)


# Download data


In [6]:
def load_images_from_folder(folder):
    images = []
    for filename in os.listdir(folder):
        img = imageio.imread(os.path.join(folder, filename))
        if img is not None:
            images.append(img)
    return np.array(images)

# Load your dataset
image_folder = '/home/lyk/Projects/Machine-Learning-Basic/VAE/images'
custom_images = load_images_from_folder(image_folder)

NameError: name 'os' is not defined

# Load the data into NumPy

In [None]:
# BouncingBall
total_images = custom_images.shape[0]

# For example, let's say you want a 70-15-15 split
train_end = int(total_images * 0.7)
valid_end = train_end + int(total_images * 0.15)

train_images = custom_images[:train_end]
valid_images = custom_images[train_end:valid_end]
test_images = custom_images[valid_end:]

# Create dicts

In [None]:
train_data_dict = {'image': train_images}
valid_data_dict = {'image': valid_images}
test_data_dict = {'image': test_images}
train_data_variance = np.var(train_data_dict['image'] / 255.0)

In [None]:
def cast_and_normalise_images(features):
    # Assuming your images are uint8 [0, 255]
    features['image'] = tf.cast(features['image'], tf.float32) / 255.0
    return features

# Create dataset

In [None]:
train_dataset = tfds.as_numpy(
    tf.data.Dataset.from_tensor_slices(train_data_dict)
    .map(cast_and_normalise_images)
    .shuffle(10000)
    .repeat(-1)  # repeat indefinitely
    .batch(batch_size, drop_remainder=True)
    .prefetch(-1))
valid_dataset = tfds.as_numpy(
    tf.data.Dataset.from_tensor_slices(valid_data_dict)
    .map(cast_and_normalise_images)
    .repeat(1)  # 1 epoch
    .batch(batch_size)
    .prefetch(-1))

# Build the model

In [None]:
devices = jax.devices()
# Incorrect usage: trying to call devices as if it were a function
some_device = devices[0]  # This will cause an error if devices is not c

In [None]:
encoder_forward = HaikuModule(Encoder, num_hiddens, num_residual_layers, num_residual_hiddens, name="encoder")
decoder_forward = HaikuModule(Decoder, num_hiddens, num_residual_layers, num_residual_hiddens, name="decoder")


In [None]:
# def f(obj, *args, **kwargs):
#       # When the HaikuModule instance is called with input data (*args, **kwargs), it first retrieves or initializes the model's state using self.get('state', self.transformed.init, rng(), *args, **kwargs). 
#       # This step involves either fetching the existing model parameters from the module's internal state or initializing them if they haven't been created yet.
      
#       # The first time the HaikuModule instance is called with input data, it proceeds as follows:
#       # 1. It attempts to retrieve the model's state (parameters) using a method like self.get. If the parameters do not yet exist in the module's context (state management system), it calls self.transformed.init to initialize them. This is where init is indirectly invoked for parameter initialization.
#       # 2. After initialization, the parameters are stored in the module's context, making them retrievable for subsequent calls.
#       # print(f"args: {args}")
#     print(f"args len: {len(args)}")
      
#     print(f"kwargs: {kwargs}")
    
#     print(type(obj.get))
    
#     rng_key = jax.random.PRNGKey(42)
    
#     print(rng_key)
    
#     state = obj.get('state', obj.forward.init, rng_key, *args, **kwargs)
#     print(f"state: {state}")
#     return state

In [None]:
from ninjax import pure


train_dataset_iter = iter(train_dataset)

# encoder_forward(next(train_dataset_iter))


initial_state = {}
rng_key = jax.random.PRNGKey(42)

encoder_forward_func = encoder_forward.__call__
f_pure = pure(encoder_forward_func)
f_pure(initial_state, rng_key, next(train_dataset_iter))

# f_pure = pure(f)

# f_pure(initial_state, rng_key, encoder_forward, next(train_dataset_iter))


# params = encoder_forward(next(train_dataset_iter), is_training=True)

In [None]:
print(params, state)

In [None]:
def forward(data, is_training):
  encoder = HaikuModule(Encoder, num_hiddens, num_residual_layers, num_residual_hiddens)
  decoder = HaikuModule(Decoder, num_hiddens, num_residual_layers, num_residual_hiddens)
  pre_vq_conv1 = HaikuModule(
    hk.Conv2D,
        output_channels=embedding_dim,
        kernel_shape=(1, 1),
        stride=(1, 1),
        name="to_vq"
        )
  if vq_use_ema:
      vq_vae = HaikuModule(
        hk.nets.VectorQuantizerEMA,
        embedding_dim=embedding_dim,
        num_embeddings=num_embeddings,
        commitment_cost=commitment_cost,
        decay=decay
        )
  else:
      vq_vae = HaikuModule(
        hk.nets.VectorQuantizer,
        embedding_dim=embedding_dim,
        num_embeddings=num_embeddings,
        commitment_cost=commitment_cost,
        decay=decay) 
        
  model = HaikuModule(
          ImageEncoderVQVAE,
          encoder, 
          vq_vae, 
          pre_vq_conv1,
          data_variance=train_data_variance
          ) 

  return model(data['image'], is_training)

forward = hk.transform_with_state(forward)

In [None]:
train_dataset_iter = iter(train_dataset)
params, state = forward.init(rng, next(train_dataset_iter), is_training=True)

In [None]:
train_data_variance = 1 # Useless
encoder = Encoder(num_hiddens, num_residual_layers, num_residual_hiddens)
decoder = Decoder(num_hiddens, num_residual_layers, num_residual_hiddens)

    
pre_vq_conv1 = hk.Conv2D(
        output_channels=embedding_dim,
        kernel_shape=(1, 1),
        stride=(1, 1),
        name="to_vq")
    
if vq_use_ema:
    vq_vae = hk.nets.VectorQuantizerEMA(
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
            commitment_cost=commitment_cost,
            decay=decay)
else:
    vq_vae = hk.nets.VectorQuantizer(
            embedding_dim=embedding_dim,
            num_embeddings=num_embeddings,
            commitment_cost=commitment_cost)   
       
# model = HaikuModule(
#         ImageEncoderVQVAE,
#         encoder, 
#         vq_vae, 
#         pre_vq_conv1,
#         data_variance=train_data_variance
#         ) 
# optimizer = OptaxModule(optax.adam, learning_rate) 

In [None]:
def train_step(opt_state, data):
    # Example function to calculate loss and update model parameters
    def loss_fn(params, state, inputs):
        outputs = model(inputs)  # Directly call the model
        loss = jnp.mean((outputs - inputs) ** 2)  # Example loss calculation
        return loss, outputs

    # Use JAX to calculate the gradients, and update the parameters
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    (loss, outputs), grads = grad_fn(params, state, data['image'])
    updates, opt_state = optimizer.update(grads, opt_state)
    params = optax.apply_updates(params, updates)
    return params, opt_state, loss, outputs


In [None]:
%%time

train_losses = []
train_recon_errors = []
train_perplexities = []
train_vqvae_loss = []

rng = jax.random.PRNGKey(42)
train_dataset_iter = iter(train_dataset)
params, state = forward.init(rng, next(train_dataset_iter), is_training=True)
opt_state = optimizer.init(params)

for step in range(1, num_training_updates + 1):
  data = next(train_dataset_iter)
  params, state, opt_state, train_results = (
      train_step(params, state, opt_state, data))

  train_results = jax.device_get(train_results)
  train_losses.append(train_results['loss'])
  train_recon_errors.append(train_results['recon_error'])
  train_perplexities.append(train_results['vq_output']['perplexity'])
  train_vqvae_loss.append(train_results['vq_output']['loss'])

  if step % 100 == 0:
    print(f'[Step {step}/{num_training_updates}] ' + 
          ('train loss: %f ' % np.mean(train_losses[-100:])) +
          ('recon_error: %.9f ' % np.mean(train_recon_errors[-100:])) +
          ('perplexity: %.9f ' % np.mean(train_perplexities[-100:])) +
          ('vqvae loss: %.9f' % np.mean(train_vqvae_loss[-100:])))

In [None]:
checkpoint_path = '/home/lyk/Projects/Machine-Learning-Basic/VAE/checkpoints/VQ-VAE-DeepMind—BouncingBall.pkl'

with open(checkpoint_path, 'rb') as file:
    loaded_model_dict = pickle.load(file)

# Extract the params and state
loaded_params = loaded_model_dict['params']
loaded_state = loaded_model_dict['state']

params = loaded_params
state = loaded_state
print("Model loaded successfully.")

# View recon

In [None]:
# Reconstructions
train_batch = next(iter(train_dataset))
valid_batch = next(iter(valid_dataset))

# Put data through the model with is_training=False, so that in the case of 
# using EMA the codebook is not updated.
train_reconstructions = forward.apply(params, state, rng, train_batch, is_training=False)[0]['x_recon']
valid_reconstructions = forward.apply(params, state, rng, valid_batch, is_training=False)[0]['x_recon']



def convert_batch_to_image_grid(image_batch, rows=4, cols=8):
    # Assuming image_batch is of shape (B, H, W, C)
    B, H, W, C = image_batch.shape
    assert B >= rows * cols, "Not enough images to fill the grid"
    
    reshaped = image_batch[:rows * cols].reshape(rows, cols, H, W, C)
    reshaped = reshaped.transpose(0, 2, 1, 3, 4)  # Transpose to (rows, H, cols, W, C)
    grid = reshaped.reshape(rows * H, cols * W, C)
    
    return grid

# Assuming 'train_batch', 'train_reconstructions', 'valid_batch', and 'valid_reconstructions' are available
f = plt.figure(figsize=(16, 16))

# Training Data Originals
ax = f.add_subplot(2, 2, 1)
ax.imshow(convert_batch_to_image_grid(train_batch['image']), interpolation='nearest')
ax.set_title('Training Data Originals')
plt.axis('off')

# Training Data Reconstructions
ax = f.add_subplot(2, 2, 2)
ax.imshow(convert_batch_to_image_grid(train_reconstructions), interpolation='nearest')
ax.set_title('Training Data Reconstructions')
plt.axis('off')

# Validation Data Originals
ax = f.add_subplot(2, 2, 3)
ax.imshow(convert_batch_to_image_grid(valid_batch['image']), interpolation='nearest')
ax.set_title('Validation Data Originals')
plt.axis('off')

# Validation Data Reconstructions
ax = f.add_subplot(2, 2, 4)
ax.imshow(convert_batch_to_image_grid(valid_reconstructions), interpolation='nearest')
ax.set_title('Validation Data Reconstructions')
plt.axis('off')

plt.show()
