Skip to content

Commit

Permalink
move global tmp dir creation to pytest fixture
Browse files Browse the repository at this point in the history
  • Loading branch information
StrikerRUS committed Nov 22, 2019
1 parent 3d0e4fb commit d53e3d7
Show file tree
Hide file tree
Showing 4 changed files with 21 additions and 17 deletions.
8 changes: 8 additions & 0 deletions tests/conftest.py
@@ -1,5 +1,7 @@
import pytest

from tests import utils


def pytest_addoption(parser):
parser.addoption(
Expand All @@ -11,3 +13,9 @@ def pytest_addoption(parser):
@pytest.fixture
def is_fast(request):
return request.config.getoption("--fast")


@pytest.fixture(scope="module")
def global_tmp_dir():
with utils.tmp_dir() as directory:
yield directory
8 changes: 3 additions & 5 deletions tests/e2e/executors/base.py
Expand Up @@ -6,7 +6,6 @@
class BaseExecutor:

_resource_tmp_dir = None
_global_resource_tmp_dir = None

@contextlib.contextmanager
def prepare_then_cleanup(self):
Expand All @@ -23,7 +22,6 @@ def prepare(self):
raise NotImplementedError

@classmethod
def prepare_global(cls):
if cls._global_resource_tmp_dir is None:
with utils.tmp_dir() as tmp_dirpath:
cls._global_resource_tmp_dir = tmp_dirpath
def prepare_global(cls, **kwargs):
for key, value in kwargs.items():
setattr(cls, "_{}".format(key), value)
17 changes: 7 additions & 10 deletions tests/e2e/executors/visual_basic.py
Expand Up @@ -56,17 +56,16 @@ def predict(self, X):
return utils.predict_from_commandline(exec_args)

@classmethod
def prepare_global(cls):
super(VisualBasicExecutor, cls).prepare_global()
def prepare_global(cls, **kwargs):
super(VisualBasicExecutor, cls).prepare_global(**kwargs)
if cls.target_exec_dir is None:
cls.target_exec_dir = os.path.join(cls._global_resource_tmp_dir,
"bin")
cls.target_exec_dir = os.path.join(cls._global_tmp_dir, "bin")

subprocess.call([cls._dotnet,
"new",
"console",
"--output",
cls._global_resource_tmp_dir,
cls._global_tmp_dir,
"--name",
cls.project_name,
"--language",
Expand All @@ -81,18 +80,16 @@ def prepare(self):
print_code=print_code)
model_code = self.interpreter.interpret(self.model_ast)

model_file_name = os.path.join(self._global_resource_tmp_dir,
"Model.vb")
executor_file_name = os.path.join(self._global_resource_tmp_dir,
"Program.vb")
model_file_name = os.path.join(self._global_tmp_dir, "Model.vb")
executor_file_name = os.path.join(self._global_tmp_dir, "Program.vb")
with open(model_file_name, "w") as f:
f.write(model_code)
with open(executor_file_name, "w") as f:
f.write(executor_code)

subprocess.call([self._dotnet,
"build",
os.path.join(self._global_resource_tmp_dir,
os.path.join(self._global_tmp_dir,
"{}.vbproj".format(self.project_name)),
"--output",
self.target_exec_dir])
5 changes: 3 additions & 2 deletions tests/e2e/test_e2e.py
Expand Up @@ -173,15 +173,16 @@ def classification_binary(model):
# <empty>
)
def test_e2e(estimator, executor_cls, model_trainer, is_fast):
def test_e2e(estimator, executor_cls, model_trainer,
is_fast, global_tmp_dir):
sys.setrecursionlimit(RECURSION_LIMIT)

X_test, y_pred_true = model_trainer(estimator)
executor = executor_cls(estimator)

idxs_to_test = [0] if is_fast else range(len(X_test))

executor.prepare_global()
executor.prepare_global(global_tmp_dir=global_tmp_dir)
with executor.prepare_then_cleanup():
for idx in idxs_to_test:
y_pred_executed = executor.predict(X_test[idx])
Expand Down

0 comments on commit d53e3d7

Please sign in to comment.