Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using pathlib to simplify is_local_source. #559

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions sacred/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import os.path
import re
import sys
import pathlib
from pathlib import Path

import pkg_resources

import sacred.optional as opt
from sacred import SETTINGS
from sacred.utils import is_subdir, iter_prefixes
from sacred.utils import iter_prefixes

MB = 1048576
MODULE_BLACKLIST = set(sys.builtin_module_names)
Expand Down Expand Up @@ -260,7 +260,7 @@ def create(cls, mod):

def convert_path_to_module_parts(path):
"""Convert path to a python file into list of module names."""
module_parts = list(pathlib.Path(path).parts)
module_parts = list(path.parts)
if module_parts[-1] in ['__init__.py', '__init__.pyc']:
# remove trailing __init__.py
module_parts = module_parts[:-1]
Expand Down Expand Up @@ -294,17 +294,19 @@ def is_local_source(filename, modname, experiment_path):
True if the module was imported locally from (a subdir of) the
experiment_path, and False otherwise.
"""
if not is_subdir(filename, experiment_path):
filename = Path(os.path.abspath(os.path.realpath(filename)))
experiment_path = Path(os.path.abspath(os.path.realpath(experiment_path)))
if experiment_path not in filename.parents:
return False
rel_path = os.path.relpath(filename, experiment_path)
rel_path = filename.relative_to(experiment_path)
path_parts = convert_path_to_module_parts(rel_path)

mod_parts = modname.split('.')
if path_parts == mod_parts:
return True
if len(path_parts) > len(mod_parts):
return False
abs_path_parts = convert_path_to_module_parts(os.path.abspath(filename))
abs_path_parts = convert_path_to_module_parts(filename)
return all([p == m for p, m in zip(reversed(abs_path_parts),
reversed(mod_parts))])

Expand Down
10 changes: 1 addition & 9 deletions sacred/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import importlib
import inspect
import logging
import os.path
import pkgutil
import re
import shlex
Expand All @@ -27,7 +26,7 @@
"set_by_dotted_path", "get_by_dotted_path", "iter_path_splits",
"iter_prefixes", "join_paths", "is_prefix",
"convert_to_nested_dict", "convert_camel_case_to_snake_case",
"print_filtered_stacktrace", "is_subdir",
"print_filtered_stacktrace",
"optional_kwargs_decorator", "get_inheritors",
"apply_backspaces_and_linefeeds", "rel_path", "IntervalTimer",
"PathType"]
Expand Down Expand Up @@ -547,13 +546,6 @@ def filtered_traceback_format(tb_exception, chain=True):
yield from tb_exception.format_exception_only()


def is_subdir(path, directory):
path = os.path.abspath(os.path.realpath(path)) + os.sep
directory = os.path.abspath(os.path.realpath(directory)) + os.sep

return path.startswith(directory)


# noinspection PyUnusedLocal
@wrapt.decorator
def optional_kwargs_decorator(wrapped, instance=None, args=None, kwargs=None):
Expand Down
20 changes: 1 addition & 19 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import pytest

from sacred.utils import (PATHCHANGE, convert_to_nested_dict,
get_by_dotted_path, is_prefix, is_subdir,
get_by_dotted_path, is_prefix,
iter_path_splits, iter_prefixes, iterate_flattened,
iterate_flattened_separately, join_paths,
recursive_update, set_by_dotted_path, get_inheritors,
Expand Down Expand Up @@ -101,24 +101,6 @@ def test_convert_to_nested_dict_nested():
{'a': {'b': {'foo': {'bar': 8, 'baz': 7}}}}


@pytest.mark.parametrize('path,parent,expected', [
('/var/test2', '/var/test', False),
('/var/test', '/var/test2', False),
('var/test2', 'var/test', False),
('var/test', 'var/test2', False),
('/var/test/sub', '/var/test', True),
('/var/test', '/var/test/sub', False),
('var/test/sub', 'var/test', True),
('var/test', 'var/test', True),
('var/test', 'var/test/fake_sub/..', True),
('var/test/sub/sub2/sub3/../..', 'var/test', True),
('var/test/sub', 'var/test/fake_sub/..', True),
('var/test', 'var/test/sub', False)
])
def test_is_subdirectory(path, parent, expected):
assert is_subdir(path, parent) == expected


def test_get_inheritors():
class A:
pass
Expand Down