In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import optax
import polars as pl
from flax import nnx
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

sys.path.append("/workspace")
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.model.mlp import RegressionMLP
from flax_trainer.trainer import Trainer

In [3]:
housing = fetch_california_housing()
df_TRAIN_DATA, df_TEST_DATA = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)

In [4]:
model = RegressionMLP(df_TRAIN_DATA.shape[1] - 1, [32, 16, 4, 2], 1, nnx.Rngs(0))

In [5]:
trainer = Trainer(
    model=model,
    optimizer=optax.adam(learning_rate=0.0001),
    train_loader=MiniBatchLoader(df_DATA=df_TRAIN_DATA, batch_size=512, seed=0),
    loss_fn=mean_squared_error,
    test_evaluator=RegressionEvaluator(df_TEST_DATA),
    early_stopping_patience=10,
    epoch_num=512,
)
trainer = trainer.fit()

[TEST  000] loss=69.95780944824219


[TRAIN 001]: 100%|██████████| 33/33 [00:00<00:00, 107.53it/s, batch_loss=4.8399506]


[TEST  001] loss=10.040655136108398


[TRAIN 002]: 100%|██████████| 33/33 [00:00<00:00, 573.58it/s, batch_loss=4.3720737]


[TEST  002] loss=7.081933498382568


[TRAIN 003]: 100%|██████████| 33/33 [00:00<00:00, 567.42it/s, batch_loss=5.3580875]


[TEST  003] loss=5.573915958404541


[TRAIN 004]: 100%|██████████| 33/33 [00:00<00:00, 543.03it/s, batch_loss=4.7945204]


[TEST  004] loss=5.5116167068481445


[TRAIN 005]: 100%|██████████| 33/33 [00:00<00:00, 570.97it/s, batch_loss=5.106104]


[TEST  005] loss=5.547776699066162


[TRAIN 006]: 100%|██████████| 33/33 [00:00<00:00, 582.57it/s, batch_loss=5.4647923]


[TEST  006] loss=5.579135417938232


[TRAIN 007]: 100%|██████████| 33/33 [00:00<00:00, 576.01it/s, batch_loss=5.686099]


[TEST  007] loss=5.604259490966797


[TRAIN 008]: 100%|██████████| 33/33 [00:00<00:00, 523.08it/s, batch_loss=5.046925]


[TEST  008] loss=5.623812675476074


[TRAIN 009]: 100%|██████████| 33/33 [00:00<00:00, 592.48it/s, batch_loss=4.2932835]


[TEST  009] loss=5.637764930725098


[TRAIN 010]: 100%|██████████| 33/33 [00:00<00:00, 583.01it/s, batch_loss=5.31737]


[TEST  010] loss=5.651158332824707


[TRAIN 011]: 100%|██████████| 33/33 [00:00<00:00, 574.61it/s, batch_loss=5.755681]


[TEST  011] loss=5.664552211761475


[TRAIN 012]: 100%|██████████| 33/33 [00:00<00:00, 337.96it/s, batch_loss=5.1685014]


[TEST  012] loss=5.677972793579102


[TRAIN 013]: 100%|██████████| 33/33 [00:00<00:00, 566.18it/s, batch_loss=5.034153]


[TEST  013] loss=5.6914448738098145


[TRAIN 014]: 100%|██████████| 33/33 [00:00<00:00, 576.00it/s, batch_loss=5.3188295]


[TEST  014] loss=5.704908847808838
