Skip to content

Commit

Permalink
added the command that was run to the DB entry
Browse files Browse the repository at this point in the history
  • Loading branch information
Qwlouse committed Jan 17, 2016
1 parent c1da46e commit 1098018
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 8 deletions.
4 changes: 2 additions & 2 deletions sacred/observers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
class RunObserver(object):
"""Defines the interface for all run observers."""

def queued_event(self, ex_info, queue_time, config, comment):
def queued_event(self, ex_info, command, queue_time, config, comment):
pass

def started_event(self, ex_info, host_info, start_time, config, comment):
def started_event(self, ex_info, command, host_info, start_time, config, comment):
pass

def heartbeat_event(self, info, captured_out, beat_time):
Expand Down
10 changes: 5 additions & 5 deletions sacred/observers/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,12 @@ def final_save(self, attempts=10):
"Stored experiment entry in '{}'".format(f.name),
file=sys.stderr)

def queued_event(self, ex_info, queue_time, config, comment):
def queued_event(self, ex_info, command, queue_time, config, comment):
if self.overwrite is not None:
raise RuntimeError("Can't overwrite with QUEUED run.")
self.run_entry = {
'experiment': dict(ex_info),
'command': command,
'queue_time': queue_time,
'config': config,
'comment': comment,
Expand All @@ -152,21 +153,20 @@ def queued_event(self, ex_info, queue_time, config, comment):
with open(source_name, 'rb') as f:
self.fs.put(f, filename=source_name)

def started_event(self, ex_info, host_info, start_time, config, comment):
def started_event(self, ex_info, command, host_info, start_time, config,
comment):
if self.overwrite is None:
self.run_entry = {
'queue_time': start_time
}
else:
if self.run_entry is not None:
raise RuntimeError("Cannot overwrite more than once!")
# sanity checks
if self.overwrite['experiment']['sources'] != ex_info['sources']:
raise RuntimeError("Sources don't match")
self.run_entry = self.overwrite

self.run_entry.update({
'experiment': dict(ex_info),
'command': command,
'host': dict(host_info),
'start_time': start_time,
'queue_time': start_time,
Expand Down
4 changes: 3 additions & 1 deletion sacred/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import traceback

from sacred.randomness import set_global_seed
from sacred.utils import tee_output, ObserverError, TimeoutInterrupt
from sacred.utils import tee_output, ObserverError, TimeoutInterrupt, join_paths

__sacred__ = True # marks files that should be filtered from stack traces

Expand Down Expand Up @@ -195,6 +195,7 @@ def _emit_queued(self):
if hasattr(observer, 'queued_event'):
observer.queued_event(
ex_info=self.experiment_info,
command=join_paths(self.main_function.prefix, self.main_function.signature.name),
queue_time=queue_time,
config=self.config,
comment=self.comment
Expand All @@ -208,6 +209,7 @@ def _emit_started(self):
if hasattr(observer, 'started_event'):
observer.started_event(
ex_info=self.experiment_info,
command=join_paths(self.main_function.prefix, self.main_function.signature.name),
host_info=self.host_info,
start_time=self.start_time,
config=self.config,
Expand Down

0 comments on commit 1098018

Please sign in to comment.