In [1]:
from run import main

In [2]:
import os
from dataset import RecipeDataset
# loading datasets takes some time.
dataset_names = ['train', 'valid_clf', 'valid_cpl', 'test_clf', 'test_cpl']
recipe_datasets = {x: RecipeDataset(os.path.join('./Container', x)) for x in dataset_names}

In [3]:
class Arguments(object):
    def __init__(self, **kwargs):
        default = dict(
            batch_size=64,
            batch_size_eval=2048,
            n_epochs=50,
            lr=1e-3,
            weight_decay=0,
            step_size=10,  # lr_scheduler
            step_factor=0.1, # lr_scheduler
            early_stop_patience=30,  # early stop
            seed=42,
            subset_length=None,
            dim_embedding=256,
            dim_hidden=256,
            dropout=0,
            encoder_mode='HYBRID', # 'FC', 'ATT', 'HYBRID'
            pooler_mode='ATT', # 'deepSets', 'ATT'
            num_outputs_cpl=2,
            num_enc_layers=2,
            num_dec_layers=2,
            loss='MultiClassASLoss', # 'CrossEntropyLoss', 'MultiClassFocalLoss', 'MultiClassASLoss'
            optimizer_name='AdamW',
            classify=True,
            complete=True,
            freeze_classify=False,
            freeze_complete=False,
            pretrained_model_path=None,
            wandb_log=True,
            verbose=True,
            datasets=recipe_datasets) # pre-load datasets
        default.update(kwargs)
        for k in default:
            setattr(self, k, default[k])
    def update(self, **kwargs):
        for k in kwargs:
            setattr(self, k, kwargs[k])
        print(self.__dict__)
        return self

In [4]:
def run(**kwargs):
    args = Arguments(data_dir='./Container')
    main(args.update(**kwargs))

In [5]:
run(classify=True, complete=False, batch_size=64, encoder_mode='HYBRID', pooler_mode='ATT', dropout=0.2, wandb_log=False)

{'batch_size': 64, 'batch_size_eval': 2048, 'n_epochs': 50, 'lr': 0.001, 'weight_decay': 0, 'step_size': 10, 'step_factor': 0.1, 'early_stop_patience': 30, 'seed': 42, 'subset_length': None, 'dim_embedding': 256, 'dim_hidden': 256, 'dropout': 0.2, 'encoder_mode': 'HYBRID', 'pooler_mode': 'ATT', 'num_outputs_cpl': 2, 'num_enc_layers': 2, 'num_dec_layers': 2, 'loss': 'MultiClassASLoss', 'optimizer_name': 'AdamW', 'classify': True, 'complete': False, 'freeze_classify': False, 'freeze_complete': False, 'pretrained_model_path': None, 'wandb_log': False, 'verbose': True, 'datasets': {'train': <dataset.RecipeDataset object at 0x13fb71f70>, 'valid_clf': <dataset.RecipeDataset object at 0x107b62850>, 'valid_cpl': <dataset.RecipeDataset object at 0x107b8ddc0>, 'test_clf': <dataset.RecipeDataset object at 0x12ba226d0>, 'test_cpl': <dataset.RecipeDataset object at 0x107e5ad00>}, 'data_dir': './Container'}
{'train': 23547, 'valid_clf': 7848, 'valid_cpl': 7848, 'test_clf': 3924, 'test_cpl': 3924}
de

  0%|          | 0/50 [00:00<?, ?it/s]

-----Training the model-----

Epoch 1/50
    *label_clf [ 0  9 13 13 16  3 12 16  7  5 13  9]
    *preds_clf [22 22 26 40 49 11 22 22 22 41 60 60]
      0% |  Loss_clf 4.9980 | Acc_clf 0.0000
    *label_clf [ 6 12  9  9 18 11  2  9  9  7 13  6]
    *preds_clf [16  5  9  9  3 11 16 16  9  7 13  9]
     27% |  Loss_clf 2.6049 | Acc_clf 0.3906
    *label_clf [11 13 12  2 18  9 13  5 13 16  0 16]
    *preds_clf [11  9  3  2 18  9 13 16 13 16 16 16]
     54% |  Loss_clf 2.1237 | Acc_clf 0.5625
    *label_clf [ 5  5  9  3  3  3  6  6  9 13  5  9]
    *preds_clf [ 9  9  9  7  3  3  9  9  9 13  9  9]
     81% |  Loss_clf 2.2157 | Acc_clf 0.4844
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6 10  0  7 16  2  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11 14 16 16  5  2  9  9 11  7]


  2%|▏         | 1/50 [00:31<25:29, 31.21s/it]

TRAIN_CLF Loss 1.8301 Acc 0.6147 Topk 0.8966 F1macro 0.3851 F1micro 0.6147
VALID_CLF Loss 1.9078 Acc 0.5892 Topk 0.8835 F1macro 0.3659 F1micro 0.5892

Epoch 2/50
    *label_clf [16  5 11 13 13 13  7  5  7  7 11  3]
    *preds_clf [16  5 11  9 13 13  7  5  7  7 11  3]
      0% |  Loss_clf 1.9252 | Acc_clf 0.6562
    *label_clf [11  8 18  9  8 13  9 15 16 18  9 13]
    *preds_clf [12  5 18  6 16 13  9  5 16  7  9 13]
     27% |  Loss_clf 2.0458 | Acc_clf 0.4531
    *label_clf [ 9 16  9  7 16 13  9  2  7  2 14  1]
    *preds_clf [ 9 16  9  7  2 13  9  5 10  9  6 16]
     54% |  Loss_clf 1.6258 | Acc_clf 0.6406
    *label_clf [13 13 13 16  9 12 11  9  9 13  9  6]
    *preds_clf [13 13 13  1  9  3 11  7  9 13 14  6]
     81% |  Loss_clf 1.6626 | Acc_clf 0.6562
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  3 13  7 16 16  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11 13 16  2  9  2  6 17 11 16]


  4%|▍         | 2/50 [01:03<25:41, 32.12s/it]

TRAIN_CLF Loss 1.6367 Acc 0.6724 Topk 0.9223 F1macro 0.4822 F1micro 0.6724
VALID_CLF Loss 1.7656 Acc 0.6407 Topk 0.9027 F1macro 0.4505 F1micro 0.6407

Epoch 3/50
    *label_clf [13  9 16 13  5  8 13  1 12  5  5  7]
    *preds_clf [13  9 16 16 16  5 13 16 12  9 13  9]
      0% |  Loss_clf 1.6010 | Acc_clf 0.7031
    *label_clf [16 13  9  7  3  9  7 14 18 17 16  3]
    *preds_clf [16 13  9 16  3  9  7  6 18  9  2  3]
     27% |  Loss_clf 1.6506 | Acc_clf 0.6719
    *label_clf [ 0  8  5  6  9 14 13 10  1 11 13  1]
    *preds_clf [17 16  5  9  9 14 13 13  5  7 13 16]
     54% |  Loss_clf 1.6395 | Acc_clf 0.6094
    *label_clf [ 2  5 13  6 11 13 13  9 13  3  7 13]
    *preds_clf [ 2  5 13  6 11 13 13  9 13 19  7 13]
     81% |  Loss_clf 1.7681 | Acc_clf 0.5938
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4 13  7  1 13  9 13  9 17]
label [11  5  4  0  9 16  9  4 11 19]
preds [11 16  1  2  9  2  9 17 11 13]


  6%|▌         | 3/50 [01:33<24:23, 31.13s/it]

TRAIN_CLF Loss 1.4988 Acc 0.7168 Topk 0.9402 F1macro 0.5665 F1micro 0.7168
VALID_CLF Loss 1.6699 Acc 0.6673 Topk 0.9131 F1macro 0.5034 F1micro 0.6673

Epoch 4/50
    *label_clf [ 9  9  9 13 13  3  9  8  9  3  9 17]
    *preds_clf [ 9  9  9 13 13  3  9  9  9  3  5 13]
      0% |  Loss_clf 1.4573 | Acc_clf 0.6875
    *label_clf [13 15 15 16  8  7  9 16  9 13  7 13]
    *preds_clf [13 19  8 16  1  7  9  0  9 13  7 13]
     27% |  Loss_clf 1.4379 | Acc_clf 0.7500
    *label_clf [12  3 14  9  7  8 18  9 14  2  3 11]
    *preds_clf [12  3 14  6  7  8 18  5 14  9  3 11]
     54% |  Loss_clf 1.3949 | Acc_clf 0.7344
    *label_clf [16 13 16 13  7 13 13  8  3  3 16 13]
    *preds_clf [16 13 16 13  7 13 13  8 13  3 16 13]
     81% |  Loss_clf 1.5752 | Acc_clf 0.7500
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4  7  7  7 13  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11 13 16 16  9  2  9 17 11 13]


  8%|▊         | 4/50 [02:03<23:23, 30.50s/it]

TRAIN_CLF Loss 1.3673 Acc 0.7562 Topk 0.9558 F1macro 0.6277 F1micro 0.7562
VALID_CLF Loss 1.5892 Acc 0.6909 Topk 0.9224 F1macro 0.5520 F1micro 0.6909

Epoch 5/50
    *label_clf [ 9  9  9 11  5  6  6 18  9  9 16  2]
    *preds_clf [ 9  9  9  7  5  6  9 18  6  9 16 16]
      0% |  Loss_clf 1.5997 | Acc_clf 0.6406
    *label_clf [ 9  9  9  9 18  9  6  9 13  9  5 19]
    *preds_clf [ 9  9  9  9 18  9  6  9 13  9  5 19]
     27% |  Loss_clf 1.3410 | Acc_clf 0.7344
    *label_clf [ 5  2  9  4 11  4 16  5  3  2 16  0]
    *preds_clf [ 1  2  9  3 11 18 16  5  3  2 15 17]
     54% |  Loss_clf 1.3886 | Acc_clf 0.7031
    *label_clf [13 14 16  9  3  7  7  9  9  5 12 13]
    *preds_clf [13 14 16  9  3  7 14  9  5  5 12 13]
     81% |  Loss_clf 1.3509 | Acc_clf 0.8125
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4  7  7  7  0  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11  1 15  4  9  2  9  4 11  9]


 10%|█         | 5/50 [02:34<22:53, 30.52s/it]

TRAIN_CLF Loss 1.3265 Acc 0.7630 Topk 0.9616 F1macro 0.6366 F1micro 0.7630
VALID_CLF Loss 1.5889 Acc 0.6922 Topk 0.9241 F1macro 0.5452 F1micro 0.6922

Epoch 6/50
    *label_clf [ 9  3  9 16 17  4  9  9  0 13 11 11]
    *preds_clf [ 9  3  9 16 17 18  9  9 13 13 11  3]
      0% |  Loss_clf 1.3034 | Acc_clf 0.7812
    *label_clf [18 16 16 13  7  7 12 18 16  8 17 11]
    *preds_clf [18 16 16 13  7 13  4  0 16  8 17  3]
     27% |  Loss_clf 1.2977 | Acc_clf 0.7344
    *label_clf [ 9 13  2  7 16  9 16 13 13 16  3 16]
    *preds_clf [ 9 13  2  7 16  9  9 13 13 16  3 16]
     54% |  Loss_clf 1.2639 | Acc_clf 0.7344
    *label_clf [ 9  6  6  9 13 15 13  9  9  1 15  7]
    *preds_clf [13  6  6  9 13 13 13  9  9  1  8  7]
     81% |  Loss_clf 1.0911 | Acc_clf 0.8438
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4  7  7  1 13  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11 16 16  2  9  2  9  9 11  9]


 12%|█▏        | 6/50 [03:03<22:05, 30.12s/it]

TRAIN_CLF Loss 1.2460 Acc 0.7931 Topk 0.9690 F1macro 0.6843 F1micro 0.7931
VALID_CLF Loss 1.5626 Acc 0.7077 Topk 0.9269 F1macro 0.5776 F1micro 0.7077

Epoch 7/50
    *label_clf [ 3  5  9 14  3  7  9 13  3  5 16 11]
    *preds_clf [ 4  5  5  6  3  7  9 13 11  5 16  3]
      0% |  Loss_clf 1.3803 | Acc_clf 0.7656
    *label_clf [ 2  7  2  2 13 13  9  2 12 13 17  5]
    *preds_clf [16  7  9 16 13 13  9  2 12 13 17  5]
     27% |  Loss_clf 1.3899 | Acc_clf 0.7969
    *label_clf [ 9 17  9  7 13  9  2  9 10  9  9  3]
    *preds_clf [ 9  5  9 18 13  9 16  9 18  9  9  3]
     54% |  Loss_clf 1.6048 | Acc_clf 0.7031
    *label_clf [ 4 16  9  7  3  7  9 18  9  3 16 14]
    *preds_clf [10 16  9  7  3  7  9 18  9  3 16  6]
     81% |  Loss_clf 1.1891 | Acc_clf 0.8438
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4 19  7 10 17  9 13  9  9]
label [11  5  4  0  9 16  9  4 11 19]
preds [11  1 16  4  9  2  6 17 11 13]


 14%|█▍        | 7/50 [03:32<21:15, 29.67s/it]

TRAIN_CLF Loss 1.2136 Acc 0.8063 Topk 0.9716 F1macro 0.7131 F1micro 0.8063
VALID_CLF Loss 1.5473 Acc 0.7041 Topk 0.9293 F1macro 0.5783 F1micro 0.7041

Epoch 8/50
    *label_clf [18 11  9  7 16 16  7 18  7  9  0  9]
    *preds_clf [18 11  9 14 16 16  7 18  7  9 16  9]
      0% |  Loss_clf 1.4681 | Acc_clf 0.7656
    *label_clf [ 2  4  6 19 12 16  9  5  5  5  5  5]
    *preds_clf [10  4  9 19 12 16  9  5  5  1  9 17]
     27% |  Loss_clf 1.3916 | Acc_clf 0.7656
    *label_clf [ 9 10 13 18 16  2  9 18  4 13 16  7]
    *preds_clf [ 9 10 13 13 16  2  9 19 18 13  3  7]
     54% |  Loss_clf 1.5095 | Acc_clf 0.7031
    *label_clf [13  7 16 11 15 16  5  9  5  9  5  9]
    *preds_clf [13  7 16 11  9 16  9 16  5  9  1  9]
     81% |  Loss_clf 1.0796 | Acc_clf 0.8594
label [ 6  4  7  7 10 17  9 13  9  9]
preds [ 6  4 13  7  8 13  9 13  9  9]


In [None]:
run(classify=False, complete=True, batch_size=64, encoder_mode='HYBRID', pooler_mode='ATT', dropout=0.2)

In [None]:
run(classify=True, complete=True, batch_size=64, encoder_mode='HYBRID', pooler_mode='ATT', dropout=0.2)