In [1]:
import argparse
import functools
import logging
import os
import sys
from typing import Sequence

import joblib
# add directory up to path to get main naming script
from optuna.pruners import MedianPruner
import copy
import time
import warnings
from typing import Dict, Any, Callable, Optional, Tuple

import numpy as np
import tensorflow as tf
import optuna
from optuna import Trial
from optuna.integration import TFKerasPruningCallback
from tensorflow.python.keras.callbacks import History

from dataclasses import dataclass, field, asdict


2024-10-02 08:33:57.996568: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0


In [8]:
DEFAULT_NCP_SEED = 22222
IMAGE_SHAPE = (144, 256, 3)
IMAGE_SHAPE_CV = (IMAGE_SHAPE[1], IMAGE_SHAPE[0])

# helper classes that contain all the parameters in the generate_*_model functions
@dataclass
class ModelParams:
    # dataclasses can't have non-default follow default
    seq_len: int = field(default=False, init=True)
    image_shape: Tuple[int, int, int] = IMAGE_SHAPE
    augmentation_params: Optional[Dict] = None
    batch_size: Optional[int] = None
    single_step: bool = False
    no_norm_layer: bool = False


@dataclass
class NCPParams(ModelParams):
    seed: int = DEFAULT_NCP_SEED

COMMON_TRAIN_PARAMS = {
    "epochs": 100,
    "val_split": 0.05,
    "opt": "adam",
    "data_shift": 16,
    "data_stride": 1,
    "cached_data_dir": "cached_data",
    "save_period": 20,
}
COMMON_MODEL_PARAMS = {
    "seq_len": 64,
    "single_step": False,
    "no_norm_layer": False,
    "augmentation_params": {
        "noise": 0.05,
        "sequence_params": {
            "brightness": 0.4,
            "contrast": 0.4,
            "saturation": 0.4,
        }
    },
}

In [9]:
def sum_val_train_loss(logs):
    return logs["loss"] + logs["val_loss"]

In [10]:
class KerasPruningCallbackFunction(TFKerasPruningCallback):
    """
    Convenience class that allows pruning based on any function of the logs, instead of just looking at 1
    log metric
    """

    def __init__(self, trial: optuna.trial.Trial, get_objective: Callable) -> None:
        super().__init__(trial, "")
        self.get_objective = get_objective

    # copied from optuna/integration/keras.py
    def on_epoch_end(self, epoch: int, logs: Optional[Dict[str, float]] = None) -> None:
        logs = logs or {}
        current_score = self.get_objective(logs)
        if current_score is None:
            message = (
                "The metric '{}' is not in the evaluation logs for pruning. "
                "Please make sure you set the correct metric name.".format(self._monitor)
            )
            warnings.warn(message)
            return
        # logging a nan obj leads to crash
        if np.isnan(current_score):
            message = f"Trial was pruned at epoch {epoch} because objective value was NaN"
            raise optuna.TrialPruned(message)

        self._trial.report(float(current_score), step=epoch)
        if self._trial.should_prune():
            message = "Trial was pruned at epoch {}.".format(epoch)
            raise optuna.TrialPruned(message)

In [11]:
def calculate_objective(trial: Trial, result: Tuple[History, str]):
    """
    Calculates objective value from history of losses and also logs train_loss and val_loss separately

    @param trial: optuna trial
    @param result: Tensorflow history object returned by trainer
    @return: objective value
    """
    history, time_str = result
    trial.set_user_attr("checkpoint_time_str", time_str)

    losses = np.array([[epoch_train_loss, epoch_val_loss] for epoch_train_loss, epoch_val_loss in
                       zip(history.history["loss"], history.history["val_loss"])])
    loss_sums = losses.sum(axis=1)
    best_epoch = np.argmin(loss_sums)
    trial.set_user_attr("sum_train_loss", losses[best_epoch, 0])
    trial.set_user_attr("sum_val_loss", losses[best_epoch, 1])
    trial.set_user_attr("best_sum_epoch", int(best_epoch))

    # calculate best train and val epochs
    best_train = np.argmin(losses[:, 0])
    trial.set_user_attr("best_train_epoch", int(best_train))
    trial.set_user_attr("best_train_loss", losses[best_train, 0])
    best_val = np.argmin(losses[:, 1])
    trial.set_user_attr("best_val_epoch", int(best_val))
    trial.set_user_attr("best_val_loss", losses[best_val, 1])

    trial.set_user_attr("trial_time", time.time())

    objective = loss_sums[best_epoch]
    return objective

In [6]:
# optuna objetive functions
def ncp_objective(trial: Trial, data_dir: str, batch_size: int, **train_kwargs: Dict[str, Any]):
    # get trial params from bayesian optimization
    seeds_to_try = list(range(22221, 22230)) + [55555]
    ncp_seed = trial.suggest_categorical("ncp_seed", seeds_to_try)

    lr = trial.suggest_float("lr", low=1e-5, high=1e-2, log=True)
    decay_rate = trial.suggest_float("decay_rate", 0.85, 1)

    prune_callback = [KerasPruningCallbackFunction(trial, sum_val_train_loss)]

    model_params = NCPParams(seed=ncp_seed, **COMMON_MODEL_PARAMS)
    merged_kwargs = copy.deepcopy(COMMON_TRAIN_PARAMS)
    merged_kwargs.update(**train_kwargs)
    history = train_model(lr=lr, decay_rate=decay_rate, callbacks=prune_callback,
                          model_params=model_params, data_dir=data_dir, batch_size=batch_size, **merged_kwargs)
    trial.set_user_attr("model_params", repr(model_params))

    return calculate_objective(trial, history)

In [7]:
def optimize_hyperparameters(obj_fn: Callable, data_dir: str, study_name: str, n_trials: int, batch_size: int,
                             timeout: float = None, storage_name: str = "sqlite:///hyperparam_tuning.db",
                             save_pkl: bool = False, train_kwargs: Optional[Dict[str, Any]] = None):
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler(sys.stdout))
    study_name_network = f"{study_name}{obj_fn.__name__}"
    study_params = {
        "study_name": study_name_network,
        "load_if_exists": True,
        "direction": "minimize",
        "pruner": MedianPruner(n_warmup_steps=10, n_min_trials=3),
    }

    if save_pkl:
        path_relative = os.path.join(SCRIPT_DIR, storage_name)
        if os.path.exists(path_relative):
            study = joblib.load(path_relative)
        else:
            print(f"No existing study found at path {path_relative}. Creating a new one")
            study = optuna.create_study(**study_params)
    else:
        study = optuna.create_study(storage=storage_name, **study_params)

    if train_kwargs is None:
        train_kwargs = {}

    # only continue training up to n_trials trials total
    current_num_trials = len(study.trials)
    remaining_trials = n_trials-current_num_trials

    study.optimize(functools.partial(obj_fn, data_dir=data_dir, batch_size=batch_size, **train_kwargs),
                   n_trials=remaining_trials, timeout=timeout)

    if save_pkl:
        joblib.dump(study, storage_name)
    return study