Skip to content

Commit

Permalink
Persist invocation results
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhukovAlexander committed Oct 10, 2016
1 parent e190496 commit 2303a5a
Show file tree
Hide file tree
Showing 6 changed files with 29 additions and 54 deletions.
39 changes: 7 additions & 32 deletions rafter/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import itertools
import json
from uuid import uuid4
import asyncio

import lmdb
from msgpack import packb, unpackb
Expand Down Expand Up @@ -163,49 +164,23 @@ def __len__(self):
return txn.stat()['entries']


class AsyncDictWrapper:

"""
async def notify(c):
async with c:
c.notify_all()
n = 0
def wait(c):
global n
if n >= 4:
return n
n += 1
asyncio.ensure_future(notify(c))
return False
async def get_await():
c = asyncio.Condition()
async with c:
print(await c.wait_for(lambda: wait(c)))
import asyncio
print(asyncio.get_event_loop().run_until_complete(get_await()))
"""
import asyncio


class AsyncDict(dict):

def __init__(self, *args, **kwargs, *, loop=None):
super().__init__(*args, **kwargs)
def __init__(self, d, loop=None):
self._d = d
self._loop = loop or asyncio.get_event_loop()
self._waiters = collections.defaultdict(lambda: asyncio.Condition(loop=self._loop))

async def wait_for(self, key):
try:
return super().__getitem__(key)
return self._d[key]
except KeyError:
c = self._waiters[key]
async with c:
return await c.wait_for(lambda: super(self.__class__, self).get(key))
return await c.wait_for(lambda: self._d.get(key))

async def set(self, key, value):
c = self._waiters[key]
async with c:
super().__setitem__(key, value)
self._d[key] = value
c.notify_all()
1 change: 1 addition & 0 deletions rafter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class KwargsType(types.BaseType):
class LogEntry(MsgpackModel):
index = types.IntType()
term = types.IntType()
uuid = types.UUIDType()
command = types.StringType()
args = ArgsType()
kwargs = KwargsType()
Expand Down
27 changes: 14 additions & 13 deletions rafter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import sys
import random
import logging
import pickle
import os
import json
from collections import defaultdict
Expand Down Expand Up @@ -42,6 +41,7 @@ def __init__(self,
address=('0.0.0.0', 10000),
log=None,
storage=None,
invocations=None,
loop=None,
server_protocol=UPDProtocolMsgPackServer,
config=None,
Expand All @@ -52,6 +52,8 @@ def __init__(self,
self.log = log if log is not None else rlog.RaftLog()
self.storage = storage or rlog.Storage()

self.invocations = rlog.AsyncDictWrapper(invocations or {})

self.bootstrap = bootstrap

if self.bootstrap:
Expand Down Expand Up @@ -132,25 +134,24 @@ def handle(self, message_type, **kwargs):
"""Dispatch to the appropriate state method"""
return getattr(self.state, message_type)(**kwargs)

async def _apply_single(self, cmd, args, kwargs, index=None):
async def _apply_single(self, cmd, invocation_id, args, kwargs, index=None):

try:
res = await getattr(self.service, cmd).apply(*args, **kwargs) if cmd else None
except Exception as e:
logger.exception('Exception during command invocation')
raise
res = dict(result=await getattr(self.service, cmd).apply(*args, **kwargs) if cmd else None)

if index is None:
return res
except Exception as e:
logger.exception('Exception during a command invocation')
res = dict(error=True, msg=str(e))

# notify waiting client
can_apply = self.pending_events.get(index)
if can_apply:
can_apply.set()
finally:
if invocation_id is None:
return res['result']
await self.invocations.set(invocation_id, res)

def apply_commited(self, start, end):
return asyncio.ensure_future(asyncio.wait(map(
lambda entry: asyncio.ensure_future(self._apply_single(entry.command,
entry.uuid,
entry.args,
entry.kwargs,
index=entry.index)),
Expand Down Expand Up @@ -188,7 +189,7 @@ async def handle_write_command(self, slug, *args, **kwargs):
async def handle_read_command(self, command, *args, **kwargs):
if not self.state.is_leader():
raise NotLeaderException('This server is not a leader')
return await self._apply_single(command, args, kwargs)
return await self._apply_single(command, None, args, kwargs)

async def send_append_entries(self, entries=(), destination=('239.255.255.250', 10000)):
prev = self.log[entries[0].index - 1] if entries else None
Expand Down
6 changes: 3 additions & 3 deletions rafter/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,11 @@ def __get__(self, instance, owner):
self._server = instance._server
return self

async def __call__(self, *args, **kwargs):
async def __call__(self, *args, invocation_id=None, **kwargs):
if not self._service:
raise UnboundExposedCommand()
if self._write:
return await self._server.handle_write_command(self.slug, *args, **kwargs)
return await self._server.handle_write_command(self.slug, invocation_id, *args, **kwargs)
return await self._server.handle_read_command(self.slug, *args, **kwargs)

async def apply(self, *args, **kwargs):
Expand Down Expand Up @@ -117,8 +117,8 @@ def encode(data):
class JsonRpcHttpRequestHandler(aiohttp.server.ServerHttpProtocol):

def __init__(self, service, *args, **kwargs):
self._service = service
super().__init__(*args, **kwargs)
self._service = service

async def handle_request(self, message, payload):

Expand Down
1 change: 0 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@

setup(
name="rafter",
version="0.1.0",
keywords="python raft distributed replication",
packages=['rafter', ],
install_requires=open('requirements.txt', 'r').read().splitlines(),
Expand Down
9 changes: 4 additions & 5 deletions tests/test_server.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from unittest import mock

import uuid
import asyncio

from rafter.server import RaftServer
Expand Down Expand Up @@ -51,15 +51,14 @@ def test_handle_calls_correct_state_method(self):
res = self.server.handle(method)
getattr(self.server.state, method).assert_called_with()


def test_maybe_commit_should_notify_clients(self):
entry = LogEntry(dict(index=0, term=self.server.term, command='test', args=(), kwargs={}))
entry = LogEntry(dict(index=0, uuid=uuid.uuid4(), term=self.server.term, command='test', args=(), kwargs={}))
self.server.log.append(entry)
self.server.pending_events[entry.index] = asyncio.Event()
waiter = asyncio.ensure_future(self.server.invocations.wait_for(entry.uuid))
res = self.loop.run_until_complete(
self.server.maybe_commit(self.server.id, self.server.term, entry.index)
)
self.assertTrue(self.server.pending_events[entry.index].is_set())
self.assertTrue(self.loop.run_until_complete(waiter))
self.assertEqual(self.server.log.commit_index, entry.index)

def test_handle_write_command_should_send_append_entries(self):
Expand Down

0 comments on commit 2303a5a

Please sign in to comment.