Skip to content

Commit

Permalink
add exception and pretty json files
Browse files Browse the repository at this point in the history
  • Loading branch information
HamidrezaKmK committed Jun 2, 2023
1 parent 110c225 commit 96ad8fe
Showing 1 changed file with 36 additions and 6 deletions.
42 changes: 36 additions & 6 deletions dysweep/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import shutil
import os
import traceback
import inspect

SPLIT = '-'

Expand All @@ -24,7 +25,6 @@ class ResumableSweepConfig:
count: th.Optional[int] = None
resume: bool = False
run_name: th.Optional[str] = None
checkpoint_interval: int = 3600
sweep_id: th.Optional[th.Union[int, str]] = None
#
use_lightning_logger: bool = False
Expand Down Expand Up @@ -129,7 +129,8 @@ def modified_function():
name=run_name,
)
experiment_id = logger.experiment.id
sweep_config = hierarchical_config(logger.experiment.config)
sweep_config = hierarchical_config(
logger.experiment.config)
else:
import wandb
wandb.init(
Expand All @@ -143,23 +144,52 @@ def modified_function():
new_dir_name = f"{len(all_subdirs)+1}{SPLIT}{experiment_id}"

os.makedirs(checkpoint_dir / new_dir_name)

# dump a json in checkpoint_dir/run_id containing the sweep config
with open(checkpoint_dir / new_dir_name / "sweep_config.json", "w") as f:
json.dump(sweep_config, f)
json.dump(sweep_config, f, indent=4, sort_keys=True)

new_checkpoint_dir = checkpoint_dir / new_dir_name
except Exception as e:
print(traceback.format_exc())
raise e

if conf.use_lightning_logger:
# check the function signature matches
# the one we expect.
# in which there are two arguments with the first one
# named config and the second one named checkpoint_dir

# get the signature of the function
sig = inspect.signature(function)
# get the parameters of the function
params = sig.parameters
# check that the function has two parameters
if len(params) != 3 or list(params.keys())[0] != "config" or list(params.keys())[1] != "logger" or list(params.keys())[2] != "checkpoint_dir":
raise ValueError(
"the run function should have the exact following parameters in order: (config, logger, checkpoint_dir)")

ret = function(sweep_config, logger, new_checkpoint_dir)
else:
# check the function signature matches
# the one we expect.
# in which there are two arguments with the first one
# named config and the second one named checkpoint_dir

# get the signature of the function
sig = inspect.signature(function)
# get the parameters of the function
params = sig.parameters
# check that the function has two parameters
if len(params) != 2 or list(params.keys())[0] != "config" or list(params.keys())[1] != "checkpoint_dir":
raise ValueError(
"the run function should have the exact following parameters in order: (config, checkpoint_dir)")
ret = function(sweep_config, new_checkpoint_dir)

# remove the entire new_checkpoint_dir if the function has finished
# running.
shutil.copyfile(new_checkpoint_dir / "sweep_config.json",
checkpoint_dir / f"{experiment_id}-config.json")
shutil.rmtree(new_checkpoint_dir)

return ret
Expand All @@ -175,8 +205,8 @@ def modified_function():
else:
try:
sweep(conf.base_config, conf.sweep_configuration,
entity=conf.entity, project=conf.project)
entity=conf.entity, project=conf.project)
except Exception as e:
print("Exception at creation of sweep:")
print(traceback.format_exc())
raise e
raise e

0 comments on commit 96ad8fe

Please sign in to comment.