In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import optax
import polars as pl
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

sys.path.append("/workspace")
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.model.mlp import RegressionMLP
from flax_trainer.trainer import Trainer

In [3]:
housing = fetch_california_housing()
df_TRAIN_DATA, df_TEST_DATA = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)

In [None]:
trainer = Trainer(
    model=RegressionMLP(layer_sizes=[32, 16, 4, 2, 1]),
    optimizer=optax.adam(learning_rate=0.01),
    train_loader=MiniBatchLoader(df_DATA=df_TRAIN_DATA, batch_size=512),
    loss_fn=mean_squared_error,
    test_evaluator=RegressionEvaluator(df_TEST_DATA),
    early_stopping_patience=10,
    epoch_num=512,
    seed=0,
)
trainer = trainer.fit()