Skip to content

Commit

Permalink
added tests for the FileStorageObserver and changed behaviour of reso…
Browse files Browse the repository at this point in the history
…urce_event
  • Loading branch information
Qwlouse committed Jul 5, 2016
1 parent 8077c80 commit 17148ff
Show file tree
Hide file tree
Showing 2 changed files with 227 additions and 4 deletions.
26 changes: 22 additions & 4 deletions sacred/observers/file_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tempfile
import json
from datetime import datetime
from shutil import copyfile

from sacred.commandline_options import CommandLineOption
from sacred.dependencies import get_digest
Expand All @@ -25,10 +26,11 @@ def json_serial(obj):
class FileStorageObserver(RunObserver):
VERSION = 'FileStorageObserver-0.7.0'

def __init__(self, basedir):
def __init__(self, basedir, resource_dir=None):
if not os.path.exists(basedir):
os.makedirs(basedir)
self.basedir = basedir
self.resource_dir = resource_dir or os.path.join(basedir, '_resources')
self.dir = None
self.run_entry = None
self.config = None
Expand Down Expand Up @@ -98,7 +100,6 @@ def save_json(self, obj, filename):
default=json_serial)

def save_file(self, filename, target_name=None):
from shutil import copyfile
target_name = target_name or os.path.basename(filename)
copyfile(filename, os.path.join(self.dir, target_name))

Expand Down Expand Up @@ -150,16 +151,33 @@ def failed_event(self, fail_time, fail_trace):
self.render_template()

def resource_event(self, filename):
self.save_file(filename)
if not os.path.exists(self.resource_dir):
os.makedirs(self.resource_dir)

res_name, ext = os.path.splitext(os.path.basename(filename))
md5hash = get_digest(filename)
self.run_entry['resources'].append((filename, md5hash))
store_name = res_name + '_' + md5hash + ext
store_path = os.path.join(self.resource_dir, store_name)

if not os.path.exists(store_path):
copyfile(filename, store_path)

self.run_entry['resources'].append((store_path, md5hash))
self.save_json(self.run_entry, 'run.json')

def artifact_event(self, name, filename):
self.save_file(filename, name)
self.run_entry['artifacts'].append(name)
self.save_json(self.run_entry, 'run.json')

def __eq__(self, other):
if isinstance(other, FileStorageObserver):
return self.basedir == other.basedir
return False

def __ne__(self, other):
return not self.__eq__(other)


class FileStorageOption(CommandLineOption):
"""Add a file-storage observer to the experiment."""
Expand Down
205 changes: 205 additions & 0 deletions tests/test_observers/test_file_storage_observer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
#!/usr/bin/env python
# coding=utf-8
from __future__ import division, print_function, unicode_literals
import datetime
import hashlib
import json
import os.path
import tempfile

import pytest

from observers.file_storage import FileStorageObserver

T1 = datetime.datetime(1999, 5, 4, 3, 2, 1, 0)
T2 = datetime.datetime(1999, 5, 5, 5, 5, 5, 5)

@pytest.fixture()
def sample_run():
exp = {'name': 'test_exp', 'sources': [], 'doc': ''}
host = {'hostname': 'test_host', 'cpu_count': 1, 'python_version': '3.4'}
config = {'config': 'True', 'foo': 'bar', 'answer': 42}
command = 'run'
meta_info = {'comment': 'test run'}
return {
'_id': 'FEDCBA9876543210',
'ex_info': exp,
'command': command,
'host_info': host,
'start_time': T1,
'config': config,
'meta_info': meta_info,
}


@pytest.fixture()
def dir_obs(tmpdir):
return tmpdir, FileStorageObserver(tmpdir.strpath)


def test_fs_observer_started_event_creates_rundir(dir_obs, sample_run):
basedir, obs = dir_obs
sample_run['_id'] = None
_id = obs.started_event(**sample_run)
assert _id is not None
run_dir = basedir.join(_id)
assert run_dir.exists()
assert run_dir.join('cout.txt').exists()
config = json.loads(run_dir.join('config.json').read())
assert config == sample_run['config']

run = json.loads(run_dir.join('run.json').read())
assert run == {
'experiment': sample_run['ex_info'],
'command': sample_run['command'],
'host': sample_run['host_info'],
'start_time': T1.isoformat(),
'heartbeat': None,
'meta': sample_run['meta_info'],
"resources": [],
"artifacts": [],
"status": "RUNNING"
}


def test_fs_observer_started_event_uses_given_id(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
assert _id == sample_run['_id']
assert basedir.join(_id).exists()


def test_fs_observer_heartbeat_event_updates_run(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)
info = {'my_info': [1, 2, 3], 'nr': 7}
outp = 'some output'
obs.heartbeat_event(info=info, captured_out=outp, beat_time=T2)

assert run_dir.join('cout.txt').read() == outp
run = json.loads(run_dir.join('run.json').read())

assert run['heartbeat'] == T2.isoformat()

assert run_dir.join('info.json').exists()
i = json.loads(run_dir.join('info.json').read())
assert info == i


def test_fs_observer_completed_event_updates_run(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)

obs.completed_event(stop_time=T2, result=42)

run = json.loads(run_dir.join('run.json').read())
assert run['stop_time'] == T2.isoformat()
assert run['status'] == 'COMPLETED'
assert run['result'] == 42


def test_fs_observer_interrupted_event_updates_run(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)

obs.interrupted_event(interrupt_time=T2, status='CUSTOM_INTERRUPTION')

run = json.loads(run_dir.join('run.json').read())
assert run['stop_time'] == T2.isoformat()
assert run['status'] == 'CUSTOM_INTERRUPTION'


def test_fs_observer_failed_event_updates_run(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)

fail_trace = "lots of errors and\nso\non..."
obs.failed_event(fail_time=T2, fail_trace=fail_trace)

run = json.loads(run_dir.join('run.json').read())
assert run['stop_time'] == T2.isoformat()
assert run['status'] == 'FAILED'
assert run['fail_trace'] == fail_trace


def test_fs_observer_artifact_event(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)

with tempfile.NamedTemporaryFile(suffix='.py') as f:
f.write(b'foo\nbar')
f.flush()
obs.artifact_event('my_artifact.py', f.name)

artifact = run_dir.join('my_artifact.py')
assert artifact.exists()
assert artifact.read() == 'foo\nbar'

run = json.loads(run_dir.join('run.json').read())
assert len(run['artifacts']) == 1
assert run['artifacts'][0] == artifact.relto(run_dir)


def test_fs_observer_resource_event(dir_obs, sample_run):
basedir, obs = dir_obs
_id = obs.started_event(**sample_run)
run_dir = basedir.join(_id)

with tempfile.NamedTemporaryFile(suffix='.py') as f:
f.write(b'foo\nbar')
f.flush()
obs.resource_event(f.name)
md5sum = hashlib.md5(open(f.name, 'rb').read()).hexdigest()

res_dir = basedir.join('_resources')
assert res_dir.exists()
assert len(res_dir.listdir()) == 1
assert res_dir.listdir()[0].read() == 'foo\nbar'

run = json.loads(run_dir.join('run.json').read())
assert len(run['resources']) == 1
assert run['resources'][0] == [res_dir.listdir()[0].strpath, md5sum]


def test_fs_observer_resource_event_does_not_duplicate(dir_obs, sample_run):
basedir, obs = dir_obs
obs2 = FileStorageObserver(obs.basedir)
_id = obs.started_event(**sample_run)

with tempfile.NamedTemporaryFile(suffix='.py') as f:
f.write(b'foo\nbar')
f.flush()
obs.resource_event(f.name)
md5sum = hashlib.md5(open(f.name, 'rb').read()).hexdigest()
# let's have another run from a different observer
sample_run['_id'] = None
_id = obs2.started_event(**sample_run)
run_dir = basedir.join(_id)
obs2.resource_event(f.name)

res_dir = basedir.join('_resources')
assert res_dir.exists()
assert len(res_dir.listdir()) == 1
assert res_dir.listdir()[0].read() == 'foo\nbar'

run = json.loads(run_dir.join('run.json').read())
assert len(run['resources']) == 1
assert run['resources'][0] == [res_dir.listdir()[0].strpath, md5sum]


def test_fs_observer_equality(dir_obs):
basedir, obs = dir_obs
obs2 = FileStorageObserver(obs.basedir)
assert obs == obs2
assert not obs != obs2

assert not obs == 'foo'
assert obs != 'foo'



0 comments on commit 17148ff

Please sign in to comment.