In [1]:
import torch
import torch.optim
import numpy as np

from nni.nas.space import model_context
from nni.nas.strategy import DARTS as DartsStrategy
from nni.nas.experiment import NasExperiment
from dataset.regression import fetch_data
from evaluators.regression import RegressionEvaluator
from models.mlp import MLP

In [2]:
np.random.seed(0)
torch.random.manual_seed(0)

<torch._C.Generator at 0x7ddcd81531d0>

# Fetch dataset loaders

In [3]:
task_config, loaders = fetch_data(batch_size=256, num_workers=4)

In [4]:
for split_name, loader in loaders.items():
    print(split_name, 'dataset size:', len(loader.dataset))

train dataset size: 13209
val dataset size: 3303
test dataset size: 4128


# Training

In [5]:
evaluator = RegressionEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    max_epochs=100,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
strategy = DartsStrategy()
model_space = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)
experiment = NasExperiment(model_space, evaluator, strategy)
experiment.run()

[2024-05-14 01:46:00] [32mConfig is not provided. Will try to infer.[0m
[2024-05-14 01:46:00] [32mStrategy is found to be a one-shot strategy. Setting execution engine to "sequential" and format to "raw".[0m
[2024-05-14 01:46:01] [32mCheckpoint saved to /home/sisha/nni-experiments/9mf8us0g/checkpoint.[0m
[2024-05-14 01:46:01] [32mExperiment initialized successfully. Starting exploration strategy...[0m


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type             | Params
-----------------------------------------------------
0 | training_module | RegressionModule | 2.6 M 
-----------------------------------------------------
2.6 M     Trainable params
0         Non-trainable params
2.6 M     Total params
10.527    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


[2024-05-14 01:50:10] [32mWaiting for models submitted to engine to finish...[0m
[2024-05-14 01:50:10] [32mExperiment is completed.[0m


/home/sisha/miniconda3/envs/test_env/lib/python3.12/site-packages/nni/nas/evaluator/pytorch/lightning.py:311: Multiple metrics without "default" is not supported by current framework.


True

# Train final model

In [7]:
exported_arch = experiment.export_top_models(formatter='dict')[0]
print(exported_arch)

{'MLP/d_block': 32, 'MLP/in_act': 1, 'MLP/n_blocks': 2, 'MLP/blocks_act': 0}


In [8]:
with model_context(exported_arch):
    final_model = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)

In [9]:
evaluator = RegressionEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    max_epochs=150,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
evaluator.fit(final_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type       | Params
-----------------------------------------
0 | criterion | MSELoss    | 0     
1 | metrics   | ModuleDict | 0     
2 | _model    | MLP        | 2.4 K 
-----------------------------------------
2.4 K     Trainable params
0         Non-trainable params
2.4 K     Total params
0.010     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:04] [32mIntermediate result: {'mse': 0.9081547856330872, 'rmse': 0.951565682888031, 'default': 0.9081547856330872}  (Index 0)[0m


/home/sisha/miniconda3/envs/test_env/lib/python3.12/site-packages/nni/nas/evaluator/pytorch/lightning.py:311: Multiple metrics without "default" is not supported by current framework.


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:05] [32mIntermediate result: {'mse': 0.6922996044158936, 'rmse': 0.8305162191390991, 'default': 0.6922996044158936}  (Index 1)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:05] [32mIntermediate result: {'mse': 0.41494056582450867, 'rmse': 0.6422576308250427, 'default': 0.41494056582450867}  (Index 2)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:05] [32mIntermediate result: {'mse': 0.3316473960876465, 'rmse': 0.5735471844673157, 'default': 0.3316473960876465}  (Index 3)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:06] [32mIntermediate result: {'mse': 0.3063434064388275, 'rmse': 0.5510765910148621, 'default': 0.3063434064388275}  (Index 4)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:06] [32mIntermediate result: {'mse': 0.2948641777038574, 'rmse': 0.5404336452484131, 'default': 0.2948641777038574}  (Index 5)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:07] [32mIntermediate result: {'mse': 0.28902557492256165, 'rmse': 0.5348267555236816, 'default': 0.28902557492256165}  (Index 6)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:07] [32mIntermediate result: {'mse': 0.2846122980117798, 'rmse': 0.5306180715560913, 'default': 0.2846122980117798}  (Index 7)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:08] [32mIntermediate result: {'mse': 0.28214046359062195, 'rmse': 0.5283733606338501, 'default': 0.28214046359062195}  (Index 8)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:08] [32mIntermediate result: {'mse': 0.27961304783821106, 'rmse': 0.5259442925453186, 'default': 0.27961304783821106}  (Index 9)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:09] [32mIntermediate result: {'mse': 0.27673083543777466, 'rmse': 0.5230873227119446, 'default': 0.27673083543777466}  (Index 10)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:09] [32mIntermediate result: {'mse': 0.2757616341114044, 'rmse': 0.5223219394683838, 'default': 0.2757616341114044}  (Index 11)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:10] [32mIntermediate result: {'mse': 0.2726818025112152, 'rmse': 0.5192016959190369, 'default': 0.2726818025112152}  (Index 12)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:10] [32mIntermediate result: {'mse': 0.2704954445362091, 'rmse': 0.5171775221824646, 'default': 0.2704954445362091}  (Index 13)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:11] [32mIntermediate result: {'mse': 0.26926666498184204, 'rmse': 0.5159916877746582, 'default': 0.26926666498184204}  (Index 14)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:11] [32mIntermediate result: {'mse': 0.26767438650131226, 'rmse': 0.5144939422607422, 'default': 0.26767438650131226}  (Index 15)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:11] [32mIntermediate result: {'mse': 0.26601290702819824, 'rmse': 0.512839138507843, 'default': 0.26601290702819824}  (Index 16)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:12] [32mIntermediate result: {'mse': 0.26439422369003296, 'rmse': 0.5112988948822021, 'default': 0.26439422369003296}  (Index 17)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:12] [32mIntermediate result: {'mse': 0.26338502764701843, 'rmse': 0.5103438496589661, 'default': 0.26338502764701843}  (Index 18)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:13] [32mIntermediate result: {'mse': 0.2618167996406555, 'rmse': 0.5088082551956177, 'default': 0.2618167996406555}  (Index 19)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:13] [32mIntermediate result: {'mse': 0.2612992823123932, 'rmse': 0.5082762837409973, 'default': 0.2612992823123932}  (Index 20)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:14] [32mIntermediate result: {'mse': 0.25921356678009033, 'rmse': 0.5061895251274109, 'default': 0.25921356678009033}  (Index 21)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:14] [32mIntermediate result: {'mse': 0.2583608329296112, 'rmse': 0.5054470300674438, 'default': 0.2583608329296112}  (Index 22)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:15] [32mIntermediate result: {'mse': 0.2577168643474579, 'rmse': 0.5048337578773499, 'default': 0.2577168643474579}  (Index 23)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:15] [32mIntermediate result: {'mse': 0.25611159205436707, 'rmse': 0.5032243728637695, 'default': 0.25611159205436707}  (Index 24)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:16] [32mIntermediate result: {'mse': 0.25478070974349976, 'rmse': 0.5018427968025208, 'default': 0.25478070974349976}  (Index 25)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:16] [32mIntermediate result: {'mse': 0.25425985455513, 'rmse': 0.501356840133667, 'default': 0.25425985455513}  (Index 26)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:17] [32mIntermediate result: {'mse': 0.2535795569419861, 'rmse': 0.5007583498954773, 'default': 0.2535795569419861}  (Index 27)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:17] [32mIntermediate result: {'mse': 0.25199514627456665, 'rmse': 0.49911147356033325, 'default': 0.25199514627456665}  (Index 28)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:18] [32mIntermediate result: {'mse': 0.2513000965118408, 'rmse': 0.4984060227870941, 'default': 0.2513000965118408}  (Index 29)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:18] [32mIntermediate result: {'mse': 0.2513766288757324, 'rmse': 0.4985346496105194, 'default': 0.2513766288757324}  (Index 30)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:19] [32mIntermediate result: {'mse': 0.2503277063369751, 'rmse': 0.49755337834358215, 'default': 0.2503277063369751}  (Index 31)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:19] [32mIntermediate result: {'mse': 0.24974460899829865, 'rmse': 0.49681389331817627, 'default': 0.24974460899829865}  (Index 32)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:19] [32mIntermediate result: {'mse': 0.24868114292621613, 'rmse': 0.49584487080574036, 'default': 0.24868114292621613}  (Index 33)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:20] [32mIntermediate result: {'mse': 0.24758578836917877, 'rmse': 0.4946666359901428, 'default': 0.24758578836917877}  (Index 34)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:20] [32mIntermediate result: {'mse': 0.24725185334682465, 'rmse': 0.4943418502807617, 'default': 0.24725185334682465}  (Index 35)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:21] [32mIntermediate result: {'mse': 0.2463676780462265, 'rmse': 0.49343758821487427, 'default': 0.2463676780462265}  (Index 36)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:21] [32mIntermediate result: {'mse': 0.24531827867031097, 'rmse': 0.49239838123321533, 'default': 0.24531827867031097}  (Index 37)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:22] [32mIntermediate result: {'mse': 0.24592503905296326, 'rmse': 0.4929882884025574, 'default': 0.24592503905296326}  (Index 38)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:22] [32mIntermediate result: {'mse': 0.24456946551799774, 'rmse': 0.49164846539497375, 'default': 0.24456946551799774}  (Index 39)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:23] [32mIntermediate result: {'mse': 0.2444908320903778, 'rmse': 0.4915936589241028, 'default': 0.2444908320903778}  (Index 40)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:23] [32mIntermediate result: {'mse': 0.24400071799755096, 'rmse': 0.49115705490112305, 'default': 0.24400071799755096}  (Index 41)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:24] [32mIntermediate result: {'mse': 0.24342605471611023, 'rmse': 0.49059581756591797, 'default': 0.24342605471611023}  (Index 42)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:24] [32mIntermediate result: {'mse': 0.24299456179141998, 'rmse': 0.4900326728820801, 'default': 0.24299456179141998}  (Index 43)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:25] [32mIntermediate result: {'mse': 0.24267065525054932, 'rmse': 0.4898702800273895, 'default': 0.24267065525054932}  (Index 44)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:25] [32mIntermediate result: {'mse': 0.24109923839569092, 'rmse': 0.4882029891014099, 'default': 0.24109923839569092}  (Index 45)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:26] [32mIntermediate result: {'mse': 0.2413003146648407, 'rmse': 0.4883432686328888, 'default': 0.2413003146648407}  (Index 46)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:26] [32mIntermediate result: {'mse': 0.240660160779953, 'rmse': 0.48774823546409607, 'default': 0.240660160779953}  (Index 47)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:27] [32mIntermediate result: {'mse': 0.24014754593372345, 'rmse': 0.4872003197669983, 'default': 0.24014754593372345}  (Index 48)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:27] [32mIntermediate result: {'mse': 0.24079853296279907, 'rmse': 0.4878478944301605, 'default': 0.24079853296279907}  (Index 49)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:28] [32mIntermediate result: {'mse': 0.23956871032714844, 'rmse': 0.48651790618896484, 'default': 0.23956871032714844}  (Index 50)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:28] [32mIntermediate result: {'mse': 0.23907454311847687, 'rmse': 0.4861368238925934, 'default': 0.23907454311847687}  (Index 51)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:29] [32mIntermediate result: {'mse': 0.23904894292354584, 'rmse': 0.4861510396003723, 'default': 0.23904894292354584}  (Index 52)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:29] [32mIntermediate result: {'mse': 0.23855234682559967, 'rmse': 0.4855602979660034, 'default': 0.23855234682559967}  (Index 53)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:30] [32mIntermediate result: {'mse': 0.23780791461467743, 'rmse': 0.4848633408546448, 'default': 0.23780791461467743}  (Index 54)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:30] [32mIntermediate result: {'mse': 0.23760639131069183, 'rmse': 0.4847193956375122, 'default': 0.23760639131069183}  (Index 55)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:31] [32mIntermediate result: {'mse': 0.23728549480438232, 'rmse': 0.484283447265625, 'default': 0.23728549480438232}  (Index 56)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:31] [32mIntermediate result: {'mse': 0.23685266077518463, 'rmse': 0.48389220237731934, 'default': 0.23685266077518463}  (Index 57)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:32] [32mIntermediate result: {'mse': 0.23591220378875732, 'rmse': 0.4828217923641205, 'default': 0.23591220378875732}  (Index 58)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:32] [32mIntermediate result: {'mse': 0.23622599244117737, 'rmse': 0.48326990008354187, 'default': 0.23622599244117737}  (Index 59)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:33] [32mIntermediate result: {'mse': 0.23510517179965973, 'rmse': 0.48204952478408813, 'default': 0.23510517179965973}  (Index 60)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:33] [32mIntermediate result: {'mse': 0.23525649309158325, 'rmse': 0.4821036458015442, 'default': 0.23525649309158325}  (Index 61)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:34] [32mIntermediate result: {'mse': 0.2350260615348816, 'rmse': 0.48189666867256165, 'default': 0.2350260615348816}  (Index 62)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:34] [32mIntermediate result: {'mse': 0.23395009338855743, 'rmse': 0.48080822825431824, 'default': 0.23395009338855743}  (Index 63)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:35] [32mIntermediate result: {'mse': 0.23421259224414825, 'rmse': 0.4811541736125946, 'default': 0.23421259224414825}  (Index 64)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:35] [32mIntermediate result: {'mse': 0.2340325564146042, 'rmse': 0.48096171021461487, 'default': 0.2340325564146042}  (Index 65)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:36] [32mIntermediate result: {'mse': 0.2334955781698227, 'rmse': 0.48031601309776306, 'default': 0.2334955781698227}  (Index 66)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:36] [32mIntermediate result: {'mse': 0.23414911329746246, 'rmse': 0.48106849193573, 'default': 0.23414911329746246}  (Index 67)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:36] [32mIntermediate result: {'mse': 0.23326082527637482, 'rmse': 0.48016902804374695, 'default': 0.23326082527637482}  (Index 68)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:37] [32mIntermediate result: {'mse': 0.23283948004245758, 'rmse': 0.479686975479126, 'default': 0.23283948004245758}  (Index 69)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:37] [32mIntermediate result: {'mse': 0.23256756365299225, 'rmse': 0.47936493158340454, 'default': 0.23256756365299225}  (Index 70)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:38] [32mIntermediate result: {'mse': 0.2320507913827896, 'rmse': 0.478931725025177, 'default': 0.2320507913827896}  (Index 71)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:38] [32mIntermediate result: {'mse': 0.23130880296230316, 'rmse': 0.47817787528038025, 'default': 0.23130880296230316}  (Index 72)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:39] [32mIntermediate result: {'mse': 0.23189155757427216, 'rmse': 0.4788006842136383, 'default': 0.23189155757427216}  (Index 73)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:39] [32mIntermediate result: {'mse': 0.23117069900035858, 'rmse': 0.4780157506465912, 'default': 0.23117069900035858}  (Index 74)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:40] [32mIntermediate result: {'mse': 0.2308538258075714, 'rmse': 0.47777438163757324, 'default': 0.2308538258075714}  (Index 75)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:40] [32mIntermediate result: {'mse': 0.2302437275648117, 'rmse': 0.4770318269729614, 'default': 0.2302437275648117}  (Index 76)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:41] [32mIntermediate result: {'mse': 0.2295316755771637, 'rmse': 0.476388543844223, 'default': 0.2295316755771637}  (Index 77)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:41] [32mIntermediate result: {'mse': 0.22980760037899017, 'rmse': 0.47658753395080566, 'default': 0.22980760037899017}  (Index 78)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:42] [32mIntermediate result: {'mse': 0.2294721156358719, 'rmse': 0.47622016072273254, 'default': 0.2294721156358719}  (Index 79)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:42] [32mIntermediate result: {'mse': 0.22917309403419495, 'rmse': 0.4758780896663666, 'default': 0.22917309403419495}  (Index 80)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:43] [32mIntermediate result: {'mse': 0.22984448075294495, 'rmse': 0.47655433416366577, 'default': 0.22984448075294495}  (Index 81)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:43] [32mIntermediate result: {'mse': 0.22897514700889587, 'rmse': 0.47577399015426636, 'default': 0.22897514700889587}  (Index 82)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:44] [32mIntermediate result: {'mse': 0.22808344662189484, 'rmse': 0.4748058021068573, 'default': 0.22808344662189484}  (Index 83)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:44] [32mIntermediate result: {'mse': 0.22935955226421356, 'rmse': 0.4760175347328186, 'default': 0.22935955226421356}  (Index 84)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:45] [32mIntermediate result: {'mse': 0.22869113087654114, 'rmse': 0.475358784198761, 'default': 0.22869113087654114}  (Index 85)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:45] [32mIntermediate result: {'mse': 0.2276465743780136, 'rmse': 0.4742990732192993, 'default': 0.2276465743780136}  (Index 86)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:46] [32mIntermediate result: {'mse': 0.22794821858406067, 'rmse': 0.4747222661972046, 'default': 0.22794821858406067}  (Index 87)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:46] [32mIntermediate result: {'mse': 0.22747518122196198, 'rmse': 0.47413524985313416, 'default': 0.22747518122196198}  (Index 88)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:47] [32mIntermediate result: {'mse': 0.227542445063591, 'rmse': 0.47417908906936646, 'default': 0.227542445063591}  (Index 89)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:47] [32mIntermediate result: {'mse': 0.22736673057079315, 'rmse': 0.474004864692688, 'default': 0.22736673057079315}  (Index 90)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:48] [32mIntermediate result: {'mse': 0.22656434774398804, 'rmse': 0.4732023775577545, 'default': 0.22656434774398804}  (Index 91)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:48] [32mIntermediate result: {'mse': 0.22630280256271362, 'rmse': 0.47299519181251526, 'default': 0.22630280256271362}  (Index 92)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:49] [32mIntermediate result: {'mse': 0.22640058398246765, 'rmse': 0.47311702370643616, 'default': 0.22640058398246765}  (Index 93)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:49] [32mIntermediate result: {'mse': 0.2256111353635788, 'rmse': 0.4722817838191986, 'default': 0.2256111353635788}  (Index 94)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:50] [32mIntermediate result: {'mse': 0.22607029974460602, 'rmse': 0.4728262424468994, 'default': 0.22607029974460602}  (Index 95)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:50] [32mIntermediate result: {'mse': 0.22613631188869476, 'rmse': 0.47288572788238525, 'default': 0.22613631188869476}  (Index 96)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:51] [32mIntermediate result: {'mse': 0.2254074662923813, 'rmse': 0.4721101224422455, 'default': 0.2254074662923813}  (Index 97)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:51] [32mIntermediate result: {'mse': 0.22483696043491364, 'rmse': 0.47143906354904175, 'default': 0.22483696043491364}  (Index 98)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:52] [32mIntermediate result: {'mse': 0.22566144168376923, 'rmse': 0.4723599851131439, 'default': 0.22566144168376923}  (Index 99)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:52] [32mIntermediate result: {'mse': 0.22515663504600525, 'rmse': 0.4718613922595978, 'default': 0.22515663504600525}  (Index 100)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:53] [32mIntermediate result: {'mse': 0.2246062159538269, 'rmse': 0.4711928963661194, 'default': 0.2246062159538269}  (Index 101)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:53] [32mIntermediate result: {'mse': 0.22446966171264648, 'rmse': 0.4710986912250519, 'default': 0.22446966171264648}  (Index 102)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:54] [32mIntermediate result: {'mse': 0.22475598752498627, 'rmse': 0.471379816532135, 'default': 0.22475598752498627}  (Index 103)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:54] [32mIntermediate result: {'mse': 0.22461217641830444, 'rmse': 0.4712158441543579, 'default': 0.22461217641830444}  (Index 104)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:55] [32mIntermediate result: {'mse': 0.22380012273788452, 'rmse': 0.47046008706092834, 'default': 0.22380012273788452}  (Index 105)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:55] [32mIntermediate result: {'mse': 0.22503678500652313, 'rmse': 0.47161492705345154, 'default': 0.22503678500652313}  (Index 106)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:56] [32mIntermediate result: {'mse': 0.2237820029258728, 'rmse': 0.47039955854415894, 'default': 0.2237820029258728}  (Index 107)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:56] [32mIntermediate result: {'mse': 0.22346562147140503, 'rmse': 0.4699731767177582, 'default': 0.22346562147140503}  (Index 108)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:57] [32mIntermediate result: {'mse': 0.22309806942939758, 'rmse': 0.4696856141090393, 'default': 0.22309806942939758}  (Index 109)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:57] [32mIntermediate result: {'mse': 0.22301802039146423, 'rmse': 0.46953168511390686, 'default': 0.22301802039146423}  (Index 110)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:58] [32mIntermediate result: {'mse': 0.22385339438915253, 'rmse': 0.4703499376773834, 'default': 0.22385339438915253}  (Index 111)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:58] [32mIntermediate result: {'mse': 0.22412648797035217, 'rmse': 0.470671683549881, 'default': 0.22412648797035217}  (Index 112)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:59] [32mIntermediate result: {'mse': 0.2228439599275589, 'rmse': 0.4693170487880707, 'default': 0.2228439599275589}  (Index 113)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:51:59] [32mIntermediate result: {'mse': 0.221928209066391, 'rmse': 0.4684615731239319, 'default': 0.221928209066391}  (Index 114)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:00] [32mIntermediate result: {'mse': 0.22204053401947021, 'rmse': 0.4685004651546478, 'default': 0.22204053401947021}  (Index 115)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:00] [32mIntermediate result: {'mse': 0.22283132374286652, 'rmse': 0.46931663155555725, 'default': 0.22283132374286652}  (Index 116)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:01] [32mIntermediate result: {'mse': 0.2221902310848236, 'rmse': 0.4685981273651123, 'default': 0.2221902310848236}  (Index 117)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:01] [32mIntermediate result: {'mse': 0.22189871966838837, 'rmse': 0.4682951271533966, 'default': 0.22189871966838837}  (Index 118)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:02] [32mIntermediate result: {'mse': 0.22083759307861328, 'rmse': 0.46726804971694946, 'default': 0.22083759307861328}  (Index 119)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:02] [32mIntermediate result: {'mse': 0.22180484235286713, 'rmse': 0.46823403239250183, 'default': 0.22180484235286713}  (Index 120)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:03] [32mIntermediate result: {'mse': 0.2216862291097641, 'rmse': 0.4681755602359772, 'default': 0.2216862291097641}  (Index 121)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:03] [32mIntermediate result: {'mse': 0.2213190346956253, 'rmse': 0.467752605676651, 'default': 0.2213190346956253}  (Index 122)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:04] [32mIntermediate result: {'mse': 0.22062116861343384, 'rmse': 0.46705153584480286, 'default': 0.22062116861343384}  (Index 123)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:04] [32mIntermediate result: {'mse': 0.2209407091140747, 'rmse': 0.46737757325172424, 'default': 0.2209407091140747}  (Index 124)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:05] [32mIntermediate result: {'mse': 0.22139126062393188, 'rmse': 0.46779850125312805, 'default': 0.22139126062393188}  (Index 125)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:05] [32mIntermediate result: {'mse': 0.2203957736492157, 'rmse': 0.4667590260505676, 'default': 0.2203957736492157}  (Index 126)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:06] [32mIntermediate result: {'mse': 0.21995767951011658, 'rmse': 0.46631672978401184, 'default': 0.21995767951011658}  (Index 127)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:06] [32mIntermediate result: {'mse': 0.22009874880313873, 'rmse': 0.46649110317230225, 'default': 0.22009874880313873}  (Index 128)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:07] [32mIntermediate result: {'mse': 0.21960793435573578, 'rmse': 0.4659542143344879, 'default': 0.21960793435573578}  (Index 129)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:07] [32mIntermediate result: {'mse': 0.22011126577854156, 'rmse': 0.46641838550567627, 'default': 0.22011126577854156}  (Index 130)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:08] [32mIntermediate result: {'mse': 0.2202417254447937, 'rmse': 0.4666409194469452, 'default': 0.2202417254447937}  (Index 131)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:08] [32mIntermediate result: {'mse': 0.21956433355808258, 'rmse': 0.465904176235199, 'default': 0.21956433355808258}  (Index 132)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:09] [32mIntermediate result: {'mse': 0.21931777894496918, 'rmse': 0.46566882729530334, 'default': 0.21931777894496918}  (Index 133)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:09] [32mIntermediate result: {'mse': 0.21987959742546082, 'rmse': 0.46616753935813904, 'default': 0.21987959742546082}  (Index 134)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:10] [32mIntermediate result: {'mse': 0.2191673070192337, 'rmse': 0.46550196409225464, 'default': 0.2191673070192337}  (Index 135)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:10] [32mIntermediate result: {'mse': 0.21832235157489777, 'rmse': 0.46461671590805054, 'default': 0.21832235157489777}  (Index 136)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:11] [32mIntermediate result: {'mse': 0.21968114376068115, 'rmse': 0.46599531173706055, 'default': 0.21968114376068115}  (Index 137)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:11] [32mIntermediate result: {'mse': 0.22069810330867767, 'rmse': 0.46710506081581116, 'default': 0.22069810330867767}  (Index 138)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:12] [32mIntermediate result: {'mse': 0.2188902348279953, 'rmse': 0.465195894241333, 'default': 0.2188902348279953}  (Index 139)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:12] [32mIntermediate result: {'mse': 0.21821171045303345, 'rmse': 0.46444541215896606, 'default': 0.21821171045303345}  (Index 140)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:13] [32mIntermediate result: {'mse': 0.21851930022239685, 'rmse': 0.46476036310195923, 'default': 0.21851930022239685}  (Index 141)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:13] [32mIntermediate result: {'mse': 0.21785947680473328, 'rmse': 0.4641134738922119, 'default': 0.21785947680473328}  (Index 142)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:14] [32mIntermediate result: {'mse': 0.21870820224285126, 'rmse': 0.4649590849876404, 'default': 0.21870820224285126}  (Index 143)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:14] [32mIntermediate result: {'mse': 0.21815989911556244, 'rmse': 0.46439626812934875, 'default': 0.21815989911556244}  (Index 144)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:15] [32mIntermediate result: {'mse': 0.21790628135204315, 'rmse': 0.464165061712265, 'default': 0.21790628135204315}  (Index 145)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:15] [32mIntermediate result: {'mse': 0.2177692949771881, 'rmse': 0.4640076160430908, 'default': 0.2177692949771881}  (Index 146)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:16] [32mIntermediate result: {'mse': 0.2185133993625641, 'rmse': 0.464765727519989, 'default': 0.2185133993625641}  (Index 147)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:16] [32mIntermediate result: {'mse': 0.21641142666339874, 'rmse': 0.4626224935054779, 'default': 0.21641142666339874}  (Index 148)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:17] [32mIntermediate result: {'mse': 0.21811595559120178, 'rmse': 0.4643125534057617, 'default': 0.21811595559120178}  (Index 149)[0m


`Trainer.fit` stopped: `max_epochs=150` reached.


[2024-05-14 01:52:17] [32mFinal result: {'mse': 0.21811595559120178, 'rmse': 0.4643125534057617, 'default': 0.21811595559120178}[0m


# Evaluate final model

In [11]:
test_evaluator = RegressionEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    val_dataloaders=loaders['test'],
    max_epochs=100,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
test_evaluator.evaluate(final_model)

[2024-05-14 01:52:59] [32mOnly validation dataloaders are available. Skip to validation.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 01:52:59] [32mIntermediate result: {'mse': 0.23070895671844482, 'rmse': 0.4781631529331207, 'default': 0.23070895671844482}  (Index 150)[0m
[2024-05-14 01:52:59] [32mFinal result: {'mse': 0.23070895671844482, 'rmse': 0.4781631529331207, 'default': 0.23070895671844482}[0m
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       val_default          0.23070895671844482
        val_loss            0.23070895671844482
         val_mse            0.23070895671844482
        val_rmse            0.4781631529331207
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.23070895671844482,
  'val_mse': 0.23070895671844482,
  'val_rmse': 0.4781631529331207,
  'val_default': 0.23070895671844482}]