diff --git a/rafter/server.py b/rafter/server.py index adcc9ea..81ae7a7 100644 --- a/rafter/server.py +++ b/rafter/server.py @@ -8,15 +8,17 @@ import os import json from collections import defaultdict +from uuid import uuid4 import uvloop from . import serverstate -from . import storage as storage +from . import storage as store from . import models from .network import UPDProtocolMsgPackServer, make_socket, ResetablePeriodicTask from .exceptions import NotLeaderException +from .utils import AsyncDictWrapper # @@ -41,7 +43,6 @@ def __init__(self, address=('0.0.0.0', 10000), log=None, storage=None, - invocations=None, loop=None, server_protocol=UPDProtocolMsgPackServer, config=None, @@ -49,15 +50,17 @@ def __init__(self, self.host, self.port = address - self.log = log if log is not None else storage.RaftLog() - self.storage = storage or storage.Storage() + self.log = log if log is not None else store.RaftLog() + self.storage = storage or store.PersistentDict() - self.invocations = storage.AsyncDictWrapper(invocations or {}) + a_wrapper = AsyncDictWrapper(storage or {}) + + self.wait_for, self.set_result = a_wrapper.wait_for, a_wrapper.set self.bootstrap = bootstrap if self.bootstrap: - self.storage.peers = {self.id: {'id': self.id}} + self.storage['peers'] = {self.id: {'id': self.id}} self.match_index = defaultdict(lambda: self.log.commit_index) self.next_index = defaultdict(lambda: self.log.commit_index + 1) @@ -88,27 +91,27 @@ def new_heartbeat(bootstraps=bootstraps): @property def id(self): # pragma: nocover - return self.storage.id + return self.storage.setdefault('id', uuid4().hex) @property def term(self): # pragma: nocover - return self.storage.term + return self.storage.setdefault('term', 0) @term.setter def term(self, value): # pragma: nocover - self.storage.term = value + self.storage['term'] = value @property def voted_for(self): # pragma: nocover - return self.storage.voted_for + return self.storage.get('voted_for') @voted_for.setter def voted_for(self, value): # pragma: nocover - self.storage.voted_for = value + self.storage['voted_for'] = value @property def peers(self): # pragma: nocover - return self.storage.peers + return self.storage.setdefault('peers', {}) def start(self): @@ -146,7 +149,7 @@ async def _apply_single(self, cmd, invocation_id, args, kwargs, index=None): finally: if invocation_id is None: return res['result'] - await self.invocations.set(invocation_id, res) + await self.set_result(invocation_id, res) def apply_commited(self, start, end): return asyncio.ensure_future(asyncio.wait(map( @@ -216,12 +219,12 @@ def broadcast_request_vote(self): def add_peer(self, peer): # - self.storage.peers = {**self.storage.peers, **{peer['id']: peer}} + self.storage['peers'] = {**self.storage['peers'], **{peer['id']: peer}} def remove_peer(self, peer_id): - peers = self.storage.peers + peers = self.storage['peers'] del peers[peer_id] - self.storage.peers = peers + self.storage['peers'] = peers def list_peers(self): - return list(self.storage.peers) + return list(self.peers) diff --git a/rafter/storage.py b/rafter/storage.py index c31244a..cc9a20b 100644 --- a/rafter/storage.py +++ b/rafter/storage.py @@ -19,29 +19,6 @@ def from_bytes(b): return int.from_bytes(b, sys.byteorder) -class MetaDataField: - def __init__(self, key, from_raw=lambda x: x, to_raw=lambda x: x, default=None): - self._key = key - self._default = default - self.from_raw = from_raw - self.to_raw = to_raw - - def __get__(self, instance, owner): - if instance is not None: - with instance.env.begin(db=instance.attrs_store) as txn: - val = txn.get(self._key) - if val is None: - self.__set__(instance, self._default) - return self._default - return self.from_raw(val) - return self - - def __set__(self, instance, value): - if instance is not None: - with instance.env.begin(write=True, db=instance.attrs_store) as txn: - txn.replace(self._key, self.to_raw(value)) - - # TODO: implement dynamic serizlizer class RaftLog(collections.abc.MutableSequence): """Implement raft log on top of the LMDB storage.""" @@ -114,35 +91,19 @@ def entry(self, term, command, args=(), kwargs=None): self.append(LogEntry(dict(index=len(self), term=term, command=command, args=args, kwargs=kwargs or {}))) return self[-1] - commit_index = MetaDataField(b'commit_index', from_raw=int, to_raw=lambda x: str(x).encode(), default=0) - - -class Storage: - - def __init__(self, env=None, db=None): - - self.env = env or lmdb.open('/tmp/rafter.lmdb', max_dbs=10) - self.attrs_store = self.env.open_db(b'store') - - term = MetaDataField(b'term', from_raw=int, to_raw=lambda x: str(x).encode(), default=0) - voted_for = MetaDataField(b'voted_for', from_raw=lambda x: x.decode(), default='') - peers = MetaDataField(b'peers', from_raw=lambda x: json.loads(x.decode()), to_raw=lambda x: json.dumps(x).encode(), default={}) - id = MetaDataField(b'id', to_raw=lambda x: x.encode(), from_raw=lambda x: x.decode(), default=uuid4().hex) - class PersistentDict(collections.abc.MutableMapping): - def __init__(self, env, db, *args, **kwargs): + def __init__(self, env=None, db=None, *args, **kwargs): super().__init__(*args, **kwargs) - self.env = env - self.db = db + self.env = env or lmdb.open('/tmp/rafter.lmdb', max_dbs=10) + self.db = db or self.env.open_db(b'storage') def __setitem__(self, key, value): with self.env.begin(write=True, db=self.db) as txn: txn.replace(packb(key), packb(value)) def __getitem__(self, key): - val = None with self.env.begin(db=self.db) as txn: val = txn.get(packb(key)) if val is None: diff --git a/rafter/utils.py b/rafter/utils.py new file mode 100644 index 0000000..d740a93 --- /dev/null +++ b/rafter/utils.py @@ -0,0 +1,24 @@ +import collections +import asyncio + + +class AsyncDictWrapper: + + def __init__(self, d, loop=None): + self._d = d + self._loop = loop or asyncio.get_event_loop() + self._waiters = collections.defaultdict(lambda: asyncio.Condition(loop=self._loop)) + + async def wait_for(self, key): + try: + return self._d[key] + except KeyError: + c = self._waiters[key] + async with c: + return await c.wait_for(lambda: self._d.get(key)) + + async def set(self, key, value): + c = self._waiters[key] + async with c: + self._d[key] = value + c.notify_all() \ No newline at end of file diff --git a/tests/mocks.py b/tests/mocks.py index 87f607d..35c7765 100644 --- a/tests/mocks.py +++ b/tests/mocks.py @@ -12,10 +12,8 @@ def entry(self, term, command, args, kwargs): return self[-1] -class Storage(mock.Mock): - term = 0 - id = 'testserver' - peers = {} +class Storage(dict): + pass async def foo(*args, **kwargs): return 'result' diff --git a/tests/test_log.py b/tests/test_log.py index d82144e..ee6cf9f 100644 --- a/tests/test_log.py +++ b/tests/test_log.py @@ -5,7 +5,7 @@ import lmdb from rafter.models import LogEntry -from rafter.storage import RaftLog, Storage, MetaDataField +from rafter.storage import RaftLog class LMDBLogTest(unittest.TestCase): @@ -16,11 +16,6 @@ def setUp(self): def tearDown(self): shutil.rmtree(self.db_dir) - def test_commit_index(self): - self.assertEqual(self.log.commit_index, 0) - self.log.commit_index += 1 - self.assertEqual(self.log.commit_index, 1) - def test_setitem(self): entry = LogEntry() self.log[0] = entry @@ -94,23 +89,3 @@ def test_cmp(self): self.assertFalse(self.log.cmp(entry.index - 1, entry.term - 1)) self.assertFalse(self.log.cmp(entry.index + 1, entry.term - 1)) self.assertTrue(self.log.cmp(entry.index - 1, entry.term + 1)) - - -class StorageTest(unittest.TestCase): - - def get_storage(self): - return Storage(env=lmdb.open(self.db_dir, max_dbs=10)) - - def setUp(self): - self.db_dir = tempfile.mkdtemp() - self.storage = self.get_storage() - - def tearDown(self): - shutil.rmtree(self.db_dir) - - def test_medadata_attribute(self): - self.assertIsInstance(Storage.id, MetaDataField) - - def test_defaults(self): - self.assertEqual(self.storage.term, 0) - self.assertEqual(self.storage.id, self.get_storage().id) diff --git a/tests/test_server.py b/tests/test_server.py index 90fbad6..f3e1a3d 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -37,6 +37,7 @@ def test_start_stop(self): def test_initial_heartbeat_calls_add_peer(self): with mock.patch('rafter.server.asyncio.ensure_future') as ensure_future: self.server.heartbeat(bootstraps=True) + ensure_future.assert_called_with(self.server.service.add_peer()) def test_heartbeat_should_schedule_ae(self): @@ -83,7 +84,7 @@ def test_broadcast_request_vote(self): def test_add_peer(self): self.server.add_peer({'id': 'peer-2'}) - self.assertIn('peer-2', self.server.peers) + self.assertIn(b'peer-2', self.server.peers) def test_remove_peer(self): with self.assertRaises(KeyError):