In [1]:
# annotations for code readability
import typing as tp
import chex

# data read and split
import polars as pl
from sklearn.model_selection import train_test_split

# nn
import jax
from jax import numpy as jnp
from flax import nnx
from flax.typing import Array
import optax

from utils import (
    MLP,
    MAE,
    _init_model,
    batch
)

$$MAE = \frac{1}{n}\sum_{i=0}^{n}|y_i - \hat{y}_i|$$

$$MSE = \frac{1}{n}\sum_{i=0}^{n}(y_i - \hat{y}_i)^2$$

In [2]:
df = pl.read_csv("./data/BostonHousing.csv")

X = df.drop(pl.col("medv")).to_jax()
y = df.select(pl.col("medv")).to_jax().squeeze()

X_train, X_test, y_train, y_test = train_test_split(X, 
                                                    y,
                                                    test_size=0.195,
                                                    random_state=42,
                                                    shuffle=True)

X_train -= X_train.mean(axis=0)
X_train /= X_train.std(axis=0)

X_test -= X_test.mean(axis=0)
X_test /= X_test.std(axis=0)


In [3]:
X_train.shape[0]

407

In [4]:
model = MLP(input_dim=13, hidden=(5,), output_dim=1, rngs=nnx.Rngs(42))
model 

MLP(
  nn=Sequential(
    layers=[Linear(
      kernel=Param(
        value=Array(shape=(13, 5), dtype=float32)
      ),
      bias=Param(
        value=Array(shape=(5,), dtype=float32)
      ),
      in_features=13,
      out_features=5,
      use_bias=True,
      dtype=None,
      param_dtype=<class 'jax.numpy.float32'>,
      precision=None,
      kernel_init=<function variance_scaling.<locals>.init at 0x0000028FF2C9ADC0>,
      bias_init=<function zeros at 0x0000028FE21D3EE0>,
      dot_general=<function dot_general at 0x0000028FE1C37550>
    ), <PjitFunction of <function silu at 0x0000028FE21E95E0>>, Linear(
      kernel=Param(
        value=Array(shape=(5, 1), dtype=float32)
      ),
      bias=Param(
        value=Array([0.], dtype=float32)
      ),
      in_features=5,
      out_features=1,
      use_bias=True,
      dtype=None,
      param_dtype=<class 'jax.numpy.float32'>,
      precision=None,
      kernel_init=<function variance_scaling.<locals>.init at 0x0000028FF2C9ADC0>,

In [5]:
y = model(jnp.ones(13))
y

Array([-1.1617936], dtype=float32)

In [6]:
def loss_fn(model: MLP, 
            features: Array, 
            targets: Array) -> tuple[Array]:
    """Computes L2-Loss with optax. Returns loss and logits"""
    logits = model(features)
    loss = jnp.mean(
        optax.l2_loss(
            predictions=logits.squeeze(), targets=targets
        )
    )
    return loss, logits


@jax.jit
def train_step(graphdef: nnx.GraphDef, 
               state: nnx.GraphState,   
               features: Array, 
               targets: Array) -> nnx.GraphState:
    """Train for a single step."""
    model, optimizer, metrics = nnx.merge(graphdef, state)
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    (loss, logits), grads = grad_fn(model, features, targets)
    metrics.update(loss=loss, logits=logits, targets=targets)
    optimizer.update(grads)
    _, state = nnx.split((model, optimizer, metrics))
    return state


@jax.jit
def eval_step(graphdef: nnx.GraphDef, 
              state: nnx.GraphState,   
              features: Array,
              targets: Array) -> nnx.GraphState:
    """Eval for single step"""
    model, optimizer, metrics = nnx.merge(graphdef, state)
    loss, logits = loss_fn(model, features, targets)
    metrics.update(loss=loss, logits=logits, targets=targets)
    _, state = nnx.split((model, optimizer, metrics))
    return state
    

In [7]:
def _init_metrics_hisory() -> dict[str, Array]:
    """Creates metrics_history
    Example:
    >>> metrics = _init_metrics_hisory()
    >>> metrics
    {'train_loss': [], 'train_mae': [], 'test_loss': [], 'test_mae': []}
    """
    metrics_history = {
        'train_loss': [],
        'train_mae': [],
        'test_loss': [],
        'test_mae': [],
    }
    return metrics_history


def run(X_train: Array,
        y_train: Array,
        X_test: Array,
        y_test: Array,
        batch_size: int = 8,
        key: tp.Optional[chex.PRNGKey] = None,
        input_dim: int = 60,
        hidden: tp.Optional[tuple[int]] = None,
        output_dim: int = 1,
        k: int = 3,
        num_epochs: int = 10) -> tuple[MLP, dict[str, Array], nnx.State, nnx.State]:
    """Inits model and it's components and run train and test"""
    if key is None:
        key = jax.random.key(42)

    batch_key, model_key = jax.random.split(key,)
    model_key = nnx.Rngs(model_key)
    
    # get model and default weights of the first layer
    num_val_samples = len(X_train) // k
    metrics_history = _init_metrics_hisory()
    metrics = nnx.MultiMetric(
        mae=MAE(),
        loss=nnx.metrics.Average('loss'),
    )
    test_metrics = {
        'test_loss': [],
        'test_mae': [],
    }
    for k_fold in range(k):
        model, _ = _init_model(input_dim=input_dim, hidden=hidden, output_dim=output_dim, rngs=model_key)
        optimizer = nnx.Optimizer(model, optax.rmsprop(learning_rate=3e-3))
        print("\nK-fold:", k_fold)
        X_val = X_train[k_fold * num_val_samples: (k_fold + 1) * num_val_samples]
        y_val = y_train[k_fold * num_val_samples: (k_fold + 1) * num_val_samples]

        partial_train_data = jnp.concatenate(
            [X_train[:k_fold * num_val_samples], 
             X_train[(k_fold + 1) * num_val_samples:]], axis=0)
        partial_train_targets = jnp.concatenate(
            [y_train[:k_fold * num_val_samples], 
             y_train[(k_fold + 1) * num_val_samples:]], axis=0)
        for _ in range(num_epochs):
            # train
            model.train()
            graphdef, state = nnx.split((model, optimizer, metrics))

            # new order for shuffle
            batch_key = jax.random.split(batch_key)[0]
            for X_batched, y_batched in batch(X=partial_train_data, 
                                              y=partial_train_targets,
                                              batch_size=batch_size,
                                              key=batch_key,
                                              train=True):
                state = train_step(graphdef=graphdef, 
                                   state=state,
                                   features=X_batched, 
                                   targets=y_batched)

            # if flax >= 0.10.0
            # nnx.update((model, optimizer, metrics), state)
            
            # else
            model, optimizer, metrics = nnx.merge(graphdef, state)
            # store train metrics
            for metric, value in metrics.compute().items():     
                metrics_history[f'train_{metric}'].append(value)
            metrics.reset()

            # eval
            model.eval()
            graphdef, state = nnx.split((model, optimizer, metrics))
            for X_batched, y_batched in batch(X=X_val, 
                                              y=y_val,
                                              batch_size=batch_size,
                                              key=batch_key,
                                              train=False):
                state = eval_step(graphdef=graphdef,
                                  state=state, 
                                  features=X_batched, 
                                  targets=y_batched)
            
            # if flax >= 0.10.0
            # nnx.update((model, optimizer, metrics), state)
            
            # else
            model, optimizer, metrics = nnx.merge(graphdef, state)
            # store eval metrics
            for metric, value in metrics.compute().items():    
                metrics_history[f'test_{metric}'].append(value)
            metrics.reset() 

        model.eval()
        graphdef, state = nnx.split((model, optimizer, metrics))
        for X_batched, y_batched in batch(X=X_val, 
                                          y=y_val,
                                          batch_size=batch_size,
                                          key=batch_key,
                                          train=False):
            state = eval_step(graphdef=graphdef,
                              state=state,
                              features=X_batched, 
                              targets=y_batched)
        
        # store eval metrics
        print("Test")

        # if flax >= 0.10.0
        # nnx.update((model, optimizer, metrics), state)
        
        # else
        model, _, metrics = nnx.merge(graphdef, state)
        for metric, value in metrics.compute().items():    
            test_metrics[f'test_{metric}'].append(value)
            print(f"{metric}: {value}")
        metrics.reset() 

    
    return model, metrics_history, test_metrics


In [8]:
_, metrics_history, test_metrics = run(X_train=X_train, 
                                       y_train=y_train, 
                                       X_test=X_test, 
                                       y_test=y_test,
                                       batch_size=9,
                                       key=jax.random.key(42),
                                       input_dim=X_train.shape[1], 
                                       hidden=(64,), 
                                       output_dim=1, 
                                       num_epochs=100, 
                                       k=4)


K-fold: 0




Test
mae: 8.474502563476562
loss: 5.22266149520874

K-fold: 1




Test
mae: 9.651590347290039
loss: 9.30495548248291

K-fold: 2




Test
mae: 10.561206817626953
loss: 4.997997760772705

K-fold: 3




Test
mae: 8.21519947052002
loss: 5.934994220733643


In [9]:
test_metrics

{'test_loss': [Array(5.2226615, dtype=float32),
  Array(9.3049555, dtype=float32),
  Array(4.9979978, dtype=float32),
  Array(5.934994, dtype=float32)],
 'test_mae': [Array(8.474503, dtype=float32),
  Array(9.65159, dtype=float32),
  Array(10.561207, dtype=float32),
  Array(8.215199, dtype=float32)]}