Skip to content

Commit

Permalink
Switch to the new storage implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ZhukovAlexander committed Nov 6, 2016
1 parent 7103aa4 commit f35b1c5
Show file tree
Hide file tree
Showing 6 changed files with 52 additions and 90 deletions.
37 changes: 20 additions & 17 deletions rafter/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,17 @@
import os
import json
from collections import defaultdict
from uuid import uuid4

import uvloop


from . import serverstate
from . import storage as storage
from . import storage as store
from . import models
from .network import UPDProtocolMsgPackServer, make_socket, ResetablePeriodicTask
from .exceptions import NotLeaderException
from .utils import AsyncDictWrapper


# <http://stackoverflow.com/a/14058475/2183102>
Expand All @@ -41,23 +43,24 @@ def __init__(self,
address=('0.0.0.0', 10000),
log=None,
storage=None,
invocations=None,
loop=None,
server_protocol=UPDProtocolMsgPackServer,
config=None,
bootstrap=False):

self.host, self.port = address

self.log = log if log is not None else storage.RaftLog()
self.storage = storage or storage.Storage()
self.log = log if log is not None else store.RaftLog()
self.storage = storage or store.PersistentDict()

self.invocations = storage.AsyncDictWrapper(invocations or {})
a_wrapper = AsyncDictWrapper(storage or {})

self.wait_for, self.set_result = a_wrapper.wait_for, a_wrapper.set

self.bootstrap = bootstrap

if self.bootstrap:
self.storage.peers = {self.id: {'id': self.id}}
self.storage['peers'] = {self.id: {'id': self.id}}

self.match_index = defaultdict(lambda: self.log.commit_index)
self.next_index = defaultdict(lambda: self.log.commit_index + 1)
Expand Down Expand Up @@ -88,27 +91,27 @@ def new_heartbeat(bootstraps=bootstraps):

@property
def id(self): # pragma: nocover
return self.storage.id
return self.storage.setdefault('id', uuid4().hex)

@property
def term(self): # pragma: nocover
return self.storage.term
return self.storage.setdefault('term', 0)

@term.setter
def term(self, value): # pragma: nocover
self.storage.term = value
self.storage['term'] = value

@property
def voted_for(self): # pragma: nocover
return self.storage.voted_for
return self.storage.get('voted_for')

@voted_for.setter
def voted_for(self, value): # pragma: nocover
self.storage.voted_for = value
self.storage['voted_for'] = value

@property
def peers(self): # pragma: nocover
return self.storage.peers
return self.storage.setdefault('peers', {})

def start(self):

Expand Down Expand Up @@ -146,7 +149,7 @@ async def _apply_single(self, cmd, invocation_id, args, kwargs, index=None):
finally:
if invocation_id is None:
return res['result']
await self.invocations.set(invocation_id, res)
await self.set_result(invocation_id, res)

def apply_commited(self, start, end):
return asyncio.ensure_future(asyncio.wait(map(
Expand Down Expand Up @@ -216,12 +219,12 @@ def broadcast_request_vote(self):

def add_peer(self, peer):
# <http://stackoverflow.com/a/26853961/2183102>
self.storage.peers = {**self.storage.peers, **{peer['id']: peer}}
self.storage['peers'] = {**self.storage['peers'], **{peer['id']: peer}}

def remove_peer(self, peer_id):
peers = self.storage.peers
peers = self.storage['peers']
del peers[peer_id]
self.storage.peers = peers
self.storage['peers'] = peers

def list_peers(self):
return list(self.storage.peers)
return list(self.peers)
45 changes: 3 additions & 42 deletions rafter/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,6 @@ def from_bytes(b):
return int.from_bytes(b, sys.byteorder)


class MetaDataField:
def __init__(self, key, from_raw=lambda x: x, to_raw=lambda x: x, default=None):
self._key = key
self._default = default
self.from_raw = from_raw
self.to_raw = to_raw

def __get__(self, instance, owner):
if instance is not None:
with instance.env.begin(db=instance.attrs_store) as txn:
val = txn.get(self._key)
if val is None:
self.__set__(instance, self._default)
return self._default
return self.from_raw(val)
return self

def __set__(self, instance, value):
if instance is not None:
with instance.env.begin(write=True, db=instance.attrs_store) as txn:
txn.replace(self._key, self.to_raw(value))


# TODO: implement dynamic serizlizer
class RaftLog(collections.abc.MutableSequence):
"""Implement raft log on top of the LMDB storage."""
Expand Down Expand Up @@ -114,35 +91,19 @@ def entry(self, term, command, args=(), kwargs=None):
self.append(LogEntry(dict(index=len(self), term=term, command=command, args=args, kwargs=kwargs or {})))
return self[-1]

commit_index = MetaDataField(b'commit_index', from_raw=int, to_raw=lambda x: str(x).encode(), default=0)


class Storage:

def __init__(self, env=None, db=None):

self.env = env or lmdb.open('/tmp/rafter.lmdb', max_dbs=10)
self.attrs_store = self.env.open_db(b'store')

term = MetaDataField(b'term', from_raw=int, to_raw=lambda x: str(x).encode(), default=0)
voted_for = MetaDataField(b'voted_for', from_raw=lambda x: x.decode(), default='')
peers = MetaDataField(b'peers', from_raw=lambda x: json.loads(x.decode()), to_raw=lambda x: json.dumps(x).encode(), default={})
id = MetaDataField(b'id', to_raw=lambda x: x.encode(), from_raw=lambda x: x.decode(), default=uuid4().hex)


class PersistentDict(collections.abc.MutableMapping):

def __init__(self, env, db, *args, **kwargs):
def __init__(self, env=None, db=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.env = env
self.db = db
self.env = env or lmdb.open('/tmp/rafter.lmdb', max_dbs=10)
self.db = db or self.env.open_db(b'storage')

def __setitem__(self, key, value):
with self.env.begin(write=True, db=self.db) as txn:
txn.replace(packb(key), packb(value))

def __getitem__(self, key):
val = None
with self.env.begin(db=self.db) as txn:
val = txn.get(packb(key))
if val is None:
Expand Down
24 changes: 24 additions & 0 deletions rafter/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import collections
import asyncio


class AsyncDictWrapper:

def __init__(self, d, loop=None):
self._d = d
self._loop = loop or asyncio.get_event_loop()
self._waiters = collections.defaultdict(lambda: asyncio.Condition(loop=self._loop))

async def wait_for(self, key):
try:
return self._d[key]
except KeyError:
c = self._waiters[key]
async with c:
return await c.wait_for(lambda: self._d.get(key))

async def set(self, key, value):
c = self._waiters[key]
async with c:
self._d[key] = value
c.notify_all()
6 changes: 2 additions & 4 deletions tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,8 @@ def entry(self, term, command, args, kwargs):
return self[-1]


class Storage(mock.Mock):
term = 0
id = 'testserver'
peers = {}
class Storage(dict):
pass

async def foo(*args, **kwargs):
return 'result'
Expand Down
27 changes: 1 addition & 26 deletions tests/test_log.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import lmdb

from rafter.models import LogEntry
from rafter.storage import RaftLog, Storage, MetaDataField
from rafter.storage import RaftLog


class LMDBLogTest(unittest.TestCase):
Expand All @@ -16,11 +16,6 @@ def setUp(self):
def tearDown(self):
shutil.rmtree(self.db_dir)

def test_commit_index(self):
self.assertEqual(self.log.commit_index, 0)
self.log.commit_index += 1
self.assertEqual(self.log.commit_index, 1)

def test_setitem(self):
entry = LogEntry()
self.log[0] = entry
Expand Down Expand Up @@ -94,23 +89,3 @@ def test_cmp(self):
self.assertFalse(self.log.cmp(entry.index - 1, entry.term - 1))
self.assertFalse(self.log.cmp(entry.index + 1, entry.term - 1))
self.assertTrue(self.log.cmp(entry.index - 1, entry.term + 1))


class StorageTest(unittest.TestCase):

def get_storage(self):
return Storage(env=lmdb.open(self.db_dir, max_dbs=10))

def setUp(self):
self.db_dir = tempfile.mkdtemp()
self.storage = self.get_storage()

def tearDown(self):
shutil.rmtree(self.db_dir)

def test_medadata_attribute(self):
self.assertIsInstance(Storage.id, MetaDataField)

def test_defaults(self):
self.assertEqual(self.storage.term, 0)
self.assertEqual(self.storage.id, self.get_storage().id)
3 changes: 2 additions & 1 deletion tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def test_start_stop(self):
def test_initial_heartbeat_calls_add_peer(self):
with mock.patch('rafter.server.asyncio.ensure_future') as ensure_future:
self.server.heartbeat(bootstraps=True)

ensure_future.assert_called_with(self.server.service.add_peer())

def test_heartbeat_should_schedule_ae(self):
Expand Down Expand Up @@ -83,7 +84,7 @@ def test_broadcast_request_vote(self):

def test_add_peer(self):
self.server.add_peer({'id': 'peer-2'})
self.assertIn('peer-2', self.server.peers)
self.assertIn(b'peer-2', self.server.peers)

def test_remove_peer(self):
with self.assertRaises(KeyError):
Expand Down

0 comments on commit f35b1c5

Please sign in to comment.