In [1]:
import torch
import torch.optim
import numpy as np

from nni.nas.space import model_context
from nni.nas.strategy import DARTS as DartsStrategy
from nni.nas.experiment import NasExperiment
from dataset.classification import fetch_data
from evaluators.classification import ClassificationEvaluator
from models.mlp import MLP

In [2]:
np.random.seed(0)
torch.random.manual_seed(0)

<torch._C.Generator at 0x7a5037f8b1d0>

# Fetch dataset loaders

In [3]:
task_config, loaders = fetch_data('ja', batch_size=256, num_workers=4)

In [4]:
for split_name, loader in loaders.items():
    print(split_name, 'dataset size:', len(loader.dataset))

print('num_classes:', task_config.out_features)

train dataset size: 53588
val dataset size: 13398
test dataset size: 16747
num_classes: 4


# Training

In [5]:
evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    num_classes=task_config.out_features,
    max_epochs=100,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [6]:
strategy = DartsStrategy()
model_space = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)
experiment = NasExperiment(model_space, evaluator, strategy)
experiment.run()

[2024-05-15 15:00:49] [32mConfig is not provided. Will try to infer.[0m
[2024-05-15 15:00:49] [32mStrategy is found to be a one-shot strategy. Setting execution engine to "sequential" and format to "raw".[0m
[2024-05-15 15:00:50] [32mCheckpoint saved to /home/sisha/nni-experiments/1430n8ut/checkpoint.[0m
[2024-05-15 15:00:50] [32mExperiment initialized successfully. Starting exploration strategy...[0m


You are using a CUDA device ('NVIDIA GeForce RTX 4060 Ti') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name            | Type                 | Params
---------------------------------------------------------
0 | training_module | ClassificationModule | 2.7 M 
---------------------------------------------------------
2.7 M     Trainable params
0         Non-trainable params
2.7 M     Total params
10.627    Total estimated model params size (MB)


Training: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=100` reached.


[2024-05-15 15:14:11] [32mWaiting for models submitted to engine to finish...[0m
[2024-05-15 15:14:11] [32mExperiment is completed.[0m


True

# Train final model

In [7]:
exported_arch = experiment.export_top_models(formatter='dict')[0]
print(exported_arch)

{'MLP/d_block': 32, 'MLP/in_act': 0, 'MLP/n_blocks': 1, 'MLP/blocks_act': 0}


In [8]:
with model_context(exported_arch):
    final_model = MLP(d_in=task_config.in_features, d_out=task_config.out_features, dropout=0.1)

In [9]:
evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    train_dataloaders=loaders['train'],
    val_dataloaders=loaders['val'],
    num_classes=task_config.out_features,
    max_epochs=150,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [10]:
evaluator.fit(final_model)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name      | Type             | Params
-----------------------------------------------
0 | criterion | CrossEntropyLoss | 0     
1 | metrics   | ModuleDict       | 0     
2 | _model    | MLP              | 2.9 K 
-----------------------------------------------
2.9 K     Trainable params
0         Non-trainable params
2.9 K     Total params
0.012     Total estimated model params size (MB)


Sanity Checking: |          | 0/? [00:00<?, ?it/s]

Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:52] [32mIntermediate result: 0.5675473809242249  (Index 0)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:53] [32mIntermediate result: 0.6273324489593506  (Index 1)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:54] [32mIntermediate result: 0.637557864189148  (Index 2)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:55] [32mIntermediate result: 0.6443499326705933  (Index 3)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:57] [32mIntermediate result: 0.6482310891151428  (Index 4)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:58] [32mIntermediate result: 0.653605043888092  (Index 5)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:14:59] [32mIntermediate result: 0.6562173366546631  (Index 6)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:00] [32mIntermediate result: 0.6595014333724976  (Index 7)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:01] [32mIntermediate result: 0.6600984930992126  (Index 8)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:03] [32mIntermediate result: 0.662263035774231  (Index 9)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:04] [32mIntermediate result: 0.6630094051361084  (Index 10)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:05] [32mIntermediate result: 0.6645768284797668  (Index 11)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:06] [32mIntermediate result: 0.6659202575683594  (Index 12)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:07] [32mIntermediate result: 0.6677862405776978  (Index 13)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:08] [32mIntermediate result: 0.6702492833137512  (Index 14)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:10] [32mIntermediate result: 0.6708464026451111  (Index 15)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:11] [32mIntermediate result: 0.6735333800315857  (Index 16)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:12] [32mIntermediate result: 0.6751007437705994  (Index 17)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:13] [32mIntermediate result: 0.677489161491394  (Index 18)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:14] [32mIntermediate result: 0.6789072751998901  (Index 19)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:16] [32mIntermediate result: 0.6802507638931274  (Index 20)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:17] [32mIntermediate result: 0.6827884912490845  (Index 21)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:18] [32mIntermediate result: 0.6812956929206848  (Index 22)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:19] [32mIntermediate result: 0.6847290396690369  (Index 23)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:21] [32mIntermediate result: 0.6851769089698792  (Index 24)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:22] [32mIntermediate result: 0.684952974319458  (Index 25)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:23] [32mIntermediate result: 0.6880877614021301  (Index 26)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:24] [32mIntermediate result: 0.6892820000648499  (Index 27)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:25] [32mIntermediate result: 0.6910732984542847  (Index 28)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:26] [32mIntermediate result: 0.6904761791229248  (Index 29)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:28] [32mIntermediate result: 0.6917450428009033  (Index 30)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:29] [32mIntermediate result: 0.6927153468132019  (Index 31)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:30] [32mIntermediate result: 0.693312406539917  (Index 32)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:31] [32mIntermediate result: 0.6930139064788818  (Index 33)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:33] [32mIntermediate result: 0.6959993839263916  (Index 34)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:34] [32mIntermediate result: 0.6955515742301941  (Index 35)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:35] [32mIntermediate result: 0.6947305798530579  (Index 36)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:36] [32mIntermediate result: 0.6946558952331543  (Index 37)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:38] [32mIntermediate result: 0.6962979435920715  (Index 38)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:39] [32mIntermediate result: 0.6965965032577515  (Index 39)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:40] [32mIntermediate result: 0.6968950629234314  (Index 40)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:41] [32mIntermediate result: 0.6971936225891113  (Index 41)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:42] [32mIntermediate result: 0.6974921822547913  (Index 42)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:44] [32mIntermediate result: 0.7012240886688232  (Index 43)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:45] [32mIntermediate result: 0.699582040309906  (Index 44)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:46] [32mIntermediate result: 0.7007015943527222  (Index 45)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:47] [32mIntermediate result: 0.701298713684082  (Index 46)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:48] [32mIntermediate result: 0.7013733386993408  (Index 47)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:50] [32mIntermediate result: 0.7014479637145996  (Index 48)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:51] [32mIntermediate result: 0.7030153870582581  (Index 49)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:52] [32mIntermediate result: 0.7030153870582581  (Index 50)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:53] [32mIntermediate result: 0.7014479637145996  (Index 51)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:54] [32mIntermediate result: 0.7039110064506531  (Index 52)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:56] [32mIntermediate result: 0.7048066854476929  (Index 53)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:57] [32mIntermediate result: 0.7039110064506531  (Index 54)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:58] [32mIntermediate result: 0.7046574354171753  (Index 55)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:15:59] [32mIntermediate result: 0.7040603160858154  (Index 56)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:01] [32mIntermediate result: 0.7027168273925781  (Index 57)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:02] [32mIntermediate result: 0.7034631967544556  (Index 58)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:03] [32mIntermediate result: 0.7054784297943115  (Index 59)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:04] [32mIntermediate result: 0.7047320604324341  (Index 60)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:05] [32mIntermediate result: 0.704209566116333  (Index 61)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:07] [32mIntermediate result: 0.7044335007667542  (Index 62)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:08] [32mIntermediate result: 0.7052544951438904  (Index 63)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:09] [32mIntermediate result: 0.7039856910705566  (Index 64)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:10] [32mIntermediate result: 0.7040603160858154  (Index 65)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:11] [32mIntermediate result: 0.7066726088523865  (Index 66)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:13] [32mIntermediate result: 0.7051052451133728  (Index 67)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:14] [32mIntermediate result: 0.7061501741409302  (Index 68)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:15] [32mIntermediate result: 0.7056276798248291  (Index 69)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:16] [32mIntermediate result: 0.7043588757514954  (Index 70)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:17] [32mIntermediate result: 0.7061501741409302  (Index 71)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:19] [32mIntermediate result: 0.7058516144752502  (Index 72)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:20] [32mIntermediate result: 0.7052544951438904  (Index 73)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:21] [32mIntermediate result: 0.7054038047790527  (Index 74)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:22] [32mIntermediate result: 0.7049559354782104  (Index 75)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:23] [32mIntermediate result: 0.70674729347229  (Index 76)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:25] [32mIntermediate result: 0.7058516144752502  (Index 77)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:26] [32mIntermediate result: 0.7054784297943115  (Index 78)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:27] [32mIntermediate result: 0.7063741087913513  (Index 79)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:28] [32mIntermediate result: 0.7048813104629517  (Index 80)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:29] [32mIntermediate result: 0.7061501741409302  (Index 81)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:31] [32mIntermediate result: 0.7064487338066101  (Index 82)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:32] [32mIntermediate result: 0.7060009241104126  (Index 83)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:33] [32mIntermediate result: 0.7077922224998474  (Index 84)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:34] [32mIntermediate result: 0.706224799156189  (Index 85)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:35] [32mIntermediate result: 0.705329179763794  (Index 86)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:37] [32mIntermediate result: 0.70674729347229  (Index 87)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:38] [32mIntermediate result: 0.7064487338066101  (Index 88)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:39] [32mIntermediate result: 0.705926239490509  (Index 89)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:40] [32mIntermediate result: 0.7048813104629517  (Index 90)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:42] [32mIntermediate result: 0.7068965435028076  (Index 91)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:43] [32mIntermediate result: 0.7057023644447327  (Index 92)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:44] [32mIntermediate result: 0.7049559354782104  (Index 93)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:45] [32mIntermediate result: 0.70704585313797  (Index 94)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:47] [32mIntermediate result: 0.705030620098114  (Index 95)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:48] [32mIntermediate result: 0.7068219184875488  (Index 96)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:49] [32mIntermediate result: 0.7062994241714478  (Index 97)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:50] [32mIntermediate result: 0.7083892822265625  (Index 98)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:51] [32mIntermediate result: 0.707941472530365  (Index 99)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:53] [32mIntermediate result: 0.7071204781532288  (Index 100)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:54] [32mIntermediate result: 0.7066726088523865  (Index 101)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:55] [32mIntermediate result: 0.7073443531990051  (Index 102)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:56] [32mIntermediate result: 0.7074936628341675  (Index 103)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:57] [32mIntermediate result: 0.707941472530365  (Index 104)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:16:59] [32mIntermediate result: 0.7077175974845886  (Index 105)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:00] [32mIntermediate result: 0.7074936628341675  (Index 106)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:01] [32mIntermediate result: 0.7078668475151062  (Index 107)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:02] [32mIntermediate result: 0.7083892822265625  (Index 108)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:04] [32mIntermediate result: 0.7073443531990051  (Index 109)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:05] [32mIntermediate result: 0.7078668475151062  (Index 110)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:06] [32mIntermediate result: 0.7065979838371277  (Index 111)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:07] [32mIntermediate result: 0.7071204781532288  (Index 112)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:08] [32mIntermediate result: 0.7077922224998474  (Index 113)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:10] [32mIntermediate result: 0.7066726088523865  (Index 114)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:11] [32mIntermediate result: 0.7066726088523865  (Index 115)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:12] [32mIntermediate result: 0.7077922224998474  (Index 116)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:13] [32mIntermediate result: 0.7082400321960449  (Index 117)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:14] [32mIntermediate result: 0.7083146572113037  (Index 118)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:16] [32mIntermediate result: 0.708762526512146  (Index 119)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:17] [32mIntermediate result: 0.7089117765426636  (Index 120)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:18] [32mIntermediate result: 0.7063741087913513  (Index 121)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:19] [32mIntermediate result: 0.707941472530365  (Index 122)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:20] [32mIntermediate result: 0.7065979838371277  (Index 123)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:22] [32mIntermediate result: 0.7077922224998474  (Index 124)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:23] [32mIntermediate result: 0.7071204781532288  (Index 125)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:24] [32mIntermediate result: 0.7084639668464661  (Index 126)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:25] [32mIntermediate result: 0.7085385918617249  (Index 127)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:26] [32mIntermediate result: 0.7083892822265625  (Index 128)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:28] [32mIntermediate result: 0.7081654071807861  (Index 129)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:29] [32mIntermediate result: 0.7086132168769836  (Index 130)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:30] [32mIntermediate result: 0.7095088958740234  (Index 131)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:31] [32mIntermediate result: 0.7097327709197998  (Index 132)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:33] [32mIntermediate result: 0.7101060152053833  (Index 133)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:34] [32mIntermediate result: 0.7093595862388611  (Index 134)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:35] [32mIntermediate result: 0.7080160975456238  (Index 135)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:36] [32mIntermediate result: 0.709658145904541  (Index 136)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:38] [32mIntermediate result: 0.709956705570221  (Index 137)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:39] [32mIntermediate result: 0.7097327709197998  (Index 138)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:40] [32mIntermediate result: 0.7107776999473572  (Index 139)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:41] [32mIntermediate result: 0.708762526512146  (Index 140)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:42] [32mIntermediate result: 0.7092849612236023  (Index 141)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:44] [32mIntermediate result: 0.7101806402206421  (Index 142)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:45] [32mIntermediate result: 0.7111509442329407  (Index 143)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:46] [32mIntermediate result: 0.7095088958740234  (Index 144)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:47] [32mIntermediate result: 0.7092849612236023  (Index 145)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:49] [32mIntermediate result: 0.7091357111930847  (Index 146)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:50] [32mIntermediate result: 0.7102552652359009  (Index 147)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:51] [32mIntermediate result: 0.7100313305854797  (Index 148)[0m


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:17:52] [32mIntermediate result: 0.7093595862388611  (Index 149)[0m


`Trainer.fit` stopped: `max_epochs=150` reached.


[2024-05-15 15:17:53] [32mFinal result: 0.7093595862388611[0m


# Evaluate final model

In [11]:
test_evaluator = ClassificationEvaluator(
    learning_rate=3e-4,
    weight_decay=1e-5,
    optimizer=torch.optim.AdamW,
    val_dataloaders=loaders['test'],
    num_classes=task_config.out_features,
    max_epochs=100,
)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [12]:
test_evaluator.evaluate(final_model)

[2024-05-15 15:18:22] [32mOnly validation dataloaders are available. Skip to validation.[0m


LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |          | 0/? [00:00<?, ?it/s]

[2024-05-15 15:18:23] [32mIntermediate result: 0.7114109992980957  (Index 150)[0m
[2024-05-15 15:18:23] [32mFinal result: 0.7114109992980957[0m
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
     Validate metric           DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
         val_acc            0.7114109992980957
        val_loss            0.6988751888275146
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────


[{'val_loss': 0.6988751888275146, 'val_acc': 0.7114109992980957}]