In [1]:
import torch
import torch.optim
import numpy as np

from nni.nas.space import model_context
from nni.nas.strategy import DARTS as DartsStrategy
from nni.nas.experiment import NasExperiment
from dataset.classification import fetch_data
from evaluators.classification import ClassificationEvaluator
from models.mlp import MLP

In [2]:
np.random.seed(0)
torch.random.manual_seed(0)

<torch._C.Generator at 0x76bd8234b1d0>

# Fetch dataset loaders

In [3]:
task_config, loaders = fetch_data(batch_size=1024, num_workers=4)

In [4]:
for split_name, loader in loaders.items():
    print(split_name, 'dataset size:', len(loader.dataset))

print('num_classes:', task_config.out_features)

train dataset size: 371847
val dataset size: 92962
test dataset size: 116203
num_classes: 7


# Training

In [6]:
evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    num_classes=task_config.out_features,
    max_epochs=200,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [7]:
strategy = DartsStrategy()
model_space = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)
experiment = NasExperiment(model_space, evaluator, strategy)
experiment.run()

[2024-05-14 12:44:40] [32mConfig is not provided. Will try to infer.[0m
[2024-05-14 12:44:40] [32mStrategy is found to be a one-shot strategy. Setting execution engine to "sequential" and format to "raw".[0m


[2024-05-14 12:44:40] [32mCheckpoint saved to /home/sisha/nni-experiments/r0jz94ba/checkpoint.[0m
[2024-05-14 12:44:40] [32mExperiment initialized successfully. Starting exploration strategy...[0m


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | training_module | ClassificationModule | 2.7 M 
---------------------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.633    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=200` reached.


[2024-05-14 13:32:20] [32mWaiting for models submitted to engine to finish...[0m
[2024-05-14 13:32:20] [32mExperiment is completed.[0m


True

# Train final model

In [8]:
exported_arch = experiment.export_top_models(formatter='dict')[0]
print(exported_arch)

{'MLP/d_block': 32, 'MLP/in_act': 0, 'MLP/n_blocks': 3, 'MLP/blocks_act': 0}


In [9]:
with model_context(exported_arch):
    final_model = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)

In [10]:
evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    num_classes=task_config.out_features,
    max_epochs=200,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [11]:
evaluator.fit(final_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | metrics   | ModuleDict       | 0     
2 | _model    | MLP              | 5.2 K 
-----------------------------------------------
5.2 K     Trainable params
0         Non-trainable params
5.2 K     Total params
0.021     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:30] [32mIntermediate result: 0.6993932723999023  (Index 0)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:32] [32mIntermediate result: 0.7151309251785278  (Index 1)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:35] [32mIntermediate result: 0.7333319187164307  (Index 2)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:37] [32mIntermediate result: 0.7373765707015991  (Index 3)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:40] [32mIntermediate result: 0.7414642572402954  (Index 4)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:42] [32mIntermediate result: 0.7448204755783081  (Index 5)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:44] [32mIntermediate result: 0.7468643188476562  (Index 6)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:47] [32mIntermediate result: 0.749080240726471  (Index 7)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:49] [32mIntermediate result: 0.7526731491088867  (Index 8)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:52] [32mIntermediate result: 0.7547062039375305  (Index 9)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:54] [32mIntermediate result: 0.7576966881752014  (Index 10)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:56] [32mIntermediate result: 0.7598158121109009  (Index 11)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:33:59] [32mIntermediate result: 0.7619134783744812  (Index 12)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:01] [32mIntermediate result: 0.7643768191337585  (Index 13)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:03] [32mIntermediate result: 0.7662162780761719  (Index 14)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:06] [32mIntermediate result: 0.767937421798706  (Index 15)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:08] [32mIntermediate result: 0.7691314816474915  (Index 16)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:11] [32mIntermediate result: 0.7705944180488586  (Index 17)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:13] [32mIntermediate result: 0.7720251083374023  (Index 18)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:16] [32mIntermediate result: 0.7727888822555542  (Index 19)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:18] [32mIntermediate result: 0.7741873264312744  (Index 20)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:20] [32mIntermediate result: 0.7756825089454651  (Index 21)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:23] [32mIntermediate result: 0.7764462828636169  (Index 22)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:25] [32mIntermediate result: 0.7798132300376892  (Index 23)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:27] [32mIntermediate result: 0.7800821661949158  (Index 24)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:30] [32mIntermediate result: 0.780598521232605  (Index 25)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:32] [32mIntermediate result: 0.7828360199928284  (Index 26)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:35] [32mIntermediate result: 0.782986581325531  (Index 27)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:37] [32mIntermediate result: 0.7851380109786987  (Index 28)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:40] [32mIntermediate result: 0.7873109579086304  (Index 29)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:42] [32mIntermediate result: 0.7879671454429626  (Index 30)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:44] [32mIntermediate result: 0.7898818850517273  (Index 31)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:47] [32mIntermediate result: 0.7914954423904419  (Index 32)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:49] [32mIntermediate result: 0.7925603985786438  (Index 33)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:51] [32mIntermediate result: 0.792678713798523  (Index 34)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:54] [32mIntermediate result: 0.7939158082008362  (Index 35)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:57] [32mIntermediate result: 0.7951958775520325  (Index 36)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:34:59] [32mIntermediate result: 0.7969385385513306  (Index 37)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:01] [32mIntermediate result: 0.797089159488678  (Index 38)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:04] [32mIntermediate result: 0.7979282140731812  (Index 39)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:06] [32mIntermediate result: 0.7977775931358337  (Index 40)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:09] [32mIntermediate result: 0.7988747954368591  (Index 41)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:11] [32mIntermediate result: 0.7994879484176636  (Index 42)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:14] [32mIntermediate result: 0.8007035255432129  (Index 43)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:16] [32mIntermediate result: 0.8002086877822876  (Index 44)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:19] [32mIntermediate result: 0.800262451171875  (Index 45)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:21] [32mIntermediate result: 0.8022955656051636  (Index 46)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:23] [32mIntermediate result: 0.8028549551963806  (Index 47)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:26] [32mIntermediate result: 0.8028441667556763  (Index 48)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:28] [32mIntermediate result: 0.8038338422775269  (Index 49)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:31] [32mIntermediate result: 0.8053720593452454  (Index 50)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:33] [32mIntermediate result: 0.8053290843963623  (Index 51)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:36] [32mIntermediate result: 0.8064262866973877  (Index 52)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:38] [32mIntermediate result: 0.8068888187408447  (Index 53)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:40] [32mIntermediate result: 0.8057055473327637  (Index 54)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:43] [32mIntermediate result: 0.8076525926589966  (Index 55)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:45] [32mIntermediate result: 0.808125913143158  (Index 56)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:48] [32mIntermediate result: 0.8083195090293884  (Index 57)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:50] [32mIntermediate result: 0.8089219331741333  (Index 58)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:52] [32mIntermediate result: 0.8095781207084656  (Index 59)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:55] [32mIntermediate result: 0.8087175488471985  (Index 60)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:35:57] [32mIntermediate result: 0.8102773427963257  (Index 61)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:00] [32mIntermediate result: 0.8099545836448669  (Index 62)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:02] [32mIntermediate result: 0.8097286820411682  (Index 63)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:05] [32mIntermediate result: 0.8108044266700745  (Index 64)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:07] [32mIntermediate result: 0.8116865158081055  (Index 65)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:09] [32mIntermediate result: 0.8112024068832397  (Index 66)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:12] [32mIntermediate result: 0.8104171752929688  (Index 67)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:14] [32mIntermediate result: 0.8111271262168884  (Index 68)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:17] [32mIntermediate result: 0.8120307326316833  (Index 69)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:19] [32mIntermediate result: 0.8126546144485474  (Index 70)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:22] [32mIntermediate result: 0.8125686049461365  (Index 71)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:24] [32mIntermediate result: 0.8133861422538757  (Index 72)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:27] [32mIntermediate result: 0.8139992952346802  (Index 73)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:29] [32mIntermediate result: 0.813784122467041  (Index 74)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:31] [32mIntermediate result: 0.814246654510498  (Index 75)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:34] [32mIntermediate result: 0.8135797381401062  (Index 76)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:36] [32mIntermediate result: 0.8137518763542175  (Index 77)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:39] [32mIntermediate result: 0.8146877288818359  (Index 78)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:41] [32mIntermediate result: 0.8135905265808105  (Index 79)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:44] [32mIntermediate result: 0.8140960931777954  (Index 80)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:46] [32mIntermediate result: 0.8153116106987  (Index 81)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:49] [32mIntermediate result: 0.8150319457054138  (Index 82)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:51] [32mIntermediate result: 0.8159570693969727  (Index 83)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:53] [32mIntermediate result: 0.8152256011962891  (Index 84)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:56] [32mIntermediate result: 0.8160216212272644  (Index 85)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:36:58] [32mIntermediate result: 0.8164411187171936  (Index 86)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:01] [32mIntermediate result: 0.8165809512138367  (Index 87)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:03] [32mIntermediate result: 0.8167207837104797  (Index 88)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:06] [32mIntermediate result: 0.8166132569313049  (Index 89)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:08] [32mIntermediate result: 0.8169897198677063  (Index 90)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:11] [32mIntermediate result: 0.8176566958427429  (Index 91)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:13] [32mIntermediate result: 0.8181729912757874  (Index 92)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:16] [32mIntermediate result: 0.8180546760559082  (Index 93)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:18] [32mIntermediate result: 0.8183881640434265  (Index 94)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:21] [32mIntermediate result: 0.8188830018043518  (Index 95)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:23] [32mIntermediate result: 0.8188184499740601  (Index 96)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:26] [32mIntermediate result: 0.819141149520874  (Index 97)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:28] [32mIntermediate result: 0.81851726770401  (Index 98)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:30] [32mIntermediate result: 0.8191734552383423  (Index 99)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:33] [32mIntermediate result: 0.8195714354515076  (Index 100)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:35] [32mIntermediate result: 0.8186571002006531  (Index 101)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:38] [32mIntermediate result: 0.8187861442565918  (Index 102)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:40] [32mIntermediate result: 0.8204104900360107  (Index 103)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:43] [32mIntermediate result: 0.8187108635902405  (Index 104)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:45] [32mIntermediate result: 0.8194961547851562  (Index 105)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:47] [32mIntermediate result: 0.8200985193252563  (Index 106)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:50] [32mIntermediate result: 0.8198941349983215  (Index 107)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:52] [32mIntermediate result: 0.819722056388855  (Index 108)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:55] [32mIntermediate result: 0.8207116723060608  (Index 109)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:37:57] [32mIntermediate result: 0.8202491402626038  (Index 110)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:00] [32mIntermediate result: 0.8205503225326538  (Index 111)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:02] [32mIntermediate result: 0.8216367959976196  (Index 112)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:05] [32mIntermediate result: 0.8204750418663025  (Index 113)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:07] [32mIntermediate result: 0.8211742639541626  (Index 114)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:09] [32mIntermediate result: 0.8209375739097595  (Index 115)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:12] [32mIntermediate result: 0.8208622932434082  (Index 116)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:14] [32mIntermediate result: 0.8221316337585449  (Index 117)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:17] [32mIntermediate result: 0.8214646577835083  (Index 118)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:19] [32mIntermediate result: 0.8213786482810974  (Index 119)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:22] [32mIntermediate result: 0.8219594955444336  (Index 120)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:24] [32mIntermediate result: 0.8213140964508057  (Index 121)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:27] [32mIntermediate result: 0.8215615153312683  (Index 122)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:29] [32mIntermediate result: 0.8224436044692993  (Index 123)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:32] [32mIntermediate result: 0.8223037123680115  (Index 124)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:34] [32mIntermediate result: 0.8220133185386658  (Index 125)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:36] [32mIntermediate result: 0.8225188851356506  (Index 126)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:39] [32mIntermediate result: 0.8221531510353088  (Index 127)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:41] [32mIntermediate result: 0.8233579397201538  (Index 128)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:44] [32mIntermediate result: 0.8230567574501038  (Index 129)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:46] [32mIntermediate result: 0.8225080966949463  (Index 130)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:49] [32mIntermediate result: 0.8235730528831482  (Index 131)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:51] [32mIntermediate result: 0.8219918012619019  (Index 132)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:54] [32mIntermediate result: 0.823314905166626  (Index 133)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:56] [32mIntermediate result: 0.8231858015060425  (Index 134)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:38:59] [32mIntermediate result: 0.8226049542427063  (Index 135)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:01] [32mIntermediate result: 0.8230674862861633  (Index 136)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:04] [32mIntermediate result: 0.823454737663269  (Index 137)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:06] [32mIntermediate result: 0.8251113295555115  (Index 138)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:09] [32mIntermediate result: 0.8241431713104248  (Index 139)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:11] [32mIntermediate result: 0.8226587176322937  (Index 140)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:14] [32mIntermediate result: 0.8239710927009583  (Index 141)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:16] [32mIntermediate result: 0.8235515356063843  (Index 142)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:19] [32mIntermediate result: 0.823777437210083  (Index 143)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:21] [32mIntermediate result: 0.8238850235939026  (Index 144)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:23] [32mIntermediate result: 0.8252296447753906  (Index 145)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:26] [32mIntermediate result: 0.8244766592979431  (Index 146)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:28] [32mIntermediate result: 0.8242937922477722  (Index 147)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:31] [32mIntermediate result: 0.8245949745178223  (Index 148)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:33] [32mIntermediate result: 0.8247778415679932  (Index 149)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:36] [32mIntermediate result: 0.8247455954551697  (Index 150)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:38] [32mIntermediate result: 0.8254985809326172  (Index 151)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:41] [32mIntermediate result: 0.8249284625053406  (Index 152)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:43] [32mIntermediate result: 0.8251758813858032  (Index 153)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:46] [32mIntermediate result: 0.8251758813858032  (Index 154)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:48] [32mIntermediate result: 0.8261010050773621  (Index 155)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:51] [32mIntermediate result: 0.8249930143356323  (Index 156)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:53] [32mIntermediate result: 0.8252834677696228  (Index 157)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:56] [32mIntermediate result: 0.8262730836868286  (Index 158)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:39:58] [32mIntermediate result: 0.8254340291023254  (Index 159)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:01] [32mIntermediate result: 0.826025664806366  (Index 160)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:03] [32mIntermediate result: 0.825401782989502  (Index 161)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:06] [32mIntermediate result: 0.8256061673164368  (Index 162)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:08] [32mIntermediate result: 0.825218915939331  (Index 163)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:11] [32mIntermediate result: 0.8268970251083374  (Index 164)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:13] [32mIntermediate result: 0.8249714970588684  (Index 165)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:16] [32mIntermediate result: 0.8264667391777039  (Index 166)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:18] [32mIntermediate result: 0.8269615769386292  (Index 167)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:21] [32mIntermediate result: 0.8263806700706482  (Index 168)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:23] [32mIntermediate result: 0.8261655569076538  (Index 169)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:25] [32mIntermediate result: 0.8264667391777039  (Index 170)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:28] [32mIntermediate result: 0.8271551728248596  (Index 171)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:30] [32mIntermediate result: 0.8265528082847595  (Index 172)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:33] [32mIntermediate result: 0.8263914585113525  (Index 173)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:35] [32mIntermediate result: 0.8263161182403564  (Index 174)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:38] [32mIntermediate result: 0.8273703455924988  (Index 175)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:40] [32mIntermediate result: 0.8271551728248596  (Index 176)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:43] [32mIntermediate result: 0.8276392221450806  (Index 177)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:45] [32mIntermediate result: 0.826585054397583  (Index 178)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:47] [32mIntermediate result: 0.828187882900238  (Index 179)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:50] [32mIntermediate result: 0.8278221487998962  (Index 180)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:52] [32mIntermediate result: 0.8278113603591919  (Index 181)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:55] [32mIntermediate result: 0.828026533126831  (Index 182)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:40:57] [32mIntermediate result: 0.8281448483467102  (Index 183)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:00] [32mIntermediate result: 0.8269400596618652  (Index 184)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:02] [32mIntermediate result: 0.8278543949127197  (Index 185)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:05] [32mIntermediate result: 0.8270583748817444  (Index 186)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:07] [32mIntermediate result: 0.8272519707679749  (Index 187)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:10] [32mIntermediate result: 0.8275639414787292  (Index 188)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:12] [32mIntermediate result: 0.828209400177002  (Index 189)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:15] [32mIntermediate result: 0.8284029960632324  (Index 190)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:17] [32mIntermediate result: 0.826262354850769  (Index 191)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:19] [32mIntermediate result: 0.8283384442329407  (Index 192)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:22] [32mIntermediate result: 0.8284890651702881  (Index 193)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:24] [32mIntermediate result: 0.828026533126831  (Index 194)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:27] [32mIntermediate result: 0.827789843082428  (Index 195)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:30] [32mIntermediate result: 0.8282738924026489  (Index 196)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:32] [32mIntermediate result: 0.8280479907989502  (Index 197)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:35] [32mIntermediate result: 0.8281555771827698  (Index 198)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:41:37] [32mIntermediate result: 0.8290591835975647  (Index 199)[0m


`Trainer.fit` stopped: `max_epochs=200` reached.


[2024-05-14 13:41:37] [32mFinal result: 0.8290591835975647[0m


# Evaluate final model

In [12]:
test_evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    val_dataloaders=loaders['test'],
    num_classes=task_config.out_features,
    max_epochs=200,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [13]:
test_evaluator.evaluate(final_model)

[2024-05-14 13:48:10] [32mOnly validation dataloaders are available. Skip to validation.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-14 13:48:11] [32mIntermediate result: 0.8291265964508057  (Index 200)[0m
[2024-05-14 13:48:11] [32mFinal result: 0.8291265964508057[0m
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.8291265964508057
        val_loss            0.4151475727558136
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.4151475727558136, 'val_acc': 0.8291265964508057}]