In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import optax
import polars as pl
from flax import nnx
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

sys.path.append("/workspace")
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.model.mlp import RegressionMLP
from flax_trainer.trainer import Trainer

In [3]:
housing = fetch_california_housing()
train_dataset_df, test_dataset_df = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)

In [4]:
model = RegressionMLP(train_dataset_df.shape[1] - 1, [32, 16, 4, 2], 1, nnx.Rngs(0))
train_loader = MiniBatchLoader(dataset_df=train_dataset_df, batch_size=512, rngs=nnx.Rngs(0))
test_evaluator = RegressionEvaluator(dataset_df=test_dataset_df)

In [5]:
# Training

trainer = Trainer(
    model=model,
    optimizer=optax.adam(learning_rate=0.001),
    train_loader=train_loader,
    loss_fn=mean_squared_error,
    test_evaluator=test_evaluator,
    early_stopping_patience=10,
    epoch_num=512,
)
trainer = trainer.fit()

[TEST  000]: loss=69.95780944824219, metrics={'mse': 69.95780944824219}


[TRAIN 001]: 100%|██████████| 33/33 [00:00<00:00, 94.49it/s, batch_loss=5.4964595] 


[TEST  001]: loss=6.54409646987915, metrics={'mse': 6.54409646987915}


[TRAIN 002]: 100%|██████████| 33/33 [00:00<00:00, 651.80it/s, batch_loss=5.1092644]


[TEST  002]: loss=5.6618971824646, metrics={'mse': 5.6618971824646}


[TRAIN 003]: 100%|██████████| 33/33 [00:00<00:00, 641.84it/s, batch_loss=2.053743]


[TEST  003]: loss=2.792478084564209, metrics={'mse': 2.792478084564209}


[TRAIN 004]: 100%|██████████| 33/33 [00:00<00:00, 617.15it/s, batch_loss=1.4365823]


[TEST  004]: loss=1.8629634380340576, metrics={'mse': 1.8629634380340576}


[TRAIN 005]: 100%|██████████| 33/33 [00:00<00:00, 608.74it/s, batch_loss=1.3979833]


[TEST  005]: loss=1.5767855644226074, metrics={'mse': 1.5767855644226074}


[TRAIN 006]: 100%|██████████| 33/33 [00:00<00:00, 637.01it/s, batch_loss=1.2941186]


[TEST  006]: loss=1.4899089336395264, metrics={'mse': 1.4899089336395264}


[TRAIN 007]: 100%|██████████| 33/33 [00:00<00:00, 608.28it/s, batch_loss=1.4245962]


[TEST  007]: loss=1.5455416440963745, metrics={'mse': 1.5455416440963745}


[TRAIN 008]: 100%|██████████| 33/33 [00:00<00:00, 637.92it/s, batch_loss=1.36414]


[TEST  008]: loss=1.452880859375, metrics={'mse': 1.452880859375}


[TRAIN 009]: 100%|██████████| 33/33 [00:00<00:00, 604.19it/s, batch_loss=1.1332653]


[TEST  009]: loss=1.396668553352356, metrics={'mse': 1.396668553352356}


[TRAIN 010]: 100%|██████████| 33/33 [00:00<00:00, 600.14it/s, batch_loss=1.209461]


[TEST  010]: loss=1.3831496238708496, metrics={'mse': 1.3831496238708496}


[TRAIN 011]: 100%|██████████| 33/33 [00:00<00:00, 630.33it/s, batch_loss=1.825509]


[TEST  011]: loss=1.362729787826538, metrics={'mse': 1.362729787826538}


[TRAIN 012]: 100%|██████████| 33/33 [00:00<00:00, 600.67it/s, batch_loss=1.458847]


[TEST  012]: loss=1.3359392881393433, metrics={'mse': 1.3359392881393433}


[TRAIN 013]: 100%|██████████| 33/33 [00:00<00:00, 619.19it/s, batch_loss=1.3124889]


[TEST  013]: loss=1.3195993900299072, metrics={'mse': 1.3195993900299072}


[TRAIN 014]: 100%|██████████| 33/33 [00:00<00:00, 617.62it/s, batch_loss=1.3559215]


[TEST  014]: loss=1.319619059562683, metrics={'mse': 1.319619059562683}


[TRAIN 015]: 100%|██████████| 33/33 [00:00<00:00, 631.77it/s, batch_loss=1.5402603]


[TEST  015]: loss=1.3320515155792236, metrics={'mse': 1.3320515155792236}


[TRAIN 016]: 100%|██████████| 33/33 [00:00<00:00, 630.81it/s, batch_loss=1.3136804]


[TEST  016]: loss=1.2950414419174194, metrics={'mse': 1.2950414419174194}


[TRAIN 017]: 100%|██████████| 33/33 [00:00<00:00, 615.22it/s, batch_loss=1.4038014]


[TEST  017]: loss=1.2775354385375977, metrics={'mse': 1.2775354385375977}


[TRAIN 018]: 100%|██████████| 33/33 [00:00<00:00, 593.91it/s, batch_loss=1.2930261]


[TEST  018]: loss=1.2538059949874878, metrics={'mse': 1.2538059949874878}


[TRAIN 019]: 100%|██████████| 33/33 [00:00<00:00, 625.83it/s, batch_loss=1.2121102]


[TEST  019]: loss=1.2387796640396118, metrics={'mse': 1.2387796640396118}


[TRAIN 020]: 100%|██████████| 33/33 [00:00<00:00, 631.47it/s, batch_loss=1.2322085]


[TEST  020]: loss=1.228780746459961, metrics={'mse': 1.228780746459961}


[TRAIN 021]: 100%|██████████| 33/33 [00:00<00:00, 628.94it/s, batch_loss=0.90739125]


[TEST  021]: loss=1.2016555070877075, metrics={'mse': 1.2016555070877075}


[TRAIN 022]: 100%|██████████| 33/33 [00:00<00:00, 630.29it/s, batch_loss=1.4209318]


[TEST  022]: loss=1.1864066123962402, metrics={'mse': 1.1864066123962402}


[TRAIN 023]: 100%|██████████| 33/33 [00:00<00:00, 605.79it/s, batch_loss=1.1080449]


[TEST  023]: loss=1.1748501062393188, metrics={'mse': 1.1748501062393188}


[TRAIN 024]: 100%|██████████| 33/33 [00:00<00:00, 629.65it/s, batch_loss=1.0827132]


[TEST  024]: loss=1.1553606986999512, metrics={'mse': 1.1553606986999512}


[TRAIN 025]: 100%|██████████| 33/33 [00:00<00:00, 619.81it/s, batch_loss=0.9466663]


[TEST  025]: loss=1.1238033771514893, metrics={'mse': 1.1238033771514893}


[TRAIN 026]: 100%|██████████| 33/33 [00:00<00:00, 631.36it/s, batch_loss=1.1592906]


[TEST  026]: loss=1.100164771080017, metrics={'mse': 1.100164771080017}


[TRAIN 027]: 100%|██████████| 33/33 [00:00<00:00, 632.45it/s, batch_loss=1.1622491]


[TEST  027]: loss=1.0822927951812744, metrics={'mse': 1.0822927951812744}


[TRAIN 028]: 100%|██████████| 33/33 [00:00<00:00, 621.53it/s, batch_loss=1.3485482]


[TEST  028]: loss=1.0435136556625366, metrics={'mse': 1.0435136556625366}


[TRAIN 029]: 100%|██████████| 33/33 [00:00<00:00, 626.10it/s, batch_loss=0.89076626]


[TEST  029]: loss=1.0270001888275146, metrics={'mse': 1.0270001888275146}


[TRAIN 030]: 100%|██████████| 33/33 [00:00<00:00, 624.07it/s, batch_loss=0.9466839]


[TEST  030]: loss=0.9725115299224854, metrics={'mse': 0.9725115299224854}


[TRAIN 031]: 100%|██████████| 33/33 [00:00<00:00, 359.14it/s, batch_loss=0.9790628]


[TEST  031]: loss=0.9285380840301514, metrics={'mse': 0.9285380840301514}


[TRAIN 032]: 100%|██████████| 33/33 [00:00<00:00, 624.57it/s, batch_loss=1.0009487]


[TEST  032]: loss=0.9496695399284363, metrics={'mse': 0.9496695399284363}


[TRAIN 033]: 100%|██████████| 33/33 [00:00<00:00, 528.84it/s, batch_loss=0.9695524]


[TEST  033]: loss=0.8595136404037476, metrics={'mse': 0.8595136404037476}


[TRAIN 034]: 100%|██████████| 33/33 [00:00<00:00, 618.19it/s, batch_loss=0.87037337]


[TEST  034]: loss=0.8998928070068359, metrics={'mse': 0.8998928070068359}


[TRAIN 035]: 100%|██████████| 33/33 [00:00<00:00, 609.69it/s, batch_loss=0.8215214]


[TEST  035]: loss=0.7850393652915955, metrics={'mse': 0.7850393652915955}


[TRAIN 036]: 100%|██████████| 33/33 [00:00<00:00, 620.38it/s, batch_loss=0.67198384]


[TEST  036]: loss=0.7623736262321472, metrics={'mse': 0.7623736262321472}


[TRAIN 037]: 100%|██████████| 33/33 [00:00<00:00, 612.76it/s, batch_loss=0.9225741]


[TEST  037]: loss=0.8671588897705078, metrics={'mse': 0.8671588897705078}


[TRAIN 038]: 100%|██████████| 33/33 [00:00<00:00, 600.31it/s, batch_loss=0.6830254]


[TEST  038]: loss=0.7081394791603088, metrics={'mse': 0.7081394791603088}


[TRAIN 039]: 100%|██████████| 33/33 [00:00<00:00, 622.91it/s, batch_loss=0.47315803]


[TEST  039]: loss=0.7726365923881531, metrics={'mse': 0.7726365923881531}


[TRAIN 040]: 100%|██████████| 33/33 [00:00<00:00, 647.90it/s, batch_loss=0.54812706]


[TEST  040]: loss=0.6759756207466125, metrics={'mse': 0.6759756207466125}


[TRAIN 041]: 100%|██████████| 33/33 [00:00<00:00, 630.40it/s, batch_loss=0.46781164]


[TEST  041]: loss=0.6632670760154724, metrics={'mse': 0.6632670760154724}


[TRAIN 042]: 100%|██████████| 33/33 [00:00<00:00, 620.44it/s, batch_loss=0.7207941]


[TEST  042]: loss=0.6461907625198364, metrics={'mse': 0.6461907625198364}


[TRAIN 043]: 100%|██████████| 33/33 [00:00<00:00, 615.76it/s, batch_loss=0.62912834]


[TEST  043]: loss=0.7795215249061584, metrics={'mse': 0.7795215249061584}


[TRAIN 044]: 100%|██████████| 33/33 [00:00<00:00, 565.38it/s, batch_loss=0.7285481]


[TEST  044]: loss=0.6324975490570068, metrics={'mse': 0.6324975490570068}


[TRAIN 045]: 100%|██████████| 33/33 [00:00<00:00, 619.22it/s, batch_loss=0.96120083]


[TEST  045]: loss=0.7085521817207336, metrics={'mse': 0.7085521817207336}


[TRAIN 046]: 100%|██████████| 33/33 [00:00<00:00, 621.31it/s, batch_loss=0.4581238]


[TEST  046]: loss=0.62704998254776, metrics={'mse': 0.62704998254776}


[TRAIN 047]: 100%|██████████| 33/33 [00:00<00:00, 360.71it/s, batch_loss=0.517988]


[TEST  047]: loss=0.7350544929504395, metrics={'mse': 0.7350544929504395}


[TRAIN 048]: 100%|██████████| 33/33 [00:00<00:00, 628.29it/s, batch_loss=0.72112006]


[TEST  048]: loss=0.7458639740943909, metrics={'mse': 0.7458639740943909}


[TRAIN 049]: 100%|██████████| 33/33 [00:00<00:00, 597.23it/s, batch_loss=0.43982118]


[TEST  049]: loss=0.6496586203575134, metrics={'mse': 0.6496586203575134}


[TRAIN 050]: 100%|██████████| 33/33 [00:00<00:00, 507.35it/s, batch_loss=0.6592014]


[TEST  050]: loss=0.6130358576774597, metrics={'mse': 0.6130358576774597}


[TRAIN 051]: 100%|██████████| 33/33 [00:00<00:00, 625.71it/s, batch_loss=0.46203598]


[TEST  051]: loss=0.6085271239280701, metrics={'mse': 0.6085271239280701}


[TRAIN 052]: 100%|██████████| 33/33 [00:00<00:00, 343.73it/s, batch_loss=0.48669547]


[TEST  052]: loss=0.6259031295776367, metrics={'mse': 0.6259031295776367}


[TRAIN 053]: 100%|██████████| 33/33 [00:00<00:00, 601.09it/s, batch_loss=0.62545073]


[TEST  053]: loss=0.613401472568512, metrics={'mse': 0.613401472568512}


[TRAIN 054]: 100%|██████████| 33/33 [00:00<00:00, 612.24it/s, batch_loss=0.75011396]


[TEST  054]: loss=0.6261976361274719, metrics={'mse': 0.6261976361274719}


[TRAIN 055]: 100%|██████████| 33/33 [00:00<00:00, 599.20it/s, batch_loss=0.64530957]


[TEST  055]: loss=0.6033914685249329, metrics={'mse': 0.6033914685249329}


[TRAIN 056]: 100%|██████████| 33/33 [00:00<00:00, 600.53it/s, batch_loss=0.44484204]


[TEST  056]: loss=0.6100307106971741, metrics={'mse': 0.6100307106971741}


[TRAIN 057]: 100%|██████████| 33/33 [00:00<00:00, 602.02it/s, batch_loss=0.5762122]


[TEST  057]: loss=0.6247549057006836, metrics={'mse': 0.6247549057006836}


[TRAIN 058]: 100%|██████████| 33/33 [00:00<00:00, 601.16it/s, batch_loss=0.65936756]


[TEST  058]: loss=0.6025345921516418, metrics={'mse': 0.6025345921516418}


[TRAIN 059]: 100%|██████████| 33/33 [00:00<00:00, 627.36it/s, batch_loss=0.52248704]


[TEST  059]: loss=0.6039100289344788, metrics={'mse': 0.6039100289344788}


[TRAIN 060]: 100%|██████████| 33/33 [00:00<00:00, 637.22it/s, batch_loss=0.6943023]


[TEST  060]: loss=0.5875662565231323, metrics={'mse': 0.5875662565231323}


[TRAIN 061]: 100%|██████████| 33/33 [00:00<00:00, 354.34it/s, batch_loss=0.36497062]


[TEST  061]: loss=0.6242859959602356, metrics={'mse': 0.6242859959602356}


[TRAIN 062]: 100%|██████████| 33/33 [00:00<00:00, 637.24it/s, batch_loss=0.74231356]


[TEST  062]: loss=0.615252673625946, metrics={'mse': 0.615252673625946}


[TRAIN 063]: 100%|██████████| 33/33 [00:00<00:00, 605.31it/s, batch_loss=0.6440625]


[TEST  063]: loss=0.6004959940910339, metrics={'mse': 0.6004959940910339}


[TRAIN 064]: 100%|██████████| 33/33 [00:00<00:00, 422.18it/s, batch_loss=0.606089]


[TEST  064]: loss=0.6048641800880432, metrics={'mse': 0.6048641800880432}


[TRAIN 065]: 100%|██████████| 33/33 [00:00<00:00, 624.00it/s, batch_loss=0.78683496]

[TEST  065]: loss=0.5949965119361877, metrics={'mse': 0.5949965119361877}



[TRAIN 066]: 100%|██████████| 33/33 [00:00<00:00, 574.97it/s, batch_loss=0.63641846]


[TEST  066]: loss=0.5981246829032898, metrics={'mse': 0.5981246829032898}


[TRAIN 067]: 100%|██████████| 33/33 [00:00<00:00, 633.92it/s, batch_loss=0.85438067]


[TEST  067]: loss=0.6092846989631653, metrics={'mse': 0.6092846989631653}


[TRAIN 068]: 100%|██████████| 33/33 [00:00<00:00, 595.41it/s, batch_loss=0.564083]


[TEST  068]: loss=0.619128942489624, metrics={'mse': 0.619128942489624}


[TRAIN 069]: 100%|██████████| 33/33 [00:00<00:00, 636.89it/s, batch_loss=0.3845288]


[TEST  069]: loss=0.6422792077064514, metrics={'mse': 0.6422792077064514}


[TRAIN 070]: 100%|██████████| 33/33 [00:00<00:00, 586.03it/s, batch_loss=0.63241506]


[TEST  070]: loss=0.6144530177116394, metrics={'mse': 0.6144530177116394}


In [6]:
# Inference

test_loader = MiniBatchLoader(dataset_df=test_dataset_df, batch_size=512, rngs=nnx.Rngs(0))
for Xs, y in test_loader:
    break
trainer.best_model(*Xs)[:10]

Array([[1.9849969],
       [2.1086004],
       [1.9600796],
       [2.944981 ],
       [1.1065688],
       [1.7676227],
       [1.6479461],
       [4.7380037],
       [3.5829756],
       [2.6884596]], dtype=float32)