Skip to content

Commit

Permalink
Add queue based observer (#471)
Browse files Browse the repository at this point in the history
* Add draft for queue based observer

* Wait for processing intstead of put back

* Reformat a bit

* Make mongo observer more robust against failure

When cooperating with the queue observer

* Make retry interval configurable

* Add create method to QueuedMongoObserver

* Add Queue compabitible mongo observer

* Remove print debug statements

* Different queue import for py27

* Fix flake8 error

* Fix more flake errors and deletegate __ne__

* Remove unused import

* Remove reimport

* Fix python 2 error
  • Loading branch information
JarnoRFB committed Jun 3, 2019
1 parent e596102 commit 59f7e0d
Show file tree
Hide file tree
Showing 7 changed files with 675 additions and 17 deletions.
97 changes: 86 additions & 11 deletions sacred/observers/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import time
from tempfile import NamedTemporaryFile

import pymongo
import pymongo.errors
import gridfs

import sacred.optional as opt
from sacred.commandline_options import CommandLineOption
from sacred.dependencies import get_digest
from sacred.observers.base import RunObserver
from sacred.observers.queue import QueueObserver
from sacred.serializer import flatten
from sacred.utils import ObserverError

Expand Down Expand Up @@ -57,8 +62,8 @@ class MongoObserver(RunObserver):
'search_spaces'}
VERSION = 'MongoObserver-0.7.0'

@staticmethod
def create(url=None, db_name='sacred', collection='runs',
@classmethod
def create(cls, url=None, db_name='sacred', collection='runs',
overwrite=None, priority=DEFAULT_MONGO_PRIORITY,
client=None, failure_dir=None, **kwargs):
"""Factory method for MongoObserver.
Expand All @@ -77,8 +82,6 @@ def create(url=None, db_name='sacred', collection='runs',
-------
An instantiated MongoObserver.
"""
import pymongo
import gridfs

if client is not None:
if not isinstance(client, pymongo.MongoClient):
Expand All @@ -95,11 +98,11 @@ def create(url=None, db_name='sacred', collection='runs',
runs_collection = database[collection]
metrics_collection = database["metrics"]
fs = gridfs.GridFS(database)
return MongoObserver(runs_collection,
fs, overwrite=overwrite,
metrics_collection=metrics_collection,
failure_dir=failure_dir,
priority=priority)
return cls(runs_collection,
fs, overwrite=overwrite,
metrics_collection=metrics_collection,
failure_dir=failure_dir,
priority=priority)

def __init__(self, runs_collection,
fs, overwrite=None, metrics_collection=None,
Expand Down Expand Up @@ -264,8 +267,6 @@ def log_metrics(self, metrics_by_name, info):
.append({"name": key, "id": str(result.upserted_id)})

def insert(self):
import pymongo.errors

if self.overwrite:
return self.save()

Expand Down Expand Up @@ -419,3 +420,77 @@ def parse_mongo_db_arg(cls, mongo_db):
kwargs[p] = g[p]

return kwargs


class QueueCompatibleMongoObserver(MongoObserver):

def log_metrics(self, metric_name, metrics_values, info):
"""Store new measurements to the database.
Take measurements and store them into
the metrics collection in the database.
Additionally, reference the metrics
in the info["metrics"] dictionary.
"""
if self.metrics is None:
# If, for whatever reason, the metrics collection has not been set
# do not try to save anything there.
return
query = {"run_id": self.run_entry['_id'],
"name": metric_name}
push = {"steps": {"$each": metrics_values["steps"]},
"values": {"$each": metrics_values["values"]},
"timestamps": {"$each": metrics_values["timestamps"]}}
update = {"$push": push}
result = self.metrics.update_one(query, update, upsert=True)
if result.upserted_id is not None:
# This is the first time we are storing this metric
info.setdefault("metrics", []) \
.append({"name": metric_name, "id": str(result.upserted_id)})

def save(self):
try:
self.runs.update_one({'_id': self.run_entry['_id']},
{'$set': self.run_entry})
except pymongo.errors.InvalidDocument:
raise ObserverError('Run contained an unserializable entry.'
'(most likely in the info)')

def final_save(self, attempts):
try:
self.runs.update_one({'_id': self.run_entry['_id']},
{'$set': self.run_entry}, upsert=True)
return

except pymongo.errors.InvalidDocument:
self.run_entry = force_bson_encodeable(self.run_entry)
print("Warning: Some of the entries of the run were not "
"BSON-serializable!\n They have been altered such that "
"they can be stored, but you should fix your experiment!"
"Most likely it is either the 'info' or the 'result'.",
file=sys.stderr)

with NamedTemporaryFile(suffix='.pickle', delete=False,
prefix='sacred_mongo_fail_') as f:
pickle.dump(self.run_entry, f)
print("Warning: saving to MongoDB failed! "
"Stored experiment entry in '{}'".format(f.name),
file=sys.stderr)

raise ObserverError("Warning: saving to MongoDB failed!")


class QueuedMongoObserver(QueueObserver):
@classmethod
def create(cls, interval=20, retry_interval=10, url=None, db_name='sacred',
collection='runs', overwrite=None,
priority=DEFAULT_MONGO_PRIORITY, client=None, **kwargs):
return cls(
QueueCompatibleMongoObserver.create(url=url, db_name=db_name,
collection=collection,
overwrite=overwrite,
priority=priority,
client=client, **kwargs),
interval=interval,
retry_interval=retry_interval,
)
116 changes: 116 additions & 0 deletions sacred/observers/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8
from __future__ import division, print_function, unicode_literals
from collections import namedtuple
import sys
from sacred.observers.base import RunObserver
from sacred.utils import IntervalTimer

if sys.version_info[0] >= 3:
from queue import Queue
else:
from Queue import Queue

WrappedEvent = namedtuple("WrappedEvent", "name args kwargs")


class QueueObserver(RunObserver):

def __init__(self, covered_observer, interval=20, retry_interval=10):
self._covered_observer = covered_observer
self._retry_interval = retry_interval
self._interval = interval
self._queue = None
self._worker = None
self._stop_worker_event = None

def queued_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("queued_event", args, kwargs))

def started_event(self, *args, **kwargs):
self._queue = Queue()
self._stop_worker_event, self._worker = IntervalTimer.create(
self._run,
interval=self._interval,
)
self._worker.start()

# Putting the started event on the queue makes no sense
# as it is required for initialization of the covered observer.
return self._covered_observer.started_event(*args, **kwargs)

def heartbeat_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("heartbeat_event", args, kwargs))

def completed_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("completed_event", args, kwargs))
self.join()

def interrupted_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("interrupted_event", args, kwargs))
self.join()

def failed_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("failed_event", args, kwargs))
self.join()

def resource_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("resource_event", args, kwargs))

def artifact_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("artifact_event", args, kwargs))

def log_metrics(self, metrics_by_name, info):
for metric_name, metric_values in metrics_by_name.items():
self._queue.put(
WrappedEvent(
"log_metrics",
[metric_name, metric_values, info],
{},
)
)

def _run(self):
while not self._queue.empty():
try:
event = self._queue.get()
except IndexError:
# Currently there is no event on the queue so
# just go back to sleep.
pass
else:
try:
# method = getattr(self._covered_observer, event.name)
method = getattr(self._covered_observer, event.name)
except NameError:
# covered observer does not implement event handler
# for the event, so just
# discard the message.
self._queue.task_done()
else:
while True:
try:
method(*event.args, **event.kwargs)
except:
# Something went wrong during the processing of
# the event so wait for some time and
# then try again.
self._stop_worker_event.wait(self._retry_interval)
continue
else:
self._queue.task_done()
break

def join(self):
if self._queue is not None:
self._queue.join()
self._stop_worker_event.set()
self._worker.join(timeout=10)

def __getattr__(self, item):
return getattr(self._covered_observer, item)

def __eq__(self, other):
return self._covered_observer == other

def __ne__(self, other):
return not self._covered_observer == other
5 changes: 5 additions & 0 deletions sacred/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ def _final_call(self, observer, method, **kwargs):
# others
self.run_logger.error(tb.format_exc())

def _wait_for_observers(self):
"""Block until all observers finished processing."""
for observer in self.observers:
self._safe_call(observer, 'join')

def _warn_about_failed_observers(self):
for observer in self._failed_observers:
self.run_logger.warning("The observer '{}' failed at some point "
Expand Down
77 changes: 71 additions & 6 deletions tests/test_observers/failing_mongo_mock.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@

import mongomock
import pymongo
import pymongo.errors



class FailingMongoClient(mongomock.MongoClient):
def __init__(self, max_calls_before_failure=2,
exception_to_raise=pymongo.errors.AutoReconnect, **kwargs):
super(FailingMongoClient, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise

def get_database(self, name, codec_options=None, read_preference=None,
write_concern=None):
db = self._databases.get(name)
if db is None:
db = self._databases[name] = FailingDatabase(
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self.exception_to_raise, client=self,
exception_to_raise=self._exception_to_raise, client=self,
name=name, )
return db

Expand All @@ -28,15 +27,15 @@ def __init__(self, max_calls_before_failure, exception_to_raise=None,
**kwargs):
super(FailingDatabase, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise

def get_collection(self, name, codec_options=None, read_preference=None,
write_concern=None):
collection = self._collections.get(name)
if collection is None:
collection = self._collections[name] = FailingCollection(
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self.exception_to_raise, db=self,
exception_to_raise=self._exception_to_raise, db=self,
name=name, )
return collection

Expand All @@ -45,7 +44,7 @@ class FailingCollection(mongomock.Collection):
def __init__(self, max_calls_before_failure, exception_to_raise, **kwargs):
super(FailingCollection, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise
self._calls = 0

def insert_one(self, document):
Expand All @@ -62,3 +61,69 @@ def update_one(self, filter, update, upsert=False):
else:
return super(FailingCollection, self).update_one(filter, update,
upsert)


class ReconnectingMongoClient(FailingMongoClient):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingMongoClient, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def get_database(self, name, codec_options=None, read_preference=None,
write_concern=None):
db = self._databases.get(name)
if db is None:
db = self._databases[name] = ReconnectingDatabase(
max_calls_before_reconnect=self._max_calls_before_reconnect,
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self._exception_to_raise, client=self,
name=name, )
return db


class ReconnectingDatabase(FailingDatabase):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingDatabase, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def get_collection(self, name, codec_options=None, read_preference=None,
write_concern=None):
collection = self._collections.get(name)
if collection is None:
collection = self._collections[name] = ReconnectingCollection(
max_calls_before_reconnect=self._max_calls_before_reconnect,
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self._exception_to_raise, db=self,
name=name, )
return collection


class ReconnectingCollection(FailingCollection):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingCollection, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def insert_one(self, document):
self._calls += 1
if self._is_in_failure_range():
print(self.name, "insert no connection")
raise self._exception_to_raise
else:
print(self.name, "insert connection reestablished")
return mongomock.Collection.insert_one(self, document)

def update_one(self, filter, update, upsert=False):
self._calls += 1
if self._is_in_failure_range():
print(self.name, "update no connection")

raise self._exception_to_raise
else:
print(self.name, "update connection reestablished")

return mongomock.Collection.update_one(self, filter, update,
upsert)

def _is_in_failure_range(self):
return (self._max_calls_before_failure
< self._calls
<= self._max_calls_before_reconnect)
Loading

0 comments on commit 59f7e0d

Please sign in to comment.