Skip to content
This repository has been archived by the owner on Dec 10, 2018. It is now read-only.

Commit

Permalink
add tracking support
Browse files Browse the repository at this point in the history
  • Loading branch information
maralla committed Mar 9, 2015
1 parent d24912d commit 01184c0
Show file tree
Hide file tree
Showing 8 changed files with 528 additions and 3 deletions.
6 changes: 3 additions & 3 deletions MANIFEST.in
@@ -1,4 +1,4 @@
include README.rst CHANGES.rst
include thriftpy/protocol/cybin/*.pyx thriftpy/protocol/cybin/*.c thriftpy/protocol/cybin/*.h
include thriftpy/transport/*/*.pyx thriftpy/transport/*/*.c
include thriftpy/transport/*.pyx thriftpy/transport/*.pxd thriftpy/transport/*.c
recursive-include thriftpy/protocol/cybin *.pyx *.c *.h
recursive-include thriftpy/transport *.pyx *.pxd *.c
include thriftpy/contrib/tracking/tracking.thrift
1 change: 1 addition & 0 deletions setup.py
Expand Up @@ -68,6 +68,7 @@
author="Lx Yu",
author_email="i@lxyu.net",
packages=find_packages(exclude=['benchmark', 'docs', 'tests']),
package_data={"thriftpy": ["contrib/tracking/tracking.thrift"]},
entry_points={},
url="https://thriftpy.readthedocs.org/",
license="MIT",
Expand Down
283 changes: 283 additions & 0 deletions tests/test_tracking.py
@@ -0,0 +1,283 @@
# -*- coding: utf-8 -*-

from __future__ import absolute_import

import contextlib
import os
import multiprocessing
import time
import tempfile
import pickle
import thriftpy

try:
import dbm
except ImportError:
import dbm.ndbm as dbm

import pytest

from thriftpy.contrib.tracking import TTrackedProcessor, TTrackedClient, \
TrackerBase, trace_thrift
from thriftpy.contrib.tracking.tracker import ctx

from thriftpy.thrift import TProcessorFactory, TClient, TProcessor
from thriftpy.server import TThreadedServer
from thriftpy.transport import TServerSocket, TBufferedTransportFactory, \
TTransportException, TSocket
from thriftpy.protocol import TBinaryProtocolFactory


addressbook = thriftpy.load(os.path.join(os.path.dirname(__file__),
"addressbook.thrift"))
_, db_file = tempfile.mkstemp()


class SampleTracker(TrackerBase):
def record(self, header, exception):
db = dbm.open(db_file, 'w')

key = "%s:%d" % (header.request_id, header.seq)
db[key.encode("ascii")] = pickle.dumps(header.__dict__)
db.close()

tracker = SampleTracker("test_client", "test_server")


class Dispatcher(object):
def __init__(self):
self.ab = addressbook.AddressBook()
self.ab.people = {}

def ping(self):
return True

def hello(self, name):
return "hello %s" % name

def remove(self, name):
person = addressbook.Person(name="mary")
with client(port=6098) as c:
c.add(person)
return True

def add(self, person):
with client(port=6099) as c:
c.hello("jane")
return True

def get(self, name):
raise addressbook.PersonNotExistsError()


class TSampleServer(TThreadedServer):
def __init__(self, processor_factory, trans, trans_factory, prot_factory):
self.daemon = False
self.processor_factory = processor_factory
self.trans = trans

self.itrans_factory = self.otrans_factory = trans_factory
self.iprot_factory = self.oprot_factory = prot_factory
self.closed = False

def handle(self, client):
processor = self.processor_factory.get_processor()
itrans = self.itrans_factory.get_transport(client)
otrans = self.otrans_factory.get_transport(client)
iprot = self.iprot_factory.get_protocol(itrans)
oprot = self.oprot_factory.get_protocol(otrans)
try:
while True:
processor.process(iprot, oprot)
except TTransportException:
pass
except Exception:
raise

itrans.close()
otrans.close()


def gen_server(port=6029, tracker=tracker, processor=TTrackedProcessor):
args = [processor, addressbook.AddressBookService, Dispatcher()]
if tracker:
args.insert(1, tracker)
processor = TProcessorFactory(*args)
server_socket = TServerSocket(host="localhost", port=port)
server = TSampleServer(processor, server_socket,
prot_factory=TBinaryProtocolFactory(),
trans_factory=TBufferedTransportFactory())
ps = multiprocessing.Process(target=server.serve)
ps.start()
return ps, server


@pytest.fixture
def server(request):
ps, ser = gen_server()
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def server1(request):
ps, ser = gen_server(port=6098)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def server2(request):
ps, ser = gen_server(port=6099)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@pytest.fixture
def not_tracked_server(request):
ps, ser = gen_server(port=6030, tracker=None, processor=TProcessor)
time.sleep(0.15)

def fin():
if ps.is_alive():
ps.terminate()
request.addfinalizer(fin)
return ser


@contextlib.contextmanager
def client(client_class=TTrackedClient, port=6029):
socket = TSocket("localhost", port)

try:
trans = TBufferedTransportFactory().get_transport(socket)
proto = TBinaryProtocolFactory().get_protocol(trans)
trans.open()
args = [addressbook.AddressBookService, proto]
if client_class.__name__ == TTrackedClient.__name__:
args.insert(0, tracker)
yield client_class(*args)
finally:
trans.close()


@pytest.fixture
def dbm_db(request):
db = dbm.open(db_file, 'n')
db.close()

def fin():
try:
os.remove(db_file)
except OSError:
pass
request.addfinalizer(fin)


def test_negotiation(server):
with client() as c:
assert c._upgraded is True


def test_tracker(server, dbm_db):
with client() as c:
c.ping()

time.sleep(0.2)

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 1

request_id = headers[0]
data = pickle.loads(db[request_id])

assert "start" in data and "end" in data
data.pop("start")
data.pop("end")

assert data == {
"request_id": request_id.decode("ascii").split(':')[0],
"seq": 0,
"client": "test_client",
"server": "test_server",
"api": "ping",
"status": True
}


def test_tracker_chain(server, server1, server2, dbm_db):
with client() as c:
c.remove("jane")

time.sleep(0.2)

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 3

headers.sort()

header0 = pickle.loads(db[headers[0]])
header1 = pickle.loads(db[headers[1]])
header2 = pickle.loads(db[headers[2]])

assert header0["request_id"] == header1["request_id"] == \
header2["request_id"] == headers[0].decode("ascii").split(':')[0]
assert header0["seq"] == 0 and header1["seq"] == 1 and header2["seq"] == 2


def test_exception(server, dbm_db):
with pytest.raises(addressbook.PersonNotExistsError):
with client() as c:
c.get("jane")

db = dbm.open(db_file, 'r')
headers = list(db.keys())
assert len(headers) == 1

header = pickle.loads(db[headers[0]])
assert header["status"] is False


def test_not_tracked_client_tracked_server(server):
with client(TClient) as c:
c.ping()
c.hello("world")


def test_tracked_client_not_tracked_server(not_tracked_server):
with client(port=6030) as c:
assert c._upgraded is False
c.ping()
c.hello("cat")


def test_request_id_func():
ctx.__dict__.clear()

header = trace_thrift.RequestHeader()
header.request_id = "hello"
header.seq = 0

tracker = TrackerBase()
tracker.handle(header)

header2 = trace_thrift.RequestHeader()
tracker.gen_header(header2)
assert header2.request_id == "hello"
Empty file added thriftpy/contrib/__init__.py
Empty file.

0 comments on commit 01184c0

Please sign in to comment.