Skip to content

Commit

Permalink
Fix missing message handling because of wrong callback sequence (#14)
Browse files Browse the repository at this point in the history
  • Loading branch information
amyangfei committed Sep 10, 2022
1 parent 58d7e38 commit 7253905
Show file tree
Hide file tree
Showing 8 changed files with 109 additions and 39 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yaml
Expand Up @@ -17,7 +17,7 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.5", "3.6", "3.7", "3.8", "3.9", "3.10"]
python-version: ["3.6", "3.7", "3.8", "3.9", "3.10"]

steps:
- uses: actions/checkout@v3
Expand Down
4 changes: 3 additions & 1 deletion gor/asyncio_impl.py
Expand Up @@ -5,12 +5,14 @@
import asyncio

from .base import Gor
from .callback import SimpleCallbackContainer


class AsyncioGor(Gor):

def __init__(self, *args, **kwargs):
super(AsyncioGor, self).__init__(*args, **kwargs)
chan_container = SimpleCallbackContainer()
super(AsyncioGor, self).__init__(chan_container, *args, **kwargs)
self.q = asyncio.Queue()
self.concurrency = kwargs.get('concurrency', 2)
self.tasks = []
Expand Down
21 changes: 7 additions & 14 deletions gor/base.py
Expand Up @@ -28,9 +28,9 @@ def __init__(self, _id, _type, meta, raw_meta, http):

class Gor(object):

def __init__(self, stderr=None):
self.stderr = stderr or sys.stderr
self.ch = {}
def __init__(self, chan_container, *args, **kwargs):
self.stderr = sys.stderr
self.chan_container = chan_container

def run(self):
raise NotImplementedError
Expand All @@ -39,12 +39,7 @@ def on(self, chan, callback, idx=None, **kwargs):
if idx is not None:
chan = chan + '#' + idx

self.ch.setdefault(chan, [])
self.ch[chan].append({
'created': datetime.datetime.now(),
'callback': callback,
'kwargs': kwargs,
})
self.chan_container.add(chan, callback, **kwargs)
return self

def emit(self, msg):
Expand All @@ -56,11 +51,9 @@ def emit(self, msg):
chan_prefix = chan_prefix_map[msg.type]
resp = msg
for chan_id in ['message', chan_prefix, chan_prefix + '#' + msg.id]:
if self.ch.get(chan_id):
for channel in self.ch[chan_id]:
r = channel['callback'](self, msg, **channel['kwargs'])
if r:
resp = r
r = self.chan_container.do_callback(self, chan_id, msg)
if r:
resp = r
if resp:
sys.stdout.write(gor_hex_data(resp))
sys.stdout.flush()
Expand Down
55 changes: 55 additions & 0 deletions gor/callback.py
@@ -0,0 +1,55 @@
# coding: utf-8

import sys
from datetime import datetime
from typing import Callable


class CallbackContainer(object):

def __init__(self, *args, **kwargs):
self.ch = None

def init_container(self):
raise NotImplementedError

def ensure_chan(self, chan: str):
raise NotImplementedError

def add(self, chan: str, callback: Callable, **kwargs):
self.ensure_chan(chan)
self.ch[chan].append({
'created': datetime.now(),
'callback': callback,
'kwargs': kwargs,
})

def do_callback(self, gor, chan_id: str, msg) -> str:
resp = ''
if self.ch.get(chan_id):
for channel in self.ch[chan_id]:
r = channel['callback'](gor, msg, **channel['kwargs'])
if r:
resp = r
return resp


class MultiProcessCallbackContainer(CallbackContainer):

def __init__(self, manager, *args, **kwargs):
super(CallbackContainer, self).__init__(*args, **kwargs)
self.manager = manager
self.ch = self.manager.dict()

def ensure_chan(self, chan: str):
self.ch.setdefault(chan, self.manager.list())


class SimpleCallbackContainer(CallbackContainer):

def __init__(self, *args, **kwargs):
super(CallbackContainer, self).__init__(*args, **kwargs)
self.ch = {}

def ensure_chan(self, chan: str):
self.ch.setdefault(chan, [])
39 changes: 25 additions & 14 deletions gor/multiprocess_impl.py
Expand Up @@ -4,17 +4,19 @@
import multiprocessing

from .base import Gor
from .callback import MultiProcessCallbackContainer


EXIT_MSG = ""

class MultiProcessGor(Gor):

def __init__(self, *args, **kwargs):
super(MultiProcessGor, self).__init__(*args, **kwargs)
self.q = multiprocessing.JoinableQueue()
chan_container = MultiProcessCallbackContainer(multiprocessing.Manager())
super(MultiProcessGor, self).__init__(chan_container, *args, **kwargs)
self.concurrency = kwargs.get('concurrency', 2)
self.workers = []
self.queues = []

def _stdin_reader(self):
while True:
Expand All @@ -26,29 +28,38 @@ def _stdin_reader(self):
if not line:
self._stop()
break
self.q.put(line)
msg = self.parse_message(line)
if msg:
# messages with the same id must be processed in serializable way
index = hash(msg.id) % len(self.queues)
self.queues[index].put(msg)

def _worker(self):
def _worker(self, queue):
while True:
line = self.q.get()
try:
if line == EXIT_MSG:
msg = queue.get()
except KeyboardInterrupt:
break
try:
if msg == EXIT_MSG:
return
msg = self.parse_message(line)
if msg:
self.emit(msg)
self.emit(msg)
finally:
self.q.task_done()
queue.task_done()

def _stop(self):
for _ in range(len(self.workers)):
self.q.put(EXIT_MSG)
self.q.join()
for queue in self.queues:
queue.put(EXIT_MSG)
queue.join()

def run(self):
for i in range(self.concurrency):
worker = multiprocessing.Process(target=self._worker)
queue = multiprocessing.JoinableQueue()
worker = multiprocessing.Process(target=self._worker, args=(queue,))
self.queues.append(queue)
self.workers.append(worker)
for worker in self.workers:
worker.start()
self._stdin_reader()
for worker in self.workers:
worker.join()
2 changes: 1 addition & 1 deletion tests/test_asyncio_gor.py
Expand Up @@ -41,7 +41,7 @@ def test_init(self):
self.gor.on('message', _incr_received, passby=passby)
self.gor.on('request', _incr_received, passby=passby)
self.gor.on('response', _incr_received, idx='2', passby=passby)
self.assertEqual(len(self.gor.ch), 3)
self.assertEqual(len(self.gor.chan_container.ch), 3)

req = self.gor.parse_message(binascii.hexlify(b'1 2 3\nGET / HTTP/1.1\r\n\r\n'))
resp = self.gor.parse_message(binascii.hexlify(b'2 2 3\nHTTP/1.1 200 OK\r\n\r\n'))
Expand Down
3 changes: 2 additions & 1 deletion tests/test_gor.py
Expand Up @@ -4,12 +4,13 @@
import unittest

from gor.base import Gor
from gor.callback import SimpleCallbackContainer


class TestCommon(unittest.TestCase):

def setUp(self):
self.gor = Gor()
self.gor = Gor(SimpleCallbackContainer())

def tearDown(self):
pass
Expand Down
22 changes: 15 additions & 7 deletions tests/test_multiprocess_gor.py
Expand Up @@ -10,26 +10,34 @@


class Counter(object):
def __init__(self):
def __init__(self, lock):
self.val = multiprocessing.Value('i', 0)
self.lock = lock

def increment(self, n=1):
with self.val.get_lock():
with self.lock:
self.val.value += n

@property
def value(self):
return self.val.value
with self.lock:
return self.val.value


counter = Counter(multiprocessing.Manager().Lock())
counters = {hash(counter): counter}


def _incr_received(proxy, msg, **kwargs):
kwargs['passby']['received'].increment()
h = kwargs['passby']['counter']
if h in counters:
counters[h].increment()


class TestMultiProcessGor(unittest.TestCase):

def setUp(self):
self.gor = MultiProcessGor()
self.gor = MultiProcessGor(concurrency=4)

def tearDown(self):
pass
Expand All @@ -43,13 +51,13 @@ def _proxy_coroutine(self, passby):

def test_run(self):
old_stdin = sys.stdin
passby = {'received': Counter()}
passby = {'counter': hash(counter)}
payload = "\n".join([
binascii.hexlify(b'1 2 3\nGET / HTTP/1.1\r\n\r\n').decode("utf-8"),
binascii.hexlify(b'2 2 3\nHTTP/1.1 200 OK\r\n\r\n').decode("utf-8"),
binascii.hexlify(b'2 3 3\nHTTP/1.1 200 OK\r\n\r\n').decode("utf-8"),
])
sys.stdin = io.StringIO(payload)
self._proxy_coroutine(passby)
self.assertEqual(passby['received'].value, 5)
self.assertEqual(counter.value, 5)
sys.stdin = old_stdin

0 comments on commit 7253905

Please sign in to comment.