Skip to content

Commit

Permalink
On Windows, prevent long trial directory names (#735)
Browse files Browse the repository at this point in the history
  • Loading branch information
tg2k committed Mar 8, 2024
1 parent 7776d00 commit 906ea9f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
7 changes: 6 additions & 1 deletion nbs/common.base_auto.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -288,14 +288,19 @@
" else:\n",
" device_dict = {'cpu':cpus}\n",
"\n",
" # on Windows, prevent long trial directory names\n",
" import platform\n",
" trial_dirname_creator=(lambda trial: f\"{trial.trainable_name}_{trial.trial_id}\") if platform.system() == 'Windows' else None\n",
"\n",
" tuner = tune.Tuner(\n",
" tune.with_resources(train_fn_with_parameters, device_dict),\n",
" run_config=air.RunConfig(callbacks=self.callbacks, verbose=verbose),\n",
" tune_config=tune.TuneConfig(\n",
" metric=\"loss\",\n",
" mode=\"min\",\n",
" num_samples=num_samples, \n",
" search_alg=search_alg\n",
" search_alg=search_alg,\n",
" trial_dirname_creator=trial_dirname_creator,\n",
" ),\n",
" param_space=config,\n",
" )\n",
Expand Down
5 changes: 5 additions & 0 deletions neuralforecast/common/_base_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,6 +241,10 @@ def _tune_model(
else:
device_dict = {"cpu": cpus}

# on Windows, prevent long trial directory names
import platform
trial_dirname_creator=(lambda trial: f"{trial.trainable_name}_{trial.trial_id}") if platform.system() == 'Windows' else None

tuner = tune.Tuner(
tune.with_resources(train_fn_with_parameters, device_dict),
run_config=air.RunConfig(callbacks=self.callbacks, verbose=verbose),
Expand All @@ -249,6 +253,7 @@ def _tune_model(
mode="min",
num_samples=num_samples,
search_alg=search_alg,
trial_dirname_creator=trial_dirname_creator,
),
param_space=config,
)
Expand Down

0 comments on commit 906ea9f

Please sign in to comment.