Skip to content

Commit

Permalink
Add requirements file to the output folder (#280)
Browse files Browse the repository at this point in the history
* Add feature codebase

* Add integration tests

* Improve tests

* Add missing test

Co-authored-by: Jeremy Wohlwend <33673620+jeremyasapp@users.noreply.github.com>
  • Loading branch information
iitzco-asapp and jeremyasapp committed Mar 26, 2020
1 parent 150a4a2 commit 7b02dff
Show file tree
Hide file tree
Showing 5 changed files with 87 additions and 3 deletions.
1 change: 1 addition & 0 deletions flambe/compile/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@
CONFIG_FILE_NAME = 'config.yaml'
STASH_FILE_NAME = 'stash.pkl'
PROTOCOL_VERSION_FILE_NAME = 'protocol_version.txt'
REQUIREMENTS_FILE_NAME = 'requirements.txt'
6 changes: 5 additions & 1 deletion flambe/compile/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import torch

from flambe.compile.registrable import yaml
from flambe.compile.utils import write_deps
from flambe.compile.downloader import download_manager
from flambe.compile.extensions import import_modules, is_installed_module, install_extensions, \
setup_default_modules
Expand All @@ -18,7 +19,7 @@
FLAMBE_CONFIG_KEY, FLAMBE_DIRECTORIES_KEY, VERSION_KEY, \
HIGHEST_SERIALIZATION_PROTOCOL_VERSION, DEFAULT_SERIALIZATION_PROTOCOL_VERSION, \
DEFAULT_PROTOCOL, STATE_FILE_NAME, VERSION_FILE_NAME, SOURCE_FILE_NAME, CONFIG_FILE_NAME, \
PROTOCOL_VERSION_FILE_NAME, FLAMBE_STASH_KEY, STASH_FILE_NAME
PROTOCOL_VERSION_FILE_NAME, FLAMBE_STASH_KEY, STASH_FILE_NAME, REQUIREMENTS_FILE_NAME


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -293,6 +294,9 @@ def save_state_to_file(state: State,
f_proto.write(str(DEFAULT_SERIALIZATION_PROTOCOL_VERSION))
with open(os.path.join(current_path, STASH_FILE_NAME), 'wb') as f_stash:
torch.save(node.object_stash, f_stash, pickle_module, pickle_protocol)

write_deps(os.path.join(path, REQUIREMENTS_FILE_NAME))

if compress:
compressed_file_name = original_path + '.tar.gz'
with tarfile.open(name=compressed_file_name, mode='w:gz') as tar_gz:
Expand Down
41 changes: 40 additions & 1 deletion flambe/compile/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
from typing import Type, Set, Any, Optional, List
from typing import Type, Set, Any, Optional, List, Iterable

from urllib.parse import urlparse

try:
from pip._internal.operations import freeze
except ImportError: # pip < 10.0
from pip.operations import freeze


def all_subclasses(class_: Type[Any]) -> Set[Type[Any]]:
"""Return a set of all subclasses for a given class object
Expand Down Expand Up @@ -90,3 +95,37 @@ def _is_url(resource: str) -> bool:
"""
scheme = urlparse(resource).scheme
return scheme != ''


def get_frozen_deps() -> Iterable[str]:
"""Get the frozen dependencies that are locally installed.
This should yield the same results as runnning 'pip freeze'.
Returns
-------
Iterable[str]
The frozen dependencies as strings.
"""
return freeze.freeze()


def write_deps(filename: str, deps: Optional[Iterable[str]] = None) -> None:
"""Write dependencies on a filename, following Python's convention
for requiremnets.
Parameters
----------
filename: str
The filename where the dependencies will be written.
deps: Optional[Iterable[str]]
Optional dependencies to write. If not provided,
this method will get the dependencies automatically.
This parameter should be used for testing purposes only.
"""
deps = deps or get_frozen_deps()
with open(filename, "w+") as f:
f.writelines('\n'.join(deps))
f.flush()
20 changes: 19 additions & 1 deletion tests/unit/compile/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import os

import torch
import tarfile
import dill
import mock
from ruamel.yaml.compat import StringIO
Expand Down Expand Up @@ -651,6 +652,24 @@ def test_module_save_and_load_roundtrip(self, basic_object, pickle_only, compres
check_mapping_equivalence(new_state, old_state)
check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)

@pytest.mark.parametrize("compress_save_file", [True, False])
@mock.patch('flambe.compile.utils.get_frozen_deps')
def test_module_save_requirements_file(self, mock_freeze, compress_save_file, basic_object):
mock_freeze.return_value = ['pkgA==1.2.3', 'pkgB']
old_obj = basic_object(from_config=True)
with tempfile.TemporaryDirectory() as root_path:
path = os.path.join(root_path, 'savefile.flambe')
save(old_obj, path, compress=compress_save_file, pickle_only=False)
if compress_save_file:
with tarfile.open(f"{path}.tar.gz", 'r:gz') as tar_gz:
tar_gz.extractall(path=root_path)

mock_freeze.assert_called_once()
assert os.path.exists(os.path.join(path, 'requirements.txt'))

with open(os.path.join(path, 'requirements.txt'), 'r') as f:
assert f.read() == 'pkgA==1.2.3\npkgB'

@pytest.mark.parametrize("pickle_only", [True, False])
@pytest.mark.parametrize("compress_save_file", [True, False])
def test_module_save_and_load_roundtrip_pytorch(self,
Expand All @@ -671,7 +690,6 @@ def test_module_save_and_load_roundtrip_pytorch(self,
check_mapping_equivalence(new_state, old_state)
check_mapping_equivalence(old_state._metadata, new_state._metadata, exclude_config=False)


def test_module_save_and_load_roundtrip_pytorch_only_bridge(self):
a = BasicStateful.compile(x=3)
b = 100
Expand Down
22 changes: 22 additions & 0 deletions tests/unit/compile/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
import tempfile
import mock


from flambe.compile.utils import write_deps


def test_write_deps():
dummy_dependencies = ['numpy==1.2.3', 'pip~=1.1.1', 'some_other-random dep']
with tempfile.NamedTemporaryFile() as tmpfile:
write_deps(tmpfile.name, dummy_dependencies)

assert tmpfile.read() == b'numpy==1.2.3\npip~=1.1.1\nsome_other-random dep'


@mock.patch('flambe.compile.utils.get_frozen_deps')
def test_write_deps_default(mock_deps):
mock_deps.return_value = ['numpy==1.2.3', 'pip~=1.1.1', 'some_other-random dep']
with tempfile.NamedTemporaryFile() as tmpfile:
write_deps(tmpfile.name)
assert tmpfile.read() == b'numpy==1.2.3\npip~=1.1.1\nsome_other-random dep'
mock_deps.assert_called_once()

0 comments on commit 7b02dff

Please sign in to comment.