In [6]:
import jax
import jax.numpy as jnp
import numpy as np
from jaxmao.layers import Dense, BatchNorm
from jaxmao.modules import Module
from jaxmao.optimizers import GradientDescent

In [4]:
key = jax.random.PRNGKey(42)

In [114]:
class ClassifierWithSimplifiedState(Module):
    def __init__(self):
        super().__init__()
        self.add("fc1", Dense(128, 64))
        self.add("bn1", BatchNorm(64))
        self.add("fc2", Dense(64, 10, activation=jnp.tanh))
    
    def forward(self, params, x, state):
        x, state = self.forward_with_state(params, x, "fc1", state)
        x, state = self.forward_with_state(params, x, "bn1", state)
        x, state = self.forward_with_state(params, x, "fc2", state)
        return x, state

# Apply JIT compilation to the forward method
model = ClassifierWithSimplifiedState()
model.forward = jax.jit(model.forward)
model.init_params(key)

x = jnp.ones((40, 128)) + np.random.uniform(0, 1, (40, 128))
y_true = jnp.zeros((40, 10))

y = model(model.params, x)  # State is managed automatically
y.shape

(40, 10)

In [115]:
from sklearn.utils import shuffle

def loss_fn(model, params, x, y_true, state):
    y_pred, new_state = model.pure_forward(params, x, state)
    loss = jnp.mean((y_pred - y_true)**2)
    return loss, new_state

def training_loop(model, optimizer, x, y, epochs, lr=0.01, batch_size=32):
    num_batch = len(x) // batch_size
    for i in range(epochs):
        losses = 0
        x, y = shuffle(x, y)
        for n in range(num_batch):
            # Compute gradients
            loss_grad_fn = jax.value_and_grad(loss_fn, argnums=1, has_aux=True)
            (loss, new_state), gradients = loss_grad_fn(model, model.params, 
                                                          x[n*batch_size:(n+1)*batch_size], 
                                                          y_true[n*batch_size:(n+1)*batch_size], 
                                                          model.state)
            losses = losses + loss
            # Update model parameters
            model.params, optim_state = optimizer.step(model.params, gradients, lr=lr)
            model.update_state(new_state)
            
        print(losses / num_batch)

In [116]:
# Initialize model and optimizer
optimizer = GradientDescent()
epochs = 5
lr = 0.1

training_loop(model, optimizer, x, y_true, epochs, lr=lr, batch_size=4)

0.5352564
0.4989636
0.47177288
0.42541075
0.38362566


In [None]:
model(model.params, x)