# Tune NN

## Set Up

Import packages/globals

In [None]:
import sys
import os
import subprocess

sys.path.append("../")
from src.config import SEED, BASE_PATH
from src.data_utils import get_data
print(f"Path: {BASE_PATH}")
NN_TUNE_STAGE =1 

## Build + Tune models

In [None]:
env = os.environ.copy()
env["PYTHONPATH"] = str(BASE_PATH)
script_path = BASE_PATH / "src" / "tune_nn.py"
match NN_TUNE_STAGE:
    case 1:
        n_trials = 300
    case 2:
        n_trials = 150
    case _:
        raise ValueError(
            f"STAGE in src/config.py must be one of [1,2]. Got {NN_TUNE_STAGE} instead."
        )
cmds = [
    [
        "uv",
        "run",
        str(script_path),
        "--X_path",
        str(BASE_PATH / "data" / "processed" / "base" / "X_train.parquet"),
        "--y_path",
        str(BASE_PATH / "data" / "processed"  / "base"/ "y_train.xlsx"),
        "--scoring_str",
        "roc_auc",
        "--log_path",
        str(BASE_PATH / "logs" / "phase_1"/ "nn.log"),
        "--results_path",
        str(
            BASE_PATH
            / "tune_results"
            / "phase_1"
            / "nn.json"
        ),
        "--n_trials",
        str(n_trials),
        "--seed",
        str(SEED),
        "--stage",
        str(NN_TUNE_STAGE)
    ]]

In [None]:
procs = [subprocess.Popen(cmd, env=env) for cmd in cmds]
[p.poll() for p in procs]

Monitor commands 
- None means still running, otherwise provides exit code
    - 0 indicates a successful exit
- Can run as many times as u want

In [None]:
[p.poll() for p in procs]

Kill processes
- Only run once
- should output non-zero exit codes for each process
    - if not, run above cell again to monitor

In [None]:
for p in procs:
    p.terminate()
[p.poll() for p in procs]

## Train models + prelim results

In [None]:
from src.tune_nn import train_and_prelim_eval
data_dict = get_data(is_nomo=False)
train_and_prelim_eval(
    data_dict=data_dict,
    json_path =BASE_PATH/ "tune_results"/ "phase_1"/ "nn.json",
    model_save_path=BASE_PATH / "models" / "phase_1_trained" / "nn.pt"
)