In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import optax
import polars as pl
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

sys.path.append("/app")
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.model.mlp import RegressionMLP
from flax_trainer.trainer import Trainer

In [3]:
housing = fetch_california_housing()
df_TRAIN_DATA, df_TEST_DATA = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)

In [4]:
trainer = Trainer(
    model=RegressionMLP(layer_sizes=[32, 16, 4, 2, 1]),
    optimizer=optax.adam(learning_rate=0.01),
    train_loader=MiniBatchLoader(df_DATA=df_TRAIN_DATA, batch_size=512),
    loss_fn=mean_squared_error,
    test_evaluator=RegressionEvaluator(df_TEST_DATA),
    early_stopping_patience=10,
    epoch_num=512,
    seed=0,
)
trainer = trainer.fit()

[TEST  000] loss=837.3048095703125


[TRAIN 001]: 100%|██████████| 33/33 [00:00<00:00, 94.64it/s, batch_loss=5.8636403] 


[TEST  001] loss=4.837385177612305


[TRAIN 002]: 100%|██████████| 33/33 [00:00<00:00, 1290.90it/s, batch_loss=4.2527456]


[TEST  002] loss=4.302976131439209


[TRAIN 003]: 100%|██████████| 33/33 [00:00<00:00, 1134.83it/s, batch_loss=3.8469005]


[TEST  003] loss=3.762735605239868


[TRAIN 004]: 100%|██████████| 33/33 [00:00<00:00, 1102.23it/s, batch_loss=3.4326794]


[TEST  004] loss=3.2613182067871094


[TRAIN 005]: 100%|██████████| 33/33 [00:00<00:00, 1046.74it/s, batch_loss=3.086034]


[TEST  005] loss=2.8228976726531982


[TRAIN 006]: 100%|██████████| 33/33 [00:00<00:00, 1053.13it/s, batch_loss=2.3669844]


[TEST  006] loss=2.4525628089904785


[TRAIN 007]: 100%|██████████| 33/33 [00:00<00:00, 1078.62it/s, batch_loss=1.9929717]


[TEST  007] loss=2.1553919315338135


[TRAIN 008]: 100%|██████████| 33/33 [00:00<00:00, 1055.67it/s, batch_loss=2.0151653]


[TEST  008] loss=1.9200503826141357


[TRAIN 009]: 100%|██████████| 33/33 [00:00<00:00, 1019.65it/s, batch_loss=1.5134]


[TEST  009] loss=1.739291787147522


[TRAIN 010]: 100%|██████████| 33/33 [00:00<00:00, 1105.39it/s, batch_loss=1.2151306]


[TEST  010] loss=1.6061941385269165


[TRAIN 011]: 100%|██████████| 33/33 [00:00<00:00, 1030.20it/s, batch_loss=1.5015311]


[TEST  011] loss=1.507871389389038


[TRAIN 012]: 100%|██████████| 33/33 [00:00<00:00, 1160.41it/s, batch_loss=1.6962211]


[TEST  012] loss=1.4383573532104492


[TRAIN 013]: 100%|██████████| 33/33 [00:00<00:00, 1041.96it/s, batch_loss=0.97760606]


[TEST  013] loss=1.389829397201538


[TRAIN 014]: 100%|██████████| 33/33 [00:00<00:00, 1077.16it/s, batch_loss=1.4378728]


[TEST  014] loss=1.3579179048538208


[TRAIN 015]: 100%|██████████| 33/33 [00:00<00:00, 1127.91it/s, batch_loss=1.339653]


[TEST  015] loss=1.3364989757537842


[TRAIN 016]: 100%|██████████| 33/33 [00:00<00:00, 862.30it/s, batch_loss=1.1904805]


[TEST  016] loss=1.322684645652771


[TRAIN 017]: 100%|██████████| 33/33 [00:00<00:00, 905.32it/s, batch_loss=1.3465756]


[TEST  017] loss=1.3144276142120361


[TRAIN 018]: 100%|██████████| 33/33 [00:00<00:00, 1197.04it/s, batch_loss=1.3009009]


[TEST  018] loss=1.3093302249908447


[TRAIN 019]: 100%|██████████| 33/33 [00:00<00:00, 1123.88it/s, batch_loss=1.2136148]


[TEST  019] loss=1.3066539764404297


[TRAIN 020]: 100%|██████████| 33/33 [00:00<00:00, 1048.43it/s, batch_loss=1.1875241]


[TEST  020] loss=1.305134654045105


[TRAIN 021]: 100%|██████████| 33/33 [00:00<00:00, 1054.30it/s, batch_loss=1.3473296]


[TEST  021] loss=1.304396152496338


[TRAIN 022]: 100%|██████████| 33/33 [00:00<00:00, 1063.85it/s, batch_loss=1.1117574]


[TEST  022] loss=1.3040844202041626


[TRAIN 023]: 100%|██████████| 33/33 [00:00<00:00, 1055.47it/s, batch_loss=1.1334361]


[TEST  023] loss=1.3039700984954834


[TRAIN 024]: 100%|██████████| 33/33 [00:00<00:00, 497.48it/s, batch_loss=1.349792]


[TEST  024] loss=1.3039584159851074


[TRAIN 025]: 100%|██████████| 33/33 [00:00<00:00, 1065.24it/s, batch_loss=1.3552812]


[TEST  025] loss=1.3039944171905518


[TRAIN 026]: 100%|██████████| 33/33 [00:00<00:00, 1130.87it/s, batch_loss=1.2368824]


[TEST  026] loss=1.304064393043518


[TRAIN 027]: 100%|██████████| 33/33 [00:00<00:00, 1092.90it/s, batch_loss=1.2309859]


[TEST  027] loss=1.3041328191757202


[TRAIN 028]: 100%|██████████| 33/33 [00:00<00:00, 1014.71it/s, batch_loss=1.2110941]


[TEST  028] loss=1.3041861057281494


[TRAIN 029]: 100%|██████████| 33/33 [00:00<00:00, 1134.51it/s, batch_loss=1.6188705]


[TEST  029] loss=1.3041958808898926


[TRAIN 030]: 100%|██████████| 33/33 [00:00<00:00, 1118.36it/s, batch_loss=1.0611517]


[TEST  030] loss=1.3042751550674438


[TRAIN 031]: 100%|██████████| 33/33 [00:00<00:00, 1174.48it/s, batch_loss=1.4398775]


[TEST  031] loss=1.3042314052581787


[TRAIN 032]: 100%|██████████| 33/33 [00:00<00:00, 1194.50it/s, batch_loss=1.1009076]


[TEST  032] loss=1.3042563199996948


[TRAIN 033]: 100%|██████████| 33/33 [00:00<00:00, 1133.33it/s, batch_loss=1.2533345]


[TEST  033] loss=1.3042422533035278


[TRAIN 034]: 100%|██████████| 33/33 [00:00<00:00, 1124.09it/s, batch_loss=1.4525952]


[TEST  034] loss=1.3042882680892944
