Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add queue based observer #471

Merged
merged 16 commits into from
Jun 3, 2019
97 changes: 86 additions & 11 deletions sacred/observers/mongo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,15 @@
import time
from tempfile import NamedTemporaryFile

import pymongo
import pymongo.errors
import gridfs

import sacred.optional as opt
from sacred.commandline_options import CommandLineOption
from sacred.dependencies import get_digest
from sacred.observers.base import RunObserver
from sacred.observers.queue import QueueObserver
from sacred.serializer import flatten
from sacred.utils import ObserverError

Expand Down Expand Up @@ -57,8 +62,8 @@ class MongoObserver(RunObserver):
'search_spaces'}
VERSION = 'MongoObserver-0.7.0'

@staticmethod
def create(url=None, db_name='sacred', collection='runs',
@classmethod
def create(cls, url=None, db_name='sacred', collection='runs',
overwrite=None, priority=DEFAULT_MONGO_PRIORITY,
client=None, failure_dir=None, **kwargs):
"""Factory method for MongoObserver.
Expand All @@ -77,8 +82,6 @@ def create(url=None, db_name='sacred', collection='runs',
-------
An instantiated MongoObserver.
"""
import pymongo
import gridfs

if client is not None:
if not isinstance(client, pymongo.MongoClient):
Expand All @@ -95,11 +98,11 @@ def create(url=None, db_name='sacred', collection='runs',
runs_collection = database[collection]
metrics_collection = database["metrics"]
fs = gridfs.GridFS(database)
return MongoObserver(runs_collection,
fs, overwrite=overwrite,
metrics_collection=metrics_collection,
failure_dir=failure_dir,
priority=priority)
return cls(runs_collection,
fs, overwrite=overwrite,
metrics_collection=metrics_collection,
failure_dir=failure_dir,
priority=priority)

def __init__(self, runs_collection,
fs, overwrite=None, metrics_collection=None,
Expand Down Expand Up @@ -264,8 +267,6 @@ def log_metrics(self, metrics_by_name, info):
.append({"name": key, "id": str(result.upserted_id)})

def insert(self):
import pymongo.errors

if self.overwrite:
return self.save()

Expand Down Expand Up @@ -419,3 +420,77 @@ def parse_mongo_db_arg(cls, mongo_db):
kwargs[p] = g[p]

return kwargs


class QueueCompatibleMongoObserver(MongoObserver):

def log_metrics(self, metric_name, metrics_values, info):
"""Store new measurements to the database.

Take measurements and store them into
the metrics collection in the database.
Additionally, reference the metrics
in the info["metrics"] dictionary.
"""
if self.metrics is None:
# If, for whatever reason, the metrics collection has not been set
# do not try to save anything there.
return
query = {"run_id": self.run_entry['_id'],
"name": metric_name}
push = {"steps": {"$each": metrics_values["steps"]},
"values": {"$each": metrics_values["values"]},
"timestamps": {"$each": metrics_values["timestamps"]}}
update = {"$push": push}
result = self.metrics.update_one(query, update, upsert=True)
if result.upserted_id is not None:
# This is the first time we are storing this metric
info.setdefault("metrics", []) \
.append({"name": metric_name, "id": str(result.upserted_id)})

def save(self):
try:
self.runs.update_one({'_id': self.run_entry['_id']},
{'$set': self.run_entry})
except pymongo.errors.InvalidDocument:
raise ObserverError('Run contained an unserializable entry.'
'(most likely in the info)')

def final_save(self, attempts):
try:
self.runs.update_one({'_id': self.run_entry['_id']},
{'$set': self.run_entry}, upsert=True)
return

except pymongo.errors.InvalidDocument:
self.run_entry = force_bson_encodeable(self.run_entry)
print("Warning: Some of the entries of the run were not "
"BSON-serializable!\n They have been altered such that "
"they can be stored, but you should fix your experiment!"
"Most likely it is either the 'info' or the 'result'.",
file=sys.stderr)

with NamedTemporaryFile(suffix='.pickle', delete=False,
prefix='sacred_mongo_fail_') as f:
pickle.dump(self.run_entry, f)
print("Warning: saving to MongoDB failed! "
"Stored experiment entry in '{}'".format(f.name),
file=sys.stderr)

raise ObserverError("Warning: saving to MongoDB failed!")


class QueuedMongoObserver(QueueObserver):
@classmethod
def create(cls, interval=20, retry_interval=10, url=None, db_name='sacred',
collection='runs', overwrite=None,
priority=DEFAULT_MONGO_PRIORITY, client=None, **kwargs):
return cls(
QueueCompatibleMongoObserver.create(url=url, db_name=db_name,
collection=collection,
overwrite=overwrite,
priority=priority,
client=client, **kwargs),
interval=interval,
retry_interval=retry_interval,
)
116 changes: 116 additions & 0 deletions sacred/observers/queue.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
# coding=utf-8
from __future__ import division, print_function, unicode_literals
from collections import namedtuple
import sys
from sacred.observers.base import RunObserver
from sacred.utils import IntervalTimer

if sys.version_info[0] >= 3:
from queue import Queue
else:
from Queue import Queue

WrappedEvent = namedtuple("WrappedEvent", "name args kwargs")


class QueueObserver(RunObserver):

def __init__(self, covered_observer, interval=20, retry_interval=10):
self._covered_observer = covered_observer
self._retry_interval = retry_interval
self._interval = interval
self._queue = None
self._worker = None
self._stop_worker_event = None

def queued_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("queued_event", args, kwargs))

def started_event(self, *args, **kwargs):
self._queue = Queue()
self._stop_worker_event, self._worker = IntervalTimer.create(
self._run,
interval=self._interval,
)
self._worker.start()

# Putting the started event on the queue makes no sense
# as it is required for initialization of the covered observer.
return self._covered_observer.started_event(*args, **kwargs)

def heartbeat_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("heartbeat_event", args, kwargs))

def completed_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("completed_event", args, kwargs))
self.join()

def interrupted_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("interrupted_event", args, kwargs))
self.join()

def failed_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("failed_event", args, kwargs))
self.join()

def resource_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("resource_event", args, kwargs))

def artifact_event(self, *args, **kwargs):
self._queue.put(WrappedEvent("artifact_event", args, kwargs))

def log_metrics(self, metrics_by_name, info):
for metric_name, metric_values in metrics_by_name.items():
self._queue.put(
WrappedEvent(
"log_metrics",
[metric_name, metric_values, info],
{},
)
)

def _run(self):
while not self._queue.empty():
try:
event = self._queue.get()
except IndexError:
# Currently there is no event on the queue so
# just go back to sleep.
pass
else:
try:
# method = getattr(self._covered_observer, event.name)
method = getattr(self._covered_observer, event.name)
except NameError:
# covered observer does not implement event handler
# for the event, so just
# discard the message.
self._queue.task_done()
else:
while True:
try:
method(*event.args, **event.kwargs)
except:
# Something went wrong during the processing of
# the event so wait for some time and
# then try again.
self._stop_worker_event.wait(self._retry_interval)
continue
else:
self._queue.task_done()
break

def join(self):
if self._queue is not None:
self._queue.join()
self._stop_worker_event.set()
self._worker.join(timeout=10)

def __getattr__(self, item):
return getattr(self._covered_observer, item)

def __eq__(self, other):
return self._covered_observer == other

def __ne__(self, other):
return not self._covered_observer == other
5 changes: 5 additions & 0 deletions sacred/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,11 @@ def _final_call(self, observer, method, **kwargs):
# others
self.run_logger.error(tb.format_exc())

def _wait_for_observers(self):
"""Block until all observers finished processing."""
for observer in self.observers:
self._safe_call(observer, 'join')

def _warn_about_failed_observers(self):
for observer in self._failed_observers:
self.run_logger.warning("The observer '{}' failed at some point "
Expand Down
77 changes: 71 additions & 6 deletions tests/test_observers/failing_mongo_mock.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,23 @@

import mongomock
import pymongo
import pymongo.errors



class FailingMongoClient(mongomock.MongoClient):
def __init__(self, max_calls_before_failure=2,
exception_to_raise=pymongo.errors.AutoReconnect, **kwargs):
super(FailingMongoClient, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise

def get_database(self, name, codec_options=None, read_preference=None,
write_concern=None):
db = self._databases.get(name)
if db is None:
db = self._databases[name] = FailingDatabase(
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self.exception_to_raise, client=self,
exception_to_raise=self._exception_to_raise, client=self,
name=name, )
return db

Expand All @@ -28,15 +27,15 @@ def __init__(self, max_calls_before_failure, exception_to_raise=None,
**kwargs):
super(FailingDatabase, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise

def get_collection(self, name, codec_options=None, read_preference=None,
write_concern=None):
collection = self._collections.get(name)
if collection is None:
collection = self._collections[name] = FailingCollection(
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self.exception_to_raise, db=self,
exception_to_raise=self._exception_to_raise, db=self,
name=name, )
return collection

Expand All @@ -45,7 +44,7 @@ class FailingCollection(mongomock.Collection):
def __init__(self, max_calls_before_failure, exception_to_raise, **kwargs):
super(FailingCollection, self).__init__(**kwargs)
self._max_calls_before_failure = max_calls_before_failure
self.exception_to_raise = exception_to_raise
self._exception_to_raise = exception_to_raise
self._calls = 0

def insert_one(self, document):
Expand All @@ -62,3 +61,69 @@ def update_one(self, filter, update, upsert=False):
else:
return super(FailingCollection, self).update_one(filter, update,
upsert)


class ReconnectingMongoClient(FailingMongoClient):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingMongoClient, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def get_database(self, name, codec_options=None, read_preference=None,
write_concern=None):
db = self._databases.get(name)
if db is None:
db = self._databases[name] = ReconnectingDatabase(
max_calls_before_reconnect=self._max_calls_before_reconnect,
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self._exception_to_raise, client=self,
name=name, )
return db


class ReconnectingDatabase(FailingDatabase):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingDatabase, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def get_collection(self, name, codec_options=None, read_preference=None,
write_concern=None):
collection = self._collections.get(name)
if collection is None:
collection = self._collections[name] = ReconnectingCollection(
max_calls_before_reconnect=self._max_calls_before_reconnect,
max_calls_before_failure=self._max_calls_before_failure,
exception_to_raise=self._exception_to_raise, db=self,
name=name, )
return collection


class ReconnectingCollection(FailingCollection):
def __init__(self, max_calls_before_reconnect, **kwargs):
super(ReconnectingCollection, self).__init__(**kwargs)
self._max_calls_before_reconnect = max_calls_before_reconnect

def insert_one(self, document):
self._calls += 1
if self._is_in_failure_range():
print(self.name, "insert no connection")
raise self._exception_to_raise
else:
print(self.name, "insert connection reestablished")
return mongomock.Collection.insert_one(self, document)

def update_one(self, filter, update, upsert=False):
self._calls += 1
if self._is_in_failure_range():
print(self.name, "update no connection")

raise self._exception_to_raise
else:
print(self.name, "update connection reestablished")

return mongomock.Collection.update_one(self, filter, update,
upsert)

def _is_in_failure_range(self):
return (self._max_calls_before_failure
< self._calls
<= self._max_calls_before_reconnect)
Loading