Skip to content

Commit

Permalink
Merge 4b948a3 into a181df5
Browse files Browse the repository at this point in the history
  • Loading branch information
gabrieldemarmiesse committed Aug 2, 2019
2 parents a181df5 + 4b948a3 commit 663da83
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 33 deletions.
14 changes: 8 additions & 6 deletions sacred/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import os.path
import re
import sys
import pathlib
from pathlib import Path

import pkg_resources

import sacred.optional as opt
from sacred import SETTINGS
from sacred.utils import is_subdir, iter_prefixes
from sacred.utils import iter_prefixes

MB = 1048576
MODULE_BLACKLIST = set(sys.builtin_module_names)
Expand Down Expand Up @@ -260,7 +260,7 @@ def create(cls, mod):

def convert_path_to_module_parts(path):
"""Convert path to a python file into list of module names."""
module_parts = list(pathlib.Path(path).parts)
module_parts = list(path.parts)
if module_parts[-1] in ['__init__.py', '__init__.pyc']:
# remove trailing __init__.py
module_parts = module_parts[:-1]
Expand Down Expand Up @@ -294,17 +294,19 @@ def is_local_source(filename, modname, experiment_path):
True if the module was imported locally from (a subdir of) the
experiment_path, and False otherwise.
"""
if not is_subdir(filename, experiment_path):
filename = Path(filename).resolve()
experiment_path = Path(experiment_path).resolve()
if experiment_path not in filename.parents:
return False
rel_path = os.path.relpath(filename, experiment_path)
rel_path = filename.relative_to(experiment_path)
path_parts = convert_path_to_module_parts(rel_path)

mod_parts = modname.split('.')
if path_parts == mod_parts:
return True
if len(path_parts) > len(mod_parts):
return False
abs_path_parts = convert_path_to_module_parts(os.path.abspath(filename))
abs_path_parts = convert_path_to_module_parts(filename)
return all([p == m for p, m in zip(reversed(abs_path_parts),
reversed(mod_parts))])

Expand Down
9 changes: 1 addition & 8 deletions sacred/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
"set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
"iter_prefixes", "join_paths", "is_prefix",
"convert_to_nested_dict", "convert_camel_case_to_snake_case",
"print_filtered_stacktrace", "is_subdir",
"print_filtered_stacktrace",
"optional_kwargs_decorator", "get_inheritors",
"apply_backspaces_and_linefeeds", "rel_path", "IntervalTimer",
"PathType"]
Expand Down Expand Up @@ -547,13 +547,6 @@ def filtered_traceback_format(tb_exception, chain=True):
yield from tb_exception.format_exception_only()


def is_subdir(path, directory):
path = os.path.abspath(os.path.realpath(path)) + os.sep
directory = os.path.abspath(os.path.realpath(directory)) + os.sep

return path.startswith(directory)


# noinspection PyUnusedLocal
@wrapt.decorator
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
Expand Down
20 changes: 1 addition & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from sacred.utils import (PATHCHANGE, convert_to_nested_dict,
get_by_dotted_path, is_prefix, is_subdir,
get_by_dotted_path, is_prefix,
iter_path_splits, iter_prefixes, iterate_flattened,
iterate_flattened_separately, join_paths,
recursive_update, set_by_dotted_path, get_inheritors,
Expand Down Expand Up @@ -101,24 +101,6 @@ def test_convert_to_nested_dict_nested():
{'a': {'b': {'foo': {'bar': 8, 'baz': 7}}}}


@pytest.mark.parametrize('path,parent,expected', [
('/var/test2', '/var/test', False),
('/var/test', '/var/test2', False),
('var/test2', 'var/test', False),
('var/test', 'var/test2', False),
('/var/test/sub', '/var/test', True),
('/var/test', '/var/test/sub', False),
('var/test/sub', 'var/test', True),
('var/test', 'var/test', True),
('var/test', 'var/test/fake_sub/..', True),
('var/test/sub/sub2/sub3/../..', 'var/test', True),
('var/test/sub', 'var/test/fake_sub/..', True),
('var/test', 'var/test/sub', False)
])
def test_is_subdirectory(path, parent, expected):
assert is_subdir(path, parent) == expected


def test_get_inheritors():
class A:
pass
Expand Down

0 comments on commit 663da83

Please sign in to comment.