Showing with 367 additions and 19 deletions.
  1. +49 −19 txyam/client.py
  2. +318 −0 txyam/tests/test_client.py
68 changes: 49 additions & 19 deletions txyam/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from twisted.internet.defer import inlineCallbacks, DeferredList, returnValue
from twisted.internet import reactor
from twisted.python import log
from twisted.python import failure, log

from txyam.utils import ketama, deferredDict
from txyam.factory import MemCacheClientFactory
Expand All @@ -25,13 +25,33 @@ def _wrap(cmd):
"""
Used to wrap all of the memcache methods (get,set,getMultiple,etc).
"""
def unwrap(result):
return result[None]

def wrapper(self, key, *args, **kwargs):
func = getattr(self.getClient(key), cmd)
return func(key, *args, **kwargs)
client = self.getClient(key)
request = {client: (None, cmd, (key,) + args, kwargs)}
return self._issueRequest(request).addCallback(unwrap)
return wrapper


def _issueRequest(request):
"""
Issue a named request to some clients.
This is primarily for testing purposes, so that wrappers can build up a
request and the tests can inspect the request.
"""
ret = {}
for client, (resultKey, method, args, kwargs) in request.iteritems():
method = getattr(client, method)
ret[resultKey] = method(*args, **kwargs)
return deferredDict(ret)


class YamClient:
_issueRequest = staticmethod(_issueRequest)

def __init__(self, hosts, connect=True):
"""
@param hosts: A C{list} of C{tuple}s containing hosts and ports.
Expand Down Expand Up @@ -80,27 +100,37 @@ def disconnect(self):
connection.transport.loseConnection()

def flushAll(self):
hosts = self.getActiveConnections()
log.msg("Flushing %i hosts" % len(hosts))
return DeferredList([host.flushAll() for host in hosts])
request = {}
for e, client in enumerate(self.getActiveConnections()):
request[client] = e, 'flushAll', (), {}
log.msg("Flushing %i hosts" % len(request))

def unwrap(result):
result = result.items()
result.sort()
return [(not isinstance(b, failure.Failure), b) for a, b in result]

return self._issueRequest(request).addCallback(unwrap)

def stats(self, arg=None):
ds = {}
def stats(self):
request = {}
for factory in self.factories:
if not factory.client is None:
hp = "%s:%i" % (factory.addr.host, factory.addr.port)
ds[hp] = factory.client.stats(arg)
log.msg("Getting stats on %i hosts" % len(ds))
return deferredDict(ds)
if factory.client is None:
continue
hp = "%s:%i" % (factory.addr.host, factory.addr.port)
request[factory.client] = hp, 'stats', (), {}
log.msg("Getting stats on %i hosts" % len(request))
return self._issueRequest(request)

def version(self):
ds = {}
request = {}
for factory in self.factories:
if not factory.client is None:
hp = "%s:%i" % (factory.addr.host, factory.addr.port)
ds[hp] = factory.client.version()
log.msg("Getting version on %i hosts" % len(ds))
return deferredDict(ds)
if factory.client is None:
continue
hp = "%s:%i" % (factory.addr.host, factory.addr.port)
request[factory.client] = hp, 'version', (), {}
log.msg("Getting version on %i hosts" % len(request))
return self._issueRequest(request)

def pickle(self, value, compress):
p = cPickle.dumps(value, cPickle.HIGHEST_PROTOCOL)
Expand Down
Loading