In [1]:
import os
from run import main
from dataset import RecipeDataset
import wandb  # if it raises an error, reinstall wandb
import warnings
warnings.filterwarnings(action='ignore')

# loading datasets takes some time.
dataset_names = ['train', 'valid_clf', 'valid_cpl', 'test_clf', 'test_cpl']
recipe_datasets = {x: RecipeDataset(os.path.join('./Container', x)) for x in dataset_names}

In [2]:
class Args(object):
    def __init__(self, **kwargs):
        default = dict(
            data_dir='./Container',
            batch_size=64,
            batch_size_eval=2048,
            n_epochs=200,
            lr=2e-4,
            weight_decay=0,
            step_size=10,  # lr_scheduler
            step_factor=0.25, # lr_scheduler
            early_stop_patience=20,  # early stop
            seed=42,
            subset_length=None,  # Default: None
            dim_embedding=256,
            dim_hidden=256,
            num_inds=10,
            dropout=0,
            encoder_mode ='HYBRID',  # 'FC' 'ISA' 'SA' 'HYBRID' 'HYBRID_SA'
            pooler_mode ='PMA',      # 'SumPool' 'PMA'
            cpl_scheme='encoded',    # 'pooled', 'encoded'
            num_enc_layers=4,
            num_dec_layers=0,
            loss='CrossEntropyLoss', # 'CrossEntropyLoss' 'MultiClassFocalLoss' 'MultiClassASLoss'
            optimizer_name='AdamW',
            classify=True,
            complete=True,
            freeze_classify=False,
            freeze_complete=False,
            freeze_encoder=False,
            pretrained_model_path=None,
            wandb_log=True,
            verbose=True,
            datasets=recipe_datasets,
            gpu=0)  # If you have multiple GPUs, you can change this number (e.g. gpu=3: "device = 'cuda:3'.")
        default.update(kwargs)
        for k in default:
            setattr(self, k, default[k])
    def update(self, **kwargs):
        for k in kwargs:
            setattr(self, k, kwargs[k])

def run(**kwargs):
    args = Args()
    args.update(**kwargs)
    main(args)

In [3]:
wandb_log = False
encoder_mode = 'HYBRID'
pooler_mode = 'PMA'
num_enc_layers = 8

## Scheme (a) (`cpl_scheme=='pooled'`)

### Classification Only (Poor result)

In [None]:
run(classify=True, complete=False, encoder_mode=encoder_mode, cpl_scheme='pooled',
    num_enc_layers=num_enc_layers, wandb_log=wandb_log,)

### Completion Only (Poor result)

In [None]:
run(classify=False, complete=True, encoder_mode=encoder_mode, cpl_scheme='pooled',
    num_enc_layers=num_enc_layers,wandb_log=wandb_log,)

### Classification + Completion (Good for classification)

In [None]:
run(classify=True, complete=True, encoder_mode=encoder_mode, cpl_scheme='pooled',
    num_enc_layers=num_enc_layers,wandb_log=wandb_log,)

## Scheme (b) (`cpl_scheme=='encoded'`)

### Classification Only (Poor result)

In [None]:
run(classify=True, complete=False, encoder_mode=encoder_mode, cpl_scheme='encoded',
    num_enc_layers=num_enc_layers, wandb_log=wandb_log,)

### Completion Only (Pretty good)

In [None]:
run(classify=False, complete=True, encoder_mode=encoder_mode, cpl_scheme='encoded',
    num_enc_layers=num_enc_layers,wandb_log=wandb_log,)

### Classification + Completion (Great for classification)

In [None]:
run(classify=True, complete=True, encoder_mode=encoder_mode, cpl_scheme='encoded',
    num_enc_layers=num_enc_layers,wandb_log=wandb_log,)