From 01184c07f0a079a5fe1f7d44d5c302974bf062af Mon Sep 17 00:00:00 2001 From: maralla Date: Mon, 9 Mar 2015 17:40:01 +0800 Subject: [PATCH] add tracking support --- MANIFEST.in | 6 +- setup.py | 1 + tests/test_tracking.py | 283 ++++++++++++++++++++++ thriftpy/contrib/__init__.py | 0 thriftpy/contrib/tracking/__init__.py | 169 +++++++++++++ thriftpy/contrib/tracking/tracker.py | 34 +++ thriftpy/contrib/tracking/tracking.thrift | 27 +++ thriftpy/thrift.py | 11 + 8 files changed, 528 insertions(+), 3 deletions(-) create mode 100644 tests/test_tracking.py create mode 100644 thriftpy/contrib/__init__.py create mode 100644 thriftpy/contrib/tracking/__init__.py create mode 100644 thriftpy/contrib/tracking/tracker.py create mode 100644 thriftpy/contrib/tracking/tracking.thrift diff --git a/MANIFEST.in b/MANIFEST.in index 065c707..6b73242 100644 --- a/MANIFEST.in +++ b/MANIFEST.in @@ -1,4 +1,4 @@ include README.rst CHANGES.rst -include thriftpy/protocol/cybin/*.pyx thriftpy/protocol/cybin/*.c thriftpy/protocol/cybin/*.h -include thriftpy/transport/*/*.pyx thriftpy/transport/*/*.c -include thriftpy/transport/*.pyx thriftpy/transport/*.pxd thriftpy/transport/*.c +recursive-include thriftpy/protocol/cybin *.pyx *.c *.h +recursive-include thriftpy/transport *.pyx *.pxd *.c +include thriftpy/contrib/tracking/tracking.thrift diff --git a/setup.py b/setup.py index 10b38a4..78c4987 100644 --- a/setup.py +++ b/setup.py @@ -68,6 +68,7 @@ author="Lx Yu", author_email="i@lxyu.net", packages=find_packages(exclude=['benchmark', 'docs', 'tests']), + package_data={"thriftpy": ["contrib/tracking/tracking.thrift"]}, entry_points={}, url="https://thriftpy.readthedocs.org/", license="MIT", diff --git a/tests/test_tracking.py b/tests/test_tracking.py new file mode 100644 index 0000000..2395324 --- /dev/null +++ b/tests/test_tracking.py @@ -0,0 +1,283 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +import contextlib +import os +import multiprocessing +import time +import tempfile +import pickle +import thriftpy + +try: + import dbm +except ImportError: + import dbm.ndbm as dbm + +import pytest + +from thriftpy.contrib.tracking import TTrackedProcessor, TTrackedClient, \ + TrackerBase, trace_thrift +from thriftpy.contrib.tracking.tracker import ctx + +from thriftpy.thrift import TProcessorFactory, TClient, TProcessor +from thriftpy.server import TThreadedServer +from thriftpy.transport import TServerSocket, TBufferedTransportFactory, \ + TTransportException, TSocket +from thriftpy.protocol import TBinaryProtocolFactory + + +addressbook = thriftpy.load(os.path.join(os.path.dirname(__file__), + "addressbook.thrift")) +_, db_file = tempfile.mkstemp() + + +class SampleTracker(TrackerBase): + def record(self, header, exception): + db = dbm.open(db_file, 'w') + + key = "%s:%d" % (header.request_id, header.seq) + db[key.encode("ascii")] = pickle.dumps(header.__dict__) + db.close() + +tracker = SampleTracker("test_client", "test_server") + + +class Dispatcher(object): + def __init__(self): + self.ab = addressbook.AddressBook() + self.ab.people = {} + + def ping(self): + return True + + def hello(self, name): + return "hello %s" % name + + def remove(self, name): + person = addressbook.Person(name="mary") + with client(port=6098) as c: + c.add(person) + return True + + def add(self, person): + with client(port=6099) as c: + c.hello("jane") + return True + + def get(self, name): + raise addressbook.PersonNotExistsError() + + +class TSampleServer(TThreadedServer): + def __init__(self, processor_factory, trans, trans_factory, prot_factory): + self.daemon = False + self.processor_factory = processor_factory + self.trans = trans + + self.itrans_factory = self.otrans_factory = trans_factory + self.iprot_factory = self.oprot_factory = prot_factory + self.closed = False + + def handle(self, client): + processor = self.processor_factory.get_processor() + itrans = self.itrans_factory.get_transport(client) + otrans = self.otrans_factory.get_transport(client) + iprot = self.iprot_factory.get_protocol(itrans) + oprot = self.oprot_factory.get_protocol(otrans) + try: + while True: + processor.process(iprot, oprot) + except TTransportException: + pass + except Exception: + raise + + itrans.close() + otrans.close() + + +def gen_server(port=6029, tracker=tracker, processor=TTrackedProcessor): + args = [processor, addressbook.AddressBookService, Dispatcher()] + if tracker: + args.insert(1, tracker) + processor = TProcessorFactory(*args) + server_socket = TServerSocket(host="localhost", port=port) + server = TSampleServer(processor, server_socket, + prot_factory=TBinaryProtocolFactory(), + trans_factory=TBufferedTransportFactory()) + ps = multiprocessing.Process(target=server.serve) + ps.start() + return ps, server + + +@pytest.fixture +def server(request): + ps, ser = gen_server() + time.sleep(0.15) + + def fin(): + if ps.is_alive(): + ps.terminate() + request.addfinalizer(fin) + return ser + + +@pytest.fixture +def server1(request): + ps, ser = gen_server(port=6098) + time.sleep(0.15) + + def fin(): + if ps.is_alive(): + ps.terminate() + request.addfinalizer(fin) + return ser + + +@pytest.fixture +def server2(request): + ps, ser = gen_server(port=6099) + time.sleep(0.15) + + def fin(): + if ps.is_alive(): + ps.terminate() + request.addfinalizer(fin) + return ser + + +@pytest.fixture +def not_tracked_server(request): + ps, ser = gen_server(port=6030, tracker=None, processor=TProcessor) + time.sleep(0.15) + + def fin(): + if ps.is_alive(): + ps.terminate() + request.addfinalizer(fin) + return ser + + +@contextlib.contextmanager +def client(client_class=TTrackedClient, port=6029): + socket = TSocket("localhost", port) + + try: + trans = TBufferedTransportFactory().get_transport(socket) + proto = TBinaryProtocolFactory().get_protocol(trans) + trans.open() + args = [addressbook.AddressBookService, proto] + if client_class.__name__ == TTrackedClient.__name__: + args.insert(0, tracker) + yield client_class(*args) + finally: + trans.close() + + +@pytest.fixture +def dbm_db(request): + db = dbm.open(db_file, 'n') + db.close() + + def fin(): + try: + os.remove(db_file) + except OSError: + pass + request.addfinalizer(fin) + + +def test_negotiation(server): + with client() as c: + assert c._upgraded is True + + +def test_tracker(server, dbm_db): + with client() as c: + c.ping() + + time.sleep(0.2) + + db = dbm.open(db_file, 'r') + headers = list(db.keys()) + assert len(headers) == 1 + + request_id = headers[0] + data = pickle.loads(db[request_id]) + + assert "start" in data and "end" in data + data.pop("start") + data.pop("end") + + assert data == { + "request_id": request_id.decode("ascii").split(':')[0], + "seq": 0, + "client": "test_client", + "server": "test_server", + "api": "ping", + "status": True + } + + +def test_tracker_chain(server, server1, server2, dbm_db): + with client() as c: + c.remove("jane") + + time.sleep(0.2) + + db = dbm.open(db_file, 'r') + headers = list(db.keys()) + assert len(headers) == 3 + + headers.sort() + + header0 = pickle.loads(db[headers[0]]) + header1 = pickle.loads(db[headers[1]]) + header2 = pickle.loads(db[headers[2]]) + + assert header0["request_id"] == header1["request_id"] == \ + header2["request_id"] == headers[0].decode("ascii").split(':')[0] + assert header0["seq"] == 0 and header1["seq"] == 1 and header2["seq"] == 2 + + +def test_exception(server, dbm_db): + with pytest.raises(addressbook.PersonNotExistsError): + with client() as c: + c.get("jane") + + db = dbm.open(db_file, 'r') + headers = list(db.keys()) + assert len(headers) == 1 + + header = pickle.loads(db[headers[0]]) + assert header["status"] is False + + +def test_not_tracked_client_tracked_server(server): + with client(TClient) as c: + c.ping() + c.hello("world") + + +def test_tracked_client_not_tracked_server(not_tracked_server): + with client(port=6030) as c: + assert c._upgraded is False + c.ping() + c.hello("cat") + + +def test_request_id_func(): + ctx.__dict__.clear() + + header = trace_thrift.RequestHeader() + header.request_id = "hello" + header.seq = 0 + + tracker = TrackerBase() + tracker.handle(header) + + header2 = trace_thrift.RequestHeader() + tracker.gen_header(header2) + assert header2.request_id == "hello" diff --git a/thriftpy/contrib/__init__.py b/thriftpy/contrib/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/thriftpy/contrib/tracking/__init__.py b/thriftpy/contrib/tracking/__init__.py new file mode 100644 index 0000000..99dbc85 --- /dev/null +++ b/thriftpy/contrib/tracking/__init__.py @@ -0,0 +1,169 @@ +# -*- coding: utf-8 -*- + +""" +Tracking support similar to twitter finagle-thrift. + +Note: When using tracking, every client should have a corresponding +server processor. +""" + +from __future__ import absolute_import + +import os.path +import time + +from ...thrift import TClient, TApplicationException, TMessageType, \ + TProcessor, TType +from ...parser import load + +trace_method = "__thriftpy_tracing_method_name__" +trace_thrift = load(os.path.join(os.path.dirname(__file__), "tracking.thrift")) + + +__all__ = ["TTrackedClient", "TTrackedProcessor", "TrackerBase", + "ConsoleTracker"] + + +class TTrackedClient(TClient): + def __init__(self, tracker_handler, *args, **kwargs): + super(TTrackedClient, self).__init__(*args, **kwargs) + + self.tracer = tracker_handler + self._upgraded = False + + try: + self._negotiation() + self._upgraded = True + except TApplicationException as e: + if e.type != TApplicationException.UNKNOWN_METHOD: + raise + + def _negotiation(self): + self._oprot.write_message_begin(trace_method, TMessageType.CALL, + self._seqid) + args = trace_thrift.UpgradeArgs() + args.write(self._oprot) + self._oprot.write_message_end() + self._oprot.trans.flush() + + api, msg_type, seqid = self._iprot.read_message_begin() + if msg_type == TMessageType.EXCEPTION: + x = TApplicationException() + x.read(self._iprot) + self._iprot.read_message_end() + raise x + else: + result = trace_thrift.UpgradeReply() + result.read(self._iprot) + self._iprot.read_message_end() + + def _send(self, _api, **kwargs): + if self._upgraded: + self._header = trace_thrift.RequestHeader() + self.tracer.gen_header(self._header) + self._header.write(self._oprot) + + self.send_start = int(time.time() * 1000) + super(TTrackedClient, self)._send(_api, **kwargs) + + def _req(self, _api, *args, **kwargs): + if not self._upgraded: + return super(TTrackedClient, self)._req(_api, *args, **kwargs) + + exception = None + try: + res = super(TTrackedClient, self)._req(_api, *args, **kwargs) + status = True + return res + except Exception as e: + exception = e + status = False + raise + finally: + header_info = trace_thrift.RequestInfo( + request_id=self._header.request_id, + seq=self._header.seq, + client=self.tracer.client, + server=self.tracer.server, + api=_api, + status=status, + start=self.send_start, + end=int(time.time() * 1000) + ) + self.tracer.record(header_info, exception) + + +class TTrackedProcessor(TProcessor): + def __init__(self, tracker_handler, *args, **kwargs): + super(TTrackedProcessor, self).__init__(*args, **kwargs) + + self.tracer = tracker_handler + self._upgraded = False + + def process(self, iprot, oprot): + if not self._upgraded: + res = self._try_upgrade(iprot) + else: + request_header = trace_thrift.RequestHeader() + request_header.read(iprot) + self.tracer.handle(request_header) + res = super(TTrackedProcessor, self).process_in(iprot) + + self._do_process(iprot, oprot, *res) + + def _try_upgrade(self, iprot): + api, msg_type, seqid = iprot.read_message_begin() + if msg_type == TMessageType.CALL and api == trace_method: + self._upgraded = True + + args = trace_thrift.UpgradeArgs() + args.read(iprot) + result = trace_thrift.UpgradeReply() + result.oneway = False + + def call(): + pass + iprot.read_message_end() + else: + result, call = self._process_in(api, iprot) + + return api, seqid, result, call + + def _process_in(self, api, iprot): + if api not in self._service.thrift_services: + iprot.skip(TType.STRUCT) + iprot.read_message_end() + return TApplicationException( + TApplicationException.UNKNOWN_METHOD), None + + args = getattr(self._service, api + "_args")() + args.read(iprot) + iprot.read_message_end() + result = getattr(self._service, api + "_result")() + + # convert kwargs to args + api_args = [args.thrift_spec[k][1] + for k in sorted(args.thrift_spec)] + + def call(): + return getattr(self._handler, api)( + *(args.__dict__[k] for k in api_args) + ) + + return result, call + + def _do_process(self, iprot, oprot, api, seqid, result, call): + if isinstance(result, TApplicationException): + return self.send_exception(oprot, api, result, seqid) + + try: + result.success = call() + except Exception as e: + # raise if api don't have throws + self.handle_exception(e, result) + + if not result.oneway: + self.send_result(oprot, api, result, seqid) + + +from .tracker import TrackerBase, ConsoleTracker # noqa diff --git a/thriftpy/contrib/tracking/tracker.py b/thriftpy/contrib/tracking/tracker.py new file mode 100644 index 0000000..0e6131c --- /dev/null +++ b/thriftpy/contrib/tracking/tracker.py @@ -0,0 +1,34 @@ +# -*- coding: utf-8 -*- + +from __future__ import absolute_import + +import threading +import uuid + +ctx = threading.local() + + +class TrackerBase(object): + def __init__(self, client=None, server=None): + self.client = client + self.server = server + + def handle(self, header): + ctx.header = header + + def gen_header(self, header): + header.request_id = self.get_request_id() + header.seq = (ctx.header.seq + 1) if hasattr(ctx, "header") else 0 + + def record(self, header, exception): + pass + + def get_request_id(self): + if hasattr(ctx, "header"): + return ctx.header.request_id + return str(uuid.uuid4()) + + +class ConsoleTracker(TrackerBase): + def record(self, header, exception): + print(header) diff --git a/thriftpy/contrib/tracking/tracking.thrift b/thriftpy/contrib/tracking/tracking.thrift new file mode 100644 index 0000000..9e571f9 --- /dev/null +++ b/thriftpy/contrib/tracking/tracking.thrift @@ -0,0 +1,27 @@ +/* + * Used to store call info. + */ +struct RequestInfo { + 1: string request_id // used to identify a request + 2: string api // api name + 3: i32 seq // sequence number + 4: string client // client name + 5: string server // server name + 6: bool status // request status + 7: i64 start // start timestamp + 8: i64 end // end timestamp +} + +/* + * This is the structure used to send call info to server. + */ +struct RequestHeader { + 1: string request_id + 2: i32 seq +} + +/** + * This is the struct that a successful upgrade will reply with. + */ +struct UpgradeReply {} +struct UpgradeArgs {} diff --git a/thriftpy/thrift.py b/thriftpy/thrift.py index 9644c0b..ffbed4e 100644 --- a/thriftpy/thrift.py +++ b/thriftpy/thrift.py @@ -299,6 +299,17 @@ def call(): return api, seqid, result, call +class TProcessorFactory(object): + def __init__(self, processor_class, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + self.processor_class = processor_class + + def get_processor(self): + return self.processor_class(*self.args, **self.kwargs) + + class TException(TPayload, Exception): """Base class for all thrift exceptions."""