In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import mlflow
import polars as pl
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split
from torch import optim

from nn_trainer.torch.evaluator import MeanSquaredErrorEvaluator
from nn_trainer.torch.loader import MiniBatchLoader
from nn_trainer.torch.loss_fn import mean_squared_error
from nn_trainer.torch.model import RegressionMLP
from nn_trainer.torch.trainer import Trainer

In [3]:
housing = fetch_california_housing()
train_dataset_df, valid_dataset_df = train_test_split(
    pl.from_numpy(data=housing.data, schema=housing.feature_names).with_columns(  # type: ignore
        pl.Series(housing.target).alias(housing.target_names[0])  # type: ignore
    ),
    test_size=0.2,
    random_state=0,
)
train_X_df, train_y_df, valid_X_df, valid_y_df = (
    train_dataset_df.drop(housing.target_names[0]).cast(pl.Float32),  # type: ignore
    train_dataset_df.select(housing.target_names[0]).cast(pl.Float32),  # type: ignore
    valid_dataset_df.drop(housing.target_names[0]).cast(pl.Float32),  # type: ignore
    valid_dataset_df.select(housing.target_names[0]).cast(pl.Float32),  # type: ignore
)

In [4]:
model = RegressionMLP(train_dataset_df.shape[1] - 1, [10, 10], 1)
train_loader = MiniBatchLoader(X_df=train_X_df, y_df=train_y_df, batch_size=512, seed=0)
valid_evaluator = MeanSquaredErrorEvaluator(X_df=valid_X_df, y_df=valid_y_df)

### Training

In [5]:
%%time

mlflow.set_tracking_uri(uri="http://localhost:8080")
mlflow.set_experiment("REGRESSION")

with mlflow.start_run() as run:
    mlflow.log_param("my", "param")

    trainer = Trainer(
        model=model,
        optimizer=optim.Adam(model.parameters(), lr=0.001),
        train_loader=train_loader,
        loss_fn=mean_squared_error,
        valid_evaluator=valid_evaluator,
        early_stopping_patience=10,
        epoch_num=512,
        active_run=run,
    )
    trainer = trainer.fit()

[VALID 000]: loss=129.85862731933594, metrics={'MSE': 129.85862731933594}


[TRAIN 001]: 100%|██████████| 32/32 [00:00<00:00, 530.54it/s, batch_loss=4.86]


[VALID 001]: loss=5.124122619628906, metrics={'MSE': 5.124122619628906}


[TRAIN 002]: 100%|██████████| 32/32 [00:00<00:00, 546.81it/s, batch_loss=2.1]


[VALID 002]: loss=2.525139808654785, metrics={'MSE': 2.525139808654785}


[TRAIN 003]: 100%|██████████| 32/32 [00:00<00:00, 426.16it/s, batch_loss=1.79]


[VALID 003]: loss=2.1007893085479736, metrics={'MSE': 2.1007893085479736}


[TRAIN 004]: 100%|██████████| 32/32 [00:00<00:00, 573.03it/s, batch_loss=1.48]


[VALID 004]: loss=1.6646232604980469, metrics={'MSE': 1.6646232604980469}


[TRAIN 005]: 100%|██████████| 32/32 [00:00<00:00, 552.02it/s, batch_loss=1.32]


[VALID 005]: loss=1.462822675704956, metrics={'MSE': 1.462822675704956}


[TRAIN 006]: 100%|██████████| 32/32 [00:00<00:00, 607.09it/s, batch_loss=1.41]

[VALID 006]: loss=1.346858024597168, metrics={'MSE': 1.346858024597168}



[TRAIN 007]: 100%|██████████| 32/32 [00:00<00:00, 610.47it/s, batch_loss=1.22]


[VALID 007]: loss=1.2897601127624512, metrics={'MSE': 1.2897601127624512}


[TRAIN 008]: 100%|██████████| 32/32 [00:00<00:00, 631.38it/s, batch_loss=1.02]


[VALID 008]: loss=1.2558667659759521, metrics={'MSE': 1.2558667659759521}


[TRAIN 009]: 100%|██████████| 32/32 [00:00<00:00, 554.61it/s, batch_loss=1.12]


[VALID 009]: loss=1.2402173280715942, metrics={'MSE': 1.2402173280715942}


[TRAIN 010]: 100%|██████████| 32/32 [00:00<00:00, 563.95it/s, batch_loss=1.08]


[VALID 010]: loss=1.2338237762451172, metrics={'MSE': 1.2338237762451172}


[TRAIN 011]: 100%|██████████| 32/32 [00:00<00:00, 591.72it/s, batch_loss=1.27]

[VALID 011]: loss=1.2166286706924438, metrics={'MSE': 1.2166286706924438}



[TRAIN 012]: 100%|██████████| 32/32 [00:00<00:00, 563.03it/s, batch_loss=1.17]


[VALID 012]: loss=1.196032166481018, metrics={'MSE': 1.196032166481018}


[TRAIN 013]: 100%|██████████| 32/32 [00:00<00:00, 566.52it/s, batch_loss=1.26]


[VALID 013]: loss=1.1865695714950562, metrics={'MSE': 1.1865695714950562}


[TRAIN 014]: 100%|██████████| 32/32 [00:00<00:00, 605.47it/s, batch_loss=1.16]


[VALID 014]: loss=1.1720062494277954, metrics={'MSE': 1.1720062494277954}


[TRAIN 015]: 100%|██████████| 32/32 [00:00<00:00, 551.76it/s, batch_loss=1.1]


[VALID 015]: loss=1.1634902954101562, metrics={'MSE': 1.1634902954101562}


[TRAIN 016]: 100%|██████████| 32/32 [00:00<00:00, 495.46it/s, batch_loss=1.09]


[VALID 016]: loss=1.151039481163025, metrics={'MSE': 1.151039481163025}


[TRAIN 017]: 100%|██████████| 32/32 [00:00<00:00, 612.54it/s, batch_loss=1.26]


[VALID 017]: loss=1.1400619745254517, metrics={'MSE': 1.1400619745254517}


[TRAIN 018]: 100%|██████████| 32/32 [00:00<00:00, 635.27it/s, batch_loss=1.16]


[VALID 018]: loss=1.130333662033081, metrics={'MSE': 1.130333662033081}


[TRAIN 019]: 100%|██████████| 32/32 [00:00<00:00, 615.70it/s, batch_loss=1.07]


[VALID 019]: loss=1.1225303411483765, metrics={'MSE': 1.1225303411483765}


[TRAIN 020]: 100%|██████████| 32/32 [00:00<00:00, 596.66it/s, batch_loss=1.1]


[VALID 020]: loss=1.1165341138839722, metrics={'MSE': 1.1165341138839722}


[TRAIN 021]: 100%|██████████| 32/32 [00:00<00:00, 608.03it/s, batch_loss=1.02]


[VALID 021]: loss=1.1058661937713623, metrics={'MSE': 1.1058661937713623}


[TRAIN 022]: 100%|██████████| 32/32 [00:00<00:00, 621.28it/s, batch_loss=1.09]


[VALID 022]: loss=1.0929332971572876, metrics={'MSE': 1.0929332971572876}


[TRAIN 023]: 100%|██████████| 32/32 [00:00<00:00, 386.06it/s, batch_loss=1]


[VALID 023]: loss=1.0852105617523193, metrics={'MSE': 1.0852105617523193}


[TRAIN 024]: 100%|██████████| 32/32 [00:00<00:00, 534.78it/s, batch_loss=0.968]


[VALID 024]: loss=1.0739105939865112, metrics={'MSE': 1.0739105939865112}


[TRAIN 025]: 100%|██████████| 32/32 [00:00<00:00, 627.46it/s, batch_loss=0.995]


[VALID 025]: loss=1.0635515451431274, metrics={'MSE': 1.0635515451431274}


[TRAIN 026]: 100%|██████████| 32/32 [00:00<00:00, 570.55it/s, batch_loss=1.01]


[VALID 026]: loss=1.0524530410766602, metrics={'MSE': 1.0524530410766602}


[TRAIN 027]: 100%|██████████| 32/32 [00:00<00:00, 526.98it/s, batch_loss=1.06]

[VALID 027]: loss=1.045375108718872, metrics={'MSE': 1.045375108718872}



[TRAIN 028]: 100%|██████████| 32/32 [00:00<00:00, 611.01it/s, batch_loss=0.97]


[VALID 028]: loss=1.031516194343567, metrics={'MSE': 1.031516194343567}


[TRAIN 029]: 100%|██████████| 32/32 [00:00<00:00, 595.91it/s, batch_loss=0.97]

[VALID 029]: loss=1.0206650495529175, metrics={'MSE': 1.0206650495529175}



[TRAIN 030]: 100%|██████████| 32/32 [00:00<00:00, 583.95it/s, batch_loss=1.01]


[VALID 030]: loss=1.0075750350952148, metrics={'MSE': 1.0075750350952148}


[TRAIN 031]: 100%|██████████| 32/32 [00:00<00:00, 584.29it/s, batch_loss=1.03]


[VALID 031]: loss=0.9997551441192627, metrics={'MSE': 0.9997551441192627}


[TRAIN 032]: 100%|██████████| 32/32 [00:00<00:00, 500.39it/s, batch_loss=0.93]

[VALID 032]: loss=0.9888955950737, metrics={'MSE': 0.9888955950737}



[TRAIN 033]: 100%|██████████| 32/32 [00:00<00:00, 558.64it/s, batch_loss=1.06]


[VALID 033]: loss=0.9740414619445801, metrics={'MSE': 0.9740414619445801}


[TRAIN 034]: 100%|██████████| 32/32 [00:00<00:00, 546.84it/s, batch_loss=1.04]


[VALID 034]: loss=0.9611275792121887, metrics={'MSE': 0.9611275792121887}


[TRAIN 035]: 100%|██████████| 32/32 [00:00<00:00, 490.27it/s, batch_loss=1.05]


[VALID 035]: loss=0.9585281610488892, metrics={'MSE': 0.9585281610488892}


[TRAIN 036]: 100%|██████████| 32/32 [00:00<00:00, 691.54it/s, batch_loss=0.897]


[VALID 036]: loss=0.9982350468635559, metrics={'MSE': 0.9982350468635559}


[TRAIN 037]: 100%|██████████| 32/32 [00:00<00:00, 613.36it/s, batch_loss=0.849]


[VALID 037]: loss=1.0247745513916016, metrics={'MSE': 1.0247745513916016}


[TRAIN 038]: 100%|██████████| 32/32 [00:00<00:00, 643.50it/s, batch_loss=0.734]


[VALID 038]: loss=1.2138173580169678, metrics={'MSE': 1.2138173580169678}


[TRAIN 039]: 100%|██████████| 32/32 [00:00<00:00, 609.52it/s, batch_loss=0.7]


[VALID 039]: loss=1.1335588693618774, metrics={'MSE': 1.1335588693618774}


[TRAIN 040]: 100%|██████████| 32/32 [00:00<00:00, 275.58it/s, batch_loss=0.705]


[VALID 040]: loss=1.0485162734985352, metrics={'MSE': 1.0485162734985352}


[TRAIN 041]: 100%|██████████| 32/32 [00:00<00:00, 499.94it/s, batch_loss=0.683]


[VALID 041]: loss=0.9254560470581055, metrics={'MSE': 0.9254560470581055}


[TRAIN 042]: 100%|██████████| 32/32 [00:00<00:00, 542.80it/s, batch_loss=0.622]


[VALID 042]: loss=0.8767102956771851, metrics={'MSE': 0.8767102956771851}


[TRAIN 043]: 100%|██████████| 32/32 [00:00<00:00, 430.27it/s, batch_loss=0.693]


[VALID 043]: loss=0.8360200524330139, metrics={'MSE': 0.8360200524330139}


[TRAIN 044]: 100%|██████████| 32/32 [00:00<00:00, 364.72it/s, batch_loss=0.566]


[VALID 044]: loss=0.8432755470275879, metrics={'MSE': 0.8432755470275879}


[TRAIN 045]: 100%|██████████| 32/32 [00:00<00:00, 499.26it/s, batch_loss=0.641]


[VALID 045]: loss=0.734619140625, metrics={'MSE': 0.734619140625}


[TRAIN 046]: 100%|██████████| 32/32 [00:00<00:00, 576.31it/s, batch_loss=0.601]


[VALID 046]: loss=0.714637279510498, metrics={'MSE': 0.714637279510498}


[TRAIN 047]: 100%|██████████| 32/32 [00:00<00:00, 495.33it/s, batch_loss=0.506]


[VALID 047]: loss=0.7006469368934631, metrics={'MSE': 0.7006469368934631}


[TRAIN 048]: 100%|██████████| 32/32 [00:00<00:00, 611.07it/s, batch_loss=0.736]


[VALID 048]: loss=0.6885693073272705, metrics={'MSE': 0.6885693073272705}


[TRAIN 049]: 100%|██████████| 32/32 [00:00<00:00, 593.51it/s, batch_loss=0.588]


[VALID 049]: loss=0.6953480839729309, metrics={'MSE': 0.6953480839729309}


[TRAIN 050]: 100%|██████████| 32/32 [00:00<00:00, 642.61it/s, batch_loss=0.576]


[VALID 050]: loss=0.6677286028862, metrics={'MSE': 0.6677286028862}


[TRAIN 051]: 100%|██████████| 32/32 [00:00<00:00, 608.77it/s, batch_loss=0.632]


[VALID 051]: loss=0.6665276288986206, metrics={'MSE': 0.6665276288986206}


[TRAIN 052]: 100%|██████████| 32/32 [00:00<00:00, 615.57it/s, batch_loss=0.696]


[VALID 052]: loss=0.6381672620773315, metrics={'MSE': 0.6381672620773315}


[TRAIN 053]: 100%|██████████| 32/32 [00:00<00:00, 632.20it/s, batch_loss=0.578]


[VALID 053]: loss=0.6273580193519592, metrics={'MSE': 0.6273580193519592}


[TRAIN 054]: 100%|██████████| 32/32 [00:00<00:00, 588.19it/s, batch_loss=0.629]


[VALID 054]: loss=0.6467772722244263, metrics={'MSE': 0.6467772722244263}


[TRAIN 055]: 100%|██████████| 32/32 [00:00<00:00, 560.17it/s, batch_loss=0.606]


[VALID 055]: loss=0.6193878054618835, metrics={'MSE': 0.6193878054618835}


[TRAIN 056]: 100%|██████████| 32/32 [00:00<00:00, 618.70it/s, batch_loss=0.649]


[VALID 056]: loss=0.6247596740722656, metrics={'MSE': 0.6247596740722656}


[TRAIN 057]: 100%|██████████| 32/32 [00:00<00:00, 617.78it/s, batch_loss=0.583]


[VALID 057]: loss=0.6166625618934631, metrics={'MSE': 0.6166625618934631}


[TRAIN 058]: 100%|██████████| 32/32 [00:00<00:00, 675.19it/s, batch_loss=0.568]


[VALID 058]: loss=0.6080231666564941, metrics={'MSE': 0.6080231666564941}


[TRAIN 059]: 100%|██████████| 32/32 [00:00<00:00, 481.49it/s, batch_loss=0.603]


[VALID 059]: loss=0.6428979635238647, metrics={'MSE': 0.6428979635238647}


[TRAIN 060]: 100%|██████████| 32/32 [00:00<00:00, 412.47it/s, batch_loss=0.51]


[VALID 060]: loss=0.6154789924621582, metrics={'MSE': 0.6154789924621582}


[TRAIN 061]: 100%|██████████| 32/32 [00:00<00:00, 491.02it/s, batch_loss=0.627]


[VALID 061]: loss=0.6213372349739075, metrics={'MSE': 0.6213372349739075}


[TRAIN 062]: 100%|██████████| 32/32 [00:00<00:00, 624.21it/s, batch_loss=0.585]


[VALID 062]: loss=0.5988492965698242, metrics={'MSE': 0.5988492965698242}


[TRAIN 063]: 100%|██████████| 32/32 [00:00<00:00, 601.73it/s, batch_loss=0.735]

[VALID 063]: loss=0.632836103439331, metrics={'MSE': 0.632836103439331}



[TRAIN 064]: 100%|██████████| 32/32 [00:00<00:00, 596.94it/s, batch_loss=0.565]


[VALID 064]: loss=0.5982796549797058, metrics={'MSE': 0.5982796549797058}


[TRAIN 065]: 100%|██████████| 32/32 [00:00<00:00, 605.03it/s, batch_loss=0.579]


[VALID 065]: loss=0.5936631560325623, metrics={'MSE': 0.5936631560325623}


[TRAIN 066]: 100%|██████████| 32/32 [00:00<00:00, 582.19it/s, batch_loss=0.608]


[VALID 066]: loss=0.5942617058753967, metrics={'MSE': 0.5942617058753967}


[TRAIN 067]: 100%|██████████| 32/32 [00:00<00:00, 600.49it/s, batch_loss=0.577]


[VALID 067]: loss=0.5928099751472473, metrics={'MSE': 0.5928099751472473}


[TRAIN 068]: 100%|██████████| 32/32 [00:00<00:00, 613.11it/s, batch_loss=0.497]


[VALID 068]: loss=0.5927556753158569, metrics={'MSE': 0.5927556753158569}


[TRAIN 069]: 100%|██████████| 32/32 [00:00<00:00, 569.43it/s, batch_loss=0.617]


[VALID 069]: loss=0.58624666929245, metrics={'MSE': 0.58624666929245}


[TRAIN 070]: 100%|██████████| 32/32 [00:00<00:00, 603.36it/s, batch_loss=0.623]


[VALID 070]: loss=0.5924893617630005, metrics={'MSE': 0.5924893617630005}


[TRAIN 071]: 100%|██████████| 32/32 [00:00<00:00, 342.47it/s, batch_loss=0.622]


[VALID 071]: loss=0.5863922834396362, metrics={'MSE': 0.5863922834396362}


[TRAIN 072]: 100%|██████████| 32/32 [00:00<00:00, 562.65it/s, batch_loss=0.548]


[VALID 072]: loss=0.6032487750053406, metrics={'MSE': 0.6032487750053406}


[TRAIN 073]: 100%|██████████| 32/32 [00:00<00:00, 635.52it/s, batch_loss=0.568]


[VALID 073]: loss=0.5860846042633057, metrics={'MSE': 0.5860846042633057}


[TRAIN 074]: 100%|██████████| 32/32 [00:00<00:00, 387.10it/s, batch_loss=0.542]


[VALID 074]: loss=0.5844674110412598, metrics={'MSE': 0.5844674110412598}


[TRAIN 075]: 100%|██████████| 32/32 [00:00<00:00, 625.37it/s, batch_loss=0.543]


[VALID 075]: loss=0.5975632667541504, metrics={'MSE': 0.5975632667541504}


[TRAIN 076]: 100%|██████████| 32/32 [00:00<00:00, 655.58it/s, batch_loss=0.559]


[VALID 076]: loss=0.5839272737503052, metrics={'MSE': 0.5839272737503052}


[TRAIN 077]: 100%|██████████| 32/32 [00:00<00:00, 627.44it/s, batch_loss=0.514]


[VALID 077]: loss=0.6000392436981201, metrics={'MSE': 0.6000392436981201}


[TRAIN 078]: 100%|██████████| 32/32 [00:00<00:00, 621.44it/s, batch_loss=0.629]


[VALID 078]: loss=0.5907929539680481, metrics={'MSE': 0.5907929539680481}


[TRAIN 079]: 100%|██████████| 32/32 [00:00<00:00, 616.01it/s, batch_loss=0.507]


[VALID 079]: loss=0.5810020565986633, metrics={'MSE': 0.5810020565986633}


[TRAIN 080]: 100%|██████████| 32/32 [00:00<00:00, 599.84it/s, batch_loss=0.562]


[VALID 080]: loss=0.5804796814918518, metrics={'MSE': 0.5804796814918518}


[TRAIN 081]: 100%|██████████| 32/32 [00:00<00:00, 571.44it/s, batch_loss=0.627]


[VALID 081]: loss=0.5851203203201294, metrics={'MSE': 0.5851203203201294}


[TRAIN 082]: 100%|██████████| 32/32 [00:00<00:00, 616.04it/s, batch_loss=0.538]


[VALID 082]: loss=0.5815272331237793, metrics={'MSE': 0.5815272331237793}


[TRAIN 083]: 100%|██████████| 32/32 [00:00<00:00, 580.48it/s, batch_loss=0.573]


[VALID 083]: loss=0.5788565278053284, metrics={'MSE': 0.5788565278053284}


[TRAIN 084]: 100%|██████████| 32/32 [00:00<00:00, 583.05it/s, batch_loss=0.531]


[VALID 084]: loss=0.6020616888999939, metrics={'MSE': 0.6020616888999939}


[TRAIN 085]: 100%|██████████| 32/32 [00:00<00:00, 623.40it/s, batch_loss=0.722]


[VALID 085]: loss=0.5798342227935791, metrics={'MSE': 0.5798342227935791}


[TRAIN 086]: 100%|██████████| 32/32 [00:00<00:00, 622.31it/s, batch_loss=0.555]


[VALID 086]: loss=0.5784165263175964, metrics={'MSE': 0.5784165263175964}


[TRAIN 087]: 100%|██████████| 32/32 [00:00<00:00, 354.90it/s, batch_loss=0.621]


[VALID 087]: loss=0.5847401022911072, metrics={'MSE': 0.5847401022911072}


[TRAIN 088]: 100%|██████████| 32/32 [00:00<00:00, 584.73it/s, batch_loss=0.607]


[VALID 088]: loss=0.5777101516723633, metrics={'MSE': 0.5777101516723633}


[TRAIN 089]: 100%|██████████| 32/32 [00:00<00:00, 619.30it/s, batch_loss=0.544]


[VALID 089]: loss=0.5795391201972961, metrics={'MSE': 0.5795391201972961}


[TRAIN 090]: 100%|██████████| 32/32 [00:00<00:00, 627.26it/s, batch_loss=0.529]


[VALID 090]: loss=0.5914598107337952, metrics={'MSE': 0.5914598107337952}


[TRAIN 091]: 100%|██████████| 32/32 [00:00<00:00, 667.99it/s, batch_loss=0.544]


[VALID 091]: loss=0.5793656706809998, metrics={'MSE': 0.5793656706809998}


[TRAIN 092]: 100%|██████████| 32/32 [00:00<00:00, 564.12it/s, batch_loss=0.535]


[VALID 092]: loss=0.5788114666938782, metrics={'MSE': 0.5788114666938782}


[TRAIN 093]: 100%|██████████| 32/32 [00:00<00:00, 624.32it/s, batch_loss=0.543]


[VALID 093]: loss=0.5773330330848694, metrics={'MSE': 0.5773330330848694}


[TRAIN 094]: 100%|██████████| 32/32 [00:00<00:00, 600.46it/s, batch_loss=0.507]


[VALID 094]: loss=0.6135315299034119, metrics={'MSE': 0.6135315299034119}


[TRAIN 095]: 100%|██████████| 32/32 [00:00<00:00, 674.06it/s, batch_loss=0.515]


[VALID 095]: loss=0.600973904132843, metrics={'MSE': 0.600973904132843}


[TRAIN 096]: 100%|██████████| 32/32 [00:00<00:00, 641.41it/s, batch_loss=0.587]


[VALID 096]: loss=0.5985718965530396, metrics={'MSE': 0.5985718965530396}


[TRAIN 097]: 100%|██████████| 32/32 [00:00<00:00, 562.40it/s, batch_loss=0.535]


[VALID 097]: loss=0.5769162774085999, metrics={'MSE': 0.5769162774085999}


[TRAIN 098]: 100%|██████████| 32/32 [00:00<00:00, 612.41it/s, batch_loss=0.676]


[VALID 098]: loss=0.5797731280326843, metrics={'MSE': 0.5797731280326843}


[TRAIN 099]: 100%|██████████| 32/32 [00:00<00:00, 661.17it/s, batch_loss=0.581]


[VALID 099]: loss=0.5729255676269531, metrics={'MSE': 0.5729255676269531}


[TRAIN 100]: 100%|██████████| 32/32 [00:00<00:00, 603.67it/s, batch_loss=0.595]


[VALID 100]: loss=0.5795766115188599, metrics={'MSE': 0.5795766115188599}


[TRAIN 101]: 100%|██████████| 32/32 [00:00<00:00, 535.92it/s, batch_loss=0.467]


[VALID 101]: loss=0.5736844539642334, metrics={'MSE': 0.5736844539642334}


[TRAIN 102]: 100%|██████████| 32/32 [00:00<00:00, 339.39it/s, batch_loss=0.581]


[VALID 102]: loss=0.5762389302253723, metrics={'MSE': 0.5762389302253723}


[TRAIN 103]: 100%|██████████| 32/32 [00:00<00:00, 563.87it/s, batch_loss=0.522]


[VALID 103]: loss=0.5740442872047424, metrics={'MSE': 0.5740442872047424}


[TRAIN 104]: 100%|██████████| 32/32 [00:00<00:00, 622.87it/s, batch_loss=0.569]


[VALID 104]: loss=0.5758420825004578, metrics={'MSE': 0.5758420825004578}


[TRAIN 105]: 100%|██████████| 32/32 [00:00<00:00, 646.99it/s, batch_loss=0.52]


[VALID 105]: loss=0.609481930732727, metrics={'MSE': 0.609481930732727}


[TRAIN 106]: 100%|██████████| 32/32 [00:00<00:00, 594.96it/s, batch_loss=0.518]


[VALID 106]: loss=0.5769953727722168, metrics={'MSE': 0.5769953727722168}


[TRAIN 107]: 100%|██████████| 32/32 [00:00<00:00, 612.65it/s, batch_loss=0.59]


[VALID 107]: loss=0.575471818447113, metrics={'MSE': 0.575471818447113}


[TRAIN 108]: 100%|██████████| 32/32 [00:00<00:00, 609.86it/s, batch_loss=0.563]


[VALID 108]: loss=0.5723907351493835, metrics={'MSE': 0.5723907351493835}


[TRAIN 109]: 100%|██████████| 32/32 [00:00<00:00, 614.14it/s, batch_loss=0.642]


[VALID 109]: loss=0.5749688148498535, metrics={'MSE': 0.5749688148498535}


[TRAIN 110]: 100%|██████████| 32/32 [00:00<00:00, 629.52it/s, batch_loss=0.561]


[VALID 110]: loss=0.5723550915718079, metrics={'MSE': 0.5723550915718079}


[TRAIN 111]: 100%|██████████| 32/32 [00:00<00:00, 585.50it/s, batch_loss=0.5]


[VALID 111]: loss=0.5837398767471313, metrics={'MSE': 0.5837398767471313}


[TRAIN 112]: 100%|██████████| 32/32 [00:00<00:00, 646.65it/s, batch_loss=0.605]


[VALID 112]: loss=0.5722090005874634, metrics={'MSE': 0.5722090005874634}


[TRAIN 113]: 100%|██████████| 32/32 [00:00<00:00, 588.56it/s, batch_loss=0.674]


[VALID 113]: loss=0.6095923185348511, metrics={'MSE': 0.6095923185348511}


[TRAIN 114]: 100%|██████████| 32/32 [00:00<00:00, 611.03it/s, batch_loss=0.543]


[VALID 114]: loss=0.5703207850456238, metrics={'MSE': 0.5703207850456238}


[TRAIN 115]: 100%|██████████| 32/32 [00:00<00:00, 579.68it/s, batch_loss=0.56]


[VALID 115]: loss=0.5699183940887451, metrics={'MSE': 0.5699183940887451}


[TRAIN 116]: 100%|██████████| 32/32 [00:00<00:00, 616.24it/s, batch_loss=0.643]


[VALID 116]: loss=0.5725022554397583, metrics={'MSE': 0.5725022554397583}


[TRAIN 117]: 100%|██████████| 32/32 [00:00<00:00, 355.33it/s, batch_loss=0.515]


[VALID 117]: loss=0.6172506809234619, metrics={'MSE': 0.6172506809234619}


[TRAIN 118]: 100%|██████████| 32/32 [00:00<00:00, 561.29it/s, batch_loss=0.646]


[VALID 118]: loss=0.5678952932357788, metrics={'MSE': 0.5678952932357788}


[TRAIN 119]: 100%|██████████| 32/32 [00:00<00:00, 617.67it/s, batch_loss=0.478]


[VALID 119]: loss=0.6343254446983337, metrics={'MSE': 0.6343254446983337}


[TRAIN 120]: 100%|██████████| 32/32 [00:00<00:00, 581.09it/s, batch_loss=0.646]


[VALID 120]: loss=0.5696931481361389, metrics={'MSE': 0.5696931481361389}


[TRAIN 121]: 100%|██████████| 32/32 [00:00<00:00, 630.30it/s, batch_loss=0.517]


[VALID 121]: loss=0.578473687171936, metrics={'MSE': 0.578473687171936}


[TRAIN 122]: 100%|██████████| 32/32 [00:00<00:00, 639.49it/s, batch_loss=0.55]


[VALID 122]: loss=0.5672616958618164, metrics={'MSE': 0.5672616958618164}


[TRAIN 123]: 100%|██████████| 32/32 [00:00<00:00, 547.39it/s, batch_loss=0.633]


[VALID 123]: loss=0.5670028924942017, metrics={'MSE': 0.5670028924942017}


[TRAIN 124]: 100%|██████████| 32/32 [00:00<00:00, 586.57it/s, batch_loss=0.635]


[VALID 124]: loss=0.5779176354408264, metrics={'MSE': 0.5779176354408264}


[TRAIN 125]: 100%|██████████| 32/32 [00:00<00:00, 630.36it/s, batch_loss=0.583]


[VALID 125]: loss=0.5729992389678955, metrics={'MSE': 0.5729992389678955}


[TRAIN 126]: 100%|██████████| 32/32 [00:00<00:00, 591.23it/s, batch_loss=0.596]


[VALID 126]: loss=0.564089834690094, metrics={'MSE': 0.564089834690094}


[TRAIN 127]: 100%|██████████| 32/32 [00:00<00:00, 585.69it/s, batch_loss=0.583]


[VALID 127]: loss=0.6025799512863159, metrics={'MSE': 0.6025799512863159}


[TRAIN 128]: 100%|██████████| 32/32 [00:00<00:00, 594.26it/s, batch_loss=0.513]


[VALID 128]: loss=0.5664742588996887, metrics={'MSE': 0.5664742588996887}


[TRAIN 129]: 100%|██████████| 32/32 [00:00<00:00, 573.93it/s, batch_loss=0.672]


[VALID 129]: loss=0.5706853270530701, metrics={'MSE': 0.5706853270530701}


[TRAIN 130]: 100%|██████████| 32/32 [00:00<00:00, 620.91it/s, batch_loss=0.628]


[VALID 130]: loss=0.6043971180915833, metrics={'MSE': 0.6043971180915833}


[TRAIN 131]: 100%|██████████| 32/32 [00:00<00:00, 278.74it/s, batch_loss=0.495]


[VALID 131]: loss=0.6030324101448059, metrics={'MSE': 0.6030324101448059}


[TRAIN 132]: 100%|██████████| 32/32 [00:00<00:00, 585.45it/s, batch_loss=0.496]


[VALID 132]: loss=0.598459005355835, metrics={'MSE': 0.598459005355835}


[TRAIN 133]: 100%|██████████| 32/32 [00:00<00:00, 589.75it/s, batch_loss=0.586]


[VALID 133]: loss=0.5661537051200867, metrics={'MSE': 0.5661537051200867}


[TRAIN 134]: 100%|██████████| 32/32 [00:00<00:00, 630.88it/s, batch_loss=0.552]


[VALID 134]: loss=0.5835055708885193, metrics={'MSE': 0.5835055708885193}


[TRAIN 135]: 100%|██████████| 32/32 [00:00<00:00, 623.87it/s, batch_loss=0.623]


[VALID 135]: loss=0.6151082515716553, metrics={'MSE': 0.6151082515716553}


[TRAIN 136]: 100%|██████████| 32/32 [00:00<00:00, 627.13it/s, batch_loss=0.555]


[VALID 136]: loss=0.5626857876777649, metrics={'MSE': 0.5626857876777649}


[TRAIN 137]: 100%|██████████| 32/32 [00:00<00:00, 596.13it/s, batch_loss=0.517]


[VALID 137]: loss=0.5643937587738037, metrics={'MSE': 0.5643937587738037}


[TRAIN 138]: 100%|██████████| 32/32 [00:00<00:00, 601.98it/s, batch_loss=0.602]


[VALID 138]: loss=0.5825226902961731, metrics={'MSE': 0.5825226902961731}


[TRAIN 139]: 100%|██████████| 32/32 [00:00<00:00, 648.34it/s, batch_loss=0.588]


[VALID 139]: loss=0.5700984001159668, metrics={'MSE': 0.5700984001159668}


[TRAIN 140]: 100%|██████████| 32/32 [00:00<00:00, 661.88it/s, batch_loss=0.52]


[VALID 140]: loss=0.5667704343795776, metrics={'MSE': 0.5667704343795776}


[TRAIN 141]: 100%|██████████| 32/32 [00:00<00:00, 606.89it/s, batch_loss=0.571]


[VALID 141]: loss=0.5608839988708496, metrics={'MSE': 0.5608839988708496}


[TRAIN 142]: 100%|██████████| 32/32 [00:00<00:00, 614.00it/s, batch_loss=0.6]


[VALID 142]: loss=0.5754746198654175, metrics={'MSE': 0.5754746198654175}


[TRAIN 143]: 100%|██████████| 32/32 [00:00<00:00, 571.54it/s, batch_loss=0.564]


[VALID 143]: loss=0.5586804151535034, metrics={'MSE': 0.5586804151535034}


[TRAIN 144]: 100%|██████████| 32/32 [00:00<00:00, 605.20it/s, batch_loss=0.523]


[VALID 144]: loss=0.5871487259864807, metrics={'MSE': 0.5871487259864807}


[TRAIN 145]: 100%|██████████| 32/32 [00:00<00:00, 636.38it/s, batch_loss=0.624]


[VALID 145]: loss=0.5620145797729492, metrics={'MSE': 0.5620145797729492}


[TRAIN 146]: 100%|██████████| 32/32 [00:00<00:00, 372.94it/s, batch_loss=0.619]


[VALID 146]: loss=0.5711836218833923, metrics={'MSE': 0.5711836218833923}


[TRAIN 147]: 100%|██████████| 32/32 [00:00<00:00, 660.83it/s, batch_loss=0.513]


[VALID 147]: loss=0.5619962811470032, metrics={'MSE': 0.5619962811470032}


[TRAIN 148]: 100%|██████████| 32/32 [00:00<00:00, 584.80it/s, batch_loss=0.557]


[VALID 148]: loss=0.5752579569816589, metrics={'MSE': 0.5752579569816589}


[TRAIN 149]: 100%|██████████| 32/32 [00:00<00:00, 671.93it/s, batch_loss=0.559]


[VALID 149]: loss=0.5583211779594421, metrics={'MSE': 0.5583211779594421}


[TRAIN 150]: 100%|██████████| 32/32 [00:00<00:00, 626.76it/s, batch_loss=0.529]


[VALID 150]: loss=0.5770490765571594, metrics={'MSE': 0.5770490765571594}


[TRAIN 151]: 100%|██████████| 32/32 [00:00<00:00, 613.01it/s, batch_loss=0.532]


[VALID 151]: loss=0.5597124099731445, metrics={'MSE': 0.5597124099731445}


[TRAIN 152]: 100%|██████████| 32/32 [00:00<00:00, 611.62it/s, batch_loss=0.599]


[VALID 152]: loss=0.5622738599777222, metrics={'MSE': 0.5622738599777222}


[TRAIN 153]: 100%|██████████| 32/32 [00:00<00:00, 631.69it/s, batch_loss=0.547]


[VALID 153]: loss=0.5571503639221191, metrics={'MSE': 0.5571503639221191}


[TRAIN 154]: 100%|██████████| 32/32 [00:00<00:00, 604.53it/s, batch_loss=0.566]


[VALID 154]: loss=0.565450131893158, metrics={'MSE': 0.565450131893158}


[TRAIN 155]: 100%|██████████| 32/32 [00:00<00:00, 612.35it/s, batch_loss=0.666]


[VALID 155]: loss=0.5702102184295654, metrics={'MSE': 0.5702102184295654}


[TRAIN 156]: 100%|██████████| 32/32 [00:00<00:00, 635.10it/s, batch_loss=0.623]


[VALID 156]: loss=0.5748679637908936, metrics={'MSE': 0.5748679637908936}


[TRAIN 157]: 100%|██████████| 32/32 [00:00<00:00, 646.17it/s, batch_loss=0.456]


[VALID 157]: loss=0.6073436141014099, metrics={'MSE': 0.6073436141014099}


[TRAIN 158]: 100%|██████████| 32/32 [00:00<00:00, 580.58it/s, batch_loss=0.643]


[VALID 158]: loss=0.5961087942123413, metrics={'MSE': 0.5961087942123413}


[TRAIN 159]: 100%|██████████| 32/32 [00:00<00:00, 658.33it/s, batch_loss=0.543]


[VALID 159]: loss=0.5579560995101929, metrics={'MSE': 0.5579560995101929}


[TRAIN 160]: 100%|██████████| 32/32 [00:00<00:00, 361.79it/s, batch_loss=0.448]


[VALID 160]: loss=0.5576397776603699, metrics={'MSE': 0.5576397776603699}


[TRAIN 161]: 100%|██████████| 32/32 [00:00<00:00, 621.88it/s, batch_loss=0.519]


[VALID 161]: loss=0.5619193315505981, metrics={'MSE': 0.5619193315505981}


[TRAIN 162]: 100%|██████████| 32/32 [00:00<00:00, 565.33it/s, batch_loss=0.593]


[VALID 162]: loss=0.572329044342041, metrics={'MSE': 0.572329044342041}


[TRAIN 163]: 100%|██████████| 32/32 [00:00<00:00, 599.76it/s, batch_loss=0.573]


[VALID 163]: loss=0.5595735311508179, metrics={'MSE': 0.5595735311508179}
🏃 View run delicate-shark-591 at: http://localhost:8080/#/experiments/663538387065204330/runs/a22eab205e624575bda6f31b4ad71e79
🧪 View experiment at: http://localhost:8080/#/experiments/663538387065204330
CPU times: user 27.3 s, sys: 2.24 s, total: 29.5 s
Wall time: 13.8 s


### Inference

In [None]:
valid_loader = MiniBatchLoader(X_df=train_X_df, y_df=train_y_df, batch_size=512, seed=0)
valid_loader.setup_epoch()
for Xs, y in valid_loader:
    break
trainer.best_model(*Xs)[:10]

tensor([[2.0524],
        [2.1204],
        [1.2183],
        [2.6473],
        [2.3307],
        [2.4153],
        [1.8353],
        [1.3988],
        [3.3848],
        [1.8919]], grad_fn=<SliceBackward0>)