Browse files

Middleware extended with per task context concept

It's now possible to implement a tracer system for example.
  • Loading branch information...
1 parent dbc8660 commit 9702bca89b457ae98de44631c7b5b17e08a65dca @bombela bombela committed Apr 4, 2012
Showing with 339 additions and 7 deletions.
  1. +274 −0 tests/test_middleware.py
  2. +13 −1 zerorpc/context.py
  3. +52 −6 zerorpc/core.py
View
274 tests/test_middleware.py
@@ -25,6 +25,9 @@
from nose.tools import assert_raises
import gevent
+import gevent.local
+import random
+import md5
from zerorpc import zmq
import zerorpc
@@ -250,3 +253,274 @@ def test(self, argument):
#FIXME: These seems to be broken
# publisher.close()
# subscriber.close()
+
+
+class Tracer:
+ '''Used by test_task_context_* tests'''
+ def __init__(self, identity):
+ self._identity = identity
+ self._locals = gevent.local.local()
+ self._log = []
+
+ @property
+ def trace_id(self):
+ return self._locals.__dict__.get('trace_id', None)
+
+ def load_task_context(self, event_header):
+ self._locals.trace_id = event_header.get('trace_id', None)
+ print self._identity, 'load_task_context', self.trace_id
+ self._log.append(('load', self.trace_id))
+
+ def get_task_context(self):
+ if self.trace_id is None:
+ # just an ugly code to generate a beautiful little hash.
+ self._locals.trace_id = '<{0}>'.format(md5.md5(
+ str(random.random())[3:]
+ ).hexdigest()[0:6].upper())
+ print self._identity, 'get_task_context! [make a new one]', self.trace_id
+ self._log.append(('new', self.trace_id))
+ else:
+ print self._identity, 'get_task_context! [reuse]', self.trace_id
+ self._log.append(('reuse', self.trace_id))
+ return { 'trace_id': self.trace_id }
+
+
+def test_task_context():
+ endpoint = random_ipc_endpoint()
+ srv_ctx = zerorpc.Context()
+ cli_ctx = zerorpc.Context()
+
+ srv_tracer = Tracer('[server]')
+ srv_ctx.register_middleware(srv_tracer)
+ cli_tracer = Tracer('[client]')
+ cli_ctx.register_middleware(cli_tracer)
+
+ class Srv:
+ def echo(self, msg):
+ return msg
+
+ @zerorpc.stream
+ def stream(self):
+ yield 42
+
+ srv = zerorpc.Server(Srv(), context=srv_ctx)
+ srv.bind(endpoint)
+ srv_task = gevent.spawn(srv.run)
+
+ c = zerorpc.Client(context=cli_ctx)
+ c.connect(endpoint)
+
+ assert c.echo('hello') == 'hello'
+ for x in c.stream():
+ assert x == 42
+
+ srv.stop()
+ srv_task.join()
+
+ assert cli_tracer._log == [
+ ('new', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+ assert srv_tracer._log == [
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+
+def test_task_context_relay():
+ endpoint1 = random_ipc_endpoint()
+ endpoint2 = random_ipc_endpoint()
+ srv_ctx = zerorpc.Context()
+ srv_relay_ctx = zerorpc.Context()
+ cli_ctx = zerorpc.Context()
+
+ srv_tracer = Tracer('[server]')
+ srv_ctx.register_middleware(srv_tracer)
+ srv_relay_tracer = Tracer('[server_relay]')
+ srv_relay_ctx.register_middleware(srv_relay_tracer)
+ cli_tracer = Tracer('[client]')
+ cli_ctx.register_middleware(cli_tracer)
+
+ class Srv:
+ def echo(self, msg):
+ return msg
+
+ srv = zerorpc.Server(Srv(), context=srv_ctx)
+ srv.bind(endpoint1)
+ srv_task = gevent.spawn(srv.run)
+
+ c_relay = zerorpc.Client(context=srv_relay_ctx)
+ c_relay.connect(endpoint1)
+
+ class SrvRelay:
+ def echo(self, msg):
+ return c_relay.echo('relay' + msg) + 'relayed'
+
+ srv_relay = zerorpc.Server(SrvRelay(), context=srv_relay_ctx)
+ srv_relay.bind(endpoint2)
+ srv_relay_task = gevent.spawn(srv_relay.run)
+
+ c = zerorpc.Client(context=cli_ctx)
+ c.connect(endpoint2)
+
+ assert c.echo('hello') == 'relayhellorelayed'
+
+ srv_relay.stop()
+ srv.stop()
+ srv_relay_task.join()
+ srv_task.join()
+
+ assert cli_tracer._log == [
+ ('new', cli_tracer.trace_id),
+ ]
+ assert srv_relay_tracer._log == [
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+ assert srv_tracer._log == [
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+
+def test_task_context_relay_fork():
+ endpoint1 = random_ipc_endpoint()
+ endpoint2 = random_ipc_endpoint()
+ srv_ctx = zerorpc.Context()
+ srv_relay_ctx = zerorpc.Context()
+ cli_ctx = zerorpc.Context()
+
+ srv_tracer = Tracer('[server]')
+ srv_ctx.register_middleware(srv_tracer)
+ srv_relay_tracer = Tracer('[server_relay]')
+ srv_relay_ctx.register_middleware(srv_relay_tracer)
+ cli_tracer = Tracer('[client]')
+ cli_ctx.register_middleware(cli_tracer)
+
+ class Srv:
+ def echo(self, msg):
+ return msg
+
+ srv = zerorpc.Server(Srv(), context=srv_ctx)
+ srv.bind(endpoint1)
+ srv_task = gevent.spawn(srv.run)
+
+ c_relay = zerorpc.Client(context=srv_relay_ctx)
+ c_relay.connect(endpoint1)
+
+ class SrvRelay:
+ def echo(self, msg):
+ def dothework(msg):
+ return c_relay.echo(msg) + 'relayed'
+ g = gevent.spawn(zerorpc.fork_task_context(dothework,
+ srv_relay_ctx), 'relay' + msg)
+ print 'relaying in separate task:', g
+ r = g.get()
+ print 'back to main task'
+ return r
+
+ srv_relay = zerorpc.Server(SrvRelay(), context=srv_relay_ctx)
+ srv_relay.bind(endpoint2)
+ srv_relay_task = gevent.spawn(srv_relay.run)
+
+ c = zerorpc.Client(context=cli_ctx)
+ c.connect(endpoint2)
+
+ assert c.echo('hello') == 'relayhellorelayed'
+
+ srv_relay.stop()
+ srv.stop()
+ srv_relay_task.join()
+ srv_task.join()
+
+ assert cli_tracer._log == [
+ ('new', cli_tracer.trace_id),
+ ]
+ assert srv_relay_tracer._log == [
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+ assert srv_tracer._log == [
+ ('load', cli_tracer.trace_id),
+ ('reuse', cli_tracer.trace_id),
+ ]
+
+
+def test_task_context_pushpull():
+ endpoint = random_ipc_endpoint()
+ puller_ctx = zerorpc.Context()
+ pusher_ctx = zerorpc.Context()
+
+ puller_tracer = Tracer('[puller]')
+ puller_ctx.register_middleware(puller_tracer)
+ pusher_tracer = Tracer('[pusher]')
+ pusher_ctx.register_middleware(pusher_tracer)
+
+ trigger = gevent.event.Event()
+
+ class Puller:
+ def echo(self, msg):
+ trigger.set()
+
+ puller = zerorpc.Puller(Puller(), context=puller_ctx)
+ puller.bind(endpoint)
+ puller_task = gevent.spawn(puller.run)
+
+ c = zerorpc.Pusher(context=pusher_ctx)
+ c.connect(endpoint)
+
+ trigger.clear()
+ c.echo('hello')
+ trigger.wait()
+
+ puller.stop()
+ puller_task.join()
+
+ assert pusher_tracer._log == [
+ ('new', pusher_tracer.trace_id),
+ ]
+ assert puller_tracer._log == [
+ ('load', pusher_tracer.trace_id),
+ ]
+
+
+def test_task_context_pubsub():
+ endpoint = random_ipc_endpoint()
+ subscriber_ctx = zerorpc.Context()
+ publisher_ctx = zerorpc.Context()
+
+ subscriber_tracer = Tracer('[subscriber]')
+ subscriber_ctx.register_middleware(subscriber_tracer)
+ publisher_tracer = Tracer('[publisher]')
+ publisher_ctx.register_middleware(publisher_tracer)
+
+ trigger = gevent.event.Event()
+
+ class Subscriber:
+ def echo(self, msg):
+ trigger.set()
+
+ subscriber = zerorpc.Subscriber(Subscriber(), context=subscriber_ctx)
+ subscriber.bind(endpoint)
+ subscriber_task = gevent.spawn(subscriber.run)
+
+ c = zerorpc.Publisher(context=publisher_ctx)
+ c.connect(endpoint)
+
+ trigger.clear()
+ c.echo('pub...')
+ trigger.wait()
+
+ subscriber.stop()
+ subscriber_task.join()
+
+ assert publisher_tracer._log == [
+ ('new', publisher_tracer.trace_id),
+ ]
+ assert subscriber_tracer._log == [
+ ('load', publisher_tracer.trace_id),
+ ]
View
14 zerorpc/context.py
@@ -37,7 +37,9 @@ def __init__(self):
self._middlewares_hooks = {
'resolve_endpoint': [],
'raise_error': [],
- 'call_procedure': []
+ 'call_procedure': [],
+ 'load_task_context': [],
+ 'get_task_context': [],
}
@staticmethod
@@ -86,3 +88,13 @@ def __call__(self, *args, **kwargs):
for functor in self._middlewares_hooks['call_procedure']:
procedure = chain(functor, procedure)
return procedure(*args, **kwargs)
+
+ def middleware_load_task_context(self, event_header):
+ for functor in self._middlewares_hooks['load_task_context']:
+ functor(event_header)
+
+ def middleware_get_task_context(self):
+ event_header = {}
+ for functor in self._middlewares_hooks['get_task_context']:
+ event_header.update(functor())
+ return event_header
View
58 zerorpc/core.py
@@ -131,6 +131,7 @@ def _async_task(self, initial_event):
bufchan = BufferedChannel(hbchan)
event = bufchan.recv()
try:
+ self._context.middleware_load_task_context(event.header)
functor = self._methods.get(event.name, None)
if functor is None:
raise NameError(event.name)
@@ -139,7 +140,8 @@ def _async_task(self, initial_event):
self._print_traceback(protocol_v1)
except Exception:
exception_info = self._print_traceback(protocol_v1)
- bufchan.emit('ERR', exception_info)
+ bufchan.emit('ERR', exception_info,
+ self._context.middleware_get_task_context())
finally:
bufchan.close()
bufchan.channel.close()
@@ -220,7 +222,8 @@ def __call__(self, method, *args, **kargs):
passive=self._passive_heartbeat)
bufchan = BufferedChannel(hbchan, inqueue_size=kargs.get('slots', 100))
- bufchan.emit(method, args)
+ xheader = self._context.middleware_get_task_context()
+ bufchan.emit(method, args, xheader)
try:
if kargs.get('async', False) is False:
@@ -274,7 +277,7 @@ class PatternReqRep():
def process_call(self, context, bufchan, event, functor):
result = context.middleware_call_procedure(functor, *event.args)
- bufchan.emit('OK', (result,))
+ bufchan.emit('OK', (result,), context.middleware_get_task_context())
def accept_answer(self, event):
return True
@@ -297,10 +300,11 @@ class rep(DecoratorBase):
class PatternReqStream():
def process_call(self, context, bufchan, event, functor):
+ xheader = context.middleware_get_task_context()
for result in iter(context.middleware_call_procedure(functor,
*event.args)):
- bufchan.emit('STREAM', result)
- bufchan.emit('STREAM_DONE', None)
+ bufchan.emit('STREAM', result, xheader)
+ bufchan.emit('STREAM_DONE', None, xheader)
def accept_answer(self, event):
return event.name in ('STREAM', 'STREAM_DONE')
@@ -358,7 +362,8 @@ def __init__(self, context=None, zmq_socket=zmq.PUSH):
super(Pusher, self).__init__(zmq_socket, context=context)
def __call__(self, method, *args):
- self._events.emit(method, args)
+ self._events.emit(method, args,
+ self._context.middleware_get_task_context())
def __getattr__(self, method):
return lambda *args: self(method, *args)
@@ -390,6 +395,7 @@ def _receiver(self):
try:
if event.name not in self._methods:
raise NameError(event.name)
+ self._context.middleware_load_task_context(event.header)
self._context.middleware_call_procedure(
self._methods[event.name],
*event.args)
@@ -420,3 +426,43 @@ def __init__(self, methods=None, context=None):
super(Subscriber, self).__init__(methods=methods, context=context,
zmq_socket=zmq.SUB)
self._events.setsockopt(zmq.SUBSCRIBE, '')
+
+
+def fork_task_context(functor, context=None):
+ '''Wrap a functor to transfer context.
+
+ Usage example:
+ gevent.spawn(zerorpc.fork_task_context(myfunction), args...)
+
+ The goal is to permit context "inheritance" from a task to another.
+ Consider the following example:
+
+ zerorpc.Server receive a new event
+ - task1 is created to handle this event this task will be linked
+ to the initial event context. zerorpc.Server does that for you.
+ - task1 make use of some zerorpc.Client instances, the initial
+ event context is transfered on every call.
+
+ - task1 spawn a new task2.
+ - task2 make use of some zerorpc.Client instances, it's a fresh
+ context. Thus there is no link to the initial context that
+ spawned task1.
+
+ - task1 spawn a new fork_task_context(task3).
+ - task3 make use of some zerorpc.Client instances, the initial
+ event context is transfered on every call.
+
+ A real use case is a distributed tracer. Each time a new event is
+ created, a trace_id is injected in it or copied from the current task
+ context. This permit passing the trace_id from a zerorpc.Server to
+ another via zerorpc.Client.
+
+ The simple rule to know if a task need to be wrapped is:
+ - if the new task will make any zerorpc call, it should be wrapped.
+ '''
+ context = context or Context.get_instance()
+ header = context.middleware_get_task_context()
+ def wrapped(*args, **kargs):
+ context.middleware_load_task_context(header)
+ return functor(*args, **kargs)
+ return wrapped

0 comments on commit 9702bca

Please sign in to comment.