Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 34 additions & 4 deletions memcache/async_memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,19 +16,39 @@ def __init__(
addr: Tuple[str, int],
*,
load_func: LoadFunc = load,
dump_func: DumpFunc = dump
dump_func: DumpFunc = dump,
username: Optional[str] = None,
password: Optional[str] = None,
):
self._addr = addr
self._load = load_func
self._dump = dump_func
self._username = username
self._password = password
self._connected = False

async def _connect(self) -> None:
self.reader, self.writer = await asyncio.open_connection(
self._addr[0], self._addr[1]
)
await self._auth()
self._connected = True

async def _auth(self) -> None:
if self._username is None or self._password is None:
return
auth_data = b"%s %s" % (
self._username.encode("utf-8"),
self._password.encode("utf-8"),
)
self.writer.write(b"set auth x 0 %d\r\n" % len(auth_data))
self.writer.write(auth_data)
self.writer.write(b"\r\n")
await self.writer.drain()
response = await self.reader.readline()
if response != b"STORED\r\n":
raise MemcacheError(response.rstrip(b"\r\n"))

async def flush_all(self) -> None:
if not self._connected:
await self._connect()
Expand Down Expand Up @@ -140,15 +160,21 @@ def __init__(
pool_size: Optional[int] = 23,
pool_timeout: Optional[int] = 1,
load_func: LoadFunc = load,
dump_func: DumpFunc = dump
dump_func: DumpFunc = dump,
username: Optional[str] = None,
password: Optional[str] = None,
):
addr = addr or ("localhost", 11211)
if isinstance(addr, list):
addrs: List[Addr] = addr
nodes: List[AsyncPool] = []
for addr in addrs:
create_connection = lambda: AsyncConnection(
addr, load_func=load_func, dump_func=dump_func
addr,
load_func=load_func,
dump_func=dump_func,
username=username,
password=password,
)
nodes.append(
AsyncPool(
Expand All @@ -159,7 +185,11 @@ def __init__(
elif isinstance(addr, tuple):
a: Addr = addr
create_connection = lambda: AsyncConnection(
a, load_func=load_func, dump_func=dump_func
a,
load_func=load_func,
dump_func=dump_func,
username=username,
password=password,
)
self._connections = hashring.HashRing(
[AsyncPool(create_connection, max_size=pool_size, timeout=pool_timeout)]
Expand Down
38 changes: 34 additions & 4 deletions memcache/memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,37 @@ def __init__(
addr: Tuple[str, int],
*,
load_func: LoadFunc = load,
dump_func: DumpFunc = dump
dump_func: DumpFunc = dump,
username: Optional[str] = None,
password: Optional[str] = None,
):
self._addr = addr
self._load = load_func
self._dump = dump_func
self._username = username
self._password = password
self._connect()

def _connect(self) -> None:
self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
self.socket.connect(self._addr)
self.stream = self.socket.makefile(mode="rwb")
self._auth()

def _auth(self) -> None:
if self._username is None or self._password is None:
return
auth_data = b"%s %s" % (
self._username.encode("utf-8"),
self._password.encode("utf-8"),
)
self.stream.write(b"set auth x 0 %d\r\n" % len(auth_data))
self.stream.write(auth_data)
self.stream.write(b"\r\n")
self.stream.flush()
response = self.stream.readline()
if response != b"STORED\r\n":
raise MemcacheError(response.rstrip(NEWLINE))

def close(self) -> None:
self.stream.close()
Expand Down Expand Up @@ -143,15 +163,21 @@ def __init__(
pool_size: Optional[int] = 23,
pool_timeout: Optional[int] = 1,
load_func: LoadFunc = load,
dump_func: DumpFunc = dump
dump_func: DumpFunc = dump,
username: Optional[str] = None,
password: Optional[str] = None,
):
addr = addr or ("localhost", 11211)
if isinstance(addr, list):
addrs: List[Addr] = addr
nodes: List[Pool] = []
for addr in addrs:
create_connection = lambda: Connection(
addr, load_func=load_func, dump_func=dump_func
addr,
load_func=load_func,
dump_func=dump_func,
username=username,
password=password,
)
nodes.append(
Pool(create_connection, max_size=pool_size, timeout=pool_timeout)
Expand All @@ -160,7 +186,11 @@ def __init__(
elif isinstance(addr, tuple):
a: Addr = addr
create_connection = lambda: Connection(
a, load_func=load_func, dump_func=dump_func
a,
load_func=load_func,
dump_func=dump_func,
username=username,
password=password,
)
self._connections = hashring.HashRing(
[Pool(create_connection, max_size=pool_size, timeout=pool_timeout)]
Expand Down
6 changes: 5 additions & 1 deletion memcache/meta_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,11 @@ def load_header(line: bytes) -> "MetaResult":
rc = parts[0]
if rc == b"CLIENT_ERROR":
# Old ascii protocol error.
raise MemcacheError(line.removeprefix(b"CLIENT_ERROR ").removesuffix(b"\r\n").decode("utf-8"))
raise MemcacheError(
line.lstrip(b"CLIENT_ERROR ")
.rstrip()
.decode("utf-8")
)

flags = []
datalen = None
Expand Down