In [None]:
#| default_exp trainer

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)


In [None]:
#| export
from __future__ import annotations
import jax, jax.numpy as jnp, jax.random as jrand
import haiku as hk
import optax
import chex
from dataclasses import dataclass
import functools as ft
from typing import Callable, Tuple, Any, Sequence, Iterable, Mapping, Dict, List, NamedTuple
import copy
from haiku_trainer.callbacks import *

## Trainer

In [None]:
#| export
class TrainState(NamedTuple):
    epoch: int
    step: int
    params: hk.Params
    state: hk.State
    opt_state: optax.OptState
    next_key: jrand.PRNGKey
    logs: dict = None

    def __eq__(self, compare: TrainState) -> bool:
        return (self.epoch == compare.epoch) and (self.step == compare.step)

In [None]:
#| export
@dataclass
class Trainer:
    transformed: hk.TransformedWithState | hk.MultiTransformedWithState
    optimizers: optax.GradientTransformation | Sequence[optax.GradientTransformation]
    rng_key: jrand.PRNGKey = None

    # callback functions
    callbacks: Sequence[Callback] = None
    step_fn: StepFn = None

    # trainer configs
    lr: float = 1e-3
    n_epochs: int = 1

    # model train state
    _train_state: TrainState = None

    @property
    def train_state(self):
        return self._train_state

    def _initialize_key(self):
        if self.rng_key is None:    return jrand.PRNGKey(42) # TODO: use global
        else:                       return self.rng_key

    def _initialize_callbacks(self):
        if self.callbacks is None:
            self.callbacks = CallbackList()
        elif isinstance(self.callbacks, CallbackList):
            self.callbacks = self.callbacks
        elif isinstance(self.callbacks, Sequence):
            self.callbacks = CallbackList(self.callbacks)
        else:
            raise ValueError(f"Invalid callbacks. Expected `CallbackList` or `Sequence[Callback]`.")

        self.callbacks.init_trainer(self)

    def _initialize_step_fn(self):
        if self.step_fn is None:
            self.step_fn = DefaultStepFn(trainer=self)
        else:
            if isinstance(self.step_fn, StepFn):
                self.step_fn.init_trainer(self)
            else:
                raise ValueError(f"Invalid `Trainer.step_fn`. Expected `StepFn`, but got `{type(self.step_fn)}`.")
    
    def _initialize(self):
        self._initialize_callbacks()
        self._initialize_step_fn()
        
    def _run_callbacks(self, hook_name: str):
        hook_fn = getattr(self.callbacks, hook_name, None)
        if hook_fn is not None:
            hook_fn(self.train_state)

    def _run_step_fn(self, step_name: str, batch: Tuple[jax.Array, ...], validate: bool = False):
        step_fn = getattr(self.step_fn, step_name)
        train_state = step_fn(self.train_state, batch)

        if validate and train_state == self.train_state:
            raise ValueError(f"Train state is not updated after `{step_name}`.")
        self.update_train_state(train_state)

    def update_train_state(self, train_state: TrainState = None, **kwargs):
        if train_state is None and kwargs == {}:
            raise ValueError("Either `train_state` or `kwargs` must be provided.")
        if train_state is None:
            train_state = self.train_state._replace(**kwargs)
        self._train_state = train_state

    def fit(self, train_dataloader, val_dataloader=None):
        self._initialize()
        self._run_callbacks("on_train_begin")
        for epoch in range(self.n_epochs):
            self._run_callbacks("on_epoch_begin")
            for batch in train_dataloader:
                self._run_callbacks("on_train_batch_begin")
                if self.train_state is None:
                    self._run_step_fn("init_step", batch)
                self._run_step_fn("train_step", batch)
                self._run_callbacks("on_train_batch_end")
            self._run_callbacks("on_epoch_end")

            if val_dataloader is not None:
                self._run_callbacks("on_val_begin")
                for batch in val_dataloader:
                    self._run_callbacks("on_val_batch_begin")
                    self._run_step_fn("val_step", batch)
                    self._run_callbacks("on_val_batch_end")
                self._run_callbacks("on_val_end")

            self._run_callbacks("on_train_end")
            self._run_step_fn("epoch_step", batch=None)
        
        self._run_callbacks("on_train_end")


## Step Functions

In [None]:
#| export
class StepFn:
    def __init__(self, trainer: Trainer=None, *args, **kwargs) -> None:
        if trainer is not None:
            self.init_trainer(trainer)

    def init_trainer(self, trainer: Trainer):
        self._trainer = trainer

    @property
    def trainer(self): return self._trainer

    @property
    def transformed(self): return self.trainer.transformed

    forward = transformed
    
    @property
    def optimizers(self): return self.trainer.optimizers

    def init_step(self, train_state: TrainState, batch: Tuple[jax.Array, ...]) -> TrainState:
        key1, next_key = jrand.split(self._init_key())
        
        params, state = self._init_params_and_state(key1, batch[0])
        opt_states = self._init_opt_state(params)
        return TrainState(
            epoch=0, step=0, params=params, state=state, 
            opt_state=opt_states, next_key=next_key,
        )
    
    def epoch_step(self, train_state: TrainState, batch=None) -> TrainState:
        return train_state._replace(epoch=train_state.epoch+1)
    
    def train_step(self, train_state: TrainState, batch: Tuple[jax.Array, ...]) -> TrainState:
        raise NotImplementedError
    
    def val_step(self, train_state: TrainState, batch: Tuple[jax.Array, ...]) -> TrainState:
        raise NotImplementedError
    
    def _init_key(self):
        if self.trainer.rng_key is None:
            return jrand.PRNGKey(0)
        elif isinstance(self.trainer.rng_key, jrand.PRNGKey):
            return self.trainer.rng_key
        else:
            raise ValueError(f"Invalid rng_key. Expected `jax.random.PRNGKey`.")

    def _init_params_and_state(self, key: jrand.PRNGKey, xs: jax.Array):
        params, state = self.transformed.init(key, xs)
        return params, state

    def _init_opt_state(self, params: hk.Params):
        if isinstance(self.optimizers, optax.GradientTransformation):
            return self.optimizers.init(params) 
        else:
            raise ValueError(f"Invalid optimizers. Expected `optax` optimizers.")


In [None]:
#| export
class DefaultStepFn(StepFn):

    @ft.partial(jax.jit, static_argnums=(0,))
    def train_step(self, train_state: TrainState, batch: Tuple[jax.Array, ...]) -> TrainState:
        def loss_fn(params: hk.Params):
            logits, new_state = self.transformed.apply(
                params, state,
                rng_key, # <== rng
                inputs, is_training=True # <== inputs
            )
            loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
            return (loss, new_state)
        
        inputs, labels = batch
        rng_key, next_key = jrand.split(train_state.next_key)
        state = train_state.state
        (loss, new_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(train_state.params)
        updates, new_opt_state = self.optimizers.update(
            grads, train_state.opt_state, train_state.params)
        new_params = optax.apply_updates(train_state.params, updates)
        return TrainState(
            epoch=train_state.epoch,
            step=train_state.step + 1,
            params=new_params,
            state=new_state,
            opt_state=new_opt_state,
            next_key=next_key,
            logs={'train/loss': loss}
        )
    
    def val_step(self, train_state: TrainState, batch: Tuple[jax.Array, ...]) -> TrainState:
        inputs, labels = batch
        rng_key, next_key = jrand.split(train_state.next_key)
        logits, _ = self.transformed.apply(
            train_state.params, train_state.state,
            rng_key, # <== rng
            inputs, is_training=False # <== inputs
        )
        loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels).mean()
        acc = (jnp.argmax(logits, axis=-1) == labels).mean()
        logs = {'val/loss': loss, "val/accuracy": acc}

        return train_state._replace(
            step=train_state.step + 1,
            next_key=next_key, logs=logs
        )

    

## Test

### Fake Module

In [None]:
class LinearBatchNorm(hk.Module):
    "Linear layer with batch normalization"
    def __init__(self, output_size, name=None):
        super().__init__(name=name)
        self.output_size = output_size
    
    def __call__(self, x, training=False):
        return hk.BatchNorm(True, True, 0.9)(hk.Linear(self.output_size)(x), is_training=training)    

In [None]:
def make_hk_module(
    module: hk.Module, # haiku module 
    *args, # haiku module arguments
    **kargs, # haiku module arguments
) -> hk.Transformed:

    def model_fn(x, is_training: bool = True):
        return module(*args, **kargs)(x, is_training)
    
    return hk.transform_with_state(model_fn)

module = make_hk_module(LinearBatchNorm, 2)

### Fake Data

In [None]:
from jax_dataloader import DataLoader, ArrayDataset
from sklearn.datasets import make_classification

In [None]:
xs, ys = make_classification(n_samples=2000, n_features=10, random_state=0)
ds = ArrayDataset(xs, ys)
dl = DataLoader(ds, 'jax', batch_size=128)

### Training

In [None]:
trainer = Trainer(
    transformed=module,
    optimizers=optax.adam(1e-3),
    callbacks=[],
)

In [None]:
trainer.fit(dl)

  param = init(shape, dtype)
