# Neural Network Construction + Tuning + Training

In [None]:
import sys
import warnings

sys.path.append("../")
from src.data_utils import get_data
from src.config import SEED, BASE_PATH
from src.tune import tune_model_nn, get_prelim_results, build_nn_estimator

import json
import optuna
from joblib import Parallel, delayed

Set Globals

In [None]:
OUTCOME_DICT = {
    "surg": get_data("outcome_surg"),
    "bleed": get_data("outcome_bleed"),
    "asp": get_data("outcome_asp"),
    "mort": get_data("outcome_mort"),
}

LOG_DIR = BASE_PATH / "logs" / "nn"
LOG_DIR.mkdir(parents=True, exist_ok=True)
RESULT_PATH = BASE_PATH / "models" / "tune_results" / "nn"
N_TRIALS = 450  # More trials to have ample search space for 2 vs 3 layers
SCORING = "roc_auc"

## Build + tune models

In [None]:
jobs = []
for outcome_name, outcome_data in OUTCOME_DICT.items():
    log_file_path = LOG_DIR / f"{outcome_name}.log"
    if log_file_path.exists():
        warnings.warn(f"Over-writing logs at path: {log_file_path}")
        log_file_path.unlink()

    cur_outcome_dict = OUTCOME_DICT[outcome_name]
    X_train = cur_outcome_dict["X_train"]
    y_train = cur_outcome_dict["y_train"].values.ravel()

    study = optuna.create_study(
        study_name=f"NN_{outcome_name}_study",
        direction="maximize",
        sampler=optuna.samplers.TPESampler(seed=SEED),
        pruner=optuna.pruners.HyperbandPruner(min_resource=5, reduction_factor=3),
    )

    log_path = LOG_DIR / f"{outcome_name}.log"

    jobs.append(
        delayed(tune_model_nn)(
            X_train=X_train,
            y_train=y_train,
            scoring=SCORING,
            study=study,
            log_path=log_path,
            save_path=RESULT_PATH / f"{outcome_name}.json",
            n_trials=N_TRIALS,
        )
    )

results = Parallel(n_jobs=len(jobs))(jobs)

## Get prelim results

In [None]:
for outcome_name, outcome_data in OUTCOME_DICT.items():
    with open(RESULT_PATH / f"{outcome_name}.json", "r") as f:
        dict_import = json.load(f)
    results_dict = {outcome_name: dict_import}
    get_prelim_results(
        results_dict=results_dict,
        model_builder=build_nn_estimator,
        model_abrv="nn",
        outcome_dict=OUTCOME_DICT,
        model_save_dir=BASE_PATH / "models" / outcome_name,
    )