In [None]:
#| default_exp callbacks.pbar

In [None]:
#| include: false
%load_ext autoreload
%autoreload 2
from ipynb_path import *
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

In [None]:
#| export
from haiku_trainer.core import *
import haiku as hk
import importlib

if importlib.util.find_spec("ipywidgets") is not None:
    from tqdm.auto import tqdm 
else:
    from tqdm import tqdm 

In [None]:
#| export
def _update_pbar_n(pbar: tqdm, n: int):
    pbar.n = n
    pbar.refresh()

In [None]:
#| export
class ProgbarLogger(Callback):
    def __init__(self, ): 
        self._train_pbar = None
        self._valid_pbar = None
        self._batch_idx = None

    @property
    def train_pbar(self) -> tqdm: return self._train_pbar

    @property
    def valid_pbar(self) -> tqdm: return self._valid_pbar

    @train_pbar.setter
    def train_pbar(self, pbar: tqdm): self._train_pbar = pbar

    @valid_pbar.setter
    def valid_pbar(self, pbar: tqdm): self._valid_pbar = pbar
    
    @property
    def num_train_batches(self) -> int: 
        return self.trainer.num_train_batches
    
    @property
    def num_valid_batches(self) -> int:
        return self.trainer.num_val_batches
    
    @property
    def num_epoches(self) -> int:
        return self.trainer.n_epochs

    def init_train_pbar(self):
        return tqdm(
            desc='Training', leave=True, dynamic_ncols=True,
            file=sys.stdout, smoothing=0, position=0
        )
    
    def init_val_pbar(self):
        return tqdm(
            desc='Validation', leave=False, dynamic_ncols=True,
            file=sys.stdout, smoothing=0, position=0
        )
    
    def on_train_begin(self, state: TrainState):
        self.train_pbar = self.init_train_pbar()

    def on_epoch_begin(self, state: TrainState):
        self.train_pbar.reset(self.num_train_batches)
        self.train_pbar.initial = 0
        self.train_pbar.set_description(f"Epoch {state.epoch}")
        self._batch_idx = 0

    def on_train_batch_end(self, state: TrainState):
        self._batch_idx += 1
        _update_pbar_n(self.train_pbar, self._batch_idx)
        self.train_pbar.set_postfix(state.logs)
        
    def on_epoch_end(self, state: TrainState):
        if not self.train_pbar.disable:
            self.train_pbar.set_postfix(state.logs)

    def on_train_end(self, state: TrainState):
        self.train_pbar.close()

    def on_val_begin(self, state: TrainState):
        self.valid_pbar = self.init_val_pbar()
        self.valid_pbar.reset(self.num_valid_batches)
        self.valid_pbar.initial = 0
        self.valid_pbar.set_description(f"Validation Dataloader")

    def on_val_batch_end(self, state: TrainState):
        self.valid_pbar.update(1)

    def on_val_end(self, state: TrainState):
        self.valid_pbar.close()


## Test

### Fake Module

In [None]:
def make_hk_module(output_size: int = 2):
    """Creates a Haiku module with a linear layer and batchnorm."""
    def model(x, is_training=True):
        return hk.BatchNorm(True, True, 0.9)(
            hk.Linear(output_size)(x), is_training=is_training)
    
    return hk.transform_with_state(model)

In [None]:
module = make_hk_module()

### Fake Data

In [None]:
from jax_dataloader import DataLoader, ArrayDataset
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
import optax

In [None]:
xs, ys = make_classification(n_samples=2000, n_features=10, random_state=0)
train_xs, test_xs, train_ys, test_ys = train_test_split(xs, ys, test_size=0.2, random_state=0)
train_ds = ArrayDataset(train_xs, train_ys)
train_dl = DataLoader(train_ds, 'jax', batch_size=128)
test_ds = ArrayDataset(test_xs, test_ys)
test_dl = DataLoader(test_ds, 'jax', batch_size=128)

### Training

In [None]:
trainer = Trainer(
    transformed=module,
    optimizers=optax.adam(1e-3),
    callbacks=[ProgbarLogger()],
    n_epochs=2,
)

In [None]:
trainer.fit(train_dl, test_dl)

  param = init(shape, dtype)


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]