Skip to content

Commit

Permalink
added overwrite option to MongoObserver
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwlouse committed Jan 7, 2016
1 parent eeced66 commit 8e807bd
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
27 changes: 21 additions & 6 deletions sacred/observers/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,20 @@ def force_bson_encodeable(obj):

class MongoObserver(RunObserver):
@staticmethod
def create(url='localhost', db_name='sacred', prefix='default', **kwargs):
def create(url='localhost', db_name='sacred', prefix='default',
overwrite=None, **kwargs):
client = pymongo.MongoClient(url, **kwargs)
database = client[db_name]
for manipulator in SON_MANIPULATORS:
database.add_son_manipulator(manipulator)
runs_collection = database[prefix + '.runs']
fs = gridfs.GridFS(database, collection=prefix)
return MongoObserver(runs_collection, fs)
return MongoObserver(runs_collection, fs, overwrite=overwrite)

def __init__(self, runs_collection, fs):
def __init__(self, runs_collection, fs, overwrite=None):
self.runs = runs_collection
self.fs = fs
self.overwrite = overwrite
self.run_entry = None

def save(self):
Expand Down Expand Up @@ -129,6 +131,8 @@ def final_save(self, attempts=10):
file=sys.stderr)

def queued_event(self, ex_info, queue_time, config, comment):
if self.overwrite is not None:
raise RuntimeError("Can't overwrite with QUEUED run.")
self.run_entry = {
'experiment': dict(ex_info),
'queue_time': queue_time,
Expand All @@ -144,11 +148,22 @@ def queued_event(self, ex_info, queue_time, config, comment):
self.fs.put(f, filename=source_name)

def started_event(self, ex_info, host_info, start_time, config, comment):
self.run_entry = {
if self.overwrite is None:
self.run_entry = {
'queue_time': start_time
}
else:
if self.run_entry is not None:
raise RuntimeError("Cannot overwrite more than once!")
# sanity checks
if self.overwrite['experiment']['sources'] != ex_info['sources']:
raise RuntimeError("Sources don't match")
self.run_entry = self.overwrite

self.run_entry.update({
'experiment': dict(ex_info),
'host': dict(host_info),
'start_time': start_time,
'queue_time': start_time,
'config': config,
'comment': comment,
'status': 'RUNNING',
Expand All @@ -157,7 +172,7 @@ def started_event(self, ex_info, host_info, start_time, config, comment):
'captured_out': '',
'info': {},
'heartbeat': None
}
})

self.save()
for source_name, md5 in ex_info['sources']:
Expand Down
2 changes: 1 addition & 1 deletion tests/test_config/test_config_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_fixing_values(conf_dict):
assert conf_dict({'a': 100})['a'] == 100


@pytest.mark.parametrize("key", ["_underscore", "white space", 12, "12", "$f"])
@pytest.mark.parametrize("key", ["white space", 12, "12", "$f"])
def test_config_dict_raises_on_invalid_keys(key):
with pytest.raises(KeyError):
ConfigDict({key: True})
Expand Down

0 comments on commit 8e807bd

Please sign in to comment.