From 330f3dc8ad47b5c3677a85d14581c6c77006d47c Mon Sep 17 00:00:00 2001 From: Chris Pacia Date: Thu, 10 Dec 2015 11:56:34 -0500 Subject: [PATCH] Refactor transferKeyValues --- dht/protocol.py | 34 +++++++++++++++++++++++++++++----- dht/tests/test_protocol.py | 24 ++++++++++++++++-------- 2 files changed, 45 insertions(+), 13 deletions(-) diff --git a/dht/protocol.py b/dht/protocol.py index d8647e03..a8bdd83b 100644 --- a/dht/protocol.py +++ b/dht/protocol.py @@ -5,7 +5,7 @@ import random -from twisted.internet import defer, reactor +from twisted.internet import reactor from zope.interface import implements import nacl.signing @@ -118,7 +118,7 @@ def rpc_inv(self, sender, serlialized_invs): try: i = objects.Inv() i.ParseFromString(inv) - if not self.storage.exists(i.keyword, i.valueKey): + if self.storage.getSpecific(i.keyword, i.valueKey) is None: ret.append(inv) except Exception: pass @@ -183,7 +183,26 @@ def transferKeyValues(self, node): is closer than the closest in that list, then store the key/value on the new node (per section 2.5 of the paper) """ - ds = [] + def send_values(inv_list): + values = [] + for requested_inv in inv_list: + try: + i = objects.Inv() + i.ParseFromString(requested_inv) + value = self.storage.getSpecific(i.keyword, i.valueKey) + if value is not None: + v = objects.Value() + v.keyword = i.keyword + v.valueKey = i.valueKey + v.serializedData = value + v.ttl = self.storage.get_ttl(i.keyword, i.valueKey) + values.append(v.SerializeToString()) + except Exception: + pass + if len(values) > 0: + self.callValues(node, values) + + inv = [] for keyword in self.storage.iterkeys(): keynode = Node(keyword) neighbors = self.router.findNeighbors(keynode, exclude=node) @@ -193,9 +212,14 @@ def transferKeyValues(self, node): if len(neighbors) == 0 \ or (newNodeClose and thisNodeClosest) \ or (thisNodeClosest and len(neighbors) < self.ksize): + # pylint: disable=W0612 for k, v in self.storage.iteritems(keyword): - ds.append(self.callStore(node, keyword, k, v, self.storage.get_ttl(keyword, k))) - return defer.gatherResults(ds) + i = objects.Inv() + i.keyword = keyword + i.valueKey = k + inv.append(i.SerializeToString()) + if len(inv) > 0: + self.callInv(node, inv).addCallback(send_values) def handleCallResponse(self, result, node): """ diff --git a/dht/tests/test_protocol.py b/dht/tests/test_protocol.py index 2746baea..1c578237 100644 --- a/dht/tests/test_protocol.py +++ b/dht/tests/test_protocol.py @@ -489,6 +489,9 @@ def test_transferKeyValues(self): self.protocol.storage[digest("keyword")] = ( digest("key"), self.protocol.sourceNode.getProto().SerializeToString(), 10) + self.protocol.storage[digest("keyword")] = ( + digest("key2"), self.protocol.sourceNode.getProto().SerializeToString(), 10) + self.protocol.transferKeyValues(Node(digest("id"), self.addr1[0], self.addr1[1])) self.clock.advance(1) @@ -498,19 +501,24 @@ def test_transferKeyValues(self): x = message.Message() x.ParseFromString(sent_message) + i = objects.Inv() + i.keyword = digest("keyword") + i.valueKey = digest("key") + + i2 = objects.Inv() + i2.keyword = digest("keyword") + i2.valueKey = digest("key2") + m = message.Message() m.sender.MergeFrom(self.protocol.sourceNode.getProto()) - m.command = message.Command.Value("STORE") + m.command = message.Command.Value("INV") m.protoVer = self.version - m.arguments.append(digest("keyword")) - m.arguments.append(digest("key")) - m.arguments.append(self.protocol.sourceNode.getProto().SerializeToString()) - m.arguments.append(str(10)) + m.arguments.append(i.SerializeToString()) + m.arguments.append(i2.SerializeToString()) self.assertEqual(x.sender, m.sender) self.assertEqual(x.command, m.command) - self.assertEqual(x.arguments[0], m.arguments[0]) - self.assertEqual(x.arguments[1], m.arguments[1]) - self.assertEqual(x.arguments[2], m.arguments[2]) + self.assertTrue(x.arguments[0] in m.arguments) + self.assertTrue(x.arguments[1] in m.arguments) def test_refreshIDs(self): node1 = Node(digest("id1"), "127.0.0.1", 12345, signed_pubkey=digest("key1"))