From 2757680710a47023fba07d24d74356e2a2ab42da Mon Sep 17 00:00:00 2001 From: AN Long Date: Sun, 18 Sep 2022 21:16:18 +0800 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20add=20ascii=20protoco=20authenticat?= =?UTF-8?q?ion=20support?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- memcache/async_memcache.py | 38 ++++++++++++++++++++++++++++++++++---- memcache/memcache.py | 38 ++++++++++++++++++++++++++++++++++---- memcache/meta_command.py | 6 +++++- 3 files changed, 73 insertions(+), 9 deletions(-) diff --git a/memcache/async_memcache.py b/memcache/async_memcache.py index 2a87cec..37f026f 100644 --- a/memcache/async_memcache.py +++ b/memcache/async_memcache.py @@ -16,19 +16,39 @@ def __init__( addr: Tuple[str, int], *, load_func: LoadFunc = load, - dump_func: DumpFunc = dump + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, ): self._addr = addr self._load = load_func self._dump = dump_func + self._username = username + self._password = password self._connected = False async def _connect(self) -> None: self.reader, self.writer = await asyncio.open_connection( self._addr[0], self._addr[1] ) + await self._auth() self._connected = True + async def _auth(self) -> None: + if self._username is None or self._password is None: + return + auth_data = b"%s %s" % ( + self._username.encode("utf-8"), + self._password.encode("utf-8"), + ) + self.writer.write(b"set auth x 0 %d\r\n" % len(auth_data)) + self.writer.write(auth_data) + self.writer.write(b"\r\n") + await self.writer.drain() + response = await self.reader.readline() + if response != b"STORED\r\n": + raise MemcacheError(response.rstrip(b"\r\n")) + async def flush_all(self) -> None: if not self._connected: await self._connect() @@ -140,7 +160,9 @@ def __init__( pool_size: Optional[int] = 23, pool_timeout: Optional[int] = 1, load_func: LoadFunc = load, - dump_func: DumpFunc = dump + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, ): addr = addr or ("localhost", 11211) if isinstance(addr, list): @@ -148,7 +170,11 @@ def __init__( nodes: List[AsyncPool] = [] for addr in addrs: create_connection = lambda: AsyncConnection( - addr, load_func=load_func, dump_func=dump_func + addr, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, ) nodes.append( AsyncPool( @@ -159,7 +185,11 @@ def __init__( elif isinstance(addr, tuple): a: Addr = addr create_connection = lambda: AsyncConnection( - a, load_func=load_func, dump_func=dump_func + a, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, ) self._connections = hashring.HashRing( [AsyncPool(create_connection, max_size=pool_size, timeout=pool_timeout)] diff --git a/memcache/memcache.py b/memcache/memcache.py index 1af95be..827da1c 100644 --- a/memcache/memcache.py +++ b/memcache/memcache.py @@ -20,17 +20,37 @@ def __init__( addr: Tuple[str, int], *, load_func: LoadFunc = load, - dump_func: DumpFunc = dump + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, ): self._addr = addr self._load = load_func self._dump = dump_func + self._username = username + self._password = password self._connect() def _connect(self) -> None: self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) self.socket.connect(self._addr) self.stream = self.socket.makefile(mode="rwb") + self._auth() + + def _auth(self) -> None: + if self._username is None or self._password is None: + return + auth_data = b"%s %s" % ( + self._username.encode("utf-8"), + self._password.encode("utf-8"), + ) + self.stream.write(b"set auth x 0 %d\r\n" % len(auth_data)) + self.stream.write(auth_data) + self.stream.write(b"\r\n") + self.stream.flush() + response = self.stream.readline() + if response != b"STORED\r\n": + raise MemcacheError(response.rstrip(NEWLINE)) def close(self) -> None: self.stream.close() @@ -143,7 +163,9 @@ def __init__( pool_size: Optional[int] = 23, pool_timeout: Optional[int] = 1, load_func: LoadFunc = load, - dump_func: DumpFunc = dump + dump_func: DumpFunc = dump, + username: Optional[str] = None, + password: Optional[str] = None, ): addr = addr or ("localhost", 11211) if isinstance(addr, list): @@ -151,7 +173,11 @@ def __init__( nodes: List[Pool] = [] for addr in addrs: create_connection = lambda: Connection( - addr, load_func=load_func, dump_func=dump_func + addr, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, ) nodes.append( Pool(create_connection, max_size=pool_size, timeout=pool_timeout) @@ -160,7 +186,11 @@ def __init__( elif isinstance(addr, tuple): a: Addr = addr create_connection = lambda: Connection( - a, load_func=load_func, dump_func=dump_func + a, + load_func=load_func, + dump_func=dump_func, + username=username, + password=password, ) self._connections = hashring.HashRing( [Pool(create_connection, max_size=pool_size, timeout=pool_timeout)] diff --git a/memcache/meta_command.py b/memcache/meta_command.py index 2fdbf4e..b1d2c92 100644 --- a/memcache/meta_command.py +++ b/memcache/meta_command.py @@ -51,7 +51,11 @@ def load_header(line: bytes) -> "MetaResult": rc = parts[0] if rc == b"CLIENT_ERROR": # Old ascii protocol error. - raise MemcacheError(line.removeprefix(b"CLIENT_ERROR ").removesuffix(b"\r\n").decode("utf-8")) + raise MemcacheError( + line.lstrip(b"CLIENT_ERROR ") + .rstrip() + .decode("utf-8") + ) flags = [] datalen = None