Skip to content

Commit

Permalink
Ensureing cwd swicth in api.get_experiment; and exposed dataloader to…
Browse files Browse the repository at this point in the history
… the returned experiment
  • Loading branch information
ahmed-shariff committed Feb 16, 2020
1 parent 49a76a5 commit 977b52b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 19 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.org
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
- Added current_version, experiment_dir and dataloader to be experiments properties
- Bug fixes:
- Moved each version of an experiment to be executed in a separate sub-process
- Ensured change to cwd when using api.get_experiment and exposed dataloader to the returned experiment
** 2.0.a.3 [2019-07-16]
- Several methods from the base classes are not required to be implemented. They are:
- Experiment.setup_model
Expand Down
42 changes: 23 additions & 19 deletions mlpipeline/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,28 @@ def mlpipeline_execute_exeperiment(file_path,


# TODO: Need to track the root of the project, or this becomes kind of ridiculous.
def get_experiment(file_path, experiment_dir, version_name, mlflow_tracking_uri=None):
def get_experiment(file_path, experiment_dir, version_name, mlflow_tracking_uri=None, load_dataloader=False):
cwd = os.getcwd()
experiment_dir = os.path.abspath(experiment_dir)
file_path = os.path.relpath(os.path.abspath(file_path), experiment_dir)
print(f"Setting root directory to: {experiment_dir}")
print(f"Loading experiment: {file_path}")
os.chdir(experiment_dir)
experiment = _load_file_as_module(file_path).EXPERIMENT
experiment.name = os.path.relpath(file_path, experiment_dir)
version_spec = experiment.versions.get_version(version_name)
experiment_dir, _ = _get_experiment_dir(experiment.name.split(".")[-2],
version_spec,
None)
# if mlflow_tracking_uri is None:
# run_id = None
# else:
# run_id = _get_mlflow_run_id(mlflow_tracking_uri, experiment, False, version_name)
experiment._current_version = version_spec
experiment._experiment_dir = None
os.chdir(cwd)
try:
experiment_dir = os.path.abspath(experiment_dir)
file_path = os.path.relpath(os.path.abspath(file_path), experiment_dir)
print(f"Setting root directory to: {experiment_dir}")
print(f"Loading experiment: {file_path}")
os.chdir(experiment_dir)
experiment = _load_file_as_module(file_path).EXPERIMENT
experiment.name = os.path.relpath(file_path, experiment_dir)
version_spec = experiment.versions.get_version(version_name)
experiment_dir, _ = _get_experiment_dir(experiment.name.split(".")[-2],
version_spec,
None)
# if mlflow_tracking_uri is None:
# run_id = None
# else:
# run_id = _get_mlflow_run_id(mlflow_tracking_uri, experiment, False, version_name)
experiment._current_version = version_spec
experiment._experiment_dir = None
if load_dataloader:
experiment._dataloader = version_spec.dataloader()
finally:
os.chdir(cwd)
return experiment # , experiment_dir, run_id

0 comments on commit 977b52b

Please sign in to comment.