In [1]:
import os
import sys

import optuna

sys.path.insert(0, os.path.join(os.getcwd(), ".."))

from src.dataset import fetch_dataset
from src.train import run_training


In [2]:
N_TRIALS = 50
os.environ["http_proxy"] = ""
os.environ["https_proxy"] = ""
os.environ["HTTP_PROXY"] = ""
os.environ["HTTPS_PROXY"] = ""


In [None]:
def objective(trial):

    checkpoint_name = trial.suggest_categorical(
        "checkpoint_name",
        (
            "bert-base-uncased",
            "distilbert-base-uncased",
            "unitary/toxic-bert",
            "roberta-base",
        ),
    )
    batch_size = trial.suggest_int("batch_size", low=6, high=10)
    epochs = trial.suggest_int("epochs", low=8, high=24)
    random_seed = trial.suggest_int("random_seed", low=0, high=1000_000)
    lr = trial.suggest_loguniform("lr", low=1e-6, high=1e-4)
    amsgrad = trial.suggest_categorical("amsgrad", (True, False))
    weight_decay = trial.suggest_loguniform("weight_decay", low=1e-4, high=0.1)
    eps = trial.suggest_loguniform("eps", low=1e-10, high=1e-6)
    beta_1 = trial.suggest_uniform("beta_1", low=0.0, high=1.0)
    beta_2 = trial.suggest_uniform("beta_2", low=0.0, high=1.0)
    lr_step_size_factor = trial.suggest_int("lr_step_size_factor", low=3, high=6)

    print(f"Epoch: {epochs}, Batch: {batch_size}")

    train_dataloader, valid_dataloader, _ = fetch_dataset(
        random_seed=random_seed, batch_size=batch_size
    )

    _, best_metrics = run_training(
        train_dataloader=train_dataloader,
        valid_dataloader=valid_dataloader,
        checkpoint_name=checkpoint_name,
        epochs=epochs,
        optimizer_parameters={
            "lr": lr,
            "amsgrad": amsgrad,
            "weight_decay": weight_decay,
            "eps": eps,
            "betas": [beta_1, beta_2],
        },
        lr_step_parameters={"step_size": epochs // lr_step_size_factor, "gamma": 0.1},
        logging_interval=1500,
    )

    return best_metrics


In [None]:
study = optuna.create_study(
    study_name="toxic-comment-classification",
    direction="minimize",
    pruner=optuna.pruners.HyperbandPruner(),
    sampler=optuna.samplers.TPESampler(multivariate=True),
    storage=STORAGE,
    load_if_exists=True,
)

study.optimize(
    objective,
    n_trials=N_TRIALS,
    gc_after_trial=True,
)
