In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

import jax
import mlflow
import optax
import polars as pl
from flax import nnx
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split

sys.path.append("/workspace")
from flax_trainer.evaluator import RegressionEvaluator
from flax_trainer.loader import MiniBatchLoader
from flax_trainer.loss_fn import mean_squared_error
from flax_trainer.model.mlp import RegressionMLP
from flax_trainer.trainer import Trainer

jax.default_backend()

'cpu'

In [3]:
housing = fetch_california_housing()
train_dataset_df, valid_dataset_df = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)

In [4]:
model = RegressionMLP(train_dataset_df.shape[1] - 1, [32, 16, 4, 2], 1, nnx.Rngs(0))
train_loader = MiniBatchLoader(dataset_df=train_dataset_df, batch_size=512, rngs=nnx.Rngs(0))
valid_evaluator = RegressionEvaluator(dataset_df=valid_dataset_df)

In [5]:
# Training

# mlflow.set_tracking_uri(uri="http://localhost:8080")
# mlflow.set_experiment("REGRESSION")

with mlflow.start_run() as run:
    mlflow.log_param("my", "param")

    trainer = Trainer(
        model=model,
        optimizer=optax.adam(learning_rate=0.001),
        train_loader=train_loader,
        loss_fn=mean_squared_error,
        valid_evaluator=valid_evaluator,
        early_stopping_patience=10,
        epoch_num=512,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=69.95780944824219, metrics={'mse': 69.95780944824219}


[TRAIN 001]: 100%|██████████| 33/33 [00:00<00:00, 133.92it/s, batch_loss=6.2312946]


[VALID 001]: loss=6.14644193649292, metrics={'mse': 6.14644193649292}


[TRAIN 002]: 100%|██████████| 33/33 [00:00<00:00, 1187.25it/s, batch_loss=4.9024467]


[VALID 002]: loss=4.766097068786621, metrics={'mse': 4.766097068786621}


[TRAIN 003]: 100%|██████████| 33/33 [00:00<00:00, 1192.32it/s, batch_loss=1.7030346]


[VALID 003]: loss=1.5000752210617065, metrics={'mse': 1.5000752210617065}


[TRAIN 004]: 100%|██████████| 33/33 [00:00<00:00, 1229.08it/s, batch_loss=1.2950621]


[VALID 004]: loss=1.3764324188232422, metrics={'mse': 1.3764324188232422}


[TRAIN 005]: 100%|██████████| 33/33 [00:00<00:00, 1285.89it/s, batch_loss=1.5802534]


[VALID 005]: loss=1.3614575862884521, metrics={'mse': 1.3614575862884521}


[TRAIN 006]: 100%|██████████| 33/33 [00:00<00:00, 1312.31it/s, batch_loss=1.3466582]


[VALID 006]: loss=1.349909782409668, metrics={'mse': 1.349909782409668}


[TRAIN 007]: 100%|██████████| 33/33 [00:00<00:00, 1253.29it/s, batch_loss=1.2202652]


[VALID 007]: loss=1.3333450555801392, metrics={'mse': 1.3333450555801392}


[TRAIN 008]: 100%|██████████| 33/33 [00:00<00:00, 1269.65it/s, batch_loss=1.2851977]


[VALID 008]: loss=1.3204874992370605, metrics={'mse': 1.3204874992370605}


[TRAIN 009]: 100%|██████████| 33/33 [00:00<00:00, 1331.19it/s, batch_loss=1.3062027]


[VALID 009]: loss=1.31205153465271, metrics={'mse': 1.31205153465271}


[TRAIN 010]: 100%|██████████| 33/33 [00:00<00:00, 1227.43it/s, batch_loss=1.3359232]


[VALID 010]: loss=1.3009798526763916, metrics={'mse': 1.3009798526763916}


[TRAIN 011]: 100%|██████████| 33/33 [00:00<00:00, 1248.68it/s, batch_loss=1.2600975]


[VALID 011]: loss=1.2889819145202637, metrics={'mse': 1.2889819145202637}


[TRAIN 012]: 100%|██████████| 33/33 [00:00<00:00, 1268.80it/s, batch_loss=1.255688]


[VALID 012]: loss=1.28468918800354, metrics={'mse': 1.28468918800354}


[TRAIN 013]: 100%|██████████| 33/33 [00:00<00:00, 1283.20it/s, batch_loss=0.9540032]


[VALID 013]: loss=1.2757188081741333, metrics={'mse': 1.2757188081741333}


[TRAIN 014]: 100%|██████████| 33/33 [00:00<00:00, 1298.42it/s, batch_loss=1.3892536]


[VALID 014]: loss=1.2745472192764282, metrics={'mse': 1.2745472192764282}


[TRAIN 015]: 100%|██████████| 33/33 [00:00<00:00, 1278.11it/s, batch_loss=1.3798119]


[VALID 015]: loss=1.2663953304290771, metrics={'mse': 1.2663953304290771}


[TRAIN 016]: 100%|██████████| 33/33 [00:00<00:00, 1301.34it/s, batch_loss=1.3083072]


[VALID 016]: loss=1.2644426822662354, metrics={'mse': 1.2644426822662354}


[TRAIN 017]: 100%|██████████| 33/33 [00:00<00:00, 1264.12it/s, batch_loss=1.7287483]


[VALID 017]: loss=1.2382045984268188, metrics={'mse': 1.2382045984268188}


[TRAIN 018]: 100%|██████████| 33/33 [00:00<00:00, 1229.62it/s, batch_loss=1.0596294]


[VALID 018]: loss=1.235413670539856, metrics={'mse': 1.235413670539856}


[TRAIN 019]: 100%|██████████| 33/33 [00:00<00:00, 1114.87it/s, batch_loss=1.3083704]


[VALID 019]: loss=1.230867862701416, metrics={'mse': 1.230867862701416}


[TRAIN 020]: 100%|██████████| 33/33 [00:00<00:00, 1285.28it/s, batch_loss=1.1092379]


[VALID 020]: loss=1.2115864753723145, metrics={'mse': 1.2115864753723145}


[TRAIN 021]: 100%|██████████| 33/33 [00:00<00:00, 1290.84it/s, batch_loss=1.1625865]


[VALID 021]: loss=1.2026088237762451, metrics={'mse': 1.2026088237762451}


[TRAIN 022]: 100%|██████████| 33/33 [00:00<00:00, 1264.73it/s, batch_loss=1.6750107]


[VALID 022]: loss=1.198635220527649, metrics={'mse': 1.198635220527649}


[TRAIN 023]: 100%|██████████| 33/33 [00:00<00:00, 1339.89it/s, batch_loss=1.2359855]


[VALID 023]: loss=1.206591248512268, metrics={'mse': 1.206591248512268}


[TRAIN 024]: 100%|██████████| 33/33 [00:00<00:00, 1392.03it/s, batch_loss=1.3601649]


[VALID 024]: loss=1.181082844734192, metrics={'mse': 1.181082844734192}


[TRAIN 025]: 100%|██████████| 33/33 [00:00<00:00, 1332.17it/s, batch_loss=1.289645]


[VALID 025]: loss=1.1564401388168335, metrics={'mse': 1.1564401388168335}


[TRAIN 026]: 100%|██████████| 33/33 [00:00<00:00, 1132.48it/s, batch_loss=1.1484969]


[VALID 026]: loss=1.1474781036376953, metrics={'mse': 1.1474781036376953}


[TRAIN 027]: 100%|██████████| 33/33 [00:00<00:00, 1292.92it/s, batch_loss=0.8981936]

[VALID 027]: loss=1.13017737865448, metrics={'mse': 1.13017737865448}



[TRAIN 028]: 100%|██████████| 33/33 [00:00<00:00, 1271.34it/s, batch_loss=1.2062454]

[VALID 028]: loss=1.110532522201538, metrics={'mse': 1.110532522201538}



[TRAIN 029]: 100%|██████████| 33/33 [00:00<00:00, 1296.43it/s, batch_loss=1.1491947]


[VALID 029]: loss=1.1185104846954346, metrics={'mse': 1.1185104846954346}


[TRAIN 030]: 100%|██████████| 33/33 [00:00<00:00, 1350.56it/s, batch_loss=1.326332]


[VALID 030]: loss=1.077314853668213, metrics={'mse': 1.077314853668213}


[TRAIN 031]: 100%|██████████| 33/33 [00:00<00:00, 1290.07it/s, batch_loss=0.84462917]


[VALID 031]: loss=1.0745527744293213, metrics={'mse': 1.0745527744293213}


[TRAIN 032]: 100%|██████████| 33/33 [00:00<00:00, 1257.75it/s, batch_loss=1.1046765]


[VALID 032]: loss=1.0242905616760254, metrics={'mse': 1.0242905616760254}


[TRAIN 033]: 100%|██████████| 33/33 [00:00<00:00, 1205.90it/s, batch_loss=0.9506]


[VALID 033]: loss=0.9979826211929321, metrics={'mse': 0.9979826211929321}


[TRAIN 034]: 100%|██████████| 33/33 [00:00<00:00, 1277.05it/s, batch_loss=1.0952853]


[VALID 034]: loss=0.9620071649551392, metrics={'mse': 0.9620071649551392}


[TRAIN 035]: 100%|██████████| 33/33 [00:00<00:00, 1277.94it/s, batch_loss=0.89672625]


[VALID 035]: loss=0.9341732263565063, metrics={'mse': 0.9341732263565063}


[TRAIN 036]: 100%|██████████| 33/33 [00:00<00:00, 1233.59it/s, batch_loss=1.0559292]


[VALID 036]: loss=0.9058147072792053, metrics={'mse': 0.9058147072792053}


[TRAIN 037]: 100%|██████████| 33/33 [00:00<00:00, 1289.70it/s, batch_loss=0.8707306]


[VALID 037]: loss=0.823212742805481, metrics={'mse': 0.823212742805481}


[TRAIN 038]: 100%|██████████| 33/33 [00:00<00:00, 1276.72it/s, batch_loss=0.8484332]


[VALID 038]: loss=1.0790104866027832, metrics={'mse': 1.0790104866027832}


[TRAIN 039]: 100%|██████████| 33/33 [00:00<00:00, 1366.78it/s, batch_loss=0.86986935]

[VALID 039]: loss=0.741835355758667, metrics={'mse': 0.741835355758667}



[TRAIN 040]: 100%|██████████| 33/33 [00:00<00:00, 1349.36it/s, batch_loss=0.86522675]


[VALID 040]: loss=0.7160562872886658, metrics={'mse': 0.7160562872886658}


[TRAIN 041]: 100%|██████████| 33/33 [00:00<00:00, 1124.42it/s, batch_loss=0.7596467]


[VALID 041]: loss=0.6995048522949219, metrics={'mse': 0.6995048522949219}


[TRAIN 042]: 100%|██████████| 33/33 [00:00<00:00, 1251.25it/s, batch_loss=0.47158062]


[VALID 042]: loss=0.7222909927368164, metrics={'mse': 0.7222909927368164}


[TRAIN 043]: 100%|██████████| 33/33 [00:00<00:00, 1448.49it/s, batch_loss=0.69241726]

[VALID 043]: loss=0.7191174030303955, metrics={'mse': 0.7191174030303955}



[TRAIN 044]: 100%|██████████| 33/33 [00:00<00:00, 1371.64it/s, batch_loss=0.7557936]


[VALID 044]: loss=0.700877845287323, metrics={'mse': 0.700877845287323}


[TRAIN 045]: 100%|██████████| 33/33 [00:00<00:00, 1099.13it/s, batch_loss=0.7585491]


[VALID 045]: loss=0.9418597221374512, metrics={'mse': 0.9418597221374512}


[TRAIN 046]: 100%|██████████| 33/33 [00:00<00:00, 1320.27it/s, batch_loss=0.5145011]


[VALID 046]: loss=0.9173449873924255, metrics={'mse': 0.9173449873924255}


[TRAIN 047]: 100%|██████████| 33/33 [00:00<00:00, 1323.53it/s, batch_loss=0.683028]


[VALID 047]: loss=0.8877587914466858, metrics={'mse': 0.8877587914466858}


[TRAIN 048]: 100%|██████████| 33/33 [00:00<00:00, 1387.81it/s, batch_loss=0.6028898]


[VALID 048]: loss=0.8417936563491821, metrics={'mse': 0.8417936563491821}


[TRAIN 049]: 100%|██████████| 33/33 [00:00<00:00, 1304.03it/s, batch_loss=0.46167946]


[VALID 049]: loss=0.784108579158783, metrics={'mse': 0.784108579158783}


[TRAIN 050]: 100%|██████████| 33/33 [00:00<00:00, 1352.90it/s, batch_loss=0.4787763]


[VALID 050]: loss=0.7463776469230652, metrics={'mse': 0.7463776469230652}


[TRAIN 051]: 100%|██████████| 33/33 [00:00<00:00, 1280.53it/s, batch_loss=0.6046458]


[VALID 051]: loss=0.7534943222999573, metrics={'mse': 0.7534943222999573}


In [6]:
# Inference

test_loader = MiniBatchLoader(dataset_df=valid_dataset_df, batch_size=512, rngs=nnx.Rngs(0))
test_loader.setup_epoch()
for Xs, y in test_loader:
    break
trainer.best_model(*Xs)[:10]

Array([[2.0212548],
       [2.1939273],
       [2.0397067],
       [2.655089 ],
       [1.4569345],
       [1.9170147],
       [1.8301634],
       [4.645258 ],
       [3.3913505],
       [2.4202533]], dtype=float32)