diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 7295949c2..000000000 --- a/.travis.yml +++ /dev/null @@ -1,105 +0,0 @@ -dist: xenial -language: python - -addons: - apt: - packages: - - socat - -env: - global: - INSTALL_DIR: >- - $HOME/redis - PYTEST_ADDOPTS: >- - "-n $(( $(nproc) * 2 ))" - -python: -- "3.6" -- "3.7" -- "3.8" -- "nightly" -- "pypy3.6-7.1.1" -- "pypy3.6-7.2.0" - -stages: -- lint -- test -- examples - -jobs: - allow_failures: - - python: "nightly" - - python: "pypy3.6-7.2.0" - include: - # Add two extra tests with uvloop - - &UVLOOP - env: UVLOOP="y" - python: "3.6" - - <<: *UVLOOP - python: "3.7" - # Lint and spell-check docs - - stage: lint - cache: false - name: documentation spell check - python: "3.7" - addons: - apt: - packages: - - enchant - - aspell - - aspell-en - install: - - pip install -r docs/requirements.txt - - pip install -e. -c tests/requirements.txt - script: make spelling - - &FLAKE - name: flake - cache: false - python: "3.6" - install: - - pip install -r tests/requirements.txt - script: - - make flake - - <<: *FLAKE - python: "3.7" - - &MYPY - name: mypy - cache: false - python: "3.6" - install: - - pip install -r tests/requirements.txt -r tests/requirements-mypy.txt - script: - - make mypy - - <<: *MYPY - python: "3.7" - # Run examples - - &EXAMPLES - stage: examples - cache: false - python: 3.7 - install: - - pip install -e. -c tests/requirements.txt - script: - - make examples - -install: -- make ci-prune-old-redis -- make -j ci-build-redis -- | - if [ "$UVLOOP" = "y" ]; then - export PYTEST_ADDOPTS="$PYTEST_ADDOPTS --uvloop" - pip install uvloop - fi; -- pip install codecov -- pip install -r tests/requirements.txt -- pip install -e. -c tests/requirements.txt - -script: -- make ci-test - -cache: - directories: - - $HOME/redis/ - -after_script: -- codecov -f coverage.xml diff --git a/Makefile b/Makefile index e22666f3b..adeeb662c 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ -PYTHON ?= python3 + PYTHON ?= python3 PYTEST ?= pytest MYPY ?= mypy -REDIS_TAGS ?= 2.6.17 2.8.22 3.0.7 3.2.13 4.0.14 5.0.9 +REDIS_TAGS ?= 2.6.17 2.8.22 3.0.7 3.2.13 4.0.14 5.0.9 6.0.10 ARCHIVE_URL = https://github.com/antirez/redis/archive INSTALL_DIR ?= build @@ -31,7 +31,7 @@ mypy: $(MYPY) aioredis --ignore-missing-imports test: - $(PYTEST) + $(PYTEST) --timeout=60 cov coverage: $(PYTEST) --cov @@ -63,11 +63,7 @@ aioredis.egg-info: pip install -Ue . -ifdef TRAVIS -examples: .start-redis $(EXAMPLES) -else examples: $(EXAMPLES) -endif $(EXAMPLES): @export REDIS_VERSION="$(redis-cli INFO SERVER | sed -n 2p)" @@ -88,11 +84,11 @@ certificate: ci-test: $(REDIS_TARGETS) $(PYTEST) \ - --cov --cov-report=xml -vvvs\ - $(foreach T,$(REDIS_TARGETS),--redis-server=$T) + --timeout=60 --cov --cov-report=xml -vvvs\ + $(foreach $(REDIS_TARGETS)) ci-test-%: $(INSTALL_DIR)/%/redis-server - $(PYTEST) --cov --redis-server=$< + $(PYTEST) --cov ci-build-redis: $(REDIS_TARGETS) diff --git a/aioredis/__init__.py b/aioredis/__init__.py index af087c938..12da3e8e9 100644 --- a/aioredis/__init__.py +++ b/aioredis/__init__.py @@ -1,61 +1,59 @@ -from .commands import GeoMember, GeoPoint, Redis, create_redis, create_redis_pool -from .connection import RedisConnection, create_connection -from .errors import ( - AuthError, - ChannelClosedError, - ConnectionClosedError, - ConnectionForcedCloseError, - MasterNotFoundError, - MasterReplyError, - MaxClientsError, - MultiExecError, - PipelineError, - PoolClosedError, - ProtocolError, +from aioredis.client import Redis, StrictRedis +from aioredis.connection import ( + BlockingConnectionPool, + Connection, + ConnectionPool, + SSLConnection, + UnixDomainSocketConnection, +) +from aioredis.exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + InvalidResponse, + PubSubError, ReadOnlyError, RedisError, - ReplyError, - SlaveNotFoundError, - SlaveReplyError, - WatchVariableError, + ResponseError, + TimeoutError, + WatchError, ) -from .pool import ConnectionsPool, create_pool -from .pubsub import Channel -from .sentinel import RedisSentinel, create_sentinel +from aioredis.utils import from_url + + +def int_or_str(value): + try: + return int(value) + except ValueError: + return value -__version__ = "1.3.1" + +__version__ = "3.5.3" +VERSION = tuple(map(int_or_str, __version__.split("."))) __all__ = [ - # Factories - "create_connection", - "create_pool", - "create_redis", - "create_redis_pool", - "create_sentinel", - # Classes - "RedisConnection", - "ConnectionsPool", + "AuthenticationError", + "AuthenticationWrongNumberOfArgsError", + "BlockingConnectionPool", + "BusyLoadingError", + "ChildDeadlockedError", + "Connection", + "ConnectionError", + "ConnectionPool", + "DataError", + "from_url", + "InvalidResponse", + "PubSubError", + "ReadOnlyError", "Redis", - "GeoPoint", - "GeoMember", - "Channel", - "RedisSentinel", - # Errors "RedisError", - "ReplyError", - "MaxClientsError", - "AuthError", - "ProtocolError", - "PipelineError", - "MultiExecError", - "WatchVariableError", - "ConnectionClosedError", - "ConnectionForcedCloseError", - "PoolClosedError", - "ChannelClosedError", - "MasterNotFoundError", - "SlaveNotFoundError", - "ReadOnlyError", - "MasterReplyError", - "SlaveReplyError", + "ResponseError", + "SSLConnection", + "StrictRedis", + "TimeoutError", + "UnixDomainSocketConnection", + "WatchError", ] diff --git a/aioredis/abc.py b/aioredis/abc.py deleted file mode 100644 index a6f3767ae..000000000 --- a/aioredis/abc.py +++ /dev/null @@ -1,152 +0,0 @@ -"""The module provides connection and connections pool interfaces. - -These are intended to be used for implementing custom connection managers. -""" -import abc - -__all__ = [ - "AbcConnection", - "AbcPool", - "AbcChannel", -] - - -class AbcConnection(abc.ABC): - """Abstract connection interface.""" - - @abc.abstractmethod - def execute(self, command, *args, **kwargs): - """Execute redis command.""" - - @abc.abstractmethod - def execute_pubsub(self, command, *args, **kwargs): - """Execute Redis (p)subscribe/(p)unsubscribe commands.""" - - @abc.abstractmethod - def close(self): - """Perform connection(s) close and resources cleanup.""" - - @abc.abstractmethod - async def wait_closed(self): - """ - Coroutine waiting until all resources are closed/released/cleaned up. - """ - - @property - @abc.abstractmethod - def closed(self): - """Flag indicating if connection is closing or already closed.""" - - @property - @abc.abstractmethod - def db(self): - """Current selected DB index.""" - - @property - @abc.abstractmethod - def encoding(self): - """Current set connection codec.""" - - @property - @abc.abstractmethod - def in_pubsub(self): - """Returns number of subscribed channels. - - Can be tested as bool indicating Pub/Sub mode state. - """ - - @property - @abc.abstractmethod - def pubsub_channels(self): - """Read-only channels dict.""" - - @property - @abc.abstractmethod - def pubsub_patterns(self): - """Read-only patterns dict.""" - - @property - @abc.abstractmethod - def address(self): - """Connection address.""" - - -class AbcPool(AbcConnection): - """Abstract connections pool interface. - - Inherited from AbcConnection so both have common interface - for executing Redis commands. - """ - - @abc.abstractmethod - def get_connection(self, command, args=()): - """ - Gets free connection from pool in a sync way. - - If no connection available — returns None. - """ - - @abc.abstractmethod - async def acquire(self, command=None, args=()): - """Acquires connection from pool.""" - - @abc.abstractmethod - def release(self, conn): - """Releases connection to pool. - - :param AbcConnection conn: Owned connection to be released. - """ - - @property - @abc.abstractmethod - def address(self): - """Connection address or None.""" - - -class AbcChannel(abc.ABC): - """Abstract Pub/Sub Channel interface.""" - - @property - @abc.abstractmethod - def name(self): - """Encoded channel name or pattern.""" - - @property - @abc.abstractmethod - def is_pattern(self): - """Boolean flag indicating if channel is pattern channel.""" - - @property - @abc.abstractmethod - def is_active(self): - """Flag indicating that channel has unreceived messages - and not marked as closed.""" - - @abc.abstractmethod - async def get(self): - """Wait and return new message. - - Will raise ``ChannelClosedError`` if channel is not active. - """ - - # wait_message is not required; details of implementation - # @abc.abstractmethod - # def wait_message(self): - # pass - - @abc.abstractmethod - def put_nowait(self, data): - """Send data to channel. - - Called by RedisConnection when new message received. - For pattern subscriptions data will be a tuple of - channel name and message itself. - """ - - @abc.abstractmethod - def close(self, exc=None): - """Marks Channel as closed, no more messages will be sent to it. - - Called by RedisConnection when channel is unsubscribed - or connection is closed. - """ diff --git a/aioredis/client.py b/aioredis/client.py new file mode 100644 index 000000000..904ef4d24 --- /dev/null +++ b/aioredis/client.py @@ -0,0 +1,4704 @@ +from __future__ import annotations + +import asyncio +import datetime +import hashlib +import inspect +import re +import threading +import time +import time as mod_time +import warnings +from itertools import chain +from typing import ( + Any, + AnyStr, + AsyncIterator, + Awaitable, + Callable, + Collection, + Dict, + Iterable, + List, + Mapping, + NoReturn, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypedDict, + TypeVar, + Union, +) + +from aioredis.connection import ( + Connection, + ConnectionPool, + EncodableT, + SSLConnection, + UnixDomainSocketConnection, +) +from aioredis.exceptions import ( + ConnectionError, + DataError, + ExecAbortError, + ModuleError, + NoScriptError, + PubSubError, + RedisError, + ResponseError, + TimeoutError, + WatchError, +) +from aioredis.lock import Lock +from aioredis.utils import safe_str, str_if_bytes + +SYM_EMPTY = b"" +EMPTY_RESPONSE = "EMPTY_RESPONSE" + + +def list_or_args(keys, args): + # returns a single new list combining keys and args + try: + iter(keys) + # a string or bytes instance can be iterated, but indicates + # keys wasn't passed as a list + if isinstance(keys, (bytes, str)): + keys = [keys] + else: + keys = list(keys) + except TypeError: + keys = [keys] + if args: + keys.extend(args) + return keys + + +def timestamp_to_datetime(response): + """Converts a unix timestamp to a Python datetime object""" + if not response: + return None + try: + response = int(response) + except ValueError: + return None + return datetime.datetime.fromtimestamp(response) + + +def string_keys_to_dict(key_string, callback): + return dict.fromkeys(key_string.split(), callback) + + +class CaseInsensitiveDict(dict): + """Case insensitive dict implementation. Assumes string keys only.""" + + def __init__(self, data): + for k, v in data.items(): + self[k.upper()] = v + + def __contains__(self, k): + return super().__contains__(k.upper()) + + def __delitem__(self, k): + super().__delitem__(k.upper()) + + def __getitem__(self, k): + return super().__getitem__(k.upper()) + + def get(self, k, default=None): + return super().get(k.upper(), default) + + def __setitem__(self, k, v): + super().__setitem__(k.upper(), v) + + def update(self, data): + data = CaseInsensitiveDict(data) + super().update(data) + + +def parse_debug_object(response): + """Parse the results of Redis's DEBUG OBJECT command into a Python dict""" + # The 'type' of the object is the first item in the response, but isn't + # prefixed with a name + response = str_if_bytes(response) + response = "type:" + response + response = dict(kv.split(":") for kv in response.split()) + + # parse some expected int values from the string response + # note: this cmd isn't spec'd so these may not appear in all redis versions + int_fields = ("refcount", "serializedlength", "lru", "lru_seconds_idle") + for field in int_fields: + if field in response: + response[field] = int(response[field]) + + return response + + +def parse_object(response, infotype): + """Parse the results of an OBJECT command""" + if infotype in ("idletime", "refcount"): + return int_or_none(response) + return response + + +def parse_info(response): + """Parse the result of Redis's INFO command into a Python dict""" + info = {} + response = str_if_bytes(response) + + def get_value(value): + if "," not in value or "=" not in value: + try: + if "." in value: + return float(value) + else: + return int(value) + except ValueError: + return value + else: + sub_dict = {} + for item in value.split(","): + k, v = item.rsplit("=", 1) + sub_dict[k] = get_value(v) + return sub_dict + + for line in response.splitlines(): + if line and not line.startswith("#"): + if line.find(":") != -1: + # Split, the info fields keys and values. + # Note that the value may contain ':'. but the 'host:' + # pseudo-command is the only case where the key contains ':' + key, value = line.split(":", 1) + if key == "cmdstat_host": + key, value = line.rsplit(":", 1) + + if key == "module": + # Hardcode a list for key 'modules' since there could be + # multiple lines that started with 'module' + info.setdefault("modules", []).append(get_value(value)) + else: + info[key] = get_value(value) + else: + # if the line isn't splittable, append it to the "__raw__" key + info.setdefault("__raw__", []).append(line) + + return info + + +def parse_memory_stats(response, **kwargs): + """Parse the results of MEMORY STATS""" + stats = pairs_to_dict(response, decode_keys=True, decode_string_values=True) + for key, value in stats.items(): + if key.startswith("db."): + stats[key] = pairs_to_dict( + value, decode_keys=True, decode_string_values=True + ) + return stats + + +SENTINEL_STATE_TYPES = { + "can-failover-its-master": int, + "config-epoch": int, + "down-after-milliseconds": int, + "failover-timeout": int, + "info-refresh": int, + "last-hello-message": int, + "last-ok-ping-reply": int, + "last-ping-reply": int, + "last-ping-sent": int, + "master-link-down-time": int, + "master-port": int, + "num-other-sentinels": int, + "num-slaves": int, + "o-down-time": int, + "pending-commands": int, + "parallel-syncs": int, + "port": int, + "quorum": int, + "role-reported-time": int, + "s-down-time": int, + "slave-priority": int, + "slave-repl-offset": int, + "voted-leader-epoch": int, +} + + +def parse_sentinel_state(item): + result = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) + flags = set(result["flags"].split(",")) + for name, flag in ( + ("is_master", "master"), + ("is_slave", "slave"), + ("is_sdown", "s_down"), + ("is_odown", "o_down"), + ("is_sentinel", "sentinel"), + ("is_disconnected", "disconnected"), + ("is_master_down", "master_down"), + ): + result[name] = flag in flags + return result + + +def parse_sentinel_master(response): + return parse_sentinel_state(map(str_if_bytes, response)) + + +def parse_sentinel_masters(response): + result = {} + for item in response: + state = parse_sentinel_state(map(str_if_bytes, item)) + result[state["name"]] = state + return result + + +def parse_sentinel_slaves_and_sentinels(response): + return [parse_sentinel_state(map(str_if_bytes, item)) for item in response] + + +def parse_sentinel_get_master(response): + return response and (response[0], int(response[1])) or None + + +def pairs_to_dict(response, decode_keys=False, decode_string_values=False): + """Create a dict given a list of key/value pairs""" + if response is None: + return {} + if decode_keys or decode_string_values: + # the iter form is faster, but I don't know how to make that work + # with a str_if_bytes() map + keys = response[::2] + if decode_keys: + keys = map(str_if_bytes, keys) + values = response[1::2] + if decode_string_values: + values = map(str_if_bytes, values) + return dict(zip(keys, values)) + else: + it = iter(response) + return dict(zip(it, it)) + + +def pairs_to_dict_typed(response, type_info): + it = iter(response) + result = {} + for key, value in zip(it, it): + if key in type_info: + try: + value = type_info[key](value) + except Exception: + # if for some reason the value can't be coerced, just use + # the string value + pass + result[key] = value + return result + + +def zset_score_pairs(response, **options): + """ + If ``withscores`` is specified in the options, return the response as + a list of (value, score) pairs + """ + if not response or not options.get("withscores"): + return response + score_cast_func = options.get("score_cast_func", float) + it = iter(response) + return list(zip(it, map(score_cast_func, it))) + + +def sort_return_tuples(response, **options): + """ + If ``groups`` is specified, return the response as a list of + n-element tuples with n being the value found in options['groups'] + """ + if not response or not options.get("groups"): + return response + n = options["groups"] + return list(zip(*[response[i::n] for i in range(n)])) + + +def int_or_none(response): + if response is None: + return None + return int(response) + + +def parse_stream_list(response): + if response is None: + return None + data = [] + for r in response: + if r is not None: + data.append((r[0], pairs_to_dict(r[1]))) + else: + data.append((None, None)) + return data + + +def pairs_to_dict_with_str_keys(response): + return pairs_to_dict(response, decode_keys=True) + + +def parse_list_of_dicts(response): + return list(map(pairs_to_dict_with_str_keys, response)) + + +def parse_xclaim(response, **options): + if options.get("parse_justid", False): + return response + return parse_stream_list(response) + + +def parse_xinfo_stream(response): + data = pairs_to_dict(response, decode_keys=True) + first = data["first-entry"] + if first is not None: + data["first-entry"] = (first[0], pairs_to_dict(first[1])) + last = data["last-entry"] + if last is not None: + data["last-entry"] = (last[0], pairs_to_dict(last[1])) + return data + + +def parse_xread(response): + if response is None: + return [] + return [[r[0], parse_stream_list(r[1])] for r in response] + + +def parse_xpending(response, **options): + if options.get("parse_detail", False): + return parse_xpending_range(response) + consumers = [{"name": n, "pending": int(p)} for n, p in response[3] or []] + return { + "pending": response[0], + "min": response[1], + "max": response[2], + "consumers": consumers, + } + + +def parse_xpending_range(response): + k = ("message_id", "consumer", "time_since_delivered", "times_delivered") + return [dict(zip(k, r)) for r in response] + + +def float_or_none(response): + if response is None: + return None + return float(response) + + +def bool_ok(response): + return str_if_bytes(response) == "OK" + + +def parse_zadd(response, **options): + if response is None: + return None + if options.get("as_score"): + return float(response) + return int(response) + + +def parse_client_list(response, **options): + clients = [] + for c in str_if_bytes(response).splitlines(): + # Values might contain '=' + clients.append(dict(pair.split("=", 1) for pair in c.split(" "))) + return clients + + +def parse_config_get(response, **options): + response = [str_if_bytes(i) if i is not None else None for i in response] + return response and pairs_to_dict(response) or {} + + +def parse_scan(response, **options): + cursor, r = response + return int(cursor), r + + +def parse_hscan(response, **options): + cursor, r = response + return int(cursor), r and pairs_to_dict(r) or {} + + +def parse_zscan(response, **options): + score_cast_func = options.get("score_cast_func", float) + cursor, r = response + it = iter(r) + return int(cursor), list(zip(it, map(score_cast_func, it))) + + +def parse_slowlog_get(response, **options): + space = " " if options.get("decode_responses", False) else b" " + return [ + { + "id": item[0], + "start_time": int(item[1]), + "duration": int(item[2]), + "command": space.join(item[3]), + } + for item in response + ] + + +def parse_cluster_info(response, **options): + response = str_if_bytes(response) + return dict(line.split(":") for line in response.splitlines() if line) + + +def _parse_node_line(line): + line_items = line.split(" ") + node_id, addr, flags, master_id, ping, pong, epoch, connected = line.split(" ")[:8] + slots = [sl.split("-") for sl in line_items[8:]] + node_dict = { + "node_id": node_id, + "flags": flags, + "master_id": master_id, + "last_ping_sent": ping, + "last_pong_rcvd": pong, + "epoch": epoch, + "slots": slots, + "connected": True if connected == "connected" else False, + } + return addr, node_dict + + +def parse_cluster_nodes(response, **options): + raw_lines = str_if_bytes(response).splitlines() + return dict(_parse_node_line(line) for line in raw_lines) + + +def parse_georadius_generic(response, **options): + if options["store"] or options["store_dist"]: + # `store` and `store_diff` cant be combined + # with other command arguments. + return response + + if type(response) != list: + response_list = [response] + else: + response_list = response + + if not options["withdist"] and not options["withcoord"] and not options["withhash"]: + # just a bunch of places + return response_list + + cast = { + "withdist": float, + "withcoord": lambda ll: (float(ll[0]), float(ll[1])), + "withhash": int, + } + + # zip all output results with each casting functino to get + # the properly native Python value. + f = [lambda x: x] + f += [cast[o] for o in ["withdist", "withhash", "withcoord"] if options[o]] + return [list(map(lambda fv: fv[0](fv[1]), zip(f, r))) for r in response_list] + + +def parse_pubsub_numsub(response, **options): + return list(zip(response[0::2], response[1::2])) + + +def parse_client_kill(response, **options): + if isinstance(response, int): + return response + return str_if_bytes(response) == "OK" + + +def parse_acl_getuser(response, **options): + if response is None: + return None + data = pairs_to_dict(response, decode_keys=True) + + # convert everything but user-defined data in 'keys' to native strings + data["flags"] = list(map(str_if_bytes, data["flags"])) + data["passwords"] = list(map(str_if_bytes, data["passwords"])) + data["commands"] = str_if_bytes(data["commands"]) + + # split 'commands' into separate 'categories' and 'commands' lists + commands, categories = [], [] + for command in data["commands"].split(" "): + if "@" in command: + categories.append(command) + else: + commands.append(command) + + data["commands"] = commands + data["categories"] = categories + data["enabled"] = "on" in data["flags"] + return data + + +def parse_acl_log(response, **options): + if response is None: + return None + if isinstance(response, list): + data = [] + for log in response: + log_data = pairs_to_dict(log, True, True) + client_info = log_data.get("client-info", "") + log_data["client-info"] = parse_client_info(client_info) + + # float() is lossy comparing to the "double" in C + log_data["age-seconds"] = float(log_data["age-seconds"]) + data.append(log_data) + else: + data = bool_ok(response) + return data + + +def parse_client_info(value): + """ + Parsing client-info in ACL Log in following format. + "key1=value1 key2=value2 key3=value3" + """ + client_info = {} + infos = value.split(" ") + for info in infos: + key, value = info.split("=") + client_info[key] = value + + # Those fields are definded as int in networking.c + for int_key in { + "id", + "age", + "idle", + "db", + "sub", + "psub", + "multi", + "qbuf", + "qbuf-free", + "obl", + "oll", + "omem", + }: + client_info[int_key] = int(client_info[int_key]) + return client_info + + +def parse_module_result(response): + if isinstance(response, ModuleError): + raise response + return True + + +class ResponseCallbackProtocol(Protocol): + def __call__(self, response: Any, **kwargs): + ... + + +class AsyncResponseCallbackProtocol(Protocol): + async def __call__(self, response: Any, **kwargs): + ... + + +ResponseCallbackT = Union[ResponseCallbackProtocol, AsyncResponseCallbackProtocol] + + +_R = TypeVar("_R") + + +class Redis: + """ + Implementation of the Redis protocol. + + This abstract class provides a Python interface to all Redis commands + and an implementation of the Redis protocol. + + Connection and Pipeline derive from this, implementing how + the commands are sent and received to the Redis server + """ + + RESPONSE_CALLBACKS = { + **string_keys_to_dict( + "AUTH EXPIRE EXPIREAT HEXISTS HMSET MOVE MSETNX PERSIST " + "PSETEX RENAMENX SISMEMBER SMOVE SETEX SETNX", + bool, + ), + **string_keys_to_dict( + "BITCOUNT BITPOS DECRBY DEL EXISTS GEOADD GETBIT HDEL HLEN " + "HSTRLEN INCRBY LINSERT LLEN LPUSHX PFADD PFCOUNT RPUSHX SADD " + "SCARD SDIFFSTORE SETBIT SETRANGE SINTERSTORE SREM STRLEN " + "SUNIONSTORE UNLINK XACK XDEL XLEN XTRIM ZCARD ZLEXCOUNT ZREM " + "ZREMRANGEBYLEX ZREMRANGEBYRANK ZREMRANGEBYSCORE", + int, + ), + **string_keys_to_dict("INCRBYFLOAT HINCRBYFLOAT", float), + **string_keys_to_dict( + # these return OK, or int if redis-server is >=1.3.4 + "LPUSH RPUSH", + lambda r: isinstance(r, int) and r or str_if_bytes(r) == "OK", + ), + **string_keys_to_dict("SORT", sort_return_tuples), + **string_keys_to_dict("ZSCORE ZINCRBY GEODIST", float_or_none), + **string_keys_to_dict( + "FLUSHALL FLUSHDB LSET LTRIM MSET PFMERGE READONLY READWRITE " + "RENAME SAVE SELECT SHUTDOWN SLAVEOF SWAPDB WATCH UNWATCH ", + bool_ok, + ), + **string_keys_to_dict("BLPOP BRPOP", lambda r: r and tuple(r) or None), + **string_keys_to_dict( + "SDIFF SINTER SMEMBERS SUNION", lambda r: r and set(r) or set() + ), + **string_keys_to_dict( + "ZPOPMAX ZPOPMIN ZRANGE ZRANGEBYSCORE ZREVRANGE ZREVRANGEBYSCORE", + zset_score_pairs, + ), + **string_keys_to_dict( + "BZPOPMIN BZPOPMAX", lambda r: r and (r[0], r[1], float(r[2])) or None + ), + **string_keys_to_dict("ZRANK ZREVRANK", int_or_none), + **string_keys_to_dict("XREVRANGE XRANGE", parse_stream_list), + **string_keys_to_dict("XREAD XREADGROUP", parse_xread), + **string_keys_to_dict("BGREWRITEAOF BGSAVE", lambda r: True), + "ACL CAT": lambda r: list(map(str_if_bytes, r)), + "ACL DELUSER": int, + "ACL GENPASS": str_if_bytes, + "ACL GETUSER": parse_acl_getuser, + "ACL LIST": lambda r: list(map(str_if_bytes, r)), + "ACL LOAD": bool_ok, + "ACL LOG": parse_acl_log, + "ACL SAVE": bool_ok, + "ACL SETUSER": bool_ok, + "ACL USERS": lambda r: list(map(str_if_bytes, r)), + "ACL WHOAMI": str_if_bytes, + "CLIENT GETNAME": str_if_bytes, + "CLIENT ID": int, + "CLIENT KILL": parse_client_kill, + "CLIENT LIST": parse_client_list, + "CLIENT SETNAME": bool_ok, + "CLIENT UNBLOCK": lambda r: r and int(r) == 1 or False, + "CLIENT PAUSE": bool_ok, + "CLUSTER ADDSLOTS": bool_ok, + "CLUSTER COUNT-FAILURE-REPORTS": lambda x: int(x), + "CLUSTER COUNTKEYSINSLOT": lambda x: int(x), + "CLUSTER DELSLOTS": bool_ok, + "CLUSTER FAILOVER": bool_ok, + "CLUSTER FORGET": bool_ok, + "CLUSTER INFO": parse_cluster_info, + "CLUSTER KEYSLOT": lambda x: int(x), + "CLUSTER MEET": bool_ok, + "CLUSTER NODES": parse_cluster_nodes, + "CLUSTER REPLICATE": bool_ok, + "CLUSTER RESET": bool_ok, + "CLUSTER SAVECONFIG": bool_ok, + "CLUSTER SET-CONFIG-EPOCH": bool_ok, + "CLUSTER SETSLOT": bool_ok, + "CLUSTER SLAVES": parse_cluster_nodes, + "CONFIG GET": parse_config_get, + "CONFIG RESETSTAT": bool_ok, + "CONFIG SET": bool_ok, + "DEBUG OBJECT": parse_debug_object, + "GEOHASH": lambda r: list(map(str_if_bytes, r)), + "GEOPOS": lambda r: list( + map(lambda ll: (float(ll[0]), float(ll[1])) if ll is not None else None, r) + ), + "GEORADIUS": parse_georadius_generic, + "GEORADIUSBYMEMBER": parse_georadius_generic, + "HGETALL": lambda r: r and pairs_to_dict(r) or {}, + "HSCAN": parse_hscan, + "INFO": parse_info, + "LASTSAVE": timestamp_to_datetime, + "MEMORY PURGE": bool_ok, + "MEMORY STATS": parse_memory_stats, + "MEMORY USAGE": int_or_none, + "MODULE LOAD": parse_module_result, + "MODULE UNLOAD": parse_module_result, + "MODULE LIST": lambda r: [pairs_to_dict(m) for m in r], + "OBJECT": parse_object, + "PING": lambda r: str_if_bytes(r) == "PONG", + "PUBSUB NUMSUB": parse_pubsub_numsub, + "RANDOMKEY": lambda r: r and r or None, + "SCAN": parse_scan, + "SCRIPT EXISTS": lambda r: list(map(bool, r)), + "SCRIPT FLUSH": bool_ok, + "SCRIPT KILL": bool_ok, + "SCRIPT LOAD": str_if_bytes, + "SENTINEL GET-MASTER-ADDR-BY-NAME": parse_sentinel_get_master, + "SENTINEL MASTER": parse_sentinel_master, + "SENTINEL MASTERS": parse_sentinel_masters, + "SENTINEL MONITOR": bool_ok, + "SENTINEL REMOVE": bool_ok, + "SENTINEL SENTINELS": parse_sentinel_slaves_and_sentinels, + "SENTINEL SET": bool_ok, + "SENTINEL SLAVES": parse_sentinel_slaves_and_sentinels, + "SET": lambda r: r and str_if_bytes(r) == "OK", + "SLOWLOG GET": parse_slowlog_get, + "SLOWLOG LEN": int, + "SLOWLOG RESET": bool_ok, + "SSCAN": parse_scan, + "TIME": lambda x: (int(x[0]), int(x[1])), + "XCLAIM": parse_xclaim, + "XGROUP CREATE": bool_ok, + "XGROUP DELCONSUMER": int, + "XGROUP DESTROY": bool, + "XGROUP SETID": bool_ok, + "XINFO CONSUMERS": parse_list_of_dicts, + "XINFO GROUPS": parse_list_of_dicts, + "XINFO STREAM": parse_xinfo_stream, + "XPENDING": parse_xpending, + "ZADD": parse_zadd, + "ZSCAN": parse_zscan, + } + + @classmethod + def from_url(cls, url: str, **kwargs): + """ + Return a Redis client object configured from the given URL + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + + """ + connection_pool = ConnectionPool.from_url(url, **kwargs) + return cls(connection_pool=connection_pool) + + def __init__( + self, + *, + host: str = "localhost", + port: int = 6379, + db: Union[str, int] = 0, + password: str = None, + socket_timeout: float = None, + socket_connect_timeout: float = None, + socket_keepalive: float = None, + socket_keepalive_options: float = None, + connection_pool: ConnectionPool = None, + unix_socket_path: str = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + ssl: bool = False, + ssl_keyfile: str = None, + ssl_certfile: str = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: str = None, + ssl_check_hostname: bool = False, + max_connections: int = None, + single_connection_client: bool = False, + health_check_interval: int = 0, + client_name: str = None, + username: str = None, + ): + kwargs: Dict[str, Any] + if not connection_pool: + kwargs = { + "db": db, + "username": username, + "password": password, + "socket_timeout": socket_timeout, + "encoding": encoding, + "encoding_errors": encoding_errors, + "decode_responses": decode_responses, + "retry_on_timeout": retry_on_timeout, + "max_connections": max_connections, + "health_check_interval": health_check_interval, + "client_name": client_name, + } + # based on input, setup appropriate connection args + if unix_socket_path is not None: + kwargs.update( + { + "path": unix_socket_path, + "connection_class": UnixDomainSocketConnection, + } + ) + else: + # TCP specific options + kwargs.update( + { + "host": host, + "port": port, + "socket_connect_timeout": socket_connect_timeout, + "socket_keepalive": socket_keepalive, + "socket_keepalive_options": socket_keepalive_options, + } + ) + + if ssl: + kwargs.update( + { + "connection_class": SSLConnection, + "ssl_keyfile": ssl_keyfile, + "ssl_certfile": ssl_certfile, + "ssl_cert_reqs": ssl_cert_reqs, + "ssl_ca_certs": ssl_ca_certs, + "ssl_check_hostname": ssl_check_hostname, + } + ) + connection_pool = ConnectionPool(**kwargs) + self.connection_pool = connection_pool + self.single_connection_client = single_connection_client + self.connection = None + + self.response_callbacks = CaseInsensitiveDict(self.__class__.RESPONSE_CALLBACKS) + + def __repr__(self): + return f"{self.__class__.__name__}<{self.connection_pool!r}>" + + def __await__(self) -> Awaitable[Redis]: + return self.initialize() + + async def initialize(self): + if self.single_connection_client and self.connection is None: + self.connection = await self.connection_pool.get_connection("_") + return self + + def set_response_callback(self, command: str, callback: ResponseCallbackT): + """Set a custom Response Callback""" + self.response_callbacks[command] = callback + + def pipeline(self, transaction: bool = True, shard_hint: str = None) -> Pipeline: + """ + Return a new pipeline object that can queue multiple commands for + later execution. ``transaction`` indicates whether all commands + should be executed atomically. Apart from making a group of operations + atomic, pipelines are useful for reducing the back-and-forth overhead + between the client and server. + """ + return Pipeline( + self.connection_pool, self.response_callbacks, transaction, shard_hint + ) + + async def transaction( + self, + func: Callable[[Pipeline], Union[Any, Awaitable[Any]]], + *watches: str, + shard_hint: str = None, + value_from_callable: bool = False, + watch_delay: float = None, + ): + """ + Convenience method for executing the callable `func` as a transaction + while watching all keys specified in `watches`. The 'func' callable + should expect a single argument which is a Pipeline object. + """ + pipe: Pipeline + async with self.pipeline(True, shard_hint) as pipe: + while True: + try: + if watches: + await pipe.watch(*watches) + func_value = func(pipe) + if inspect.isawaitable(func_value): + func_value = await func_value + exec_value = await pipe.execute() + return func_value if value_from_callable else exec_value + except WatchError: + if watch_delay is not None and watch_delay > 0: + await asyncio.sleep(watch_delay) + continue + + def lock( + self, + name: str, + timeout: float = None, + sleep: float = 0.1, + blocking_timeout: float = None, + lock_class: Type[Lock] = None, + thread_local=True, + ) -> Lock: + """ + Return a new Lock object using key ``name`` that mimics + the behavior of threading.Lock. + + If specified, ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``lock_class`` forces the specified lock implementation. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage.""" + if lock_class is None: + lock_class = Lock + return lock_class( + self, + name, + timeout=timeout, + sleep=sleep, + blocking_timeout=blocking_timeout, + thread_local=thread_local, + ) + + def pubsub(self, **kwargs) -> PubSub: + """ + Return a Publish/Subscribe object. With this object, you can + subscribe to channels and listen for messages that get published to + them. + """ + return PubSub(self.connection_pool, **kwargs) + + def monitor(self) -> Monitor: + return Monitor(self.connection_pool) + + def client(self) -> Redis: + return self.__class__( + connection_pool=self.connection_pool, single_connection_client=True + ) + + async def __aenter__(self) -> Redis: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.close() + + def __del__(self): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + pass + + async def close(self): + conn = self.connection + if conn: + self.connection = None + await self.connection_pool.release(conn) + + # COMMAND EXECUTION AND PROTOCOL PARSING + async def execute_command(self, *args, **options): + """Execute a command and return a parsed response""" + await self.initialize() + pool = self.connection_pool + command_name = args[0] + conn = self.connection or await pool.get_connection(command_name, **options) + try: + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + except (ConnectionError, TimeoutError) as e: + await conn.disconnect() + if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): + raise + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + finally: + if not self.connection: + await pool.release(conn) + + async def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + """Parses a response from the Redis server""" + try: + response = await connection.read_response() + except ResponseError: + if EMPTY_RESPONSE in options: + return options[EMPTY_RESPONSE] + raise + if command_name in self.response_callbacks: + retval = self.response_callbacks[command_name](response, **options) + return await retval if inspect.isawaitable(retval) else retval + return response + + # SERVER INFORMATION + + # ACL methods + def acl_cat(self, category: str = None) -> Awaitable: + """ + Returns a list of categories or commands within a category. + + If ``category`` is not supplied, returns a list of all categories. + If ``category`` is supplied, returns a list of all commands within + that category. + """ + pieces: List[EncodableT] = [category] if category else [] + return self.execute_command("ACL CAT", *pieces) + + def acl_deluser(self, username: str) -> Awaitable: + """Delete the ACL for the specified ``username``""" + return self.execute_command("ACL DELUSER", username) + + def acl_genpass(self) -> Awaitable: + """Generate a random password value""" + return self.execute_command("ACL GENPASS") + + def acl_getuser(self, username: str) -> Awaitable: + """ + Get the ACL details for the specified ``username``. + + If ``username`` does not exist, return None + """ + return self.execute_command("ACL GETUSER", username) + + def acl_list(self) -> Awaitable: + """Return a list of all ACLs on the server""" + return self.execute_command("ACL LIST") + + def acl_log(self, count: int = None) -> Awaitable: + """ + Get ACL logs as a list. + :param int count: Get logs[0:count]. + :rtype: List. + """ + args = [] + if count is not None: + if not isinstance(count, int): + raise DataError("ACL LOG count must be an integer") + args.append(count) + + return self.execute_command("ACL LOG", *args) + + def acl_log_reset(self) -> Awaitable: + """ + Reset ACL logs. + :rtype: Boolean. + """ + args = [b"RESET"] + return self.execute_command("ACL LOG", *args) + + def acl_load(self) -> Awaitable: + """ + Load ACL rules from the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to load ACL rules from an aclfile. + """ + return self.execute_command("ACL LOAD") + + def acl_save(self) -> Awaitable: + """ + Save ACL rules to the configured ``aclfile``. + + Note that the server must be configured with the ``aclfile`` + directive to be able to save ACL rules to an aclfile. + """ + return self.execute_command("ACL SAVE") + + def acl_setuser( # noqa: C901 + self, + username: str, + enabled: bool = False, + nopass: bool = False, + passwords: Iterable[str] = None, + hashed_passwords: Iterable[str] = None, + categories: Iterable[str] = None, + commands: Iterable[str] = None, + keys: Collection[str] = None, + reset: bool = False, + reset_keys: bool = False, + reset_passwords: bool = False, + ) -> Awaitable: + """ + Create or update an ACL user. + + Create or update the ACL for ``username``. If the user already exists, + the existing ACL is completely overwritten and replaced with the + specified values. + + ``enabled`` is a boolean indicating whether the user should be allowed + to authenticate or not. Defaults to ``False``. + + ``nopass`` is a boolean indicating whether the can authenticate without + a password. This cannot be True if ``passwords`` are also specified. + + ``passwords`` if specified is a list of plain text passwords + to add to or remove from the user. Each password must be prefixed with + a '+' to add or a '-' to remove. For convenience, the value of + ``passwords`` can be a simple prefixed string when adding or + removing a single password. + + ``hashed_passwords`` if specified is a list of SHA-256 hashed passwords + to add to or remove from the user. Each hashed password must be + prefixed with a '+' to add or a '-' to remove. For convenience, + the value of ``hashed_passwords`` can be a simple prefixed string when + adding or removing a single password. + + ``categories`` if specified is a list of strings representing category + permissions. Each string must be prefixed with either a '+' to add the + category permission or a '-' to remove the category permission. + + ``commands`` if specified is a list of strings representing command + permissions. Each string must be prefixed with either a '+' to add the + command permission or a '-' to remove the command permission. + + ``keys`` if specified is a list of key patterns to grant the user + access to. Keys patterns allow '*' to support wildcard matching. For + example, '*' grants access to all keys while 'cache:*' grants access + to all keys that are prefixed with 'cache:'. ``keys`` should not be + prefixed with a '~'. + + ``reset`` is a boolean indicating whether the user should be fully + reset prior to applying the new ACL. Setting this to True will + remove all existing passwords, flags and privileges from the user and + then apply the specified rules. If this is False, the user's existing + passwords, flags and privileges will be kept and any new specified + rules will be applied on top. + + ``reset_keys`` is a boolean indicating whether the user's key + permissions should be reset prior to applying any new key permissions + specified in ``keys``. If this is False, the user's existing + key permissions will be kept and any new specified key permissions + will be applied on top. + + ``reset_passwords`` is a boolean indicating whether to remove all + existing passwords and the 'nopass' flag from the user prior to + applying any new passwords specified in 'passwords' or + 'hashed_passwords'. If this is False, the user's existing passwords + and 'nopass' status will be kept and any new specified passwords + or hashed_passwords will be applied on top. + """ + encoder = self.connection_pool.get_encoder() + pieces: List[Union[str, bytes]] = [username] + + if reset: + pieces.append(b"reset") + + if reset_keys: + pieces.append(b"resetkeys") + + if reset_passwords: + pieces.append(b"resetpass") + + if enabled: + pieces.append(b"on") + else: + pieces.append(b"off") + + if (passwords or hashed_passwords) and nopass: + raise DataError( + "Cannot set 'nopass' and supply " "'passwords' or 'hashed_passwords'" + ) + + if passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + passwords = list_or_args(passwords, []) + for i, password in enumerate(passwords): + password = encoder.encode(password) + if password.startswith(b"+"): + pieces.append(b">%s" % password[1:]) + elif password.startswith(b"-"): + pieces.append(b"<%s" % password[1:]) + else: + raise DataError( + "Password %d must be prefixeed with a " + '"+" to add or a "-" to remove' % i + ) + + if hashed_passwords: + # as most users will have only one password, allow remove_passwords + # to be specified as a simple string or a list + hashed_passwords = list_or_args(hashed_passwords, []) + for i, hashed_password in enumerate(hashed_passwords): + hashed_password = encoder.encode(hashed_password) + if hashed_password.startswith(b"+"): + pieces.append(b"#%s" % hashed_password[1:]) + elif hashed_password.startswith(b"-"): + pieces.append(b"!%s" % hashed_password[1:]) + else: + raise DataError( + "Hashed %d password must be prefixeed " + 'with a "+" to add or a "-" to remove' % i + ) + + if nopass: + pieces.append(b"nopass") + + if categories: + for category in categories: + category = encoder.encode(category) + # categories can be prefixed with one of (+@, +, -@, -) + if category.startswith(b"+@"): + pieces.append(category) + elif category.startswith(b"+"): + pieces.append(b"+@%s" % category[1:]) + elif category.startswith(b"-@"): + pieces.append(category) + elif category.startswith(b"-"): + pieces.append(b"-@%s" % category[1:]) + else: + raise DataError( + 'Category "%s" must be prefixed with ' + '"+" or "-"' % encoder.decode(category, force=True) + ) + if commands: + for cmd in commands: + cmd = encoder.encode(cmd) + if not cmd.startswith(b"+") and not cmd.startswith(b"-"): + raise DataError( + 'Command "%s" must be prefixed with ' + '"+" or "-"' % encoder.decode(cmd, force=True) + ) + pieces.append(cmd) + + if keys: + for key in keys: + key = encoder.encode(key) + pieces.append(b"~%s" % key) + + return self.execute_command("ACL SETUSER", *pieces) + + def acl_users(self) -> Awaitable: + """Returns a list of all registered users on the server.""" + return self.execute_command("ACL USERS") + + def acl_whoami(self) -> Awaitable: + """Get the username for the current connection""" + return self.execute_command("ACL WHOAMI") + + def bgrewriteaof(self) -> Awaitable: + """Tell the Redis server to rewrite the AOF file from data in memory.""" + return self.execute_command("BGREWRITEAOF") + + def bgsave(self) -> Awaitable: + """ + Tell the Redis server to save its data to disk. Unlike save(), + this method is asynchronous and returns immediately. + """ + return self.execute_command("BGSAVE") + + def client_kill(self, address: str) -> Awaitable: + """Disconnects the client at ``address`` (ip:port)""" + return self.execute_command("CLIENT KILL", address) + + def client_kill_filter( + self, _id: str = None, _type: str = None, addr: str = None, skipme: bool = None + ) -> Awaitable: + """ + Disconnects client(s) using a variety of filter options + :param _id: Kills a client by its unique ID field + :param _type: Kills a client by type where type is one of 'normal', + 'master', 'slave' or 'pubsub' + :param addr: Kills a client by its 'address:port' + :param skipme: If True, then the client calling the command + will not get killed even if it is identified by one of the filter + options. If skipme is not provided, the server defaults to skipme=True + """ + args = [] + if _type is not None: + client_types = ("normal", "master", "slave", "pubsub") + if str(_type).lower() not in client_types: + raise DataError(f"CLIENT KILL type must be one of {client_types!r}") + args.extend((b"TYPE", _type)) + if skipme is not None: + if not isinstance(skipme, bool): + raise DataError("CLIENT KILL skipme must be a bool") + if skipme: + args.extend((b"SKIPME", b"YES")) + else: + args.extend((b"SKIPME", b"NO")) + if _id is not None: + args.extend((b"ID", _id)) + if addr is not None: + args.extend((b"ADDR", addr)) + if not args: + raise DataError( + "CLIENT KILL ... ... " + " must specify at least one filter" + ) + return self.execute_command("CLIENT KILL", *args) + + def client_list(self, _type: str = None) -> Awaitable: + """ + Returns a list of currently connected clients. + If type of client specified, only that type will be returned. + :param _type: optional. one of the client types (normal, master, + replica, pubsub) + """ + "Returns a list of currently connected clients" + if _type is not None: + client_types = ("normal", "master", "replica", "pubsub") + if str(_type).lower() not in client_types: + raise DataError(f"CLIENT LIST _type must be one of {client_types!r}") + return self.execute_command("CLIENT LIST", b"TYPE", _type) + return self.execute_command("CLIENT LIST") + + def client_getname(self) -> Awaitable: + """Returns the current connection name""" + return self.execute_command("CLIENT GETNAME") + + def client_id(self) -> Awaitable: + """Returns the current connection id""" + return self.execute_command("CLIENT ID") + + def client_setname(self, name: str) -> Awaitable: + """Sets the current connection name""" + return self.execute_command("CLIENT SETNAME", name) + + def client_unblock(self, client_id: int, error: bool = False) -> Awaitable: + """ + Unblocks a connection by its client id. + If ``error`` is True, unblocks the client with a special error message. + If ``error`` is False (default), the client is unblocked using the + regular timeout mechanism. + """ + args = ["CLIENT UNBLOCK", int(client_id)] + if error: + args.append(b"ERROR") + return self.execute_command(*args) + + def client_pause(self, timeout: int) -> Awaitable: + """ + Suspend all the Redis clients for the specified amount of time + :param timeout: milliseconds to pause clients + """ + if not isinstance(timeout, int): + raise DataError("CLIENT PAUSE timeout must be an integer") + return self.execute_command("CLIENT PAUSE", str(timeout)) + + def readwrite(self) -> Awaitable: + """Disables read queries for a connection to a Redis Cluster slave node""" + return self.execute_command("READWRITE") + + def readonly(self) -> Awaitable: + """Enables read queries for a connection to a Redis Cluster replica node""" + return self.execute_command("READONLY") + + def config_get(self, pattern: str = "*") -> Awaitable: + """Return a dictionary of configuration based on the ``pattern``""" + return self.execute_command("CONFIG GET", pattern) + + def config_set(self, name: str, value: EncodableT) -> Awaitable: + """Set config item ``name`` with ``value``""" + return self.execute_command("CONFIG SET", name, value) + + def config_resetstat(self) -> Awaitable: + """Reset runtime statistics""" + return self.execute_command("CONFIG RESETSTAT") + + def config_rewrite(self) -> Awaitable: + """Rewrite config file with the minimal change to reflect running config""" + return self.execute_command("CONFIG REWRITE") + + def dbsize(self) -> Awaitable: + """Returns the number of keys in the current database""" + return self.execute_command("DBSIZE") + + def debug_object(self, key: str) -> Awaitable: + """Returns version specific meta information about a given key""" + return self.execute_command("DEBUG OBJECT", key) + + def echo(self, value: EncodableT) -> Awaitable: + """Echo the string back from the server""" + return self.execute_command("ECHO", value) + + def flushall(self, asynchronous: bool = False) -> Awaitable: + """ + Delete all keys in all databases on the current host. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b"ASYNC") + return self.execute_command("FLUSHALL", *args) + + def flushdb(self, asynchronous: bool = False) -> Awaitable: + """ + Delete all keys in the current database. + + ``asynchronous`` indicates whether the operation is + executed asynchronously by the server. + """ + args = [] + if asynchronous: + args.append(b"ASYNC") + return self.execute_command("FLUSHDB", *args) + + def swapdb(self, first: int, second: int) -> Awaitable: + """Swap two databases""" + return self.execute_command("SWAPDB", first, second) + + def info(self, section: str = None) -> Awaitable: + """ + Returns a dictionary containing information about the Redis server + + The ``section`` option can be used to select a specific section + of information + + The section option is not supported by older versions of Redis Server, + and will generate ResponseError + """ + if section is None: + return self.execute_command("INFO") + else: + return self.execute_command("INFO", section) + + def lastsave(self) -> Awaitable: + """ + Return a Python datetime object representing the last time the + Redis database was saved to disk + """ + return self.execute_command("LASTSAVE") + + def migrate( + self, + host: str, + port: int, + keys: Collection[str], + destination_db: int, + timeout: int, + copy: bool = False, + replace: bool = False, + auth: str = None, + ) -> Awaitable: + """ + Migrate 1 or more keys from the current Redis server to a different + server specified by the ``host``, ``port`` and ``destination_db``. + + The ``timeout``, specified in milliseconds, indicates the maximum + time the connection between the two servers can be idle before the + command is interrupted. + + If ``copy`` is True, the specified ``keys`` are NOT deleted from + the source server. + + If ``replace`` is True, this operation will overwrite the keys + on the destination server if they exist. + + If ``auth`` is specified, authenticate to the destination server with + the password provided. + """ + keys = list_or_args(keys, []) + if not keys: + raise DataError("MIGRATE requires at least one key") + pieces: List[EncodableT] = [] + if copy: + pieces.append(b"COPY") + if replace: + pieces.append(b"REPLACE") + if auth: + pieces.append(b"AUTH") + pieces.append(auth) + pieces.append(b"KEYS") + pieces.extend(keys) + return self.execute_command( + "MIGRATE", host, port, "", destination_db, timeout, *pieces + ) + + def object(self, infotype: str, key: str) -> Awaitable: + """Return the encoding, idletime, or refcount about the key""" + return self.execute_command("OBJECT", infotype, key, infotype=infotype) + + def memory_stats(self) -> Awaitable: + """Return a dictionary of memory stats""" + return self.execute_command("MEMORY STATS") + + def memory_usage(self, key: str, samples: int = None) -> Awaitable: + """ + Return the total memory usage for key, its value and associated + administrative overheads. + + For nested data structures, ``samples`` is the number of elements to + sample. If left unspecified, the server's default is 5. Use 0 to sample + all elements. + """ + args = [] + if isinstance(samples, int): + args.extend([b"SAMPLES", samples]) + return self.execute_command("MEMORY USAGE", key, *args) + + def memory_purge(self) -> Awaitable: + """Attempts to purge dirty pages for reclamation by allocator""" + return self.execute_command("MEMORY PURGE") + + def ping(self) -> Awaitable: + """Ping the Redis server""" + return self.execute_command("PING") + + def save(self) -> Awaitable: + """ + Tell the Redis server to save its data to disk, + blocking until the save is complete + """ + return self.execute_command("SAVE") + + def sentinel_get_master_addr_by_name(self, service_name: str) -> Awaitable: + """Returns a (host, port) pair for the given ``service_name``""" + return self.execute_command("SENTINEL GET-MASTER-ADDR-BY-NAME", service_name) + + def sentinel_master(self, service_name: str) -> Awaitable: + """Returns a dictionary containing the specified masters state.""" + return self.execute_command("SENTINEL MASTER", service_name) + + def sentinel_masters(self) -> Awaitable: + """Returns a list of dictionaries containing each master's state.""" + return self.execute_command("SENTINEL MASTERS") + + def sentinel_monitor(self, name: str, ip: str, port: int, quorum: int) -> Awaitable: + """Add a new master to Sentinel to be monitored""" + return self.execute_command("SENTINEL MONITOR", name, ip, port, quorum) + + def sentinel_remove(self, name: str) -> Awaitable: + """Remove a master from Sentinel's monitoring""" + return self.execute_command("SENTINEL REMOVE", name) + + def sentinel_sentinels(self, service_name: str) -> Awaitable: + """Returns a list of sentinels for ``service_name``""" + return self.execute_command("SENTINEL SENTINELS", service_name) + + def sentinel_set(self, name: str, option: str, value: EncodableT) -> Awaitable: + """Set Sentinel monitoring parameters for a given master""" + return self.execute_command("SENTINEL SET", name, option, value) + + def sentinel_slaves(self, service_name: str) -> Awaitable: + """Returns a list of slaves for ``service_name``""" + return self.execute_command("SENTINEL SLAVES", service_name) + + def shutdown(self, save: bool = False, nosave: bool = False) -> Awaitable: + """Shutdown the Redis server. If Redis has persistence configured, + data will be flushed before shutdown. If the "save" option is set, + a data flush will be attempted even if there is no persistence + configured. If the "nosave" option is set, no data flush will be + attempted. The "save" and "nosave" options cannot both be set. + """ + if save and nosave: + raise DataError("SHUTDOWN save and nosave cannot both be set") + args = ["SHUTDOWN"] + if save: + args.append("SAVE") + if nosave: + args.append("NOSAVE") + try: + self.execute_command(*args) + except ConnectionError: + # a ConnectionError here is expected + return + raise RedisError("SHUTDOWN seems to have failed.") + + def slaveof(self, host: str = None, port: int = None) -> Awaitable: + """ + Set the server to be a replicated slave of the instance identified + by the ``host`` and ``port``. If called without arguments, the + instance is promoted to a master instead. + """ + if host is None and port is None: + return self.execute_command("SLAVEOF", b"NO", b"ONE") + return self.execute_command("SLAVEOF", host, port) + + def slowlog_get(self, num: int = None) -> Awaitable: + """ + Get the entries from the slowlog. If ``num`` is specified, get the + most recent ``num`` items. + """ + args: List[EncodableT] = ["SLOWLOG GET"] + if num is not None: + args.append(num) + decode_responses = self.connection_pool.connection_kwargs.get( + "decode_responses", False + ) + return self.execute_command(*args, decode_responses=decode_responses) + + def slowlog_len(self) -> Awaitable: + """Get the number of items in the slowlog""" + return self.execute_command("SLOWLOG LEN") + + def slowlog_reset(self) -> Awaitable: + """Remove all items in the slowlog""" + return self.execute_command("SLOWLOG RESET") + + def time(self) -> Awaitable: + """ + Returns the server time as a 2-item tuple of ints: + (seconds since epoch, microseconds into this second). + """ + return self.execute_command("TIME") + + def wait(self, num_replicas: int, timeout: int) -> Awaitable: + """ + Redis synchronous replication + That returns the number of replicas that processed the query when + we finally have at least ``num_replicas``, or when the ``timeout`` was + reached. + """ + return self.execute_command("WAIT", num_replicas, timeout) + + # BASIC KEY COMMANDS + def append(self, key: str, value: EncodableT) -> Awaitable: + """ + Appends the string ``value`` to the value at ``key``. If ``key`` + doesn't already exist, create it with a value of ``value``. + Returns the new length of the value at ``key``. + """ + return self.execute_command("APPEND", key, value) + + def bitcount(self, key: str, start: int = None, end: int = None) -> Awaitable: + """ + Returns the count of set bits in the value of ``key``. Optional + ``start`` and ``end`` paramaters indicate which bytes to consider + """ + params: List[EncodableT] = [key] + if start is not None and end is not None: + params.append(start) + params.append(end) + elif (start is not None and end is None) or (end is not None and start is None): + raise DataError("Both start and end must be specified") + return self.execute_command("BITCOUNT", *params) + + def bitfield(self, key: str, default_overflow: str = None) -> BitFieldOperation: + """ + Return a BitFieldOperation instance to conveniently construct one or + more bitfield operations on ``key``. + """ + return BitFieldOperation(self, key, default_overflow=default_overflow) + + def bitop(self, operation: str, dest: str, *keys: str) -> Awaitable: + """ + Perform a bitwise operation using ``operation`` between ``keys`` and + store the result in ``dest``. + """ + return self.execute_command("BITOP", operation, dest, *keys) + + def bitpos( + self, key: str, bit: int, start: int = None, end: int = None + ) -> Awaitable: + """ + Return the position of the first bit set to 1 or 0 in a string. + ``start`` and ``end`` difines search range. The range is interpreted + as a range of bytes and not a range of bits, so start=0 and end=2 + means to look at the first three bytes. + """ + if bit not in (0, 1): + raise DataError("bit must be 0 or 1") + params = [key, bit] + + start is not None and params.append(start) + + if start is not None and end is not None: + params.append(end) + elif start is None and end is not None: + raise DataError("start argument is not set, " "when end is specified") + return self.execute_command("BITPOS", *params) + + def decr(self, name: str, amount: int = 1) -> Awaitable: + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + """ + # An alias for ``decr()``, because it is already implemented + # as DECRBY redis command. + return self.decrby(name, amount) + + def decrby(self, name: str, amount: int = 1) -> Awaitable: + """ + Decrements the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as 0 - ``amount`` + """ + return self.execute_command("DECRBY", name, amount) + + def delete(self, *names: str) -> Awaitable: + """Delete one or more keys specified by ``names``""" + return self.execute_command("DEL", *names) + + def dump(self, name: str) -> Awaitable: + """ + Return a serialized version of the value stored at the specified key. + If key does not exist a nil bulk reply is returned. + """ + return self.execute_command("DUMP", name) + + def exists(self, *names: str) -> Awaitable: + """Returns the number of ``names`` that exist""" + return self.execute_command("EXISTS", *names) + + def expire(self, name: str, time: Union[int, datetime.timedelta]) -> Awaitable: + """ + Set an expire flag on key ``name`` for ``time`` seconds. ``time`` + can be represented by an integer or a Python timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + return self.execute_command("EXPIRE", name, time) + + def expireat(self, name: str, when: Union[float, datetime.datetime]) -> Awaitable: + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer indicating unix time or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + when = int(mod_time.mktime(when.timetuple())) + return self.execute_command("EXPIREAT", name, when) + + def get(self, name: str) -> Awaitable: + """ + Return the value at key ``name``, or None if the key doesn't exist + """ + return self.execute_command("GET", name) + + def getbit(self, name: str, offset: int) -> Awaitable: + """Returns a boolean indicating the value of ``offset`` in ``name``""" + return self.execute_command("GETBIT", name, offset) + + def getrange(self, key: str, start: int, end: int) -> Awaitable: + """ + Returns the substring of the string value stored at ``key``, + determined by the offsets ``start`` and ``end`` (both are inclusive) + """ + return self.execute_command("GETRANGE", key, start, end) + + def getset(self, name: str, value: EncodableT) -> Awaitable: + """ + Sets the value at key ``name`` to ``value`` + and returns the old value at key ``name`` atomically. + """ + return self.execute_command("GETSET", name, value) + + def incr(self, name: str, amount: int = 1) -> Awaitable: + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + return self.incrby(name, amount) + + def incrby(self, name: str, amount: int = 1) -> Awaitable: + """ + Increments the value of ``key`` by ``amount``. If no key exists, + the value will be initialized as ``amount`` + """ + # An alias for ``incr()``, because it is already implemented + # as INCRBY redis command. + return self.execute_command("INCRBY", name, amount) + + def incrbyfloat(self, name: str, amount: float = 1.0) -> Awaitable: + """ + Increments the value at key ``name`` by floating ``amount``. + If no key exists, the value will be initialized as ``amount`` + """ + return self.execute_command("INCRBYFLOAT", name, amount) + + def keys(self, pattern: str = "*") -> Awaitable: + """Returns a list of keys matching ``pattern``""" + return self.execute_command("KEYS", pattern) + + def mget(self, keys: str, *args: EncodableT) -> Awaitable: + """ + Returns a list of values ordered identically to ``keys`` + """ + args = list_or_args(keys, args) + options: Dict[str, Union[EncodableT, Iterable[EncodableT]]] = {} + if not args: + options[EMPTY_RESPONSE] = [] + return self.execute_command("MGET", *args, **options) + + def mset(self, mapping: Mapping[str, EncodableT]) -> Awaitable: + """ + Sets key/values based on a mapping. Mapping is a dictionary of + key/value pairs. Both keys and values should be strings or types that + can be cast to a string via str(). + """ + items: List[EncodableT] = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("MSET", *items) + + def msetnx(self, mapping: Mapping[str, EncodableT]) -> Awaitable: + """ + Sets key/values based on a mapping if none of the keys are already set. + Mapping is a dictionary of key/value pairs. Both keys and values + should be strings or types that can be cast to a string via str(). + Returns a boolean indicating if the operation was successful. + """ + items: List[EncodableT] = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("MSETNX", *items) + + def move(self, name: str, db: int) -> Awaitable: + """Moves the key ``name`` to a different Redis database ``db``""" + return self.execute_command("MOVE", name, db) + + def persist(self, name: str) -> Awaitable: + """Removes an expiration on ``name``""" + return self.execute_command("PERSIST", name) + + def pexpire(self, name: str, time: Union[int, datetime.timedelta]) -> Awaitable: + """ + Set an expire flag on key ``name`` for ``time`` milliseconds. + ``time`` can be represented by an integer or a Python timedelta + object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds() * 1000) + return self.execute_command("PEXPIRE", name, time) + + def pexpireat(self, name: str, when: Union[float, datetime.datetime]) -> Awaitable: + """ + Set an expire flag on key ``name``. ``when`` can be represented + as an integer representing unix time in milliseconds (unix time * 1000) + or a Python datetime object. + """ + if isinstance(when, datetime.datetime): + ms = int(when.microsecond / 1000) + when = int(mod_time.mktime(when.timetuple())) * 1000 + ms + return self.execute_command("PEXPIREAT", name, when) + + def psetex( + self, name: str, time_ms: Union[int, datetime.timedelta], value: EncodableT + ) -> Awaitable: + """ + Set the value of key ``name`` to ``value`` that expires in ``time_ms`` + milliseconds. ``time_ms`` can be represented by an integer or a Python + timedelta object + """ + if isinstance(time_ms, datetime.timedelta): + time_ms = int(time_ms.total_seconds() * 1000) + return self.execute_command("PSETEX", name, time_ms, value) + + def pttl(self, name: str) -> Awaitable: + """Returns the number of milliseconds until the key ``name`` will expire""" + return self.execute_command("PTTL", name) + + def randomkey(self) -> Awaitable: + """Returns the name of a random key""" + return self.execute_command("RANDOMKEY") + + def rename(self, src: str, dst: str) -> Awaitable: + """ + Rename key ``src`` to ``dst`` + """ + return self.execute_command("RENAME", src, dst) + + def renamenx(self, src: str, dst: str) -> Awaitable: + """Rename key ``src`` to ``dst`` if ``dst`` doesn't already exist""" + return self.execute_command("RENAMENX", src, dst) + + def restore( + self, + name: str, + ttl: float, + value: EncodableT, + replace: bool = False, + absttl: bool = False, + ) -> Awaitable: + """ + Create a key using the provided serialized value, previously obtained + using DUMP. + + ``replace`` allows an existing key on ``name`` to be overridden. If + it's not specified an error is raised on collision. + + ``absttl`` if True, specified ``ttl`` should represent an absolute Unix + timestamp in milliseconds in which the key will expire. (Redis 5.0 or + greater). + """ + params = [name, ttl, value] + if replace: + params.append("REPLACE") + if absttl: + params.append("ABSTTL") + return self.execute_command("RESTORE", *params) + + def set( + self, + name: str, + value: EncodableT, + ex: int = None, + px: int = None, + nx: bool = False, + xx: bool = False, + keepttl: bool = False, + ) -> Awaitable: + """ + Set the value at key ``name`` to ``value`` + + ``ex`` sets an expire flag on key ``name`` for ``ex`` seconds. + + ``px`` sets an expire flag on key ``name`` for ``px`` milliseconds. + + ``nx`` if set to True, set the value at key ``name`` to ``value`` only + if it does not exist. + + ``xx`` if set to True, set the value at key ``name`` to ``value`` only + if it already exists. + + ``keepttl`` if True, retain the time to live associated with the key. + (Available since Redis 6.0) + """ + pieces: List[EncodableT] = [name, value] + if ex is not None: + pieces.append("EX") + if isinstance(ex, datetime.timedelta): + ex = int(ex.total_seconds()) + pieces.append(ex) + if px is not None: + pieces.append("PX") + if isinstance(px, datetime.timedelta): + px = int(px.total_seconds() * 1000) + pieces.append(px) + + if nx: + pieces.append("NX") + if xx: + pieces.append("XX") + + if keepttl: + pieces.append("KEEPTTL") + + return self.execute_command("SET", *pieces) + + def setbit(self, name: str, offset: int, value: int) -> Awaitable: + """ + Flag the ``offset`` in ``name`` as ``value``. Returns a boolean + indicating the previous value of ``offset``. + """ + value = value and 1 or 0 + return self.execute_command("SETBIT", name, offset, value) + + def setex( + self, name: str, time: Union[int, datetime.timedelta], value: EncodableT + ) -> Awaitable: + """ + Set the value of key ``name`` to ``value`` that expires in ``time`` + seconds. ``time`` can be represented by an integer or a Python + timedelta object. + """ + if isinstance(time, datetime.timedelta): + time = int(time.total_seconds()) + return self.execute_command("SETEX", name, time, value) + + def setnx(self, name: str, value: EncodableT) -> Awaitable: + """Set the value of key ``name`` to ``value`` if key doesn't exist""" + return self.execute_command("SETNX", name, value) + + def setrange(self, name: str, offset: int, value: EncodableT) -> Awaitable: + """ + Overwrite bytes in the value of ``name`` starting at ``offset`` with + ``value``. If ``offset`` plus the length of ``value`` exceeds the + length of the original value, the new value will be larger than before. + If ``offset`` exceeds the length of the original value, null bytes + will be used to pad between the end of the previous value and the start + of what's being injected. + + Returns the length of the new string. + """ + return self.execute_command("SETRANGE", name, offset, value) + + def strlen(self, name: str) -> Awaitable: + """Return the number of bytes stored in the value of ``name``""" + return self.execute_command("STRLEN", name) + + def substr(self, name: str, start: int, end: int = -1) -> Awaitable: + """ + Return a substring of the string at key ``name``. ``start`` and ``end`` + are 0-based integers specifying the portion of the string to return. + """ + return self.execute_command("SUBSTR", name, start, end) + + def touch(self, *args: str) -> Awaitable: + """ + Alters the last access time of a key(s) ``*args``. A key is ignored + if it does not exist. + """ + return self.execute_command("TOUCH", *args) + + def ttl(self, name: str) -> Awaitable: + """Returns the number of seconds until the key ``name`` will expire""" + return self.execute_command("TTL", name) + + def type(self, name: str) -> Awaitable: + """Returns the type of key ``name``""" + return self.execute_command("TYPE", name) + + def unlink(self, *names: str) -> Awaitable: + """Unlink one or more keys specified by ``names``""" + return self.execute_command("UNLINK", *names) + + # LIST COMMANDS + def blpop(self, keys: str, timeout: int = 0) -> Awaitable: + """ + LPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to LPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BLPOP", *keys) + + def brpop(self, keys: str, timeout: int = 0) -> Awaitable: + """ + RPOP a value off of the first non-empty list + named in the ``keys`` list. + + If none of the lists in ``keys`` has a value to RPOP, then block + for ``timeout`` seconds, or until a value gets pushed on to one + of the lists. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BRPOP", *keys) + + def brpoplpush(self, src: str, dst: str, timeout: int = 0) -> Awaitable: + """ + Pop a value off the tail of ``src``, push it on the head of ``dst`` + and then return it. + + This command blocks until a value is in ``src`` or until ``timeout`` + seconds elapse, whichever is first. A ``timeout`` value of 0 blocks + forever. + """ + if timeout is None: + timeout = 0 + return self.execute_command("BRPOPLPUSH", src, dst, timeout) + + def lindex(self, name: str, index: int) -> Awaitable: + """ + Return the item from list ``name`` at position ``index`` + + Negative indexes are supported and will return an item at the + end of the list + """ + return self.execute_command("LINDEX", name, index) + + def linsert( + self, name: str, where: str, refvalue: EncodableT, value: EncodableT + ) -> Awaitable: + """ + Insert ``value`` in list ``name`` either immediately before or after + [``where``] ``refvalue`` + + Returns the new length of the list on success or -1 if ``refvalue`` + is not in the list. + """ + return self.execute_command("LINSERT", name, where, refvalue, value) + + def llen(self, name: str) -> Awaitable: + """Return the length of the list ``name``""" + return self.execute_command("LLEN", name) + + def lpop(self, name: str) -> Awaitable: + """Remove and return the first item of the list ``name``""" + return self.execute_command("LPOP", name) + + def lpush(self, name: str, *values: EncodableT) -> Awaitable: + """Push ``values`` onto the head of the list ``name``""" + return self.execute_command("LPUSH", name, *values) + + def lpushx(self, name: str, value: str) -> Awaitable: + """Push ``value`` onto the head of the list ``name`` if ``name`` exists""" + return self.execute_command("LPUSHX", name, value) + + def lrange(self, name: str, start: int, end: int) -> Awaitable: + """ + Return a slice of the list ``name`` between + position ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command("LRANGE", name, start, end) + + def lrem(self, name: str, count: int, value: EncodableT) -> Awaitable: + """ + Remove the first ``count`` occurrences of elements equal to ``value`` + from the list stored at ``name``. + + The count argument influences the operation in the following ways: + count > 0: Remove elements equal to value moving from head to tail. + count < 0: Remove elements equal to value moving from tail to head. + count = 0: Remove all elements equal to value. + """ + return self.execute_command("LREM", name, count, value) + + def lset(self, name: str, index: int, value: EncodableT) -> Awaitable: + """Set ``position`` of list ``name`` to ``value``""" + return self.execute_command("LSET", name, index, value) + + def ltrim(self, name: str, start: int, end: int) -> Awaitable: + """ + Trim the list ``name``, removing all values not within the slice + between ``start`` and ``end`` + + ``start`` and ``end`` can be negative numbers just like + Python slicing notation + """ + return self.execute_command("LTRIM", name, start, end) + + def rpop(self, name: str) -> Awaitable: + """Remove and return the last item of the list ``name``""" + return self.execute_command("RPOP", name) + + def rpoplpush(self, src: str, dst: str) -> Awaitable: + """ + RPOP a value off of the ``src`` list and atomically LPUSH it + on to the ``dst`` list. Returns the value. + """ + return self.execute_command("RPOPLPUSH", src, dst) + + def rpush(self, name: str, *values: EncodableT) -> Awaitable: + """Push ``values`` onto the tail of the list ``name``""" + return self.execute_command("RPUSH", name, *values) + + def rpushx(self, name: str, value: EncodableT) -> Awaitable: + """Push ``value`` onto the tail of the list ``name`` if ``name`` exists""" + return self.execute_command("RPUSHX", name, value) + + def lpos( + self, + name: str, + value: EncodableT, + rank: int = None, + count: int = None, + maxlen: int = None, + ) -> Awaitable: + """ + Get position of ``value`` within the list ``name`` + + If specified, ``rank`` indicates the "rank" of the first element to + return in case there are multiple copies of ``value`` in the list. + By default, LPOS returns the position of the first occurrence of + ``value`` in the list. When ``rank`` 2, LPOS returns the position of + the second ``value`` in the list. If ``rank`` is negative, LPOS + searches the list in reverse. For example, -1 would return the + position of the last occurrence of ``value`` and -2 would return the + position of the next to last occurrence of ``value``. + + If specified, ``count`` indicates that LPOS should return a list of + up to ``count`` positions. A ``count`` of 2 would return a list of + up to 2 positions. A ``count`` of 0 returns a list of all positions + matching ``value``. When ``count`` is specified and but ``value`` + does not exist in the list, an empty list is returned. + + If specified, ``maxlen`` indicates the maximum number of list + elements to scan. A ``maxlen`` of 1000 will only return the + position(s) of items within the first 1000 entries in the list. + A ``maxlen`` of 0 (the default) will scan the entire list. + """ + pieces: List[EncodableT] = [name, value] + if rank is not None: + pieces.extend(["RANK", rank]) + + if count is not None: + pieces.extend(["COUNT", count]) + + if maxlen is not None: + pieces.extend(["MAXLEN", maxlen]) + + return self.execute_command("LPOS", *pieces) + + def sort( + self, + name: str, + start: int = None, + num: int = None, + by: str = None, + get: Collection[str] = None, + desc: bool = False, + alpha: bool = False, + store: str = None, + groups: bool = False, + ) -> Awaitable: + """ + Sort and return the list, set or sorted set at ``name``. + + ``start`` and ``num`` allow for paging through the sorted data + + ``by`` allows using an external key to weight and sort the items. + Use an "*" to indicate where in the key the item value is located + + ``get`` allows for returning items from external keys rather than the + sorted data itself. Use an "*" to indicate where in the key + the item value is located + + ``desc`` allows for reversing the sort + + ``alpha`` allows for sorting lexicographically rather than numerically + + ``store`` allows for storing the result of the sort into + the key ``store`` + + ``groups`` if set to True and if ``get`` contains at least two + elements, sort will return a list of tuples, each containing the + values fetched from the arguments to ``get``. + + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + + pieces: List[EncodableT] = [name] + if by is not None: + pieces.append(b"BY") + pieces.append(by) + if start is not None and num is not None: + pieces.append(b"LIMIT") + pieces.append(start) + pieces.append(num) + if get is not None: + # If get is a string assume we want to get a single value. + # Otherwise assume it's an interable and we want to get multiple + # values. We can't just iterate blindly because strings are + # iterable. + if isinstance(get, (bytes, str)): + pieces.append(b"GET") + pieces.append(get) + else: + for g in get: + pieces.append(b"GET") + pieces.append(g) + if desc: + pieces.append(b"DESC") + if alpha: + pieces.append(b"ALPHA") + if store is not None: + pieces.append(b"STORE") + pieces.append(store) + + if groups: + if not get or isinstance(get, (bytes, str)) or len(get) < 2: + raise DataError( + 'when using "groups" the "get" argument ' + "must be specified and contain at least " + "two keys" + ) + + options = {"groups": len(get) if groups else None} + return self.execute_command("SORT", *pieces, **options) + + # SCAN COMMANDS + def scan( + self, cursor: int = 0, match: str = None, count: int = None, _type: str = None + ) -> Awaitable: + """ + Incrementally return lists of key names. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + pieces: List[EncodableT] = [cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + if _type is not None: + pieces.extend([b"TYPE", _type]) + return self.execute_command("SCAN", *pieces) + + async def scan_iter( + self, match: str = None, count: int = None, _type: str = None + ) -> AsyncIterator: + """ + Make an iterator using the SCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` provides a hint to Redis about the number of keys to + return per batch. + + ``_type`` filters the returned values by a particular Redis type. + Stock Redis instances allow for the following types: + HASH, LIST, SET, STREAM, STRING, ZSET + Additionally, Redis modules can expose other types as well. + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.scan( + cursor=cursor, match=match, count=count, _type=_type + ) + for d in data: + yield d + + def sscan( + self, name: str, cursor: int = 0, match: str = None, count: int = None + ) -> Awaitable: + """ + Incrementally return lists of elements in a set. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + pieces: List[EncodableT] = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + return self.execute_command("SSCAN", *pieces) + + async def sscan_iter(self, name, match=None, count=None) -> AsyncIterator: + """ + Make an iterator using the SSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.sscan( + name, cursor=cursor, match=match, count=count + ) + for d in data: + yield d + + def hscan( + self, name: str, cursor: int = 0, match: str = None, count: int = None + ) -> Awaitable: + """ + Incrementally return key/value slices in a hash. Also return a cursor + indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + pieces: List[EncodableT] = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + return self.execute_command("HSCAN", *pieces) + + async def hscan_iter( + self, name: str, match: str = None, count: int = None + ) -> AsyncIterator: + """ + Make an iterator using the HSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.hscan( + name, cursor=cursor, match=match, count=count + ) + for it in data.items(): + yield it + + def zscan( + self, + name: str, + cursor: int = 0, + match: str = None, + count: int = None, + score_cast_func: Union[Type, Callable] = float, + ) -> Awaitable: + """ + Incrementally return lists of elements in a sorted set. Also return a + cursor indicating the scan position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + pieces: List[EncodableT] = [name, cursor] + if match is not None: + pieces.extend([b"MATCH", match]) + if count is not None: + pieces.extend([b"COUNT", count]) + options = {"score_cast_func": score_cast_func} + return self.execute_command("ZSCAN", *pieces, **options) + + async def zscan_iter( + self, + name: str, + match: str = None, + count: int = None, + score_cast_func: Union[Type, Callable] = float, + ) -> AsyncIterator: + """ + Make an iterator using the ZSCAN command so that the client doesn't + need to remember the cursor position. + + ``match`` allows for filtering the keys by pattern + + ``count`` allows for hint the minimum number of returns + + ``score_cast_func`` a callable used to cast the score return value + """ + cursor = "0" + while cursor != 0: + cursor, data = await self.zscan( + name, + cursor=cursor, + match=match, + count=count, + score_cast_func=score_cast_func, + ) + for d in data: + yield d + + # SET COMMANDS + def sadd(self, name: str, *values: EncodableT) -> Awaitable: + """Add ``value(s)`` to set ``name``""" + return self.execute_command("SADD", name, *values) + + def scard(self, name: str) -> Awaitable: + """Return the number of elements in set ``name``""" + return self.execute_command("SCARD", name) + + def sdiff(self, keys: str, *args: EncodableT) -> Awaitable: + """Return the difference of sets specified by ``keys``""" + args = list_or_args(keys, args) + return self.execute_command("SDIFF", *args) + + def sdiffstore( + self, dest: str, keys: Collection[str], *args: EncodableT + ) -> Awaitable: + """ + Store the difference of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command("SDIFFSTORE", dest, *args) + + def sinter(self, keys: str, *args: EncodableT) -> Awaitable: + """Return the intersection of sets specified by ``keys``""" + args = list_or_args(keys, args) + return self.execute_command("SINTER", *args) + + def sinterstore( + self, dest: str, keys: Collection[str], *args: EncodableT + ) -> Awaitable: + """ + Store the intersection of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command("SINTERSTORE", dest, *args) + + def sismember(self, name: str, value: EncodableT) -> Awaitable: + """Return a boolean indicating if ``value`` is a member of set ``name``""" + return self.execute_command("SISMEMBER", name, value) + + def smembers(self, name: str) -> Awaitable: + """Return all members of the set ``name``""" + return self.execute_command("SMEMBERS", name) + + def smove(self, src: str, dst: str, value: EncodableT) -> Awaitable: + """Move ``value`` from set ``src`` to set ``dst`` atomically""" + return self.execute_command("SMOVE", src, dst, value) + + def spop(self, name: str, count: int = None) -> Awaitable: + """Remove and return a random member of set ``name``""" + args = (count is not None) and [count] or [] + return self.execute_command("SPOP", name, *args) + + def srandmember(self, name: str, number: int = None) -> Awaitable: + """ + If ``number`` is None, returns a random member of set ``name``. + + If ``number`` is supplied, returns a list of ``number`` random + members of set ``name``. Note this is only available when running + Redis 2.6+. + """ + args = (number is not None) and [number] or [] + return self.execute_command("SRANDMEMBER", name, *args) + + def srem(self, name: str, *values: EncodableT) -> Awaitable: + """Remove ``values`` from set ``name``""" + return self.execute_command("SREM", name, *values) + + def sunion(self, keys: Collection[str], *args: EncodableT) -> Awaitable: + """Return the union of sets specified by ``keys``""" + args = list_or_args(keys, args) + return self.execute_command("SUNION", *args) + + def sunionstore( + self, dest: str, keys: Collection[str], *args: EncodableT + ) -> Awaitable: + """ + Store the union of sets specified by ``keys`` into a new + set named ``dest``. Returns the number of keys in the new set. + """ + args = list_or_args(keys, args) + return self.execute_command("SUNIONSTORE", dest, *args) + + # STREAMS COMMANDS + def xack(self, name: str, groupname: str, *ids: str) -> Awaitable: + """ + Acknowledges the successful processing of one or more messages. + name: name of the stream. + groupname: name of the consumer group. + *ids: message ids to acknowlege. + """ + return self.execute_command("XACK", name, groupname, *ids) + + def xadd( + self, + name: str, + fields: Dict[str, EncodableT], + id: str = "*", + maxlen: int = None, + approximate: bool = True, + ) -> Awaitable: + """ + Add to a stream. + name: name of the stream + fields: dict of field/value pairs to insert into the stream + id: Location to insert this record. By default it is appended. + maxlen: truncate old stream members beyond this size + approximate: actual stream length may be slightly more than maxlen + + """ + pieces: List[EncodableT] = [] + if maxlen is not None: + if not isinstance(maxlen, int) or maxlen < 1: + raise DataError("XADD maxlen must be a positive integer") + pieces.append(b"MAXLEN") + if approximate: + pieces.append(b"~") + pieces.append(str(maxlen)) + pieces.append(id) + if not isinstance(fields, dict) or len(fields) == 0: + raise DataError("XADD fields must be a non-empty dict") + for pair in fields.items(): + pieces.extend(pair) + return self.execute_command("XADD", name, *pieces) + + def xclaim( + self, + name: str, + groupname: str, + consumername: str, + min_idle_time: int, + message_ids: Sequence[str], + idle: int = None, + time: int = None, + retrycount: int = None, + force: bool = False, + justid: bool = False, + ) -> Awaitable: + """ + Changes the ownership of a pending message. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of a consumer that claims the message. + min_idle_time: filter messages that were idle less than this amount of + milliseconds + message_ids: non-empty list or tuple of message IDs to claim + idle: optional. Set the idle time (last time it was delivered) of the + message in ms + time: optional integer. This is the same as idle but instead of a + relative amount of milliseconds, it sets the idle time to a specific + Unix time (in milliseconds). + retrycount: optional integer. set the retry counter to the specified + value. This counter is incremented every time a message is delivered + again. + force: optional boolean, false by default. Creates the pending message + entry in the PEL even if certain specified IDs are not already in the + PEL assigned to a different client. + justid: optional boolean, false by default. Return just an array of IDs + of messages successfully claimed, without returning the actual message + """ + if not isinstance(min_idle_time, int) or min_idle_time < 0: + raise DataError("XCLAIM min_idle_time must be a non negative " "integer") + if not isinstance(message_ids, (list, tuple)) or not message_ids: + raise DataError( + "XCLAIM message_ids must be a non empty list or " + "tuple of message IDs to claim" + ) + + kwargs = {} + pieces: List[EncodableT] = [name, groupname, consumername, str(min_idle_time)] + pieces.extend(list(message_ids)) + + if idle is not None: + if not isinstance(idle, int): + raise DataError("XCLAIM idle must be an integer") + pieces.extend((b"IDLE", str(idle))) + if time is not None: + if not isinstance(time, int): + raise DataError("XCLAIM time must be an integer") + pieces.extend((b"TIME", str(time))) + if retrycount is not None: + if not isinstance(retrycount, int): + raise DataError("XCLAIM retrycount must be an integer") + pieces.extend((b"RETRYCOUNT", str(retrycount))) + + if force: + if not isinstance(force, bool): + raise DataError("XCLAIM force must be a boolean") + pieces.append(b"FORCE") + if justid: + if not isinstance(justid, bool): + raise DataError("XCLAIM justid must be a boolean") + pieces.append(b"JUSTID") + kwargs["parse_justid"] = True + return self.execute_command("XCLAIM", *pieces, **kwargs) + + def xdel(self, name: str, *ids: str) -> Awaitable: + """ + Deletes one or more messages from a stream. + name: name of the stream. + *ids: message ids to delete. + """ + return self.execute_command("XDEL", name, *ids) + + def xgroup_create( + self, name: str, groupname: str, id: str = "$", mkstream: bool = False + ) -> Awaitable: + """ + Create a new consumer group associated with a stream. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + pieces: List[EncodableT] = ["XGROUP CREATE", name, groupname, id] + if mkstream: + pieces.append(b"MKSTREAM") + return self.execute_command(*pieces) + + def xgroup_delconsumer( + self, name: str, groupname: str, consumername: str + ) -> Awaitable: + """ + Remove a specific consumer from a consumer group. + Returns the number of pending messages that the consumer had before it + was deleted. + name: name of the stream. + groupname: name of the consumer group. + consumername: name of consumer to delete + """ + return self.execute_command("XGROUP DELCONSUMER", name, groupname, consumername) + + def xgroup_destroy(self, name: str, groupname: str) -> Awaitable: + """ + Destroy a consumer group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command("XGROUP DESTROY", name, groupname) + + def xgroup_setid(self, name: str, groupname: str, id: str) -> Awaitable: + """ + Set the consumer group last delivered ID to something else. + name: name of the stream. + groupname: name of the consumer group. + id: ID of the last item in the stream to consider already delivered. + """ + return self.execute_command("XGROUP SETID", name, groupname, id) + + def xinfo_consumers(self, name: str, groupname: str) -> Awaitable: + """ + Returns general information about the consumers in the group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command("XINFO CONSUMERS", name, groupname) + + def xinfo_groups(self, name: str) -> Awaitable: + """ + Returns general information about the consumer groups of the stream. + name: name of the stream. + """ + return self.execute_command("XINFO GROUPS", name) + + def xinfo_stream(self, name: str) -> Awaitable: + """ + Returns general information about the stream. + name: name of the stream. + """ + return self.execute_command("XINFO STREAM", name) + + def xlen(self, name: str) -> Awaitable: + """ + Returns the number of elements in a given stream. + """ + return self.execute_command("XLEN", name) + + def xpending(self, name: str, groupname: str) -> Awaitable: + """ + Returns information about pending messages of a group. + name: name of the stream. + groupname: name of the consumer group. + """ + return self.execute_command("XPENDING", name, groupname) + + def xpending_range( + self, + name: str, + groupname: str, + min: int, + max: int, + count: int, + consumername: str = None, + ) -> Awaitable: + """ + Returns information about pending messages, in a range. + name: name of the stream. + groupname: name of the consumer group. + min: minimum stream ID. + max: maximum stream ID. + count: number of messages to return + consumername: name of a consumer to filter by (optional). + """ + pieces: List[EncodableT] = [name, groupname] + if min is not None or max is not None or count is not None: + if min is None or max is None or count is None: + raise DataError( + "XPENDING must be provided with min, max " + "and count parameters, or none of them. " + ) + if not isinstance(count, int) or count < -1: + raise DataError("XPENDING count must be a integer >= -1") + pieces.extend((min, max, str(count))) + if consumername is not None: + if min is None or max is None or count is None: + raise DataError( + "if XPENDING is provided with consumername," + " it must be provided with min, max and" + " count parameters" + ) + pieces.append(consumername) + return self.execute_command("XPENDING", *pieces, parse_detail=True) + + def xrange( + self, name: str, min: str = "-", max: str = "+", count: int = None + ) -> Awaitable: + """ + Read stream values within an interval. + name: name of the stream. + start: first stream ID. defaults to '-', + meaning the earliest available. + finish: last stream ID. defaults to '+', + meaning the latest available. + count: if set, only return this many items, beginning with the + earliest available. + """ + pieces: List[EncodableT] = [min, max] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XRANGE count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + + return self.execute_command("XRANGE", name, *pieces) + + def xread( + self, streams: Dict[str, str], count: int = None, block: int = None + ) -> Awaitable: + """ + Block and monitor multiple streams for new data. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + """ + pieces: List[EncodableT] = [] + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError("XREAD block must be a non-negative integer") + pieces.append(b"BLOCK") + pieces.append(str(block)) + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREAD count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError("XREAD streams must be a non empty dict") + pieces.append(b"STREAMS") + keys, values = zip(*streams.items()) + pieces.extend(keys) + pieces.extend(values) + return self.execute_command("XREAD", *pieces) + + def xreadgroup( + self, + groupname: str, + consumername: str, + streams: Dict[str, str], + count: int = None, + block: int = None, + noack: bool = False, + ) -> Awaitable: + """ + Read from a stream via a consumer group. + groupname: name of the consumer group. + consumername: name of the requesting consumer. + streams: a dict of stream names to stream IDs, where + IDs indicate the last ID already seen. + count: if set, only return this many items, beginning with the + earliest available. + block: number of milliseconds to wait, if nothing already present. + noack: do not add messages to the PEL + """ + pieces: List[EncodableT] = [b"GROUP", groupname, consumername] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREADGROUP count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + if block is not None: + if not isinstance(block, int) or block < 0: + raise DataError("XREADGROUP block must be a non-negative " "integer") + pieces.append(b"BLOCK") + pieces.append(str(block)) + if noack: + pieces.append(b"NOACK") + if not isinstance(streams, dict) or len(streams) == 0: + raise DataError("XREADGROUP streams must be a non empty dict") + pieces.append(b"STREAMS") + pieces.extend(streams.keys()) + pieces.extend(streams.values()) + return self.execute_command("XREADGROUP", *pieces) + + def xrevrange( + self, name: str, max: str = "+", min: str = "-", count: int = None + ) -> Awaitable: + """ + Read stream values within an interval, in reverse order. + name: name of the stream + start: first stream ID. defaults to '+', + meaning the latest available. + finish: last stream ID. defaults to '-', + meaning the earliest available. + count: if set, only return this many items, beginning with the + latest available. + """ + pieces: List[EncodableT] = [max, min] + if count is not None: + if not isinstance(count, int) or count < 1: + raise DataError("XREVRANGE count must be a positive integer") + pieces.append(b"COUNT") + pieces.append(str(count)) + + return self.execute_command("XREVRANGE", name, *pieces) + + def xtrim(self, name: str, maxlen: int, approximate: bool = True) -> Awaitable: + """ + Trims old messages from a stream. + name: name of the stream. + maxlen: truncate old stream messages beyond this size + approximate: actual stream length may be slightly more than maxlen + """ + pieces: List[EncodableT] = [b"MAXLEN"] + if approximate: + pieces.append(b"~") + pieces.append(maxlen) + return self.execute_command("XTRIM", name, *pieces) + + # SORTED SET COMMANDS + def zadd( + self, + name: str, + mapping: Mapping[str, EncodableT], + nx: bool = False, + xx: bool = False, + ch: bool = False, + incr: bool = False, + ) -> Awaitable: + """ + Set any number of element-name, score pairs to the key ``name``. Pairs + are specified as a dict of element-names keys to score values. + + ``nx`` forces ZADD to only create new elements and not to update + scores for elements that already exist. + + ``xx`` forces ZADD to only update scores of elements that already + exist. New elements will not be added. + + ``ch`` modifies the return value to be the numbers of elements changed. + Changed elements include new elements that were added and elements + whose scores changed. + + ``incr`` modifies ZADD to behave like ZINCRBY. In this mode only a + single element/score pair can be specified and the score is the amount + the existing score will be incremented by. When using this mode the + return value of ZADD will be the new score of the element. + + The return value of ZADD varies based on the mode specified. With no + options, ZADD returns the number of new elements added to the sorted + set. + """ + if not mapping: + raise DataError("ZADD requires at least one element/score pair") + if nx and xx: + raise DataError("ZADD allows either 'nx' or 'xx', not both") + if incr and len(mapping) != 1: + raise DataError( + "ZADD option 'incr' only works when passing a " + "single element/score pair" + ) + pieces: List[EncodableT] = [] + options = {} + if nx: + pieces.append(b"NX") + if xx: + pieces.append(b"XX") + if ch: + pieces.append(b"CH") + if incr: + pieces.append(b"INCR") + options["as_score"] = True + for pair in mapping.items(): + pieces.append(pair[1]) + pieces.append(pair[0]) + return self.execute_command("ZADD", name, *pieces, **options) + + def zcard(self, name: str) -> Awaitable: + """Return the number of elements in the sorted set ``name``""" + return self.execute_command("ZCARD", name) + + def zcount(self, name: str, min: int, max: int) -> Awaitable: + """ + Returns the number of elements in the sorted set at key ``name`` with + a score between ``min`` and ``max``. + """ + return self.execute_command("ZCOUNT", name, min, max) + + def zincrby(self, name: str, amount: float, value: EncodableT) -> Awaitable: + """Increment the score of ``value`` in sorted set ``name`` by ``amount``""" + return self.execute_command("ZINCRBY", name, amount, value) + + def zinterstore( + self, dest: str, keys: Collection[str], aggregate: str = None + ) -> Awaitable: + """ + Intersect multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + """ + return self._zaggregate("ZINTERSTORE", dest, keys, aggregate) + + def zlexcount(self, name: str, min: str, max: str) -> Awaitable: + """ + Return the number of items in the sorted set ``name`` between the + lexicographical range ``min`` and ``max``. + """ + return self.execute_command("ZLEXCOUNT", name, min, max) + + def zpopmax(self, name: str, count: int = None) -> Awaitable: + """ + Remove and return up to ``count`` members with the highest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = {"withscores": True} + return self.execute_command("ZPOPMAX", name, *args, **options) + + def zpopmin(self, name: str, count: int = None) -> Awaitable: + """ + Remove and return up to ``count`` members with the lowest scores + from the sorted set ``name``. + """ + args = (count is not None) and [count] or [] + options = {"withscores": True} + return self.execute_command("ZPOPMIN", name, *args, **options) + + def bzpopmax(self, keys: Collection[str], timeout: int = 0) -> Awaitable: + """ + ZPOPMAX a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMAX, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + keys = list_or_args(keys, None) + keys.append(timeout) + return self.execute_command("BZPOPMAX", *keys) + + def bzpopmin(self, keys: Collection[str], timeout: int = 0) -> Awaitable: + """ + ZPOPMIN a value off of the first non-empty sorted set + named in the ``keys`` list. + + If none of the sorted sets in ``keys`` has a value to ZPOPMIN, + then block for ``timeout`` seconds, or until a member gets added + to one of the sorted sets. + + If timeout is 0, then block indefinitely. + """ + if timeout is None: + timeout = 0 + klist: List[EncodableT] = list_or_args(keys, None) + klist.append(timeout) + return self.execute_command("BZPOPMIN", *klist) + + def zrange( + self, + name: str, + start: int, + end: int, + desc: bool = False, + withscores: bool = False, + score_cast_func: Union[Type, Callable] = float, + ) -> Awaitable: + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in ascending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``desc`` a boolean indicating whether to sort the results descendingly + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if desc: + return self.zrevrange(name, start, end, withscores, score_cast_func) + pieces: List[EncodableT] = ["ZRANGE", name, start, end] + if withscores: + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrangebylex( + self, name: str, min: int, max: int, start: int = None, num: int = None + ) -> Awaitable: + """ + Return the lexicographical range of values from sorted set ``name`` + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces: List[EncodableT] = ["ZRANGEBYLEX", name, min, max] + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + return self.execute_command(*pieces) + + def zrevrangebylex( + self, name: str, max: int, min: int, start: int = None, num: int = None + ) -> Awaitable: + """ + Return the reversed lexicographical range of values from sorted set + ``name`` between ``max`` and ``min``. + + If ``start`` and ``num`` are specified, then return a slice of the + range. + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces: List[EncodableT] = ["ZREVRANGEBYLEX", name, max, min] + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + return self.execute_command(*pieces) + + def zrangebyscore( + self, + name: str, + min: int, + max: int, + start: int = None, + num: int = None, + withscores: bool = False, + score_cast_func: Union[Type, Callable] = float, + ) -> Awaitable: + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max``. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + `score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces: List[EncodableT] = ["ZRANGEBYSCORE", name, min, max] + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + if withscores: + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrank(self, name: str, value: EncodableT) -> Awaitable: + """ + Returns a 0-based value indicating the rank of ``value`` in sorted set + ``name`` + """ + return self.execute_command("ZRANK", name, value) + + def zrem(self, name: str, *values: EncodableT) -> Awaitable: + """Remove member ``values`` from sorted set ``name``""" + return self.execute_command("ZREM", name, *values) + + def zremrangebylex(self, name: str, min: int, max: int) -> Awaitable: + """ + Remove all elements in the sorted set ``name`` between the + lexicographical range specified by ``min`` and ``max``. + + Returns the number of elements removed. + """ + return self.execute_command("ZREMRANGEBYLEX", name, min, max) + + def zremrangebyrank(self, name: str, min: int, max: int) -> Awaitable: + """ + Remove all elements in the sorted set ``name`` with ranks between + ``min`` and ``max``. Values are 0-based, ordered from smallest score + to largest. Values can be negative indicating the highest scores. + Returns the number of elements removed + """ + return self.execute_command("ZREMRANGEBYRANK", name, min, max) + + def zremrangebyscore(self, name: str, min: int, max: int) -> Awaitable: + """ + Remove all elements in the sorted set ``name`` with scores + between ``min`` and ``max``. Returns the number of elements removed. + """ + return self.execute_command("ZREMRANGEBYSCORE", name, min, max) + + def zrevrange( + self, + name: str, + start: int, + end: int, + withscores: bool = False, + score_cast_func: Union[Type, Callable] = float, + ) -> Awaitable: + """ + Return a range of values from sorted set ``name`` between + ``start`` and ``end`` sorted in descending order. + + ``start`` and ``end`` can be negative, indicating the end of the range. + + ``withscores`` indicates to return the scores along with the values + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + pieces: List[EncodableT] = ["ZREVRANGE", name, start, end] + if withscores: + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrevrangebyscore( + self, + name: str, + min: int, + max: int, + start: int = None, + num: int = None, + withscores: bool = False, + score_cast_func: Union[Type, Callable] = float, + ) -> Awaitable: + """ + Return a range of values from the sorted set ``name`` with scores + between ``min`` and ``max`` in descending order. + + If ``start`` and ``num`` are specified, then return a slice + of the range. + + ``withscores`` indicates to return the scores along with the values. + The return type is a list of (value, score) pairs + + ``score_cast_func`` a callable used to cast the score return value + """ + if (start is not None and num is None) or (num is not None and start is None): + raise DataError("``start`` and ``num`` must both be specified") + pieces: List[EncodableT] = ["ZREVRANGEBYSCORE", name, min, max] + if start is not None and num is not None: + pieces.extend([b"LIMIT", start, num]) + if withscores: + pieces.append(b"WITHSCORES") + options = {"withscores": withscores, "score_cast_func": score_cast_func} + return self.execute_command(*pieces, **options) + + def zrevrank(self, name: str, value: EncodableT) -> Awaitable: + """ + Returns a 0-based value indicating the descending rank of + ``value`` in sorted set ``name`` + """ + return self.execute_command("ZREVRANK", name, value) + + def zscore(self, name: str, value: EncodableT) -> Awaitable: + """Return the score of element ``value`` in sorted set ``name``""" + return self.execute_command("ZSCORE", name, value) + + def zunionstore( + self, dest: str, keys: Collection[str], aggregate: str = None + ) -> Awaitable: + """ + Union multiple sorted sets specified by ``keys`` into + a new sorted set, ``dest``. Scores in the destination will be + aggregated based on the ``aggregate``, or SUM if none is provided. + """ + return self._zaggregate("ZUNIONSTORE", dest, keys, aggregate) + + def _zaggregate( + self, command: str, dest: str, keys: Collection[str], aggregate: str = None + ) -> Awaitable: + pieces: List[EncodableT] = [command, dest, len(keys)] + if isinstance(keys, dict): + keys, weights = keys.keys(), keys.values() + else: + weights = None + pieces.extend(keys) + if weights: + pieces.append(b"WEIGHTS") + pieces.extend(weights) + if aggregate: + pieces.append(b"AGGREGATE") + pieces.append(aggregate) + return self.execute_command(*pieces) + + # HYPERLOGLOG COMMANDS + def pfadd(self, name: str, *values: EncodableT) -> Awaitable: + """Adds the specified elements to the specified HyperLogLog.""" + return self.execute_command("PFADD", name, *values) + + def pfcount(self, *sources: str) -> Awaitable: + """ + Return the approximated cardinality of + the set observed by the HyperLogLog at key(s). + """ + return self.execute_command("PFCOUNT", *sources) + + def pfmerge(self, dest: str, *sources: str) -> Awaitable: + """Merge N different HyperLogLogs into a single one.""" + return self.execute_command("PFMERGE", dest, *sources) + + # HASH COMMANDS + def hdel(self, name: str, *keys: str) -> Awaitable: + """Delete ``keys`` from hash ``name``""" + return self.execute_command("HDEL", name, *keys) + + def hexists(self, name: str, key: str) -> Awaitable: + """Returns a boolean indicating if ``key`` exists within hash ``name``""" + return self.execute_command("HEXISTS", name, key) + + def hget(self, name: str, key: str) -> Awaitable: + """Return the value of ``key`` within the hash ``name``""" + return self.execute_command("HGET", name, key) + + def hgetall(self, name: str) -> Awaitable: + """Return a Python dict of the hash's name/value pairs""" + return self.execute_command("HGETALL", name) + + def hincrby(self, name: str, key: str, amount: int = 1) -> Awaitable: + """Increment the value of ``key`` in hash ``name`` by ``amount``""" + return self.execute_command("HINCRBY", name, key, amount) + + def hincrbyfloat(self, name: str, key: str, amount: float = 1.0) -> Awaitable: + """ + Increment the value of ``key`` in hash ``name`` by floating ``amount`` + """ + return self.execute_command("HINCRBYFLOAT", name, key, amount) + + def hkeys(self, name: str) -> Awaitable: + """Return the list of keys within hash ``name``""" + return self.execute_command("HKEYS", name) + + def hlen(self, name: str) -> Awaitable: + """Return the number of elements in hash ``name``""" + return self.execute_command("HLEN", name) + + def hset( + self, + name: str, + key: str = None, + value: EncodableT = None, + mapping: Mapping[str, EncodableT] = None, + ) -> Awaitable: + """ + Set ``key`` to ``value`` within hash ``name``, + ``mapping`` accepts a dict of key/value pairs that that will be + added to hash ``name``. + Returns the number of fields that were added. + """ + if key is None and not mapping: + raise DataError("'hset' with no key value pairs") + items: List[EncodableT] = [] + if key is not None: + items.extend((key, value)) + if mapping: + for pair in mapping.items(): + items.extend(pair) + + return self.execute_command("HSET", name, *items) + + def hsetnx(self, name: str, key: str, value: EncodableT) -> Awaitable: + """ + Set ``key`` to ``value`` within hash ``name`` if ``key`` does not + exist. Returns 1 if HSETNX created a field, otherwise 0. + """ + return self.execute_command("HSETNX", name, key, value) + + def hmset(self, name: str, mapping: Mapping[str, EncodableT]) -> Awaitable: + """ + Set key to value within hash ``name`` for each corresponding + key and value from the ``mapping`` dict. + """ + warnings.warn( + "%s.hmset() is deprecated. Use %s.hset() instead." + % (self.__class__.__name__, self.__class__.__name__), + DeprecationWarning, + stacklevel=2, + ) + if not mapping: + raise DataError("'hmset' with 'mapping' of length 0") + items = [] + for pair in mapping.items(): + items.extend(pair) + return self.execute_command("HMSET", name, *items) + + def hmget(self, name: str, keys: Sequence[str], *args: str) -> Awaitable: + """Returns a list of values ordered identically to ``keys``""" + args = list_or_args(keys, args) + return self.execute_command("HMGET", name, *args) + + def hvals(self, name: str) -> Awaitable: + """Return the list of values within hash ``name``""" + return self.execute_command("HVALS", name) + + def hstrlen(self, name: str, key: str) -> Awaitable: + """ + Return the number of bytes stored in the value of ``key`` + within hash ``name`` + """ + return self.execute_command("HSTRLEN", name, key) + + def publish(self, channel: str, message: EncodableT) -> Awaitable: + """ + Publish ``message`` on ``channel``. + Returns the number of subscribers the message was delivered to. + """ + return self.execute_command("PUBLISH", channel, message) + + def pubsub_channels(self, pattern: str = "*") -> Awaitable: + """ + Return a list of channels that have at least one subscriber + """ + return self.execute_command("PUBSUB CHANNELS", pattern) + + def pubsub_numpat(self) -> Awaitable: + """ + Returns the number of subscriptions to patterns + """ + return self.execute_command("PUBSUB NUMPAT") + + def pubsub_numsub(self, *args) -> Awaitable: + """ + Return a list of (channel, number of subscribers) tuples + for each channel given in ``*args`` + """ + return self.execute_command("PUBSUB NUMSUB", *args) + + def cluster(self, cluster_arg: str, *args: str) -> Awaitable: + return self.execute_command(f"CLUSTER {cluster_arg.upper()}", *args) + + def eval(self, script: str, numkeys: int, *keys_and_args: str) -> Awaitable: + """ + Execute the Lua ``script``, specifying the ``numkeys`` the script + will touch and the key names and argument values in ``keys_and_args``. + Returns the result of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command("EVAL", script, numkeys, *keys_and_args) + + def evalsha(self, sha: str, numkeys: int, *keys_and_args: str) -> Awaitable: + """ + Use the ``sha`` to execute a Lua script already registered via EVAL + or SCRIPT LOAD. Specify the ``numkeys`` the script will touch and the + key names and argument values in ``keys_and_args``. Returns the result + of the script. + + In practice, use the object returned by ``register_script``. This + function exists purely for Redis API completion. + """ + return self.execute_command("EVALSHA", sha, numkeys, *keys_and_args) + + def script_exists(self, *args: str) -> Awaitable: + """ + Check if a script exists in the script cache by specifying the SHAs of + each script as ``args``. Returns a list of boolean values indicating if + if each already script exists in the cache. + """ + return self.execute_command("SCRIPT EXISTS", *args) + + def script_flush(self) -> Awaitable: + """Flush all scripts from the script cache""" + return self.execute_command("SCRIPT FLUSH") + + def script_kill(self) -> Awaitable: + """Kill the currently executing Lua script""" + return self.execute_command("SCRIPT KILL") + + def script_load(self, script: str) -> Awaitable: + """Load a Lua ``script`` into the script cache. Returns the SHA.""" + return self.execute_command("SCRIPT LOAD", script) + + def register_script(self, script: str) -> Script: + """ + Register a Lua ``script`` specifying the ``keys`` it will touch. + Returns a Script object that is callable and hides the complexity of + deal with scripts, keys, and shas. This is the preferred way to work + with Lua scripts. + """ + return Script(self, script) + + # GEO COMMANDS + def geoadd(self, name: str, *values: EncodableT) -> Awaitable: + """ + Add the specified geospatial items to the specified key identified + by the ``name`` argument. The Geospatial items are given as ordered + members of the ``values`` argument, each item or place is formed by + the triad longitude, latitude and name. + """ + if len(values) % 3 != 0: + raise DataError("GEOADD requires places with lon, lat and name values") + return self.execute_command("GEOADD", name, *values) + + def geodist( + self, name: str, place1: str, place2: str, unit: str = None + ) -> Awaitable: + """ + Return the distance between ``place1`` and ``place2`` members of the + ``name`` key. + The units must be one of the following : m, km mi, ft. By default + meters are used. + """ + pieces: List[EncodableT] = [name, place1, place2] + if unit and unit not in ("m", "km", "mi", "ft"): + raise DataError("GEODIST invalid unit") + elif unit: + pieces.append(unit) + return self.execute_command("GEODIST", *pieces) + + def geohash(self, name: str, *values: EncodableT) -> Awaitable: + """ + Return the geo hash string for each item of ``values`` members of + the specified key identified by the ``name`` argument. + """ + return self.execute_command("GEOHASH", name, *values) + + def geopos(self, name: str, *values: EncodableT) -> Awaitable: + """ + Return the positions of each item of ``values`` as members of + the specified key identified by the ``name`` argument. Each position + is represented by the pairs lon and lat. + """ + return self.execute_command("GEOPOS", name, *values) + + def georadius( + self, + name: str, + longitude: float, + latitude: float, + radius: float, + unit: str = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: bool = None, + sort: str = None, + store: str = None, + store_dist: str = None, + ) -> Awaitable: + """ + Return the members of the specified key identified by the + ``name`` argument which are within the borders of the area specified + with the ``latitude`` and ``longitude`` location and the maximum + distance from the center specified by the ``radius`` value. + + The units must be one of the following : m, km mi, ft. By default + + ``withdist`` indicates to return the distances of each place. + + ``withcoord`` indicates to return the latitude and longitude of + each place. + + ``withhash`` indicates to return the geohash string of each place. + + ``count`` indicates to return the number of elements up to N. + + ``sort`` indicates to return the places in a sorted way, ASC for + nearest to fairest and DESC for fairest to nearest. + + ``store`` indicates to save the places names in a sorted set named + with a specific key, each element of the destination sorted set is + populated with the score got from the original geo sorted set. + + ``store_dist`` indicates to save the places names in a sorted set + named with a specific key, instead of ``store`` the sorted set + destination score is set with the distance. + """ + return self._georadiusgeneric( + "GEORADIUS", + name, + longitude, + latitude, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + ) + + def georadiusbymember( + self, + name: str, + member: str, + radius: float, + unit: str = None, + withdist: bool = False, + withcoord: bool = False, + withhash: bool = False, + count: bool = None, + sort: str = None, + store: str = None, + store_dist: str = None, + ) -> Awaitable: + """ + This command is exactly like ``georadius`` with the sole difference + that instead of taking, as the center of the area to query, a longitude + and latitude value, it takes the name of a member already existing + inside the geospatial index represented by the sorted set. + """ + return self._georadiusgeneric( + "GEORADIUSBYMEMBER", + name, + member, + radius, + unit=unit, + withdist=withdist, + withcoord=withcoord, + withhash=withhash, + count=count, + sort=sort, + store=store, + store_dist=store_dist, + ) + + def _georadiusgeneric( + self, command: str, *args: EncodableT, **kwargs: EncodableT + ) -> Awaitable: + pieces: List[EncodableT] = list(args) + if kwargs["unit"] and kwargs["unit"] not in ("m", "km", "mi", "ft"): + raise DataError("GEORADIUS invalid unit") + elif kwargs["unit"]: + pieces.append(kwargs["unit"]) + else: + pieces.append( + "m", + ) + + for arg_name, byte_repr in ( + ("withdist", b"WITHDIST"), + ("withcoord", b"WITHCOORD"), + ("withhash", b"WITHHASH"), + ): + if kwargs[arg_name]: + pieces.append(byte_repr) + + if kwargs["count"]: + pieces.extend([b"COUNT", kwargs["count"]]) + + if kwargs["sort"]: + if kwargs["sort"] == "ASC": + pieces.append(b"ASC") + elif kwargs["sort"] == "DESC": + pieces.append(b"DESC") + else: + raise DataError("GEORADIUS invalid sort") + + if kwargs["store"] and kwargs["store_dist"]: + raise DataError("GEORADIUS store and store_dist cant be set" " together") + + if kwargs["store"]: + pieces.extend([b"STORE", kwargs["store"]]) + + if kwargs["store_dist"]: + pieces.extend([b"STOREDIST", kwargs["store_dist"]]) + + return self.execute_command(command, *pieces, **kwargs) + + # MODULE COMMANDS + def module_load(self, path: str) -> Awaitable: + """ + Loads the module from ``path``. + Raises ``ModuleError`` if a module is not found at ``path``. + """ + return self.execute_command("MODULE LOAD", path) + + def module_unload(self, name: str) -> Awaitable: + """ + Unloads the module ``name``. + Raises ``ModuleError`` if ``name`` is not in loaded modules. + """ + return self.execute_command("MODULE UNLOAD", name) + + def module_list(self) -> Awaitable: + """ + Returns a list of dictionaries containing the name and version of + all loaded modules. + """ + return self.execute_command("MODULE LIST") + + +StrictRedis = Redis + + +class MonitorCommandInfo(TypedDict): + time: float + db: int + client_address: str + client_port: str + client_type: str + command: str + + +class Monitor: + """ + Monitor is useful for handling the MONITOR command to the redis server. + next_command() method returns one command from monitor + listen() method yields commands from monitor. + """ + + monitor_re = re.compile(r"\[(\d+) (.*)\] (.*)") + command_re = re.compile(r'"(.*?)(? MonitorCommandInfo: + """Parse the response from a monitor command""" + await self.connect() + response = await self.connection.read_response() + if isinstance(response, bytes): + response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) + m = self.monitor_re.match(command_data) + db_id, client_info, command = m.groups() + command = " ".join(self.command_re.findall(command)) + # Redis escapes double quotes because each piece of the command + # string is surrounded by double quotes. We don't have that + # requirement so remove the escaping and leave the quote. + command = command.replace('\\"', '"') + + if client_info == "lua": + client_address = "lua" + client_port = "" + client_type = "lua" + elif client_info.startswith("unix"): + client_address = "unix" + client_port = client_info[5:] + client_type = "unix" + else: + # use rsplit as ipv6 addresses contain colons + client_address, client_port = client_info.rsplit(":", 1) + client_type = "tcp" + return { + "time": float(command_time), + "db": int(db_id), + "client_address": client_address, + "client_port": client_port, + "client_type": client_type, + "command": command, + } + + async def listen(self) -> AsyncIterator[MonitorCommandInfo]: + """Listen for commands coming to the server.""" + while True: + yield await self.next_command() + + +class PubSub: + """ + PubSub provides publish, subscribe and listen support to Redis channels. + + After subscribing to one or more channels, the listen() method will block + until a message arrives on one of the subscribed channels. That message + will be returned and it's safe to start listening again. + """ + + PUBLISH_MESSAGE_TYPES = ("message", "pmessage") + UNSUBSCRIBE_MESSAGE_TYPES = ("unsubscribe", "punsubscribe") + HEALTH_CHECK_MESSAGE = "redis-py-health-check" + + def __init__( + self, + connection_pool: ConnectionPool, + shard_hint: str = None, + ignore_subscribe_messages: bool = False, + ): + self.connection_pool = connection_pool + self.shard_hint = shard_hint + self.ignore_subscribe_messages = ignore_subscribe_messages + self.connection = None + # we need to know the encoding options for this connection in order + # to lookup channel and pattern names for callback handlers. + self.encoder = self.connection_pool.get_encoder() + if self.encoder.decode_responses: + self.health_check_response = ["pong", self.HEALTH_CHECK_MESSAGE] + else: + self.health_check_response = [ + b"pong", + self.encoder.encode(self.HEALTH_CHECK_MESSAGE), + ] + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + self._lock = asyncio.Lock() + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.close()) + else: + loop.run_until_complete(self.close()) + except Exception: + pass + + async def reset(self): + async with self._lock: + if self.connection: + await self.connection.disconnect() + self.connection.clear_connect_callbacks() + await self.connection_pool.release(self.connection) + self.connection = None + self.channels = {} + self.pending_unsubscribe_channels = set() + self.patterns = {} + self.pending_unsubscribe_patterns = set() + + def close(self) -> Awaitable[NoReturn]: + return self.reset() + + async def on_connect(self, connection: Connection): + """Re-subscribe to any channels and patterns previously subscribed to""" + # NOTE: for python3, we can't pass bytestrings as keyword arguments + # so we need to decode channel/pattern names back to unicode strings + # before passing them to [p]subscribe. + self.pending_unsubscribe_channels.clear() + self.pending_unsubscribe_patterns.clear() + if self.channels: + channels = {} + for k, v in self.channels.items(): + channels[self.encoder.decode(k, force=True)] = v + await self.subscribe(**channels) + if self.patterns: + patterns = {} + for k, v in self.patterns.items(): + patterns[self.encoder.decode(k, force=True)] = v + await self.psubscribe(**patterns) + + @property + def subscribed(self): + """Indicates if there are subscriptions to any channels or patterns""" + return bool(self.channels or self.patterns) + + async def execute_command(self, *args: str): + """Execute a publish/subscribe command""" + + # NOTE: don't parse the response in this function -- it could pull a + # legitimate message off the stack if the connection is already + # subscribed to one or more channels + + if self.connection is None: + self.connection = await self.connection_pool.get_connection( + "pubsub", self.shard_hint + ) + # register a callback that re-subscribes to any channels we + # were listening to when we were disconnected + self.connection.register_connect_callback(self.on_connect) + connection = self.connection + kwargs = {"check_health": not self.subscribed} + await self._execute(connection, connection.send_command, *args, **kwargs) + + async def _execute(self, connection, command, *args, **kwargs): + try: + return await command(*args, **kwargs) + except (ConnectionError, TimeoutError) as e: + await connection.disconnect() + if not (connection.retry_on_timeout and isinstance(e, TimeoutError)): + raise + # Connect manually here. If the Redis server is down, this will + # fail and raise a ConnectionError as desired. + await connection.connect() + # the ``on_connect`` callback should haven been called by the + # connection to resubscribe us to any channels and patterns we were + # previously listening to + return await command(*args, **kwargs) + + async def parse_response(self, block: bool = True, timeout: int = 0): + """Parse the response from a publish/subscribe command""" + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + await self.check_health() + + if not block and not await conn.can_read(timeout=timeout): + return None + response = await self._execute(conn, conn.read_response) + + if conn.health_check_interval and response == self.health_check_response: + # ignore the health check message as user might not expect it + return None + return response + + async def check_health(self): + conn = self.connection + if conn is None: + raise RuntimeError( + "pubsub connection not set: " + "did you forget to call subscribe() or psubscribe()?" + ) + + if conn.health_check_interval and time.time() > conn.next_health_check: + await conn.send_command( + "PING", self.HEALTH_CHECK_MESSAGE, check_health=False + ) + + def _normalize_keys(self, data: Mapping[EncodableT, EncodableT]): + """ + normalize channel/pattern names to be either bytes or strings + based on whether responses are automatically decoded. this saves us + from coercing the value for each message coming in. + """ + encode = self.encoder.encode + decode = self.encoder.decode + return {decode(encode(k)): v for k, v in data.items()} + + async def psubscribe(self, *args: Union[str, bytes], **kwargs: EncodableT): + """ + Subscribe to channel patterns. Patterns supplied as keyword arguments + expect a pattern name as the key and a callable as the value. A + pattern's callable will be invoked automatically when a message is + received on that pattern rather than producing a message via + ``listen()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_patterns: Dict[Union[str, bytes], EncodableT] = dict.fromkeys(args) + new_patterns.update(kwargs) + ret_val = await self.execute_command("PSUBSCRIBE", *new_patterns.keys()) + # update the patterns dict AFTER we send the command. we don't want to + # subscribe twice to these patterns, once for the command and again + # for the reconnection. + new_patterns = self._normalize_keys(new_patterns) + self.patterns.update(new_patterns) + self.pending_unsubscribe_patterns.difference_update(new_patterns) + return ret_val + + def punsubscribe(self, *args: EncodableT) -> Awaitable: + """ + Unsubscribe from the supplied patterns. If empty, unsubscribe from + all patterns. + """ + if args: + args = list_or_args(args[0], args[1:]) + patterns = self._normalize_keys(dict.fromkeys(args)) + else: + patterns = self.patterns + self.pending_unsubscribe_patterns.update(patterns) + return self.execute_command("PUNSUBSCRIBE", *args) + + async def subscribe(self, *args, **kwargs): + """ + Subscribe to channels. Channels supplied as keyword arguments expect + a channel name as the key and a callable as the value. A channel's + callable will be invoked automatically when a message is received on + that channel rather than producing a message via ``listen()`` or + ``get_message()``. + """ + if args: + args = list_or_args(args[0], args[1:]) + new_channels = dict.fromkeys(args) + new_channels.update(kwargs) + ret_val = await self.execute_command("SUBSCRIBE", *new_channels.keys()) + # update the channels dict AFTER we send the command. we don't want to + # subscribe twice to these channels, once for the command and again + # for the reconnection. + new_channels = self._normalize_keys(new_channels) + self.channels.update(new_channels) + self.pending_unsubscribe_channels.difference_update(new_channels) + return ret_val + + def unsubscribe(self, *args) -> Awaitable: + """ + Unsubscribe from the supplied channels. If empty, unsubscribe from + all channels + """ + if args: + args = list_or_args(args[0], args[1:]) + channels = self._normalize_keys(dict.fromkeys(args)) + else: + channels = self.channels + self.pending_unsubscribe_channels.update(channels) + return self.execute_command("UNSUBSCRIBE", *args) + + async def listen(self) -> AsyncIterator: + """Listen for messages on channels this client has been subscribed to""" + while self.subscribed: + response = self.handle_message(await self.parse_response(block=True)) + if response is not None: + yield response + + async def get_message( + self, ignore_subscribe_messages: bool = False, timeout: float = 0.0 + ): + """ + Get the next message if one is available, otherwise None. + + If timeout is specified, the system will wait for `timeout` seconds + before returning. Timeout should be specified as a floating point + number. + """ + response = await self.parse_response(block=False, timeout=timeout) + if response: + return self.handle_message(response, ignore_subscribe_messages) + return None + + def ping(self, message=None) -> Awaitable: + """ + Ping the Redis server + """ + message = "" if message is None else message + return self.execute_command("PING", message) + + def handle_message(self, response, ignore_subscribe_messages=False): + """ + Parses a pub/sub message. If the channel or pattern was subscribed to + with a message handler, the handler is invoked instead of a parsed + message being returned. + """ + message_type = str_if_bytes(response[0]) + if message_type == "pmessage": + message = { + "type": message_type, + "pattern": response[1], + "channel": response[2], + "data": response[3], + } + elif message_type == "pong": + message = { + "type": message_type, + "pattern": None, + "channel": None, + "data": response[1], + } + else: + message = { + "type": message_type, + "pattern": None, + "channel": response[1], + "data": response[2], + } + + # if this is an unsubscribe message, remove it from memory + if message_type in self.UNSUBSCRIBE_MESSAGE_TYPES: + if message_type == "punsubscribe": + pattern = response[1] + if pattern in self.pending_unsubscribe_patterns: + self.pending_unsubscribe_patterns.remove(pattern) + self.patterns.pop(pattern, None) + else: + channel = response[1] + if channel in self.pending_unsubscribe_channels: + self.pending_unsubscribe_channels.remove(channel) + self.channels.pop(channel, None) + + if message_type in self.PUBLISH_MESSAGE_TYPES: + # if there's a message handler, invoke it + if message_type == "pmessage": + handler = self.patterns.get(message["pattern"], None) + else: + handler = self.channels.get(message["channel"], None) + if handler: + handler(message) + return None + elif message_type != "pong": + # this is a subscribe/unsubscribe message. ignore if we don't + # want them + if ignore_subscribe_messages or self.ignore_subscribe_messages: + return None + + return message + + def run_in_thread( + self, daemon: bool = False, exception_handler: Callable = None + ) -> PubSubWorkerThread: + for channel, handler in self.channels.items(): + if handler is None: + raise PubSubError("Channel: '%s' has no handler registered" % channel) + for pattern, handler in self.patterns.items(): + if handler is None: + raise PubSubError("Pattern: '%s' has no handler registered" % pattern) + + thread = PubSubWorkerThread( + self, daemon=daemon, exception_handler=exception_handler + ) + thread.start() + return thread + + +class PubsubWorkerExceptionHandler(Protocol): + def __call__(self, e: BaseException, pubsub: PubSub, t: PubSubWorkerThread): + ... + + +class AsyncPubsubWorkerExceptionHandler(Protocol): + async def __call__(self, e: BaseException, pubsub: PubSub, t: PubSubWorkerThread): + ... + + +PSWorkerThreadExcHandlerT = Union[ + PubsubWorkerExceptionHandler, AsyncPubsubWorkerExceptionHandler +] + + +class PubSubWorkerThread(threading.Thread): + def __init__( + self, + pubsub: PubSub, + daemon: bool = False, + poll_timeout: float = 1.0, + exception_handler: PSWorkerThreadExcHandlerT = None, + ): + super().__init__() + self.daemon = daemon + self.pubsub = pubsub + self.poll_timeout = poll_timeout + self.exception_handler = exception_handler + self._running = threading.Event() + # Make sure we have the current thread loop before we + # fork into the new thread. If not loop has been set on the connection + # pool use the current default event loop. + self.loop = pubsub.connection_pool.loop or asyncio.get_event_loop() + + async def _run(self): + pubsub = self.pubsub + while self._running.is_set(): + try: + await pubsub.get_message( + ignore_subscribe_messages=True, timeout=self.poll_timeout + ) + except BaseException as e: + if self.exception_handler is None: + raise + res = self.exception_handler(e, pubsub, self) + if inspect.isawaitable(res): + await res + await pubsub.close() + + def run(self): + if self._running.is_set(): + return + self._running.set() + future = asyncio.run_coroutine_threadsafe(self._run(), self.loop) + return future.result() + + def stop(self): + # trip the flag so the run loop exits. the run loop will + # close the pubsub connection, which disconnects the socket + # and returns the connection to the pool. + self._running.clear() + + +CommandT = Tuple[Tuple[Union[str, bytes], ...], Mapping[str, Any]] +CommandStackT = List[CommandT] + + +class Pipeline(Redis): + """ + Pipelines provide a way to transmit multiple commands to the Redis server + in one transmission. This is convenient for batch processing, such as + saving all the values in a list to Redis. + + All commands executed within a pipeline are wrapped with MULTI and EXEC + calls. This guarantees all commands executed in the pipeline will be + executed atomically. + + Any command raising an exception does *not* halt the execution of + subsequent commands in the pipeline. Instead, the exception is caught + and its instance is placed into the response list returned by execute(). + Code iterating over the response list should be able to deal with an + instance of an exception as a potential value. In general, these will be + ResponseError exceptions, such as those raised when issuing a command + on a key of a different datatype. + """ + + UNWATCH_COMMANDS = {"DISCARD", "EXEC", "UNWATCH"} + + def __init__( + self, + connection_pool: ConnectionPool, + response_callbacks: Mapping[str, ResponseCallbackT], + transaction: bool, + shard_hint: Optional[str], + ): + self.connection_pool = connection_pool + self.connection = None + self.response_callbacks = response_callbacks + self.transaction = transaction + self.shard_hint = shard_hint + self.watching = False + self.command_stack = [] + self.scripts = set() + self.explicit_transaction = False + + async def __aenter__(self) -> Pipeline: + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.reset() + + def __del__(self): + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + loop.create_task(self.reset()) + else: + loop.run_until_complete(self.reset()) + except Exception: + pass + + def __len__(self): + return len(self.command_stack) + + def __bool__(self): + """Pipeline instances should always evaluate to True""" + return True + + async def reset(self): + self.command_stack = [] + self.scripts = set() + # make sure to reset the connection state in the event that we were + # watching something + if self.watching and self.connection: + try: + # call this manually since our unwatch or + # immediate_execute_command methods can call reset() + await self.connection.send_command("UNWATCH") + await self.connection.read_response() + except ConnectionError: + # disconnect will also remove any previous WATCHes + self.connection.disconnect() + # clean up the other instance attributes + self.watching = False + self.explicit_transaction = False + # we can safely return the connection to the pool here since we're + # sure we're no longer WATCHing anything + if self.connection: + await self.connection_pool.release(self.connection) + self.connection = None + + def multi(self): + """ + Start a transactional block of the pipeline after WATCH commands + are issued. End the transactional block with `execute`. + """ + if self.explicit_transaction: + raise RedisError("Cannot issue nested calls to MULTI") + if self.command_stack: + raise RedisError( + "Commands without an initial WATCH have already " "been issued" + ) + self.explicit_transaction = True + + def execute_command(self, *args, **kwargs) -> Union[Pipeline, Awaitable[Pipeline]]: + if (self.watching or args[0] == "WATCH") and not self.explicit_transaction: + return self.immediate_execute_command(*args, **kwargs) + return self.pipeline_execute_command(*args, **kwargs) + + async def immediate_execute_command(self, *args, **options): + """ + Execute a command immediately, but don't auto-retry on a + ConnectionError if we're already WATCHing a variable. Used when + issuing WATCH or subsequent commands retrieving their values but before + MULTI is called. + """ + command_name = args[0] + conn = self.connection + # if this is the first call, we need a connection + if not conn: + conn = await self.connection_pool.get_connection( + command_name, self.shard_hint + ) + self.connection = conn + try: + await conn.send_command(*args) + return await self.parse_response(conn, command_name, **options) + except (ConnectionError, TimeoutError) as e: + await conn.disconnect() + # if we were already watching a variable, the watch is no longer + # valid since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + await self.reset() + raise WatchError( + "A ConnectionError occurred on while watching one or more keys" + ) from e + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): + await self.reset() + raise + + # retry_on_timeout is set, this is a TimeoutError and we are not + # already WATCHing any variables. retry the command. + try: + await conn.send_command(*args) + return self.parse_response(conn, command_name, **options) + except (ConnectionError, TimeoutError): + # a subsequent failure should simply be raised + await self.reset() + raise + except asyncio.CancelledError: + await conn.disconnect() + raise + + def pipeline_execute_command(self, *args, **options): + """ + Stage a command to be executed when execute() is next called + + Returns the current Pipeline object back so commands can be + chained together, such as: + + pipe = pipe.set('foo', 'bar').incr('baz').decr('bang') + + At some other point, you can then run: pipe.execute(), + which will execute all commands queued in the pipe. + """ + self.command_stack.append((args, options)) + return self + + async def _execute_transaction( + self, connection: Connection, commands: CommandStackT, raise_on_error + ): + cmds = chain([(("MULTI",), {})], commands, [(("EXEC",), {})]) + all_cmds = connection.pack_commands( + [args for args, options in cmds if EMPTY_RESPONSE not in options] + ) + await connection.send_packed_command(all_cmds) + errors = [] + + # parse off the response for MULTI + # NOTE: we need to handle ResponseErrors here and continue + # so that we read all the additional command messages from + # the socket + try: + await self.parse_response(connection, "_") + except ResponseError as err: + errors.append((0, err)) + + # and all the other commands + for i, command in enumerate(commands): + if EMPTY_RESPONSE in command[1]: + errors.append((i, command[1][EMPTY_RESPONSE])) + else: + try: + await self.parse_response(connection, "_") + except ResponseError as err: + self.annotate_exception(err, i + 1, command[0]) + errors.append((i, err)) + + # parse the EXEC. + try: + response = await self.parse_response(connection, "_") + except ExecAbortError as err: + if errors: + raise errors[0][1] from err + raise + + # EXEC clears any watched keys + self.watching = False + + if response is None: + raise WatchError("Watched variable changed.") from None + + # put any parse errors into the response + for i, e in errors: + response.insert(i, e) + + if len(response) != len(commands): + await self.connection.disconnect() + raise ResponseError( + "Wrong number of response items from pipeline execution" + ) from None + + # find any errors in the response and raise if necessary + if raise_on_error: + self.raise_first_error(commands, response) + + # We have to run response callbacks manually + data = [] + for r, cmd in zip(response, commands): + if not isinstance(r, Exception): + args, options = cmd + command_name = args[0] + if command_name in self.response_callbacks: + r = self.response_callbacks[command_name](r, **options) + if inspect.isawaitable(r): + r = await r + data.append(r) + return data + + async def _execute_pipeline( + self, connection: Connection, commands: CommandStackT, raise_on_error: bool + ): + # build up all commands into a single request to increase network perf + all_cmds = connection.pack_commands([args for args, _ in commands]) + await connection.send_packed_command(all_cmds) + + response = [] + for args, options in commands: + try: + response.append( + await self.parse_response(connection, args[0], **options) + ) + except ResponseError as e: + response.append(e) + + if raise_on_error: + self.raise_first_error(commands, response) + return response + + def raise_first_error(self, commands: CommandStackT, response: Iterable[Any]): + for i, r in enumerate(response): + if isinstance(r, ResponseError): + self.annotate_exception(r, i + 1, commands[i][0]) + raise r + + def annotate_exception(self, exception: Exception, number: int, command: str): + cmd = " ".join(map(safe_str, command)) + msg = f"Command # {number} ({cmd}) of pipeline caused error: {exception.args}" + exception.args = (msg,) + exception.args[1:] + + def parse_response( + self, connection: Connection, command_name: Union[str, bytes], **options + ): + result = Redis.parse_response(self, connection, command_name, **options) + if command_name in self.UNWATCH_COMMANDS: + self.watching = False + elif command_name == "WATCH": + self.watching = True + return result + + async def load_scripts(self): + # make sure all scripts that are about to be run on this pipeline exist + scripts = list(self.scripts) + immediate = self.immediate_execute_command + shas = [s.sha for s in scripts] + # we can't use the normal script_* methods because they would just + # get buffered in the pipeline. + exists = await immediate("SCRIPT EXISTS", *shas) + if not all(exists): + for s, exist in zip(scripts, exists): + if not exist: + s.sha = await immediate("SCRIPT LOAD", s.script) + + async def execute(self, raise_on_error: bool = True): + """Execute all the commands in the current pipeline""" + stack = self.command_stack + if not stack and not self.watching: + return [] + if self.scripts: + await self.load_scripts() + if self.transaction or self.explicit_transaction: + execute = self._execute_transaction + else: + execute = self._execute_pipeline + + conn = self.connection + if not conn: + conn = await self.connection_pool.get_connection("MULTI", self.shard_hint) + # assign to self.connection so reset() releases the connection + # back to the pool after we're done + self.connection = conn + + try: + return await execute(conn, stack, raise_on_error) + except (ConnectionError, TimeoutError) as e: + await conn.disconnect() + # if we were watching a variable, the watch is no longer valid + # since this connection has died. raise a WatchError, which + # indicates the user should retry this transaction. + if self.watching: + raise WatchError( + "A ConnectionError occurred on while " "watching one or more keys" + ) from e + # if retry_on_timeout is not set, or the error is not + # a TimeoutError, raise it + if not (conn.retry_on_timeout and isinstance(e, TimeoutError)): + raise + # retry a TimeoutError when retry_on_timeout is set + return await execute(conn, stack, raise_on_error) + finally: + await self.reset() + + async def watch(self, *names: str): + """Watches the values at keys ``names``""" + if self.explicit_transaction: + raise RedisError("Cannot issue a WATCH after a MULTI") + return await self.execute_command("WATCH", *names) + + async def unwatch(self): + """Unwatches all previously specified keys""" + return self.watching and await self.execute_command("UNWATCH") or True + + +class Script: + """An executable Lua script object returned by ``register_script``""" + + def __init__(self, registered_client: Redis, script: AnyStr): + self.registered_client = registered_client + self.script: AnyStr = script + # Precalculate and store the SHA1 hex digest of the script. + + if isinstance(script, str): + # We need the encoding from the client in order to generate an + # accurate byte representation of the script + encoder = registered_client.connection_pool.get_encoder() + script = encoder.encode(script) + self.sha = hashlib.sha1(script).hexdigest() + + async def __call__( + self, + keys: Collection[str] = None, + args: Iterable[str] = None, + client: Redis = None, + ): + """Execute the script, passing any required ``args``""" + keys = keys or [] + args = args or [] + if client is None: + client = self.registered_client + args = tuple(keys) + tuple(args) + # make sure the Redis server knows about the script + if isinstance(client, Pipeline): + # Make sure the pipeline can register the script before executing. + client.scripts.add(self) + return client.evalsha(self.sha, len(keys), *args) + try: + return await client.evalsha(self.sha, len(keys), *args) + except NoScriptError: + # Maybe the client is pointed to a differnet server than the client + # that created this instance? + # Overwrite the sha just in case there was a discrepancy. + self.sha = await client.script_load(self.script) + return await client.evalsha(self.sha, len(keys), *args) + + +class BitFieldOperation: + """ + Command builder for BITFIELD commands. + """ + + def __init__(self, client: Redis, key: str, default_overflow: str = None): + self.client = client + self.key = key + self._default_overflow = default_overflow + self.operations: List[Tuple[EncodableT, ...]] = [] + self._last_overflow = "WRAP" + self.reset() + + def reset(self): + """ + Reset the state of the instance to when it was constructed + """ + self.operations = [] + self._last_overflow = "WRAP" + self.overflow(self._default_overflow or self._last_overflow) + + def overflow(self, overflow: str): + """ + Update the overflow algorithm of successive INCRBY operations + :param overflow: Overflow algorithm, one of WRAP, SAT, FAIL. See the + Redis docs for descriptions of these algorithmsself. + :returns: a :py:class:`BitFieldOperation` instance. + """ + overflow = overflow.upper() + if overflow != self._last_overflow: + self._last_overflow = overflow + self.operations.append(("OVERFLOW", overflow)) + return self + + def incrby(self, fmt: str, offset: str, increment: int, overflow: str = None): + """ + Increment a bitfield by a given amount. + :param fmt: format-string for the bitfield being updated, e.g. 'u8' + for an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int increment: value to increment the bitfield by. + :param str overflow: overflow algorithm. Defaults to WRAP, but other + acceptable values are SAT and FAIL. See the Redis docs for + descriptions of these algorithms. + :returns: a :py:class:`BitFieldOperation` instance. + """ + if overflow is not None: + self.overflow(overflow) + + self.operations.append(("INCRBY", fmt, offset, increment)) + return self + + def get(self, fmt: str, offset: EncodableT): + """ + Get the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("GET", fmt, offset)) + return self + + def set(self, fmt: str, offset: EncodableT, value: int): + """ + Set the value of a given bitfield. + :param fmt: format-string for the bitfield being read, e.g. 'u8' for + an unsigned 8-bit integer. + :param offset: offset (in number of bits). If prefixed with a + '#', this is an offset multiplier, e.g. given the arguments + fmt='u8', offset='#2', the offset will be 16. + :param int value: value to set at the given position. + :returns: a :py:class:`BitFieldOperation` instance. + """ + self.operations.append(("SET", fmt, offset, value)) + return self + + @property + def command(self): + cmd = ["BITFIELD", self.key] + for ops in self.operations: + cmd.extend(ops) + return cmd + + def execute(self): + """ + Execute the operation(s) in a single BITFIELD command. The return value + is a list of values corresponding to each operation. If the client + used to create this instance was a pipeline, the list of values + will be present within the pipeline's execute. + """ + command = self.command + self.reset() + return self.client.execute_command(*command) diff --git a/aioredis/commands/__init__.py b/aioredis/commands/__init__.py deleted file mode 100644 index 7350a219e..000000000 --- a/aioredis/commands/__init__.py +++ /dev/null @@ -1,240 +0,0 @@ -from aioredis.abc import AbcPool -from aioredis.connection import create_connection -from aioredis.pool import create_pool -from aioredis.util import _NOTSET, wait_ok - -from .cluster import ClusterCommandsMixin -from .generic import GenericCommandsMixin -from .geo import GeoCommandsMixin, GeoMember, GeoPoint -from .hash import HashCommandsMixin -from .hyperloglog import HyperLogLogCommandsMixin -from .list import ListCommandsMixin -from .pubsub import PubSubCommandsMixin -from .scripting import ScriptingCommandsMixin -from .server import ServerCommandsMixin -from .set import SetCommandsMixin -from .sorted_set import SortedSetCommandsMixin -from .streams import StreamCommandsMixin -from .string import StringCommandsMixin -from .transaction import MultiExec, Pipeline, TransactionsCommandsMixin - -__all__ = [ - "create_redis", - "create_redis_pool", - "Redis", - "Pipeline", - "MultiExec", - "GeoPoint", - "GeoMember", -] - - -class Redis( - GenericCommandsMixin, - StringCommandsMixin, - HyperLogLogCommandsMixin, - SetCommandsMixin, - HashCommandsMixin, - TransactionsCommandsMixin, - SortedSetCommandsMixin, - ListCommandsMixin, - ScriptingCommandsMixin, - ServerCommandsMixin, - PubSubCommandsMixin, - ClusterCommandsMixin, - GeoCommandsMixin, - StreamCommandsMixin, -): - """High-level Redis interface. - - Gathers in one place Redis commands implemented in mixins. - - For commands details see: http://redis.io/commands/#connection - """ - - def __init__(self, pool_or_conn): - self._pool_or_conn = pool_or_conn - - def __repr__(self): - return f"<{self.__class__.__name__} {self._pool_or_conn!r}>" - - def execute(self, command, *args, **kwargs): - return self._pool_or_conn.execute(command, *args, **kwargs) - - def close(self): - """Close client connections.""" - self._pool_or_conn.close() - - async def wait_closed(self): - """Coroutine waiting until underlying connections are closed.""" - await self._pool_or_conn.wait_closed() - - @property - def db(self): - """Currently selected db index.""" - return self._pool_or_conn.db - - @property - def encoding(self): - """Current set codec or None.""" - return self._pool_or_conn.encoding - - @property - def connection(self): - """Either :class:`aioredis.RedisConnection`, - or :class:`aioredis.ConnectionsPool` instance. - """ - return self._pool_or_conn - - @property - def address(self): - """Redis connection address (if applicable).""" - return self._pool_or_conn.address - - @property - def in_transaction(self): - """Set to True when MULTI command was issued.""" - # XXX: this must be bound to real connection - return self._pool_or_conn.in_transaction - - @property - def closed(self): - """True if connection is closed.""" - return self._pool_or_conn.closed - - def auth(self, password): - """Authenticate to server. - - This method wraps call to :meth:`aioredis.RedisConnection.auth()` - """ - return self._pool_or_conn.auth(password) - - def echo(self, message, *, encoding=_NOTSET): - """Echo the given string.""" - return self.execute("ECHO", message, encoding=encoding) - - def ping(self, message=_NOTSET, *, encoding=_NOTSET): - """Ping the server. - - Accept optional echo message. - """ - if message is not _NOTSET: - args = (message,) - else: - args = () - return self.execute("PING", *args, encoding=encoding) - - def quit(self): - """Close the connection.""" - # TODO: warn when using pool - return self.execute("QUIT") - - def select(self, db): - """Change the selected database.""" - return self._pool_or_conn.select(db) - - def swapdb(self, from_index, to_index): - return wait_ok(self.execute(b"SWAPDB", from_index, to_index)) - - def __await__(self): - if isinstance(self._pool_or_conn, AbcPool): - conn = yield from self._pool_or_conn.acquire().__await__() - release = self._pool_or_conn.release - else: - # TODO: probably a lock is needed here if _pool_or_conn - # is Connection instance. - conn = self._pool_or_conn - release = None - return ContextRedis(conn, release) - - -class ContextRedis(Redis): - """An instance of Redis class bound to single connection.""" - - def __init__(self, conn, release_cb=None): - super().__init__(conn) - self._release_callback = release_cb - - def __enter__(self): - return self - - def __exit__(self, *exc_info): - if self._release_callback is not None: - conn, self._pool_or_conn = self._pool_or_conn, None - self._release_callback(conn) - - def __await__(self): - return ContextRedis(self._pool_or_conn) - yield - - -async def create_redis( - address, - *, - db=None, - password=None, - ssl=None, - encoding=None, - commands_factory=Redis, - parser=None, - timeout=None, - connection_cls=None, - loop=None, - name=None, -): - """Creates high-level Redis interface. - - This function is a coroutine. - """ - conn = await create_connection( - address, - db=db, - password=password, - ssl=ssl, - encoding=encoding, - parser=parser, - timeout=timeout, - connection_cls=connection_cls, - loop=loop, - name=name, - ) - return commands_factory(conn) - - -async def create_redis_pool( - address, - *, - db=None, - password=None, - ssl=None, - encoding=None, - commands_factory=Redis, - minsize=1, - maxsize=10, - parser=None, - timeout=None, - pool_cls=None, - connection_cls=None, - loop=None, - name=None, -): - """Creates high-level Redis interface. - - This function is a coroutine. - """ - pool = await create_pool( - address, - db=db, - password=password, - ssl=ssl, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - parser=parser, - create_connection_timeout=timeout, - pool_cls=pool_cls, - connection_cls=connection_cls, - loop=loop, - name=name, - ) - return commands_factory(pool) diff --git a/aioredis/commands/cluster.py b/aioredis/commands/cluster.py deleted file mode 100644 index eed078495..000000000 --- a/aioredis/commands/cluster.py +++ /dev/null @@ -1,101 +0,0 @@ -from aioredis.util import wait_ok - - -class ClusterCommandsMixin: - """Cluster commands mixin. - - For commands details see: http://redis.io/commands#cluster - """ - - def cluster_add_slots(self, slot, *slots): - """Assign new hash slots to receiving node.""" - slots = (slot,) + slots - if not all(isinstance(s, int) for s in slots): - raise TypeError("All parameters must be of type int") - fut = self.execute(b"CLUSTER", b"ADDSLOTS", *slots) - return wait_ok(fut) - - def cluster_count_failure_reports(self, node_id): - """Return the number of failure reports active for a given node.""" - return self.execute(b"CLUSTER", b"COUNT-FAILURE-REPORTS", node_id) - - def cluster_count_key_in_slots(self, slot): - """Return the number of local keys in the specified hash slot.""" - if not isinstance(slot, int): - raise TypeError( - "Expected slot to be of type int, got {}".format(type(slot)) - ) - return self.execute(b"CLUSTER", b"COUNTKEYSINSLOT", slot) - - def cluster_del_slots(self, slot, *slots): - """Set hash slots as unbound in receiving node.""" - slots = (slot,) + slots - if not all(isinstance(s, int) for s in slots): - raise TypeError("All parameters must be of type int") - fut = self.execute(b"CLUSTER", b"DELSLOTS", *slots) - return wait_ok(fut) - - def cluster_failover(self): - """Forces a slave to perform a manual failover of its master.""" - pass # TODO: Implement - - def cluster_forget(self, node_id): - """Remove a node from the nodes table.""" - fut = self.execute(b"CLUSTER", b"FORGET", node_id) - return wait_ok(fut) - - def cluster_get_keys_in_slots(self, slot, count, *, encoding): - """Return local key names in the specified hash slot.""" - return self.execute( - b"CLUSTER", b"GETKEYSINSLOT", slot, count, encoding=encoding - ) - - def cluster_info(self): - """Provides info about Redis Cluster node state.""" - pass # TODO: Implement - - def cluster_keyslot(self, key): - """Returns the hash slot of the specified key.""" - return self.execute(b"CLUSTER", b"KEYSLOT", key) - - def cluster_meet(self, ip, port): - """Force a node cluster to handshake with another node.""" - fut = self.execute(b"CLUSTER", b"MEET", ip, port) - return wait_ok(fut) - - def cluster_nodes(self): - """Get Cluster config for the node.""" - pass # TODO: Implement - - def cluster_replicate(self, node_id): - """Reconfigure a node as a slave of the specified master node.""" - fut = self.execute(b"CLUSTER", b"REPLICATE", node_id) - return wait_ok(fut) - - def cluster_reset(self, *, hard=False): - """Reset a Redis Cluster node.""" - reset = hard and b"HARD" or b"SOFT" - fut = self.execute(b"CLUSTER", b"RESET", reset) - return wait_ok(fut) - - def cluster_save_config(self): - """Force the node to save cluster state on disk.""" - fut = self.execute(b"CLUSTER", b"SAVECONFIG") - return wait_ok(fut) - - def cluster_set_config_epoch(self, config_epoch): - """Set the configuration epoch in a new node.""" - fut = self.execute(b"CLUSTER", b"SET-CONFIG-EPOCH", config_epoch) - return wait_ok(fut) - - def cluster_setslot(self, slot, command, node_id): - """Bind a hash slot to specified node.""" - pass # TODO: Implement - - def cluster_slaves(self, node_id): - """List slave nodes of the specified master node.""" - pass # TODO: Implement - - def cluster_slots(self): - """Get array of Cluster slot to node mappings.""" - pass # TODO: Implement diff --git a/aioredis/commands/generic.py b/aioredis/commands/generic.py deleted file mode 100644 index 204d696ac..000000000 --- a/aioredis/commands/generic.py +++ /dev/null @@ -1,308 +0,0 @@ -from aioredis.util import _NOTSET, _ScanIter, wait_convert, wait_ok - - -class GenericCommandsMixin: - """Generic commands mixin. - - For commands details see: http://redis.io/commands/#generic - """ - - def delete(self, key, *keys): - """Delete a key.""" - fut = self.execute(b"DEL", key, *keys) - return wait_convert(fut, int) - - def dump(self, key): - """Dump a key.""" - return self.execute(b"DUMP", key) - - def exists(self, key, *keys): - """Check if key(s) exists. - - .. versionchanged:: v0.2.9 - Accept multiple keys; **return** type **changed** from bool to int. - """ - return self.execute(b"EXISTS", key, *keys) - - def expire(self, key, timeout): - """Set a timeout on key. - - if timeout is float it will be multiplied by 1000 - coerced to int and passed to `pexpire` method. - - Otherwise raises TypeError if timeout argument is not int. - """ - if isinstance(timeout, float): - return self.pexpire(key, int(timeout * 1000)) - if not isinstance(timeout, int): - raise TypeError(f"timeout argument must be int, not {timeout!r}") - fut = self.execute(b"EXPIRE", key, timeout) - return wait_convert(fut, bool) - - def expireat(self, key, timestamp): - """Set expire timestamp on a key. - - if timeout is float it will be multiplied by 1000 - coerced to int and passed to `pexpireat` method. - - Otherwise raises TypeError if timestamp argument is not int. - """ - if isinstance(timestamp, float): - return self.pexpireat(key, int(timestamp * 1000)) - if not isinstance(timestamp, int): - raise TypeError(f"timestamp argument must be int, not {timestamp!r}") - fut = self.execute(b"EXPIREAT", key, timestamp) - return wait_convert(fut, bool) - - def keys(self, pattern, *, encoding=_NOTSET): - """Returns all keys matching pattern.""" - return self.execute(b"KEYS", pattern, encoding=encoding) - - def migrate(self, host, port, key, dest_db, timeout, *, copy=False, replace=False): - """Atomically transfer a key from a Redis instance to another one.""" - if not isinstance(host, str): - raise TypeError("host argument must be str") - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if not isinstance(dest_db, int): - raise TypeError("dest_db argument must be int") - if not host: - raise ValueError("Got empty host") - if dest_db < 0: - raise ValueError("dest_db must be greater equal 0") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - - flags = [] - if copy: - flags.append(b"COPY") - if replace: - flags.append(b"REPLACE") - fut = self.execute(b"MIGRATE", host, port, key, dest_db, timeout, *flags) - return wait_ok(fut) - - def migrate_keys( - self, host, port, keys, dest_db, timeout, *, copy=False, replace=False - ): - """Atomically transfer keys from one Redis instance to another one. - - Keys argument must be list/tuple of keys to migrate. - """ - if not isinstance(host, str): - raise TypeError("host argument must be str") - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if not isinstance(dest_db, int): - raise TypeError("dest_db argument must be int") - if not isinstance(keys, (list, tuple)): - raise TypeError("keys argument must be list or tuple") - if not host: - raise ValueError("Got empty host") - if dest_db < 0: - raise ValueError("dest_db must be greater equal 0") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - if not keys: - raise ValueError("keys must not be empty") - - flags = [] - if copy: - flags.append(b"COPY") - if replace: - flags.append(b"REPLACE") - flags.append(b"KEYS") - flags.extend(keys) - fut = self.execute(b"MIGRATE", host, port, "", dest_db, timeout, *flags) - return wait_ok(fut) - - def move(self, key, db): - """Move key from currently selected database to specified destination. - - :raises TypeError: if db is not int - :raises ValueError: if db is less than 0 - """ - if not isinstance(db, int): - raise TypeError(f"db argument must be int, not {db!r}") - if db < 0: - raise ValueError(f"db argument must be not less than 0, {db!r}") - fut = self.execute(b"MOVE", key, db) - return wait_convert(fut, bool) - - def object_refcount(self, key): - """Returns the number of references of the value associated - with the specified key (OBJECT REFCOUNT). - """ - return self.execute(b"OBJECT", b"REFCOUNT", key) - - def object_encoding(self, key): - """Returns the kind of internal representation used in order - to store the value associated with a key (OBJECT ENCODING). - """ - return self.execute(b"OBJECT", b"ENCODING", key, encoding="utf-8") - - def object_idletime(self, key): - """Returns the number of seconds since the object is not requested - by read or write operations (OBJECT IDLETIME). - """ - return self.execute(b"OBJECT", b"IDLETIME", key) - - def persist(self, key): - """Remove the existing timeout on key.""" - fut = self.execute(b"PERSIST", key) - return wait_convert(fut, bool) - - def pexpire(self, key, timeout): - """Set a milliseconds timeout on key. - - :raises TypeError: if timeout is not int - """ - if not isinstance(timeout, int): - raise TypeError(f"timeout argument must be int, not {timeout!r}") - fut = self.execute(b"PEXPIRE", key, timeout) - return wait_convert(fut, bool) - - def pexpireat(self, key, timestamp): - """Set expire timestamp on key, timestamp in milliseconds. - - :raises TypeError: if timeout is not int - """ - if not isinstance(timestamp, int): - raise TypeError(f"timestamp argument must be int, not {timestamp!r}") - fut = self.execute(b"PEXPIREAT", key, timestamp) - return wait_convert(fut, bool) - - def pttl(self, key): - """Returns time-to-live for a key, in milliseconds. - - Special return values (starting with Redis 2.8): - - * command returns -2 if the key does not exist. - * command returns -1 if the key exists but has no associated expire. - """ - # TODO: maybe convert negative values to: - # -2 to None - no key - # -1 to False - no expire - return self.execute(b"PTTL", key) - - def randomkey(self, *, encoding=_NOTSET): - """Return a random key from the currently selected database.""" - return self.execute(b"RANDOMKEY", encoding=encoding) - - def rename(self, key, newkey): - """Renames key to newkey. - - :raises ValueError: if key == newkey - """ - if key == newkey: - raise ValueError("key and newkey are the same") - fut = self.execute(b"RENAME", key, newkey) - return wait_ok(fut) - - def renamenx(self, key, newkey): - """Renames key to newkey only if newkey does not exist. - - :raises ValueError: if key == newkey - """ - if key == newkey: - raise ValueError("key and newkey are the same") - fut = self.execute(b"RENAMENX", key, newkey) - return wait_convert(fut, bool) - - def restore(self, key, ttl, value): - """Creates a key associated with a value that is obtained via DUMP.""" - return self.execute(b"RESTORE", key, ttl, value) - - def scan(self, cursor=0, match=None, count=None, key_type=None): - """Incrementally iterate the keys space. - - Usage example: - - >>> match = 'something*' - >>> cur = b'0' - >>> while cur: - ... cur, keys = await redis.scan(cur, match=match) - ... for key in keys: - ... print('Matched:', key) - - """ - args = [] - if match is not None: - args += [b"MATCH", match] - if count is not None: - args += [b"COUNT", count] - if key_type is not None: - args += [b"TYPE", key_type] - fut = self.execute(b"SCAN", cursor, *args) - return wait_convert(fut, lambda o: (int(o[0]), o[1])) - - def iscan(self, *, match=None, count=None): - """Incrementally iterate the keys space using async for. - - Usage example: - - >>> async for key in redis.iscan(match='something*'): - ... print('Matched:', key) - - """ - return _ScanIter(lambda cur: self.scan(cur, match=match, count=count)) - - def sort( - self, - key, - *get_patterns, - by=None, - offset=None, - count=None, - asc=None, - alpha=False, - store=None, - ): - """Sort the elements in a list, set or sorted set.""" - args = [] - if by is not None: - args += [b"BY", by] - if offset is not None and count is not None: - args += [b"LIMIT", offset, count] - if get_patterns: - args += sum(([b"GET", pattern] for pattern in get_patterns), []) - if asc is not None: - args += [asc is True and b"ASC" or b"DESC"] - if alpha: - args += [b"ALPHA"] - if store is not None: - args += [b"STORE", store] - return self.execute(b"SORT", key, *args) - - def touch(self, key, *keys): - """Alters the last access time of a key(s). - - Returns the number of keys that were touched. - """ - return self.execute(b"TOUCH", key, *keys) - - def ttl(self, key): - """Returns time-to-live for a key, in seconds. - - Special return values (starting with Redis 2.8): - * command returns -2 if the key does not exist. - * command returns -1 if the key exists but has no associated expire. - """ - # TODO: maybe convert negative values to: - # -2 to None - no key - # -1 to False - no expire - return self.execute(b"TTL", key) - - def type(self, key): - """Returns the string representation of the value's type stored at key.""" - # NOTE: for non-existent keys TYPE returns b'none' - return self.execute(b"TYPE", key) - - def unlink(self, key, *keys): - """Delete a key asynchronously in another thread.""" - return wait_convert(self.execute(b"UNLINK", key, *keys), int) - - def wait(self, numslaves, timeout): - """Wait for the synchronous replication of all the write - commands sent in the context of the current connection. - """ - return self.execute(b"WAIT", numslaves, timeout) diff --git a/aioredis/commands/geo.py b/aioredis/commands/geo.py deleted file mode 100644 index 7c5e83dde..000000000 --- a/aioredis/commands/geo.py +++ /dev/null @@ -1,225 +0,0 @@ -from collections import namedtuple - -from aioredis.util import _NOTSET, wait_convert - -GeoPoint = namedtuple("GeoPoint", ("longitude", "latitude")) -GeoMember = namedtuple("GeoMember", ("member", "dist", "hash", "coord")) - - -class GeoCommandsMixin: - """Geo commands mixin. - - For commands details see: http://redis.io/commands#geo - """ - - def geoadd(self, key, longitude, latitude, member, *args, **kwargs): - """Add one or more geospatial items in the geospatial index represented - using a sorted set. - - :rtype: int - """ - return self.execute( - b"GEOADD", key, longitude, latitude, member, *args, **kwargs - ) - - def geohash(self, key, member, *members, **kwargs): - """Returns members of a geospatial index as standard geohash strings. - - :rtype: list[str or bytes or None] - """ - return self.execute(b"GEOHASH", key, member, *members, **kwargs) - - def geopos(self, key, member, *members, **kwargs): - """Returns longitude and latitude of members of a geospatial index. - - :rtype: list[GeoPoint or None] - """ - fut = self.execute(b"GEOPOS", key, member, *members, **kwargs) - return wait_convert(fut, make_geopos) - - def geodist(self, key, member1, member2, unit="m"): - """Returns the distance between two members of a geospatial index. - - :rtype: list[float or None] - """ - fut = self.execute(b"GEODIST", key, member1, member2, unit) - return wait_convert(fut, make_geodist) - - def georadius( - self, - key, - longitude, - latitude, - radius, - unit="m", - *, - with_dist=False, - with_hash=False, - with_coord=False, - count=None, - sort=None, - encoding=_NOTSET - ): - """Query a sorted set representing a geospatial index to fetch members - matching a given maximum distance from a point. - - Return value follows Redis convention: - - * if none of ``WITH*`` flags are set -- list of strings returned: - - >>> await redis.georadius('Sicily', 15, 37, 200, 'km') - [b"Palermo", b"Catania"] - - * if any flag (or all) is set -- list of named tuples returned: - - >>> await redis.georadius('Sicily', 15, 37, 200, 'km', - ... with_dist=True) - [GeoMember(name=b"Palermo", dist=190.4424, hash=None, coord=None), - GeoMember(name=b"Catania", dist=56.4413, hash=None, coord=None)] - - :raises TypeError: radius is not float or int - :raises TypeError: count is not int - :raises ValueError: if unit not equal ``m``, ``km``, ``mi`` or ``ft`` - :raises ValueError: if sort not equal ``ASC`` or ``DESC`` - - :rtype: list[str] or list[GeoMember] - """ - args = validate_georadius_options( - radius, unit, with_dist, with_hash, with_coord, count, sort - ) - - fut = self.execute( - b"GEORADIUS", - key, - longitude, - latitude, - radius, - unit, - *args, - encoding=encoding - ) - if with_dist or with_hash or with_coord: - return wait_convert( - fut, - make_geomember, - with_dist=with_dist, - with_hash=with_hash, - with_coord=with_coord, - ) - return fut - - def georadiusbymember( - self, - key, - member, - radius, - unit="m", - *, - with_dist=False, - with_hash=False, - with_coord=False, - count=None, - sort=None, - encoding=_NOTSET - ): - """Query a sorted set representing a geospatial index to fetch members - matching a given maximum distance from a member. - - Return value follows Redis convention: - - * if none of ``WITH*`` flags are set -- list of strings returned: - - >>> await redis.georadiusbymember('Sicily', 'Palermo', 200, 'km') - [b"Palermo", b"Catania"] - - * if any flag (or all) is set -- list of named tuples returned: - - >>> await redis.georadiusbymember('Sicily', 'Palermo', 200, 'km', - ... with_dist=True) - [GeoMember(name=b"Palermo", dist=190.4424, hash=None, coord=None), - GeoMember(name=b"Catania", dist=56.4413, hash=None, coord=None)] - - :raises TypeError: radius is not float or int - :raises TypeError: count is not int - :raises ValueError: if unit not equal ``m``, ``km``, ``mi`` or ``ft`` - :raises ValueError: if sort not equal ``ASC`` or ``DESC`` - - :rtype: list[str] or list[GeoMember] - """ - args = validate_georadius_options( - radius, unit, with_dist, with_hash, with_coord, count, sort - ) - - fut = self.execute( - b"GEORADIUSBYMEMBER", key, member, radius, unit, *args, encoding=encoding - ) - if with_dist or with_hash or with_coord: - return wait_convert( - fut, - make_geomember, - with_dist=with_dist, - with_hash=with_hash, - with_coord=with_coord, - ) - return fut - - -def validate_georadius_options( - radius, unit, with_dist, with_hash, with_coord, count, sort -): - args = [] - - if with_dist: - args.append(b"WITHDIST") - if with_hash: - args.append(b"WITHHASH") - if with_coord: - args.append(b"WITHCOORD") - - if unit not in ["m", "km", "mi", "ft"]: - raise ValueError("unit argument must be 'm', 'km', 'mi' or 'ft'") - if not isinstance(radius, (int, float)): - raise TypeError("radius argument must be int or float") - if count: - if not isinstance(count, int): - raise TypeError("count argument must be int") - args += [b"COUNT", count] - if sort: - if sort not in ["ASC", "DESC"]: - raise ValueError("sort argument must be euqal 'ASC' or 'DESC'") - args.append(sort) - return args - - -def make_geocoord(value): - if isinstance(value, list): - return GeoPoint(*map(float, value)) - return value - - -def make_geodist(value): - if value: - return float(value) - return value - - -def make_geopos(value): - return [make_geocoord(val) for val in value] - - -def make_geomember(value, with_dist, with_coord, with_hash): - res_rows = [] - - for row in value: - name = row.pop(0) - dist = hash_ = coord = None - if with_dist: - dist = float(row.pop(0)) - if with_hash: - hash_ = int(row.pop(0)) - if with_coord: - coord = GeoPoint(*map(float, row.pop(0))) - - res_rows.append(GeoMember(name, dist, hash_, coord)) - - return res_rows diff --git a/aioredis/commands/hash.py b/aioredis/commands/hash.py deleted file mode 100644 index 8edfacf10..000000000 --- a/aioredis/commands/hash.py +++ /dev/null @@ -1,177 +0,0 @@ -import warnings -from itertools import chain - -from aioredis.util import _NOTSET, _ScanIter, wait_convert, wait_make_dict, wait_ok - - -class HashCommandsMixin: - """Hash commands mixin. - - For commands details see: http://redis.io/commands#hash - """ - - def hdel(self, key, field, *fields): - """Delete one or more hash fields.""" - return self.execute(b"HDEL", key, field, *fields) - - def hexists(self, key, field): - """Determine if hash field exists.""" - fut = self.execute(b"HEXISTS", key, field) - return wait_convert(fut, bool) - - def hget(self, key, field, *, encoding=_NOTSET): - """Get the value of a hash field.""" - return self.execute(b"HGET", key, field, encoding=encoding) - - def hgetall(self, key, *, encoding=_NOTSET): - """Get all the fields and values in a hash.""" - fut = self.execute(b"HGETALL", key, encoding=encoding) - return wait_make_dict(fut) - - def hincrby(self, key, field, increment=1): - """Increment the integer value of a hash field by the given number.""" - return self.execute(b"HINCRBY", key, field, increment) - - def hincrbyfloat(self, key, field, increment=1.0): - """Increment the float value of a hash field by the given number.""" - fut = self.execute(b"HINCRBYFLOAT", key, field, increment) - return wait_convert(fut, float) - - def hkeys(self, key, *, encoding=_NOTSET): - """Get all the fields in a hash.""" - return self.execute(b"HKEYS", key, encoding=encoding) - - def hlen(self, key): - """Get the number of fields in a hash.""" - return self.execute(b"HLEN", key) - - def hmget(self, key, field, *fields, encoding=_NOTSET): - """Get the values of all the given fields.""" - return self.execute(b"HMGET", key, field, *fields, encoding=encoding) - - def hmset(self, key, field, value, *pairs): - """Set multiple hash fields to multiple values. - - .. deprecated:: - HMSET is deprecated since redis 4.0.0, use HSET instead. - - """ - warnings.warn( - "%s.hmset() is deprecated since redis 4.0.0, use %s.hset() instead" - % (self.__class__.__name__, self.__class__.__name__), - DeprecationWarning, - ) - if len(pairs) % 2 != 0: - raise TypeError("length of pairs must be even number") - return wait_ok(self.execute(b"HMSET", key, field, value, *pairs)) - - def hmset_dict(self, key, *args, **kwargs): - """Set multiple hash fields to multiple values. - - .. deprecated:: - HMSET is deprecated since redis 4.0.0, use HSET instead. - - dict can be passed as first positional argument: - - >>> await redis.hmset_dict( - ... 'key', {'field1': 'value1', 'field2': 'value2'}) - - or keyword arguments can be used: - - >>> await redis.hmset_dict( - ... 'key', field1='value1', field2='value2') - - or dict argument can be mixed with kwargs: - - >>> await redis.hmset_dict( - ... 'key', {'field1': 'value1'}, field2='value2') - - .. note:: ``dict`` and ``kwargs`` not get mixed into single dictionary, - if both specified and both have same key(s) -- ``kwargs`` will win: - - >>> await redis.hmset_dict('key', {'foo': 'bar'}, foo='baz') - >>> await redis.hget('key', 'foo', encoding='utf-8') - 'baz' - - """ - warnings.warn( - "%s.hmset() is deprecated since redis 4.0.0, use %s.hset() instead" - % (self.__class__.__name__, self.__class__.__name__), - DeprecationWarning, - ) - - if not args and not kwargs: - raise TypeError("args or kwargs must be specified") - pairs = () - if len(args) > 1: - raise TypeError("single positional argument allowed") - elif len(args) == 1: - if not isinstance(args[0], dict): - raise TypeError("args[0] must be dict") - elif not args[0] and not kwargs: - raise ValueError("args[0] is empty dict") - pairs = chain.from_iterable(args[0].items()) - kwargs_pairs = chain.from_iterable(kwargs.items()) - return wait_ok(self.execute(b"HMSET", key, *chain(pairs, kwargs_pairs))) - - def hset(self, key, field=None, value=None, mapping=None): - """Set multiple hash fields to multiple values. - - Setting a single hash field to a value: - >>> await redis.hset('key', 'some_field', 'some_value') - - Setting values for multipe has fields at once: - >>> await redis.hset('key', mapping={'field1': 'abc', 'field2': 'def'}) - - .. note:: Using both the field/value pair and mapping at the same time - will also work. - - """ - if not field and not mapping: - raise ValueError("hset needs either a field/value pair or mapping") - if mapping and not isinstance(mapping, dict): - raise TypeError("'mapping' should be dict") - - items = [] - if field: - items.extend((field, value)) - if mapping: - for item in mapping.items(): - items.extend(item) - return self.execute(b"HSET", key, *items) - - def hsetnx(self, key, field, value): - """Set the value of a hash field, only if the field does not exist.""" - return self.execute(b"HSETNX", key, field, value) - - def hvals(self, key, *, encoding=_NOTSET): - """Get all the values in a hash.""" - return self.execute(b"HVALS", key, encoding=encoding) - - def hscan(self, key, cursor=0, match=None, count=None): - """Incrementally iterate hash fields and associated values.""" - args = [key, cursor] - match is not None and args.extend([b"MATCH", match]) - count is not None and args.extend([b"COUNT", count]) - fut = self.execute(b"HSCAN", *args) - return wait_convert(fut, _make_pairs) - - def ihscan(self, key, *, match=None, count=None): - """Incrementally iterate sorted set items using async for. - - Usage example: - - >>> async for name, val in redis.ihscan(key, match='something*'): - ... print('Matched:', name, '->', val) - - """ - return _ScanIter(lambda cur: self.hscan(key, cur, match=match, count=count)) - - def hstrlen(self, key, field): - """Get the length of the value of a hash field.""" - return self.execute(b"HSTRLEN", key, field) - - -def _make_pairs(obj): - it = iter(obj[1]) - return (int(obj[0]), list(zip(it, it))) diff --git a/aioredis/commands/hyperloglog.py b/aioredis/commands/hyperloglog.py deleted file mode 100644 index 7b2a9d992..000000000 --- a/aioredis/commands/hyperloglog.py +++ /dev/null @@ -1,23 +0,0 @@ -from aioredis.util import wait_ok - - -class HyperLogLogCommandsMixin: - """HyperLogLog commands mixin. - - For commands details see: http://redis.io/commands#hyperloglog - """ - - def pfadd(self, key, value, *values): - """Adds the specified elements to the specified HyperLogLog.""" - return self.execute(b"PFADD", key, value, *values) - - def pfcount(self, key, *keys): - """Return the approximated cardinality of - the set(s) observed by the HyperLogLog at key(s). - """ - return self.execute(b"PFCOUNT", key, *keys) - - def pfmerge(self, destkey, sourcekey, *sourcekeys): - """Merge N different HyperLogLogs into a single one.""" - fut = self.execute(b"PFMERGE", destkey, sourcekey, *sourcekeys) - return wait_ok(fut) diff --git a/aioredis/commands/list.py b/aioredis/commands/list.py deleted file mode 100644 index f86206890..000000000 --- a/aioredis/commands/list.py +++ /dev/null @@ -1,153 +0,0 @@ -from aioredis.util import _NOTSET, wait_ok - - -class ListCommandsMixin: - """List commands mixin. - - For commands details see: http://redis.io/commands#list - """ - - def blpop(self, key, *keys, timeout=0, encoding=_NOTSET): - """Remove and get the first element in a list, or block until - one is available. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - args = keys + (timeout,) - return self.execute(b"BLPOP", key, *args, encoding=encoding) - - def brpop(self, key, *keys, timeout=0, encoding=_NOTSET): - """Remove and get the last element in a list, or block until one - is available. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - args = keys + (timeout,) - return self.execute(b"BRPOP", key, *args, encoding=encoding) - - def brpoplpush(self, sourcekey, destkey, timeout=0, encoding=_NOTSET): - """Remove and get the last element in a list, or block until one - is available. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - return self.execute( - b"BRPOPLPUSH", sourcekey, destkey, timeout, encoding=encoding - ) - - def lindex(self, key, index, *, encoding=_NOTSET): - """Get an element from a list by its index. - - :raises TypeError: if index is not int - """ - if not isinstance(index, int): - raise TypeError("index argument must be int") - return self.execute(b"LINDEX", key, index, encoding=encoding) - - def linsert(self, key, pivot, value, before=False): - """Inserts value in the list stored at key either before or - after the reference value pivot. - """ - where = b"AFTER" if not before else b"BEFORE" - return self.execute(b"LINSERT", key, where, pivot, value) - - def llen(self, key): - """Returns the length of the list stored at key.""" - return self.execute(b"LLEN", key) - - def lpop(self, key, *, encoding=_NOTSET): - """Removes and returns the first element of the list stored at key.""" - return self.execute(b"LPOP", key, encoding=encoding) - - def lpush(self, key, value, *values): - """Insert all the specified values at the head of the list - stored at key. - """ - return self.execute(b"LPUSH", key, value, *values) - - def lpushx(self, key, value): - """Inserts value at the head of the list stored at key, only if key - already exists and holds a list. - """ - return self.execute(b"LPUSHX", key, value) - - def lrange(self, key, start, stop, *, encoding=_NOTSET): - """Returns the specified elements of the list stored at key. - - :raises TypeError: if start or stop is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(stop, int): - raise TypeError("stop argument must be int") - return self.execute(b"LRANGE", key, start, stop, encoding=encoding) - - def lrem(self, key, count, value): - """Removes the first count occurrences of elements equal to value - from the list stored at key. - - :raises TypeError: if count is not int - """ - if not isinstance(count, int): - raise TypeError("count argument must be int") - return self.execute(b"LREM", key, count, value) - - def lset(self, key, index, value): - """Sets the list element at index to value. - - :raises TypeError: if index is not int - """ - if not isinstance(index, int): - raise TypeError("index argument must be int") - return self.execute(b"LSET", key, index, value) - - def ltrim(self, key, start, stop): - """Trim an existing list so that it will contain only the specified - range of elements specified. - - :raises TypeError: if start or stop is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(stop, int): - raise TypeError("stop argument must be int") - fut = self.execute(b"LTRIM", key, start, stop) - return wait_ok(fut) - - def rpop(self, key, *, encoding=_NOTSET): - """Removes and returns the last element of the list stored at key.""" - return self.execute(b"RPOP", key, encoding=encoding) - - def rpoplpush(self, sourcekey, destkey, *, encoding=_NOTSET): - """Atomically returns and removes the last element (tail) of the - list stored at source, and pushes the element at the first element - (head) of the list stored at destination. - """ - return self.execute(b"RPOPLPUSH", sourcekey, destkey, encoding=encoding) - - def rpush(self, key, value, *values): - """Insert all the specified values at the tail of the list - stored at key. - """ - return self.execute(b"RPUSH", key, value, *values) - - def rpushx(self, key, value): - """Inserts value at the tail of the list stored at key, only if - key already exists and holds a list. - """ - return self.execute(b"RPUSHX", key, value) diff --git a/aioredis/commands/pubsub.py b/aioredis/commands/pubsub.py deleted file mode 100644 index 9a4416cff..000000000 --- a/aioredis/commands/pubsub.py +++ /dev/null @@ -1,112 +0,0 @@ -import json - -from aioredis.util import wait_make_dict - - -class PubSubCommandsMixin: - """Pub/Sub commands mixin. - - For commands details see: http://redis.io/commands/#pubsub - """ - - def publish(self, channel, message): - """Post a message to channel.""" - return self.execute(b"PUBLISH", channel, message) - - def publish_json(self, channel, obj, encoder=json.dumps): - """Post a JSON-encoded message to channel.""" - return self.publish(channel, encoder(obj)) - - def subscribe(self, channel, *channels): - """Switch connection to Pub/Sub mode and - subscribe to specified channels. - - Arguments can be instances of :class:`~aioredis.Channel`. - - Returns :func:`asyncio.gather()` coroutine which when done will return - a list of :class:`~aioredis.Channel` objects. - """ - conn = self._pool_or_conn - return wait_return_channels( - conn.execute_pubsub(b"SUBSCRIBE", channel, *channels), - conn, - "pubsub_channels", - ) - - def unsubscribe(self, channel, *channels): - """Unsubscribe from specific channels. - - Arguments can be instances of :class:`~aioredis.Channel`. - """ - conn = self._pool_or_conn - return conn.execute_pubsub(b"UNSUBSCRIBE", channel, *channels) - - def psubscribe(self, pattern, *patterns): - """Switch connection to Pub/Sub mode and - subscribe to specified patterns. - - Arguments can be instances of :class:`~aioredis.Channel`. - - Returns :func:`asyncio.gather()` coroutine which when done will return - a list of subscribed :class:`~aioredis.Channel` objects with - ``is_pattern`` property set to ``True``. - """ - conn = self._pool_or_conn - return wait_return_channels( - conn.execute_pubsub(b"PSUBSCRIBE", pattern, *patterns), - conn, - "pubsub_patterns", - ) - - def punsubscribe(self, pattern, *patterns): - """Unsubscribe from specific patterns. - - Arguments can be instances of :class:`~aioredis.Channel`. - """ - conn = self._pool_or_conn - return conn.execute_pubsub(b"PUNSUBSCRIBE", pattern, *patterns) - - def pubsub_channels(self, pattern=None): - """Lists the currently active channels.""" - args = [b"PUBSUB", b"CHANNELS"] - if pattern is not None: - args.append(pattern) - return self.execute(*args) - - def pubsub_numsub(self, *channels): - """Returns the number of subscribers for the specified channels.""" - return wait_make_dict(self.execute(b"PUBSUB", b"NUMSUB", *channels)) - - def pubsub_numpat(self): - """Returns the number of subscriptions to patterns.""" - return self.execute(b"PUBSUB", b"NUMPAT") - - @property - def channels(self): - """Returns read-only channels dict. - - See :attr:`~aioredis.RedisConnection.pubsub_channels` - """ - return self._pool_or_conn.pubsub_channels - - @property - def patterns(self): - """Returns read-only patterns dict. - - See :attr:`~aioredis.RedisConnection.pubsub_patterns` - """ - return self._pool_or_conn.pubsub_patterns - - @property - def in_pubsub(self): - """Indicates that connection is in PUB/SUB mode. - - Provides the number of subscribed channels. - """ - return self._pool_or_conn.in_pubsub - - -async def wait_return_channels(fut, conn, field): - res = await fut - channels_dict = getattr(conn, field) - return [channels_dict[name] for cmd, name, count in res] diff --git a/aioredis/commands/scripting.py b/aioredis/commands/scripting.py deleted file mode 100644 index 92a5177eb..000000000 --- a/aioredis/commands/scripting.py +++ /dev/null @@ -1,34 +0,0 @@ -from aioredis.util import wait_ok - - -class ScriptingCommandsMixin: - """Set commands mixin. - - For commands details see: http://redis.io/commands#scripting - """ - - def eval(self, script, keys=[], args=[]): - """Execute a Lua script server side.""" - return self.execute(b"EVAL", script, len(keys), *(keys + args)) - - def evalsha(self, digest, keys=[], args=[]): - """Execute a Lua script server side by its SHA1 digest.""" - return self.execute(b"EVALSHA", digest, len(keys), *(keys + args)) - - def script_exists(self, digest, *digests): - """Check existence of scripts in the script cache.""" - return self.execute(b"SCRIPT", b"EXISTS", digest, *digests) - - def script_kill(self): - """Kill the script currently in execution.""" - fut = self.execute(b"SCRIPT", b"KILL") - return wait_ok(fut) - - def script_flush(self): - """Remove all the scripts from the script cache.""" - fut = self.execute(b"SCRIPT", b"FLUSH") - return wait_ok(fut) - - def script_load(self, script): - """Load the specified Lua script into the script cache.""" - return self.execute(b"SCRIPT", b"LOAD", script) diff --git a/aioredis/commands/server.py b/aioredis/commands/server.py deleted file mode 100644 index 0355879d6..000000000 --- a/aioredis/commands/server.py +++ /dev/null @@ -1,301 +0,0 @@ -from collections import namedtuple - -from aioredis.util import _NOTSET, wait_convert, wait_make_dict, wait_ok - - -class ServerCommandsMixin: - """Server commands mixin. - - For commands details see: http://redis.io/commands/#server - """ - - SHUTDOWN_SAVE = "SHUTDOWN_SAVE" - SHUTDOWN_NOSAVE = "SHUTDOWN_NOSAVE" - - def bgrewriteaof(self): - """Asynchronously rewrite the append-only file.""" - fut = self.execute(b"BGREWRITEAOF") - return wait_ok(fut) - - def bgsave(self): - """Asynchronously save the dataset to disk.""" - fut = self.execute(b"BGSAVE") - return wait_ok(fut) - - def client_kill(self): - """Kill the connection of a client. - - .. warning:: Not Implemented - """ - raise NotImplementedError - - def client_list(self): - """Get the list of client connections. - - Returns list of ClientInfo named tuples. - """ - fut = self.execute(b"CLIENT", b"LIST", encoding="utf-8") - return wait_convert(fut, to_tuples) - - def client_getname(self, encoding=_NOTSET): - """Get the current connection name.""" - return self.execute(b"CLIENT", b"GETNAME", encoding=encoding) - - def client_pause(self, timeout): - """Stop processing commands from clients for *timeout* milliseconds. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - fut = self.execute(b"CLIENT", b"PAUSE", timeout) - return wait_ok(fut) - - def client_reply(self): - raise NotImplementedError() - - def client_setname(self, name): - """Set the current connection name.""" - fut = self.execute(b"CLIENT", b"SETNAME", name) - return wait_ok(fut) - - def command(self): - """Get array of Redis commands.""" - # TODO: convert result - return self.execute(b"COMMAND", encoding="utf-8") - - def command_count(self): - """Get total number of Redis commands.""" - return self.execute(b"COMMAND", b"COUNT") - - def command_getkeys(self, command, *args, encoding="utf-8"): - """Extract keys given a full Redis command.""" - return self.execute(b"COMMAND", b"GETKEYS", command, *args, encoding=encoding) - - def command_info(self, command, *commands): - """Get array of specific Redis command details.""" - return self.execute(b"COMMAND", b"INFO", command, *commands, encoding="utf-8") - - def config_get(self, parameter="*"): - """Get the value of a configuration parameter(s). - - If called without argument will return all parameters. - - :raises TypeError: if parameter is not string - """ - if not isinstance(parameter, str): - raise TypeError("parameter must be str") - fut = self.execute(b"CONFIG", b"GET", parameter, encoding="utf-8") - return wait_make_dict(fut) - - def config_rewrite(self): - """Rewrite the configuration file with the in memory configuration.""" - fut = self.execute(b"CONFIG", b"REWRITE") - return wait_ok(fut) - - def config_set(self, parameter, value): - """Set a configuration parameter to the given value.""" - if not isinstance(parameter, str): - raise TypeError("parameter must be str") - fut = self.execute(b"CONFIG", b"SET", parameter, value) - return wait_ok(fut) - - def config_resetstat(self): - """Reset the stats returned by INFO.""" - fut = self.execute(b"CONFIG", b"RESETSTAT") - return wait_ok(fut) - - def dbsize(self): - """Return the number of keys in the selected database.""" - return self.execute(b"DBSIZE") - - def debug_sleep(self, timeout): - """Suspend connection for timeout seconds.""" - fut = self.execute(b"DEBUG", b"SLEEP", timeout) - return wait_ok(fut) - - def debug_object(self, key): - """Get debugging information about a key.""" - return self.execute(b"DEBUG", b"OBJECT", key) - - def debug_segfault(self, key): - """Make the server crash.""" - # won't test, this probably works - return self.execute(b"DEBUG", "SEGFAULT") # pragma: no cover - - def flushall(self, async_op=False): - """ - Remove all keys from all databases. - - :param async_op: lets the entire dataset to be freed asynchronously. \ - Defaults to False - """ - if async_op: - fut = self.execute(b"FLUSHALL", b"ASYNC") - else: - fut = self.execute(b"FLUSHALL") - return wait_ok(fut) - - def flushdb(self, async_op=False): - """ - Remove all keys from the current database. - - :param async_op: lets a single database to be freed asynchronously. \ - Defaults to False - """ - if async_op: - fut = self.execute(b"FLUSHDB", b"ASYNC") - else: - fut = self.execute(b"FLUSHDB") - return wait_ok(fut) - - def info(self, section="default"): - """Get information and statistics about the server. - - If called without argument will return default set of sections. - For available sections, see http://redis.io/commands/INFO - - :raises ValueError: if section is invalid - - """ - if not section: - raise ValueError("invalid section") - fut = self.execute(b"INFO", section, encoding="utf-8") - return wait_convert(fut, parse_info) - - def lastsave(self): - """Get the UNIX time stamp of the last successful save to disk.""" - return self.execute(b"LASTSAVE") - - def monitor(self): - """Listen for all requests received by the server in real time. - - .. warning:: - Will not be implemented for now. - """ - # NOTE: will not implement for now; - raise NotImplementedError - - def role(self): - """Return the role of the server instance. - - Returns named tuples describing role of the instance. - For fields information see http://redis.io/commands/role#output-format - """ - fut = self.execute(b"ROLE", encoding="utf-8") - return wait_convert(fut, parse_role) - - def save(self): - """Synchronously save the dataset to disk.""" - return self.execute(b"SAVE") - - def shutdown(self, save=None): - """Synchronously save the dataset to disk and then - shut down the server. - """ - if save is self.SHUTDOWN_SAVE: - return self.execute(b"SHUTDOWN", b"SAVE") - elif save is self.SHUTDOWN_NOSAVE: - return self.execute(b"SHUTDOWN", b"NOSAVE") - else: - return self.execute(b"SHUTDOWN") - - def slaveof(self, host, port=None): - """Make the server a slave of another instance, - or promote it as master. - - Calling ``slaveof(None)`` will send ``SLAVEOF NO ONE``. - - .. versionchanged:: v0.2.6 - ``slaveof()`` form deprecated - in favour of explicit ``slaveof(None)``. - """ - if host is None and port is None: - return self.execute(b"SLAVEOF", b"NO", b"ONE") - return self.execute(b"SLAVEOF", host, port) - - def slowlog_get(self, length=None): - """Returns the Redis slow queries log.""" - if length is not None: - if not isinstance(length, int): - raise TypeError("length must be int or None") - return self.execute(b"SLOWLOG", b"GET", length) - else: - return self.execute(b"SLOWLOG", b"GET") - - def slowlog_len(self): - """Returns length of Redis slow queries log.""" - return self.execute(b"SLOWLOG", b"LEN") - - def slowlog_reset(self): - """Resets Redis slow queries log.""" - fut = self.execute(b"SLOWLOG", b"RESET") - return wait_ok(fut) - - def sync(self): - """Redis-server internal command used for replication.""" - return self.execute(b"SYNC") - - def time(self): - """Return current server time.""" - fut = self.execute(b"TIME") - return wait_convert(fut, to_time) - - -def _split(s): - k, v = s.split("=") - return k.replace("-", "_"), v - - -def to_time(obj): - return int(obj[0]) + int(obj[1]) * 1e-6 - - -def to_tuples(value): - line, *lines = value.splitlines(False) - line = list(map(_split, line.split(" "))) - ClientInfo = namedtuple("ClientInfo", " ".join(k for k, v in line)) - # TODO: parse flags and other known fields - result = [ClientInfo(**dict(line))] - for line in lines: - result.append(ClientInfo(**dict(map(_split, line.split(" "))))) - return result - - -def parse_info(info): - res = {} - for block in info.split("\r\n\r\n"): - section, *block = block.strip().splitlines() - section = section[2:].lower() - res[section] = tmp = {} - for line in block: - key, value = line.split(":", 1) - if "," in line and "=" in line: - value = dict(map(lambda i: i.split("="), value.split(","))) - tmp[key] = value - return res - - -# XXX: may change in future -# (may be hard to maintain for new/old redis versions) -MasterInfo = namedtuple("MasterInfo", "role replication_offset slaves") -MasterSlaveInfo = namedtuple("MasterSlaveInfo", "ip port ack_offset") - -SlaveInfo = namedtuple("SlaveInfo", "role master_ip master_port state received") - -SentinelInfo = namedtuple("SentinelInfo", "role masters") - - -def parse_role(role): - type_ = role[0] - if type_ == "master": - slaves = [MasterSlaveInfo(s[0], int(s[1]), int(s[2])) for s in role[2]] - return MasterInfo(role[0], int(role[1]), slaves) - elif type_ == "slave": - return SlaveInfo(role[0], role[1], int(role[2]), role[3], int(role[4])) - elif type_ == "sentinel": - return SentinelInfo(*role) - return role diff --git a/aioredis/commands/set.py b/aioredis/commands/set.py deleted file mode 100644 index ee7d51d6a..000000000 --- a/aioredis/commands/set.py +++ /dev/null @@ -1,88 +0,0 @@ -from aioredis.util import _NOTSET, _ScanIter, wait_convert - - -class SetCommandsMixin: - """Set commands mixin. - - For commands details see: http://redis.io/commands#set - """ - - def sadd(self, key, member, *members): - """Add one or more members to a set.""" - return self.execute(b"SADD", key, member, *members) - - def scard(self, key): - """Get the number of members in a set.""" - return self.execute(b"SCARD", key) - - def sdiff(self, key, *keys): - """Subtract multiple sets.""" - return self.execute(b"SDIFF", key, *keys) - - def sdiffstore(self, destkey, key, *keys): - """Subtract multiple sets and store the resulting set in a key.""" - return self.execute(b"SDIFFSTORE", destkey, key, *keys) - - def sinter(self, key, *keys): - """Intersect multiple sets.""" - return self.execute(b"SINTER", key, *keys) - - def sinterstore(self, destkey, key, *keys): - """Intersect multiple sets and store the resulting set in a key.""" - return self.execute(b"SINTERSTORE", destkey, key, *keys) - - def sismember(self, key, member): - """Determine if a given value is a member of a set.""" - return self.execute(b"SISMEMBER", key, member) - - def smembers(self, key, *, encoding=_NOTSET): - """Get all the members in a set.""" - return self.execute(b"SMEMBERS", key, encoding=encoding) - - def smove(self, sourcekey, destkey, member): - """Move a member from one set to another.""" - return self.execute(b"SMOVE", sourcekey, destkey, member) - - def spop(self, key, count=None, *, encoding=_NOTSET): - """Remove and return one or multiple random members from a set.""" - args = [key] - if count is not None: - args.append(count) - return self.execute(b"SPOP", *args, encoding=encoding) - - def srandmember(self, key, count=None, *, encoding=_NOTSET): - """Get one or multiple random members from a set.""" - args = [key] - count is not None and args.append(count) - return self.execute(b"SRANDMEMBER", *args, encoding=encoding) - - def srem(self, key, member, *members): - """Remove one or more members from a set.""" - return self.execute(b"SREM", key, member, *members) - - def sunion(self, key, *keys): - """Add multiple sets.""" - return self.execute(b"SUNION", key, *keys) - - def sunionstore(self, destkey, key, *keys): - """Add multiple sets and store the resulting set in a key.""" - return self.execute(b"SUNIONSTORE", destkey, key, *keys) - - def sscan(self, key, cursor=0, match=None, count=None): - """Incrementally iterate Set elements.""" - tokens = [key, cursor] - match is not None and tokens.extend([b"MATCH", match]) - count is not None and tokens.extend([b"COUNT", count]) - fut = self.execute(b"SSCAN", *tokens) - return wait_convert(fut, lambda obj: (int(obj[0]), obj[1])) - - def isscan(self, key, *, match=None, count=None): - """Incrementally iterate set elements using async for. - - Usage example: - - >>> async for val in redis.isscan(key, match='something*'): - ... print('Matched:', val) - - """ - return _ScanIter(lambda cur: self.sscan(key, cur, match=match, count=count)) diff --git a/aioredis/commands/sorted_set.py b/aioredis/commands/sorted_set.py deleted file mode 100644 index 079b5731a..000000000 --- a/aioredis/commands/sorted_set.py +++ /dev/null @@ -1,502 +0,0 @@ -from aioredis.util import _NOTSET, _ScanIter, wait_convert - - -class SortedSetCommandsMixin: - """Sorted Sets commands mixin. - - For commands details see: http://redis.io/commands/#sorted_set - """ - - ZSET_EXCLUDE_MIN = "ZSET_EXCLUDE_MIN" - ZSET_EXCLUDE_MAX = "ZSET_EXCLUDE_MAX" - ZSET_EXCLUDE_BOTH = "ZSET_EXCLUDE_BOTH" - - ZSET_AGGREGATE_SUM = "ZSET_AGGREGATE_SUM" - ZSET_AGGREGATE_MIN = "ZSET_AGGREGATE_MIN" - ZSET_AGGREGATE_MAX = "ZSET_AGGREGATE_MAX" - - ZSET_IF_NOT_EXIST = "ZSET_IF_NOT_EXIST" # NX - ZSET_IF_EXIST = "ZSET_IF_EXIST" # XX - - def bzpopmax(self, key, *keys, timeout=0, encoding=_NOTSET): - """Remove and get an element with the highest score in the sorted set, - or block until one is available. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - args = keys + (timeout,) - return self.execute(b"BZPOPMAX", key, *args, encoding=encoding) - - def bzpopmin(self, key, *keys, timeout=0, encoding=_NOTSET): - """Remove and get an element with the lowest score in the sorted set, - or block until one is available. - - :raises TypeError: if timeout is not int - :raises ValueError: if timeout is less than 0 - """ - if not isinstance(timeout, int): - raise TypeError("timeout argument must be int") - if timeout < 0: - raise ValueError("timeout must be greater equal 0") - args = keys + (timeout,) - return self.execute(b"BZPOPMIN", key, *args, encoding=encoding) - - def zadd(self, key, score, member, *pairs, exist=None, changed=False, incr=False): - """Add one or more members to a sorted set or update its score. - - :raises TypeError: score not int or float - :raises TypeError: length of pairs is not even number - """ - if not isinstance(score, (int, float)): - raise TypeError("score argument must be int or float") - if len(pairs) % 2 != 0: - raise TypeError("length of pairs must be even number") - - scores = (item for i, item in enumerate(pairs) if i % 2 == 0) - if any(not isinstance(s, (int, float)) for s in scores): - raise TypeError("all scores must be int or float") - - args = [] - if exist is self.ZSET_IF_EXIST: - args.append(b"XX") - elif exist is self.ZSET_IF_NOT_EXIST: - args.append(b"NX") - - if changed: - args.append(b"CH") - - if incr: - if pairs: - raise ValueError( - "only one score-element pair " "can be specified in this mode" - ) - args.append(b"INCR") - - args.extend([score, member]) - if pairs: - args.extend(pairs) - return self.execute(b"ZADD", key, *args) - - def zcard(self, key): - """Get the number of members in a sorted set.""" - return self.execute(b"ZCARD", key) - - def zcount(self, key, min=float("-inf"), max=float("inf"), *, exclude=None): - """Count the members in a sorted set with scores - within the given values. - - :raises TypeError: min or max is not float or int - :raises ValueError: if min greater than max - """ - if not isinstance(min, (int, float)): - raise TypeError("min argument must be int or float") - if not isinstance(max, (int, float)): - raise TypeError("max argument must be int or float") - if min > max: - raise ValueError("min could not be greater than max") - return self.execute(b"ZCOUNT", key, *_encode_min_max(exclude, min, max)) - - def zincrby(self, key, increment, member): - """Increment the score of a member in a sorted set. - - :raises TypeError: increment is not float or int - """ - if not isinstance(increment, (int, float)): - raise TypeError("increment argument must be int or float") - fut = self.execute(b"ZINCRBY", key, increment, member) - return wait_convert(fut, int_or_float) - - def zinterstore(self, destkey, key, *keys, with_weights=False, aggregate=None): - """Intersect multiple sorted sets and store result in a new key. - - :param bool with_weights: when set to true each key must be a tuple - in form of (key, weight) - """ - keys = (key,) + keys - numkeys = len(keys) - args = [] - if with_weights: - assert all( - isinstance(val, (list, tuple)) for val in keys - ), "All key arguments must be (key, weight) tuples" - weights = ["WEIGHTS"] - for key, weight in keys: - args.append(key) - weights.append(weight) - args.extend(weights) - else: - args.extend(keys) - - if aggregate is self.ZSET_AGGREGATE_SUM: - args.extend(("AGGREGATE", "SUM")) - elif aggregate is self.ZSET_AGGREGATE_MAX: - args.extend(("AGGREGATE", "MAX")) - elif aggregate is self.ZSET_AGGREGATE_MIN: - args.extend(("AGGREGATE", "MIN")) - fut = self.execute(b"ZINTERSTORE", destkey, numkeys, *args) - return fut - - def zlexcount(self, key, min=b"-", max=b"+"): - """Count the number of members in a sorted set between a given - lexicographical range. - """ - return self.execute(b"ZLEXCOUNT", key, min, max) - - def zrange(self, key, start=0, stop=-1, withscores=False, encoding=_NOTSET): - """Return a range of members in a sorted set, by index. - - :raises TypeError: if start is not int - :raises TypeError: if stop is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(stop, int): - raise TypeError("stop argument must be int") - if withscores: - args = [b"WITHSCORES"] - else: - args = [] - fut = self.execute(b"ZRANGE", key, start, stop, *args, encoding=encoding) - if withscores: - return wait_convert(fut, pairs_int_or_float) - return fut - - def zrangebylex( - self, - key, - min=b"-", - max=b"+", - offset=None, - count=None, - encoding=_NOTSET, - ): - """Return a range of members in a sorted set, by lexicographical range. - - :raises TypeError: if both offset and count are not specified - :raises TypeError: if offset is not int - :raises TypeError: if count is not int - """ - if (offset is not None and count is None) or ( - count is not None and offset is None - ): - raise TypeError("offset and count must both be specified") - if offset is not None and not isinstance(offset, int): - raise TypeError("offset argument must be int") - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - args = [] - if offset is not None and count is not None: - args.extend([b"LIMIT", offset, count]) - - return self.execute(b"ZRANGEBYLEX", key, min, max, *args, encoding=encoding) - - def zrangebyscore( - self, - key, - min=float("-inf"), - max=float("inf"), - withscores=False, - offset=None, - count=None, - *, - exclude=None, - encoding=_NOTSET, - ): - """Return a range of members in a sorted set, by score. - - :raises TypeError: if min or max is not float or int - :raises TypeError: if both offset and count are not specified - :raises TypeError: if offset is not int - :raises TypeError: if count is not int - """ - if not isinstance(min, (int, float)): - raise TypeError("min argument must be int or float") - if not isinstance(max, (int, float)): - raise TypeError("max argument must be int or float") - - if (offset is not None and count is None) or ( - count is not None and offset is None - ): - raise TypeError("offset and count must both be specified") - if offset is not None and not isinstance(offset, int): - raise TypeError("offset argument must be int") - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - min, max = _encode_min_max(exclude, min, max) - - args = [] - if withscores: - args = [b"WITHSCORES"] - if offset is not None and count is not None: - args.extend([b"LIMIT", offset, count]) - fut = self.execute(b"ZRANGEBYSCORE", key, min, max, *args, encoding=encoding) - if withscores: - return wait_convert(fut, pairs_int_or_float) - return fut - - def zrank(self, key, member): - """Determine the index of a member in a sorted set.""" - return self.execute(b"ZRANK", key, member) - - def zrem(self, key, member, *members): - """Remove one or more members from a sorted set.""" - return self.execute(b"ZREM", key, member, *members) - - def zremrangebylex(self, key, min=b"-", max=b"+"): - """Remove all members in a sorted set between the given - lexicographical range. - """ - return self.execute(b"ZREMRANGEBYLEX", key, min, max) - - def zremrangebyrank(self, key, start, stop): - """Remove all members in a sorted set within the given indexes. - - :raises TypeError: if start is not int - :raises TypeError: if stop is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(stop, int): - raise TypeError("stop argument must be int") - return self.execute(b"ZREMRANGEBYRANK", key, start, stop) - - def zremrangebyscore( - self, key, min=float("-inf"), max=float("inf"), *, exclude=None - ): - """Remove all members in a sorted set within the given scores. - - :raises TypeError: if min or max is not int or float - """ - if not isinstance(min, (int, float)): - raise TypeError("min argument must be int or float") - if not isinstance(max, (int, float)): - raise TypeError("max argument must be int or float") - - min, max = _encode_min_max(exclude, min, max) - return self.execute(b"ZREMRANGEBYSCORE", key, min, max) - - def zrevrange(self, key, start, stop, withscores=False, encoding=_NOTSET): - """Return a range of members in a sorted set, by index, - with scores ordered from high to low. - - :raises TypeError: if start or stop is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(stop, int): - raise TypeError("stop argument must be int") - if withscores: - args = [b"WITHSCORES"] - else: - args = [] - fut = self.execute(b"ZREVRANGE", key, start, stop, *args, encoding=encoding) - if withscores: - return wait_convert(fut, pairs_int_or_float) - return fut - - def zrevrangebyscore( - self, - key, - max=float("inf"), - min=float("-inf"), - *, - exclude=None, - withscores=False, - offset=None, - count=None, - encoding=_NOTSET, - ): - """Return a range of members in a sorted set, by score, - with scores ordered from high to low. - - :raises TypeError: if min or max is not float or int - :raises TypeError: if both offset and count are not specified - :raises TypeError: if offset is not int - :raises TypeError: if count is not int - """ - if not isinstance(min, (int, float)): - raise TypeError("min argument must be int or float") - if not isinstance(max, (int, float)): - raise TypeError("max argument must be int or float") - - if (offset is not None and count is None) or ( - count is not None and offset is None - ): - raise TypeError("offset and count must both be specified") - if offset is not None and not isinstance(offset, int): - raise TypeError("offset argument must be int") - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - min, max = _encode_min_max(exclude, min, max) - - args = [] - if withscores: - args = [b"WITHSCORES"] - if offset is not None and count is not None: - args.extend([b"LIMIT", offset, count]) - fut = self.execute(b"ZREVRANGEBYSCORE", key, max, min, *args, encoding=encoding) - if withscores: - return wait_convert(fut, pairs_int_or_float) - return fut - - def zrevrangebylex( - self, - key, - min=b"-", - max=b"+", - offset=None, - count=None, - encoding=_NOTSET, - ): - """Return a range of members in a sorted set, by lexicographical range - from high to low. - - :raises TypeError: if both offset and count are not specified - :raises TypeError: if offset is not int - :raises TypeError: if count is not int - """ - if (offset is not None and count is None) or ( - count is not None and offset is None - ): - raise TypeError("offset and count must both be specified") - if offset is not None and not isinstance(offset, int): - raise TypeError("offset argument must be int") - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - args = [] - if offset is not None and count is not None: - args.extend([b"LIMIT", offset, count]) - - return self.execute(b"ZREVRANGEBYLEX", key, max, min, *args, encoding=encoding) - - def zrevrank(self, key, member): - """Determine the index of a member in a sorted set, with - scores ordered from high to low. - """ - return self.execute(b"ZREVRANK", key, member) - - def zscore(self, key, member): - """Get the score associated with the given member in a sorted set.""" - fut = self.execute(b"ZSCORE", key, member) - return wait_convert(fut, optional_int_or_float) - - def zunionstore(self, destkey, key, *keys, with_weights=False, aggregate=None): - """Add multiple sorted sets and store result in a new key.""" - keys = (key,) + keys - numkeys = len(keys) - args = [] - if with_weights: - assert all( - isinstance(val, (list, tuple)) for val in keys - ), "All key arguments must be (key, weight) tuples" - weights = ["WEIGHTS"] - for key, weight in keys: - args.append(key) - weights.append(weight) - args.extend(weights) - else: - args.extend(keys) - - if aggregate is self.ZSET_AGGREGATE_SUM: - args.extend(("AGGREGATE", "SUM")) - elif aggregate is self.ZSET_AGGREGATE_MAX: - args.extend(("AGGREGATE", "MAX")) - elif aggregate is self.ZSET_AGGREGATE_MIN: - args.extend(("AGGREGATE", "MIN")) - fut = self.execute(b"ZUNIONSTORE", destkey, numkeys, *args) - return fut - - def zscan(self, key, cursor=0, match=None, count=None): - """Incrementally iterate sorted sets elements and associated scores.""" - args = [] - if match is not None: - args += [b"MATCH", match] - if count is not None: - args += [b"COUNT", count] - fut = self.execute(b"ZSCAN", key, cursor, *args) - - def _converter(obj): - return (int(obj[0]), pairs_int_or_float(obj[1])) - - return wait_convert(fut, _converter) - - def izscan(self, key, *, match=None, count=None): - """Incrementally iterate sorted set items using async for. - - Usage example: - - >>> async for val, score in redis.izscan(key, match='something*'): - ... print('Matched:', val, ':', score) - - """ - return _ScanIter(lambda cur: self.zscan(key, cur, match=match, count=count)) - - def zpopmin(self, key, count=None, *, encoding=_NOTSET): - """Removes and returns up to count members with the lowest scores - in the sorted set stored at key. - - :raises TypeError: if count is not int - """ - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - args = [] - if count is not None: - args.extend([count]) - - fut = self.execute(b"ZPOPMIN", key, *args, encoding=encoding) - return fut - - def zpopmax(self, key, count=None, *, encoding=_NOTSET): - """Removes and returns up to count members with the highest scores - in the sorted set stored at key. - - :raises TypeError: if count is not int - """ - if count is not None and not isinstance(count, int): - raise TypeError("count argument must be int") - - args = [] - if count is not None: - args.extend([count]) - - fut = self.execute(b"ZPOPMAX", key, *args, encoding=encoding) - return fut - - -def _encode_min_max(flag, min, max): - if flag is SortedSetCommandsMixin.ZSET_EXCLUDE_MIN: - return f"({min}", max - elif flag is SortedSetCommandsMixin.ZSET_EXCLUDE_MAX: - return min, f"({max}" - elif flag is SortedSetCommandsMixin.ZSET_EXCLUDE_BOTH: - return f"({min}", f"({max}" - return min, max - - -def int_or_float(value): - assert isinstance(value, (str, bytes)), "raw_value must be bytes" - try: - return int(value) - except ValueError: - return float(value) - - -def optional_int_or_float(value): - if value is None: - return value - return int_or_float(value) - - -def pairs_int_or_float(value): - it = iter(value) - return [(val, int_or_float(score)) for val, score in zip(it, it)] diff --git a/aioredis/commands/streams.py b/aioredis/commands/streams.py deleted file mode 100644 index e52c5994f..000000000 --- a/aioredis/commands/streams.py +++ /dev/null @@ -1,286 +0,0 @@ -from collections import OrderedDict - -from aioredis.util import wait_convert, wait_make_dict, wait_ok - - -def fields_to_dict(fields, type_=OrderedDict): - """Convert a flat list of key/values into an OrderedDict""" - fields_iterator = iter(fields) - return type_(zip(fields_iterator, fields_iterator)) - - -def parse_messages(messages): - """Parse messages as returned by Redis into something useful - - Messages returned by XRANGE arrive in the form: - - [ - [message_id, [key1, value1, key2, value2, ...]], - ... - ] - - Here we parse this into: - - [ - [message_id, OrderedDict( - (key1, value1), - (key2, value2), - ... - )], - ... - ] - - """ - parsed_messages = [] - for message in messages: - if message is None: - # In some conditions redis will return a NIL message - parsed_messages.append((None, {})) - else: - mid, values = message - values = values or {} - parsed_messages.append((mid, fields_to_dict(values))) - - return parsed_messages - - -def parse_messages_by_stream(messages_by_stream): - """Parse messages returned by stream - - Messages returned by XREAD arrive in the form: - [stream_name, - [ - [message_id, [key1, value1, key2, value2, ...]], - ... - ], - ... - ] - - Here we parse this into (with the help of the above parse_messages() - function): - - [ - [stream_name, message_id, OrderedDict( - (key1, value1), - (key2, value2),. - ... - )], - ... - ] - - """ - if messages_by_stream is None: - return [] - - parsed = [] - for stream, messages in messages_by_stream: - for message_id, fields in parse_messages(messages): - parsed.append((stream, message_id, fields)) - return parsed - - -def parse_lists_to_dicts(lists): - """ Convert [[a, 1, b, 2], ...] into [{a:1, b: 2}, ...]""" - return [fields_to_dict(l, type_=dict) for l in lists] # noqa: E741 - - -class StreamCommandsMixin: - """Stream commands mixin - - Streams are available in Redis since v5.0 - """ - - def xadd(self, stream, fields, message_id=b"*", max_len=None, exact_len=False): - """Add a message to a stream.""" - args = [] - if max_len is not None: - if exact_len: - args.extend((b"MAXLEN", max_len)) - else: - args.extend((b"MAXLEN", b"~", max_len)) - - args.append(message_id) - - for k, v in fields.items(): - args.extend([k, v]) - return self.execute(b"XADD", stream, *args) - - def xrange(self, stream, start="-", stop="+", count=None): - """Retrieve messages from a stream.""" - if count is not None: - extra = ["COUNT", count] - else: - extra = [] - fut = self.execute(b"XRANGE", stream, start, stop, *extra) - return wait_convert(fut, parse_messages) - - def xrevrange(self, stream, start="+", stop="-", count=None): - """Retrieve messages from a stream in reverse order.""" - if count is not None: - extra = ["COUNT", count] - else: - extra = [] - fut = self.execute(b"XREVRANGE", stream, start, stop, *extra) - return wait_convert(fut, parse_messages) - - def xread(self, streams, timeout=0, count=None, latest_ids=None): - """Perform a blocking read on the given stream - - :raises ValueError: if the length of streams and latest_ids do - not match - """ - args = self._xread(streams, timeout, count, latest_ids) - fut = self.execute(b"XREAD", *args) - return wait_convert(fut, parse_messages_by_stream) - - def xread_group( - self, - group_name, - consumer_name, - streams, - timeout=0, - count=None, - latest_ids=None, - no_ack=False, - ): - """Perform a blocking read on the given stream as part of a consumer group - - :raises ValueError: if the length of streams and latest_ids do - not match - """ - args = self._xread(streams, timeout, count, latest_ids, no_ack) - fut = self.execute(b"XREADGROUP", b"GROUP", group_name, consumer_name, *args) - return wait_convert(fut, parse_messages_by_stream) - - def xgroup_create(self, stream, group_name, latest_id="$", mkstream=False): - """Create a consumer group""" - args = [b"CREATE", stream, group_name, latest_id] - if mkstream: - args.append(b"MKSTREAM") - fut = self.execute(b"XGROUP", *args) - return wait_ok(fut) - - def xgroup_setid(self, stream, group_name, latest_id="$"): - """Set the latest ID for a consumer group""" - fut = self.execute(b"XGROUP", b"SETID", stream, group_name, latest_id) - return wait_ok(fut) - - def xgroup_destroy(self, stream, group_name): - """Delete a consumer group""" - fut = self.execute(b"XGROUP", b"DESTROY", stream, group_name) - return wait_ok(fut) - - def xgroup_delconsumer(self, stream, group_name, consumer_name): - """Delete a specific consumer from a group""" - fut = self.execute(b"XGROUP", b"DELCONSUMER", stream, group_name, consumer_name) - return wait_convert(fut, int) - - def xpending( - self, stream, group_name, start=None, stop=None, count=None, consumer=None - ): - """Get information on pending messages for a stream - - Returned data will vary depending on the presence (or not) - of the start/stop/count parameters. For more details see: - https://redis.io/commands/xpending - - :raises ValueError: if the start/stop/count parameters are only - partially specified - """ - # Returns: total pel messages, min id, max id, count - ssc = [start, stop, count] - ssc_count = len([v for v in ssc if v is not None]) - if ssc_count != 3 and ssc_count != 0: - raise ValueError( - "Either specify non or all of the start/stop/count arguments" - ) - if not any(ssc): - ssc = [] - - args = [stream, group_name] + ssc - if consumer: - args.append(consumer) - return self.execute(b"XPENDING", *args) - - def xclaim(self, stream, group_name, consumer_name, min_idle_time, id, *ids): - """Claim a message for a given consumer""" - fut = self.execute( - b"XCLAIM", stream, group_name, consumer_name, min_idle_time, id, *ids - ) - return wait_convert(fut, parse_messages) - - def xack(self, stream, group_name, id, *ids): - """Acknowledge a message for a given consumer group""" - return self.execute(b"XACK", stream, group_name, id, *ids) - - def xdel(self, stream, id, *ids): - """Removes the specified entries(IDs) from a stream""" - return self.execute(b"XDEL", stream, id, *ids) - - def xtrim(self, stream, max_len, exact_len=False): - """trims the stream to a given number of items, evicting older items""" - args = [] - if exact_len: - args.extend((b"MAXLEN", max_len)) - else: - args.extend((b"MAXLEN", b"~", max_len)) - return self.execute(b"XTRIM", stream, *args) - - def xlen(self, stream): - """Returns the number of entries inside a stream""" - return self.execute(b"XLEN", stream) - - def xinfo(self, stream): - """Retrieve information about the given stream. - - An alias for ``xinfo_stream()`` - """ - return self.xinfo_stream(stream) - - def xinfo_consumers(self, stream, group_name): - """Retrieve consumers of a consumer group""" - fut = self.execute(b"XINFO", b"CONSUMERS", stream, group_name) - - return wait_convert(fut, parse_lists_to_dicts) - - def xinfo_groups(self, stream): - """Retrieve the consumer groups for a stream""" - fut = self.execute(b"XINFO", b"GROUPS", stream) - return wait_convert(fut, parse_lists_to_dicts) - - def xinfo_stream(self, stream): - """Retrieve information about the given stream.""" - fut = self.execute(b"XINFO", b"STREAM", stream) - return wait_make_dict(fut) - - def xinfo_help(self): - """Retrieve help regarding the ``XINFO`` sub-commands""" - fut = self.execute(b"XINFO", b"HELP") - return wait_convert(fut, lambda l: b"\n".join(l)) - - def _xread(self, streams, timeout=0, count=None, latest_ids=None, no_ack=False): - """Wraps up common functionality between ``xread()`` - and ``xread_group()`` - - You should probably be using ``xread()`` or ``xread_group()`` directly. - """ - if latest_ids is None: - latest_ids = ["$"] * len(streams) - if len(streams) != len(latest_ids): - raise ValueError( - "The streams and latest_ids parameters must be of the " "same length" - ) - - count_args = [b"COUNT", count] if count else [] - if timeout is None: - block_args = [] - elif not isinstance(timeout, int): - raise TypeError(f"timeout argument must be int, not {timeout!r}") - else: - block_args = [b"BLOCK", timeout] - - noack_args = [b"NOACK"] if no_ack else [] - - return ( - count_args + block_args + noack_args + [b"STREAMS"] + streams + latest_ids - ) diff --git a/aioredis/commands/string.py b/aioredis/commands/string.py deleted file mode 100644 index ec277884e..000000000 --- a/aioredis/commands/string.py +++ /dev/null @@ -1,248 +0,0 @@ -from itertools import chain - -from aioredis.util import _NOTSET, wait_convert, wait_ok - - -class StringCommandsMixin: - """String commands mixin. - - For commands details see: http://redis.io/commands/#string - """ - - SET_IF_NOT_EXIST = "SET_IF_NOT_EXIST" # NX - SET_IF_EXIST = "SET_IF_EXIST" # XX - - def append(self, key, value): - """Append a value to key.""" - return self.execute(b"APPEND", key, value) - - def bitcount(self, key, start=None, end=None): - """Count set bits in a string. - - :raises TypeError: if only start or end specified. - """ - if start is None and end is not None: - raise TypeError("both start and stop must be specified") - elif start is not None and end is None: - raise TypeError("both start and stop must be specified") - elif start is not None and end is not None: - args = (start, end) - else: - args = () - return self.execute(b"BITCOUNT", key, *args) - - def bitfield(self): - raise NotImplementedError() - - def bitop_and(self, dest, key, *keys): - """Perform bitwise AND operations between strings.""" - return self.execute(b"BITOP", b"AND", dest, key, *keys) - - def bitop_or(self, dest, key, *keys): - """Perform bitwise OR operations between strings.""" - return self.execute(b"BITOP", b"OR", dest, key, *keys) - - def bitop_xor(self, dest, key, *keys): - """Perform bitwise XOR operations between strings.""" - return self.execute(b"BITOP", b"XOR", dest, key, *keys) - - def bitop_not(self, dest, key): - """Perform bitwise NOT operations between strings.""" - return self.execute(b"BITOP", b"NOT", dest, key) - - def bitpos(self, key, bit, start=None, end=None): - """Find first bit set or clear in a string. - - :raises ValueError: if bit is not 0 or 1 - """ - if bit not in (1, 0): - raise ValueError("bit argument must be either 1 or 0") - bytes_range = [] - if start is not None: - bytes_range.append(start) - if end is not None: - if start is None: - bytes_range = [0, end] - else: - bytes_range.append(end) - return self.execute(b"BITPOS", key, bit, *bytes_range) - - def decr(self, key): - """Decrement the integer value of a key by one.""" - return self.execute(b"DECR", key) - - def decrby(self, key, decrement): - """Decrement the integer value of a key by the given number. - - :raises TypeError: if decrement is not int - """ - if not isinstance(decrement, int): - raise TypeError("decrement must be of type int") - return self.execute(b"DECRBY", key, decrement) - - def get(self, key, *, encoding=_NOTSET): - """Get the value of a key.""" - return self.execute(b"GET", key, encoding=encoding) - - def getbit(self, key, offset): - """Returns the bit value at offset in the string value stored at key. - - :raises TypeError: if offset is not int - :raises ValueError: if offset is less than 0 - """ - if not isinstance(offset, int): - raise TypeError("offset argument must be int") - if offset < 0: - raise ValueError("offset must be greater equal 0") - return self.execute(b"GETBIT", key, offset) - - def getrange(self, key, start, end, *, encoding=_NOTSET): - """Get a substring of the string stored at a key. - - :raises TypeError: if start or end is not int - """ - if not isinstance(start, int): - raise TypeError("start argument must be int") - if not isinstance(end, int): - raise TypeError("end argument must be int") - return self.execute(b"GETRANGE", key, start, end, encoding=encoding) - - def getset(self, key, value, *, encoding=_NOTSET): - """Set the string value of a key and return its old value.""" - return self.execute(b"GETSET", key, value, encoding=encoding) - - def incr(self, key): - """Increment the integer value of a key by one.""" - return self.execute(b"INCR", key) - - def incrby(self, key, increment): - """Increment the integer value of a key by the given amount. - - :raises TypeError: if increment is not int - """ - if not isinstance(increment, int): - raise TypeError("increment must be of type int") - return self.execute(b"INCRBY", key, increment) - - def incrbyfloat(self, key, increment): - """Increment the float value of a key by the given amount. - - :raises TypeError: if increment is not float - """ - if not isinstance(increment, float): - raise TypeError("increment must be of type float") - fut = self.execute(b"INCRBYFLOAT", key, increment) - return wait_convert(fut, float) - - def mget(self, key, *keys, encoding=_NOTSET): - """Get the values of all the given keys.""" - return self.execute(b"MGET", key, *keys, encoding=encoding) - - def mset(self, *args): - """Set multiple keys to multiple values or unpack dict to keys & values. - - :raises TypeError: if len of args is not event number - :raises TypeError: if len of args equals 1 and it is not a dict - """ - data = args - if len(args) == 1: - if not isinstance(args[0], dict): - raise TypeError("if one arg it should be a dict") - data = chain.from_iterable(args[0].items()) - elif len(args) % 2 != 0: - raise TypeError("length of pairs must be even number") - fut = self.execute(b"MSET", *data) - return wait_ok(fut) - - def msetnx(self, key, value, *pairs): - """Set multiple keys to multiple values, - only if none of the keys exist. - - :raises TypeError: if len of pairs is not event number - """ - if len(pairs) % 2 != 0: - raise TypeError("length of pairs must be even number") - return self.execute(b"MSETNX", key, value, *pairs) - - def psetex(self, key, milliseconds, value): - """Set the value and expiration in milliseconds of a key. - - :raises TypeError: if milliseconds is not int - """ - if not isinstance(milliseconds, int): - raise TypeError("milliseconds argument must be int") - fut = self.execute(b"PSETEX", key, milliseconds, value) - return wait_ok(fut) - - def set(self, key, value, *, expire=0, pexpire=0, exist=None): - """Set the string value of a key. - - :raises TypeError: if expire or pexpire is not int - """ - if expire and not isinstance(expire, int): - raise TypeError("expire argument must be int") - if pexpire and not isinstance(pexpire, int): - raise TypeError("pexpire argument must be int") - - args = [] - if expire: - args[:] = [b"EX", expire] - if pexpire: - args[:] = [b"PX", pexpire] - - if exist is self.SET_IF_EXIST: - args.append(b"XX") - elif exist is self.SET_IF_NOT_EXIST: - args.append(b"NX") - fut = self.execute(b"SET", key, value, *args) - return wait_ok(fut) - - def setbit(self, key, offset, value): - """Sets or clears the bit at offset in the string value stored at key. - - :raises TypeError: if offset is not int - :raises ValueError: if offset is less than 0 or value is not 0 or 1 - """ - if not isinstance(offset, int): - raise TypeError("offset argument must be int") - if offset < 0: - raise ValueError("offset must be greater equal 0") - if value not in (0, 1): - raise ValueError("value argument must be either 1 or 0") - return self.execute(b"SETBIT", key, offset, value) - - def setex(self, key, seconds, value): - """Set the value and expiration of a key. - - If seconds is float it will be multiplied by 1000 - coerced to int and passed to `psetex` method. - - :raises TypeError: if seconds is neither int nor float - """ - if isinstance(seconds, float): - return self.psetex(key, int(seconds * 1000), value) - if not isinstance(seconds, int): - raise TypeError("milliseconds argument must be int") - fut = self.execute(b"SETEX", key, seconds, value) - return wait_ok(fut) - - def setnx(self, key, value): - """Set the value of a key, only if the key does not exist.""" - fut = self.execute(b"SETNX", key, value) - return wait_convert(fut, bool) - - def setrange(self, key, offset, value): - """Overwrite part of a string at key starting at the specified offset. - - :raises TypeError: if offset is not int - :raises ValueError: if offset less than 0 - """ - if not isinstance(offset, int): - raise TypeError("offset argument must be int") - if offset < 0: - raise ValueError("offset must be greater equal 0") - return self.execute(b"SETRANGE", key, offset, value) - - def strlen(self, key): - """Get the length of the value stored in a key.""" - return self.execute(b"STRLEN", key) diff --git a/aioredis/commands/transaction.py b/aioredis/commands/transaction.py deleted file mode 100644 index 440b22f06..000000000 --- a/aioredis/commands/transaction.py +++ /dev/null @@ -1,314 +0,0 @@ -import asyncio -import functools - -from ..abc import AbcPool -from ..errors import ConnectionClosedError, MultiExecError, PipelineError, RedisError -from ..util import _set_exception, get_event_loop, wait_ok - - -class TransactionsCommandsMixin: - """Transaction commands mixin. - - For commands details see: http://redis.io/commands/#transactions - - Transactions HOWTO: - - >>> tr = redis.multi_exec() - >>> result_future1 = tr.incr('foo') - >>> result_future2 = tr.incr('bar') - >>> try: - ... result = await tr.execute() - ... except MultiExecError: - ... pass # check what happened - >>> result1 = await result_future1 - >>> result2 = await result_future2 - >>> assert result == [result1, result2] - """ - - def unwatch(self): - """Forget about all watched keys.""" - fut = self._pool_or_conn.execute(b"UNWATCH") - return wait_ok(fut) - - def watch(self, key, *keys): - """Watch the given keys to determine execution of the MULTI/EXEC block.""" - # FIXME: we can send watch through one connection and then issue - # 'multi/exec' command through other. - # Possible fix: - # "Remember" a connection that was used for 'watch' command - # and then send 'multi / exec / discard' through it. - fut = self._pool_or_conn.execute(b"WATCH", key, *keys) - return wait_ok(fut) - - def multi_exec(self): - """Returns MULTI/EXEC pipeline wrapper. - - Usage: - - >>> tr = redis.multi_exec() - >>> fut1 = tr.incr('foo') # NO `await` as it will block forever! - >>> fut2 = tr.incr('bar') - >>> result = await tr.execute() - >>> result - [1, 1] - >>> await asyncio.gather(fut1, fut2) - [1, 1] - """ - return MultiExec(self._pool_or_conn, self.__class__) - - def pipeline(self): - """Returns :class:`Pipeline` object to execute bulk of commands. - - It is provided for convenience. - Commands can be pipelined without it. - - Example: - - >>> pipe = redis.pipeline() - >>> fut1 = pipe.incr('foo') # NO `await` as it will block forever! - >>> fut2 = pipe.incr('bar') - >>> result = await pipe.execute() - >>> result - [1, 1] - >>> await asyncio.gather(fut1, fut2) - [1, 1] - >>> # - >>> # The same can be done without pipeline: - >>> # - >>> fut1 = redis.incr('foo') # the 'INCRY foo' command already sent - >>> fut2 = redis.incr('bar') - >>> await asyncio.gather(fut1, fut2) - [2, 2] - """ - return Pipeline(self._pool_or_conn, self.__class__) - - -class _RedisBuffer: - def __init__(self, pipeline, *, loop=None): - # TODO: deprecation note - # if loop is None: - # loop = asyncio.get_event_loop() - self._pipeline = pipeline - - def execute(self, cmd, *args, **kw): - fut = get_event_loop().create_future() - self._pipeline.append((fut, cmd, args, kw)) - return fut - - # TODO: add here or remove in connection methods like `select`, `auth` etc - - -class Pipeline: - """Commands pipeline. - - Usage: - - >>> pipe = redis.pipeline() - >>> fut1 = pipe.incr('foo') - >>> fut2 = pipe.incr('bar') - >>> await pipe.execute() - [1, 1] - >>> await fut1 - 1 - >>> await fut2 - 1 - """ - - error_class = PipelineError - - def __init__( - self, pool_or_connection, commands_factory=lambda conn: conn, *, loop=None - ): - # TODO: deprecation note - # if loop is None: - # loop = asyncio.get_event_loop() - self._pool_or_conn = pool_or_connection - self._pipeline = [] - self._results = [] - self._buffer = _RedisBuffer(self._pipeline) - self._redis = commands_factory(self._buffer) - self._done = False - - def __getattr__(self, name): - assert not self._done, "Pipeline already executed. Create new one." - attr = getattr(self._redis, name) - if callable(attr): - - @functools.wraps(attr) - def wrapper(*args, **kw): - try: - task = asyncio.ensure_future(attr(*args, **kw)) - except Exception as exc: - task = get_event_loop().create_future() - task.set_exception(exc) - self._results.append(task) - return task - - return wrapper - return attr - - async def execute(self, *, return_exceptions=False): - """Execute all buffered commands. - - Any exception that is raised by any command is caught and - raised later when processing results. - - Exceptions can also be returned in result if - `return_exceptions` flag is set to True. - """ - assert not self._done, "Pipeline already executed. Create new one." - self._done = True - - if self._pipeline: - if isinstance(self._pool_or_conn, AbcPool): - async with self._pool_or_conn.get() as conn: - return await self._do_execute( - conn, return_exceptions=return_exceptions - ) - else: - return await self._do_execute( - self._pool_or_conn, return_exceptions=return_exceptions - ) - else: - return await self._gather_result(return_exceptions) - - async def _do_execute(self, conn, *, return_exceptions=False): - await asyncio.gather(*self._send_pipeline(conn), return_exceptions=True) - return await self._gather_result(return_exceptions) - - async def _gather_result(self, return_exceptions): - errors = [] - results = [] - for fut in self._results: - try: - res = await fut - results.append(res) - except Exception as exc: - errors.append(exc) - results.append(exc) - if errors and not return_exceptions: - raise self.error_class(errors) - return results - - def _send_pipeline(self, conn): - with conn._buffered(): - for fut, cmd, args, kw in self._pipeline: - try: - result_fut = conn.execute(cmd, *args, **kw) - result_fut.add_done_callback( - functools.partial(self._check_result, waiter=fut) - ) - except Exception as exc: - fut.set_exception(exc) - else: - yield result_fut - - def _check_result(self, fut, waiter): - if fut.cancelled(): - waiter.cancel() - elif fut.exception(): - waiter.set_exception(fut.exception()) - else: - waiter.set_result(fut.result()) - - -class MultiExec(Pipeline): - """Multi/Exec pipeline wrapper. - - Usage: - - >>> tr = redis.multi_exec() - >>> f1 = tr.incr('foo') - >>> f2 = tr.incr('bar') - >>> # A) - >>> await tr.execute() - >>> res1 = await f1 - >>> res2 = await f2 - >>> # or B) - >>> res1, res2 = await tr.execute() - - and ofcourse try/except: - - >>> tr = redis.multi_exec() - >>> f1 = tr.incr('1') # won't raise any exception (why?) - >>> try: - ... res = await tr.execute() - ... except RedisError: - ... pass - >>> assert f1.done() - >>> assert f1.result() is res - - >>> tr = redis.multi_exec() - >>> wait_ok_coro = tr.mset('1') - >>> try: - ... ok1 = await tr.execute() - ... except RedisError: - ... pass # handle it - >>> ok2 = await wait_ok_coro - >>> # for this to work `wait_ok_coro` must be wrapped in Future - """ - - error_class = MultiExecError - - async def _do_execute(self, conn, *, return_exceptions=False): - self._waiters = waiters = [] - with conn._buffered(): - multi = conn.execute("MULTI") - coros = list(self._send_pipeline(conn)) - exec_ = conn.execute("EXEC") - gather = asyncio.gather(multi, *coros, return_exceptions=True) - last_error = None - try: - await asyncio.shield(gather) - except asyncio.CancelledError: - await gather - except Exception as err: - last_error = err - raise - finally: - if conn.closed: - if last_error is None: - last_error = ConnectionClosedError() - for fut in waiters: - _set_exception(fut, last_error) - # fut.cancel() - for fut in self._results: - if not fut.done(): - fut.set_exception(last_error) - # fut.cancel() - else: - try: - results = await exec_ - except RedisError as err: - for fut in waiters: - fut.set_exception(err) - else: - assert len(results) == len(waiters), ( - "Results does not match waiters", - results, - waiters, - ) - self._resolve_waiters(results, return_exceptions) - return await self._gather_result(return_exceptions) - - def _resolve_waiters(self, results, return_exceptions): - errors = [] - for val, fut in zip(results, self._waiters): - if isinstance(val, RedisError): - fut.set_exception(val) - errors.append(val) - else: - fut.set_result(val) - if errors and not return_exceptions: - raise MultiExecError(errors) - - def _check_result(self, fut, waiter): - assert waiter not in self._waiters, (fut, waiter, self._waiters) - assert not waiter.done(), waiter - if fut.cancelled(): # await gather was cancelled - waiter.cancel() - elif fut.exception(): # server replied with error - waiter.set_exception(fut.exception()) - elif fut.result() in {b"QUEUED", "QUEUED"}: - # got result, it should be QUEUED - self._waiters.append(waiter) diff --git a/aioredis/connection.py b/aioredis/connection.py index dc86293cb..d97c4b2b0 100644 --- a/aioredis/connection.py +++ b/aioredis/connection.py @@ -1,585 +1,1598 @@ +from __future__ import annotations + import asyncio +import errno +import inspect +import io +import os import socket -import sys -import types +import ssl +import threading +import time import warnings -from collections import deque -from contextlib import contextmanager -from functools import partial - -from .abc import AbcChannel, AbcConnection -from .errors import ( - ConnectionClosedError, - ConnectionForcedCloseError, - MaxClientsError, - ProtocolError, +from distutils.version import StrictVersion +from itertools import chain +from typing import ( + Any, + Iterable, + List, + Mapping, + Optional, + Protocol, + Set, + Tuple, + Type, + TypedDict, + TypeVar, + Union, +) +from urllib.parse import ParseResult, parse_qs, unquote, urlparse + +import async_timeout + +from .exceptions import ( + AuthenticationError, + AuthenticationWrongNumberOfArgsError, + BusyLoadingError, + ChildDeadlockedError, + ConnectionError, + DataError, + ExecAbortError, + InvalidResponse, + ModuleError, + NoPermissionError, + NoScriptError, ReadOnlyError, RedisError, - ReplyError, - WatchVariableError, -) -from .log import logger -from .parser import Reader -from .pubsub import Channel -from .stream import open_connection, open_unix_connection -from .util import ( - _NOTSET, - _set_exception, - _set_result, - coerced_keys_dict, - decode, - encode_command, - get_event_loop, - parse_url, - wait_ok, + ResponseError, + TimeoutError, ) +from .utils import str_if_bytes + +NONBLOCKING_EXCEPTION_ERROR_NUMBERS = { + BlockingIOError: errno.EWOULDBLOCK, + ssl.SSLWantReadError: 2, + ssl.SSLWantWriteError: 2, + ssl.SSLError: 2, +} + +NONBLOCKING_EXCEPTIONS = tuple(NONBLOCKING_EXCEPTION_ERROR_NUMBERS.keys()) + +try: + import hiredis + +except (ImportError, ModuleNotFoundError): + HIREDIS_AVAILABLE = False +else: + HIREDIS_AVAILABLE = True + hiredis_version = StrictVersion(hiredis.__version__) + if hiredis_version < StrictVersion("1.0.0"): + warnings.warn( + "aioredis supports hiredis @ 1.0.0 or higher. " + f"You have hiredis @ {hiredis.__version__}. " + "Pure-python parser will be used instead." + ) + HIREDIS_AVAILABLE = False -__all__ = ["create_connection", "RedisConnection"] +SYM_STAR = b"*" +SYM_DOLLAR = b"$" +SYM_CRLF = b"\r\n" +SYM_LF = b"\n" +SYM_EMPTY = b"" -MAX_CHUNK_SIZE = 65536 +SERVER_CLOSED_CONNECTION_ERROR = "Connection closed by server." -_PUBSUB_COMMANDS = ( - "SUBSCRIBE", - b"SUBSCRIBE", - "PSUBSCRIBE", - b"PSUBSCRIBE", - "UNSUBSCRIBE", - b"UNSUBSCRIBE", - "PUNSUBSCRIBE", - b"PUNSUBSCRIBE", +SENTINEL = object() +MODULE_LOAD_ERROR = "Error loading the extension. Please check the server logs." +NO_SUCH_MODULE_ERROR = "Error unloading module: no such module with that name" +MODULE_UNLOAD_NOT_POSSIBLE_ERROR = "Error unloading module: operation not possible." +MODULE_EXPORTS_DATA_TYPES_ERROR = ( + "Error unloading module: the module " + "exports one or more module-side data " + "types, can't unload" ) +EncodedT = Union[bytes, memoryview] +DecodedT = Union[str, int, float] +EncodableT = Union[EncodedT, DecodedT, None] -async def create_connection( - address, - *, - db=None, - password=None, - ssl=None, - encoding=None, - parser=None, - loop=None, - timeout=None, - connection_cls=None, - name=None, -): - """Creates redis connection. - - Opens connection to Redis server specified by address argument. - Address argument can be one of the following: - * A tuple representing (host, port) pair for TCP connections; - * A string representing either Redis URI or unix domain socket path. - - SSL argument is passed through to asyncio.create_connection. - By default SSL/TLS is not used. - - By default any timeout is applied at the connection stage, however - you can set a limitted time used trying to open a connection via - the `timeout` Kw. - - Encoding argument can be used to decode byte-replies to strings. - By default no decoding is done. - - Parser parameter can be used to pass custom Redis protocol parser class. - By default hiredis.Reader is used (unless it is missing or platform - is not CPython). - - Return value is RedisConnection instance or a connection_cls if it is - given. - - This function is a coroutine. - """ - assert isinstance(address, (tuple, list, str)), "tuple or str expected" - if isinstance(address, str): - address, options = parse_url(address) - logger.debug("Parsed Redis URI %r", address) - db = options.setdefault("db", db) - password = options.setdefault("password", password) - encoding = options.setdefault("encoding", encoding) - timeout = options.setdefault("timeout", timeout) - if "ssl" in options: - assert options["ssl"] or (not options["ssl"] and not ssl), ( - "Conflicting ssl options are set", - options["ssl"], - ssl, + +class Encoder: + """Encode strings to bytes-like and decode bytes-like to strings""" + + __slots__ = "encoding", "encoding_errors", "decode_responses" + + def __init__(self, encoding: str, encoding_errors: str, decode_responses: bool): + self.encoding = encoding + self.encoding_errors = encoding_errors + self.decode_responses = decode_responses + + def encode(self, value: EncodableT) -> EncodedT: + """Return a bytestring or bytes-like representation of the value""" + if isinstance(value, (bytes, memoryview)): + return value + if isinstance(value, bool): + # special case bool since it is a subclass of int + raise DataError( + "Invalid input of type: 'bool'. " + "Convert to a bytes, string, int or float first." ) - ssl = ssl or options["ssl"] + if isinstance(value, (int, float)): + return repr(value).encode() + if not isinstance(value, str): + # a value we don't know how to deal with. throw an error + typename = value.__class__.__name__ + raise DataError( + f"Invalid input of type: {typename!r}. " + "Convert to a bytes, string, int or float first." + ) + if isinstance(value, str): + return value.encode(self.encoding, self.encoding_errors) + return value - if timeout is not None and timeout <= 0: - raise ValueError("Timeout has to be None or a number greater than 0") + def decode(self, value: EncodableT, force=False) -> EncodableT: + """Return a unicode string from the bytes-like representation""" + if self.decode_responses or force: + if isinstance(value, memoryview): + return value.tobytes().decode(self.encoding, self.encoding_errors) + if isinstance(value, bytes): + return value.decode(self.encoding, self.encoding_errors) + return value + + +ExceptionMappingT = Mapping[str, Union[Type[Exception], Mapping[str, Type[Exception]]]] - if connection_cls: - assert issubclass( - connection_cls, AbcConnection - ), "connection_class does not meet the AbcConnection contract" - cls = connection_cls - else: - cls = RedisConnection - if loop is not None and sys.version_info >= (3, 8, 0): - warnings.warn("The loop argument is deprecated", DeprecationWarning) +class BaseParser: + """Plain Python parsing class""" + + __slots__ = "_stream", "_buffer", "_read_size" + + EXCEPTION_CLASSES: ExceptionMappingT = { + "ERR": { + "max number of clients reached": ConnectionError, + "Client sent AUTH, but no password is set": AuthenticationError, + "invalid password": AuthenticationError, + # some Redis server versions report invalid command syntax + # in lowercase + "wrong number of arguments for 'auth' command": AuthenticationWrongNumberOfArgsError, + # some Redis server versions report invalid command syntax + # in uppercase + "wrong number of arguments for 'AUTH' command": AuthenticationWrongNumberOfArgsError, + MODULE_LOAD_ERROR: ModuleError, + MODULE_EXPORTS_DATA_TYPES_ERROR: ModuleError, + NO_SUCH_MODULE_ERROR: ModuleError, + MODULE_UNLOAD_NOT_POSSIBLE_ERROR: ModuleError, + }, + "EXECABORT": ExecAbortError, + "LOADING": BusyLoadingError, + "NOSCRIPT": NoScriptError, + "READONLY": ReadOnlyError, + "NOAUTH": AuthenticationError, + "NOPERM": NoPermissionError, + } + + def __init__(self, socket_read_size: int): + self._stream: Optional[asyncio.StreamReader] = None + self._buffer: Optional[SocketBuffer] = None + self._read_size = socket_read_size + + def __del__(self): + try: + self.on_disconnect() + except Exception: + pass + + def parse_error(self, response: str) -> ResponseError: + """Parse an error response""" + error_code = response.split(" ")[0] + if error_code in self.EXCEPTION_CLASSES: + response = response[len(error_code) + 1 :] + exception_class = self.EXCEPTION_CLASSES[error_code] + if isinstance(exception_class, dict): + exception_class = exception_class.get(response, ResponseError) + return exception_class(response) + return ResponseError(response) + + def on_disconnect(self): + raise NotImplementedError() + + def on_connect(self, connection: Connection): + raise NotImplementedError() + + async def can_read(self, timeout: float) -> bool: + raise NotImplementedError() + + async def read_response(self) -> Union[EncodableT, ResponseError, None]: + raise NotImplementedError() + + +class SocketBuffer: + """Async-friendly re-impl of redis-py's SocketBuffer. + + TODO: We're currently passing through two buffers, + the asyncio.StreamReader and this. I imagine we can reduce the layers here + while maintaining compliance with prior art. + """ + + def __init__( + self, + stream_reader: asyncio.StreamReader, + socket_read_size: int, + socket_timeout: float, + ): + self._stream = stream_reader + self.socket_read_size = socket_read_size + self.socket_timeout = socket_timeout + self._buffer = io.BytesIO() + # number of bytes written to the buffer from the socket + self.bytes_written = 0 + # number of bytes read from the buffer + self.bytes_read = 0 + + @property + def length(self): + return self.bytes_written - self.bytes_read - if isinstance(address, (list, tuple)): - host, port = address - logger.debug("Creating tcp connection to %r", address) - reader, writer = await asyncio.wait_for( - open_connection(host, port, limit=MAX_CHUNK_SIZE, ssl=ssl), timeout + async def _read_from_socket( + self, + length: int = None, + timeout: Optional[float] = SENTINEL, # type: ignore + raise_on_timeout: bool = True, + ) -> bool: + buf = self._buffer + buf.seek(self.bytes_written) + marker = 0 + timeout = timeout if timeout is not SENTINEL else self.socket_timeout + + try: + while True: + async with async_timeout.timeout(timeout): + data = await self._stream.read(self.socket_read_size) + # an empty string indicates the server shutdown the socket + if isinstance(data, bytes) and len(data) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + buf.write(data) + data_length = len(data) + self.bytes_written += data_length + marker += data_length + + if length is not None and length > marker: + continue + return True + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def can_read(self, timeout: float) -> bool: + return bool(self.length) or await self._read_from_socket( + timeout=timeout, raise_on_timeout=False ) - sock = writer.transport.get_extra_info("socket") - if sock is not None: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - address = sock.getpeername() - address = tuple(address[:2]) - else: - logger.debug("Creating unix connection to %r", address) - reader, writer = await asyncio.wait_for( - open_unix_connection(address, ssl=ssl, limit=MAX_CHUNK_SIZE), timeout + + async def read(self, length: int) -> bytes: + length = length + 2 # make sure to read the \r\n terminator + # make sure we've read enough data from the socket + if length > self.length: + await self._read_from_socket(length - self.length) + + self._buffer.seek(self.bytes_read) + data = self._buffer.read(length) + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + async def readline(self) -> bytes: + buf = self._buffer + buf.seek(self.bytes_read) + data = buf.readline() + while not data.endswith(SYM_CRLF): + # there's more data in the socket that we need + await self._read_from_socket() + buf.seek(self.bytes_read) + data = buf.readline() + + self.bytes_read += len(data) + + # purge the buffer when we've consumed it all so it doesn't + # grow forever + if self.bytes_read == self.bytes_written: + self.purge() + + return data[:-2] + + def purge(self): + self._buffer.seek(0) + self._buffer.truncate() + self.bytes_written = 0 + self.bytes_read = 0 + + def close(self): + try: + self.purge() + self._buffer.close() + except Exception: + # issue #633 suggests the purge/close somehow raised a + # BadFileDescriptor error. Perhaps the client ran out of + # memory or something else? It's probably OK to ignore + # any error being raised from purge/close since we're + # removing the reference to the instance below. + pass + self._buffer = None + self._stream = None + + +class PythonParser(BaseParser): + """Plain Python parsing class""" + + __slots__ = BaseParser.__slots__ + ("encoder",) + + def __init__(self, socket_read_size: int): + super().__init__(socket_read_size) + self.encoder: Optional[Encoder] = None + + def on_connect(self, connection: Connection): + """Called when the stream connects""" + self._stream = connection._reader + self._buffer = SocketBuffer( + self._stream, self._read_size, connection.socket_timeout ) - sock = writer.transport.get_extra_info("socket") - if sock is not None: - address = sock.getpeername() + self.encoder = connection.encoder + + def on_disconnect(self): + """Called when the stream disconnects""" + if self._stream is not None: + self._stream = None + if self._buffer is not None: + self._buffer.close() + self._buffer = None + self.encoder = None - conn = cls(reader, writer, encoding=encoding, address=address, parser=parser) + async def can_read(self, timeout: float): + return self._buffer and bool(await self._buffer.can_read(timeout)) - try: - if password is not None: - await conn.auth(password) - if db is not None: - await conn.select(db) - if name is not None: - await conn.setname(name) - except Exception: - conn.close() - await conn.wait_closed() - raise - return conn + async def read_response(self) -> Union[EncodableT, ResponseError, None]: + if not self._buffer: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + raw = await self._buffer.readline() + if not raw: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + response: Any + byte, response = raw[:1], raw[1:] + if byte not in (b"-", b"+", b":", b"$", b"*"): + raise InvalidResponse(f"Protocol Error: {raw!r}") -class RedisConnection(AbcConnection): - """Redis connection.""" + # server returned an error + if byte == b"-": + response = response.decode("utf-8", errors="replace") + error = self.parse_error(response) + # if the error is a ConnectionError, raise immediately so the user + # is notified + if isinstance(error, ConnectionError): + raise error + # otherwise, we're dealing with a ResponseError that might belong + # inside a pipeline response. the connection's read_response() + # and/or the pipeline's execute() will raise this error if + # necessary, so just return the exception instance here. + return error + # single value + elif byte == b"+": + pass + # int value + elif byte == b":": + response = int(response) + # bulk response + elif byte == b"$": + length = int(response) + if length == -1: + return None + response = await self._buffer.read(length) + # multi-bulk response + elif byte == b"*": + length = int(response) + if length == -1: + return None + response = [(await self.read_response()) for i in range(length)] + if isinstance(response, bytes): + response = self.encoder.decode(response) + return response + + +class HiredisParser(BaseParser): + """Parser class for connections using Hiredis""" + + __slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout") + + def __init__(self, socket_read_size: int): + if not HIREDIS_AVAILABLE: + raise RedisError("Hiredis is not available.") + super().__init__(socket_read_size=socket_read_size) + self._next_response = ... + self._reader: Optional[hiredis.Reader] = None + self._socket_timeout: Optional[float] = None + + def on_connect(self, connection: Connection): + self._stream = connection._reader + kwargs = { + "protocolError": InvalidResponse, + "replyError": self.parse_error, + } + if connection.encoder.decode_responses: + kwargs.update( + encoding=connection.encoder.encoding, + errors=connection.encoder.encoding_errors, + ) + + self._reader = hiredis.Reader(**kwargs) + self._next_response = False + self._socket_timeout = connection.socket_timeout + + def on_disconnect(self): + self._stream = None + self._reader = None + self._next_response = False + + async def can_read(self, timeout: float): + if not self._reader: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) + + if self._next_response is False: + self._next_response = self._reader.gets() + if self._next_response is False: + return await self.read_from_socket(timeout=timeout, raise_on_timeout=False) + return True + + async def read_from_socket( + self, timeout: Optional[float] = SENTINEL, raise_on_timeout: bool = True + ): + timeout = self._socket_timeout if timeout is SENTINEL else timeout + try: + async with async_timeout.timeout(timeout): + buffer = await self._stream.read(self._read_size) + if not isinstance(buffer, bytes) or len(buffer) == 0: + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + self._reader.feed(buffer) + # data was read from the socket and added to the buffer. + # return True to indicate that data was read. + return True + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + if raise_on_timeout: + raise TimeoutError("Timeout reading from socket") from None + return False + except NONBLOCKING_EXCEPTIONS as ex: + # if we're in nonblocking mode and the recv raises a + # blocking error, simply return False indicating that + # there's no data to be read. otherwise raise the + # original exception. + allowed = NONBLOCKING_EXCEPTION_ERROR_NUMBERS.get(ex.__class__, -1) + if not raise_on_timeout and ex.errno == allowed: + return False + raise ConnectionError(f"Error while reading from socket: {ex.args}") + + async def read_response(self) -> EncodableT: + if not self._stream or not self._reader: + self.on_disconnect() + raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None + + # _next_response might be cached from a can_read() call + if self._next_response is not False: + response = self._next_response + self._next_response = False + return response + + response = self._reader.gets() + while response is False: + await self.read_from_socket() + response = self._reader.gets() + + # if the response is a ConnectionError or the response is a list and + # the first item is a ConnectionError, raise it as something bad + # happened + if isinstance(response, ConnectionError): + raise response + elif ( + isinstance(response, list) + and response + and isinstance(response[0], ConnectionError) + ): + raise response[0] + return response + + +DefaultParser: Type[Union[PythonParser, HiredisParser]] +if HIREDIS_AVAILABLE: + DefaultParser = HiredisParser +else: + DefaultParser = PythonParser + + +class ConnectCallbackProtocol(Protocol): + def __call__(self, connection: Connection): + ... + + +class AsyncConnectCallbackProtocol(Protocol): + async def __call__(self, connection: Connection): + ... + + +ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] + + +class Connection: + """Manages TCP communication to and from a Redis server""" + + __slots__ = ( + "pid", + "host", + "port", + "db", + "username", + "client_name", + "password", + "socket_timeout", + "socket_connect_timeout", + "socket_keepalive", + "socket_keepalive_options", + "socket_type", + "retry_on_timeout", + "health_check_interval", + "next_health_check", + "last_active_at", + "encoder", + "ssl_context", + "_reader", + "_writer", + "_parser", + "_connect_callbacks", + "_buffer_cutoff", + "_loop", + "__dict__", + ) def __init__( self, - reader, - writer, *, - address, - encoding=None, - parser=None, - loop=None, - name=None, + host: str = "localhost", + port: Union[str, int] = 6379, + db: Union[str, int] = 0, + password: str = None, + socket_timeout: float = None, + socket_connect_timeout: float = None, + socket_keepalive: bool = False, + socket_keepalive_options: dict = None, + socket_type: int = 0, + retry_on_timeout: bool = False, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: int = 0, + client_name: str = None, + username: str = None, + encoder_class: Type[Encoder] = Encoder, + loop: asyncio.AbstractEventLoop = None, ): - if loop is not None and sys.version_info >= (3, 8): - warnings.warn("The loop argument is deprecated", DeprecationWarning) - if parser is None: - parser = Reader - assert callable(parser), ("Parser argument is not callable", parser) - self._reader = reader - self._writer = writer - self._address = address - self._waiters = deque() - self._reader.set_parser( - parser(protocolError=ProtocolError, replyError=ReplyError) + self.pid = os.getpid() + self.host = host + self.port = int(port) + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 + self.ssl_context: Optional[RedisSSLContext] = None + self.encoder = encoder_class(encoding, encoding_errors, decode_responses) + self._reader: Optional[asyncio.StreamReader] = None + self._writer: Optional[asyncio.StreamWriter] = None + self._parser = parser_class( + socket_read_size=socket_read_size, ) - self._reader_task = asyncio.ensure_future(self._read_data()) - self._close_msg = None - self._db = 0 - self._closing = False - self._closed = False - self._close_state = asyncio.Event() - self._reader_task.add_done_callback(lambda x: self._close_state.set()) - self._in_transaction = None - self._transaction_error = None # XXX: never used? - self._in_pubsub = 0 - self._pubsub_channels = coerced_keys_dict() - self._pubsub_patterns = coerced_keys_dict() - self._encoding = encoding - self._pipeline_buffer = None - self._name = name + self._connect_callbacks: List[ConnectCallbackT] = [] + self._buffer_cutoff = 6000 + self._loop = loop def __repr__(self): - return f"" + repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) + return f"{self.__class__.__name__}<{repr_args}>" - async def _read_data(self): - """Response reader task.""" - last_error = ConnectionClosedError("Connection has been closed by server") - while not self._reader.at_eof(): - try: - obj = await self._reader.readobj() - except asyncio.CancelledError: - # NOTE: reader can get cancelled from `close()` method only. - last_error = RuntimeError("this is unexpected") - break - except ProtocolError as exc: - # ProtocolError is fatal - # so connection must be closed - if self._in_transaction is not None: - self._transaction_error = exc - last_error = exc - break - except Exception as exc: - # NOTE: for QUIT command connection error can be received - # before response - last_error = exc - break - else: - if (obj == b"" or obj is None) and self._reader.at_eof(): - logger.debug( - "Connection has been closed by server," " response: %r", obj - ) - last_error = ConnectionClosedError("Reader at end of file") - break - - if isinstance(obj, MaxClientsError): - last_error = obj - break - if self._in_pubsub: - self._process_pubsub(obj) + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def __del__(self): + try: + if self.is_connected: + loop = self._loop or asyncio.get_event_loop() + coro = self.disconnect() + if loop.is_running(): + loop.create_task(coro) else: - self._process_data(obj) - self._closing = True - get_event_loop().call_soon(self._do_close, last_error) - - def _process_data(self, obj): - """Processes command results.""" - assert len(self._waiters) > 0, (type(obj), obj) - waiter, encoding, cb = self._waiters.popleft() - if isinstance(obj, RedisError): - if isinstance(obj, ReplyError): - if obj.args[0].startswith("READONLY"): - obj = ReadOnlyError(obj.args[0]) - _set_exception(waiter, obj) - if self._in_transaction is not None: - self._transaction_error = obj - else: - if encoding is not None: - try: - obj = decode(obj, encoding) - except Exception as exc: - _set_exception(waiter, exc) - return - if cb is not None: - try: - obj = cb(obj) - except Exception as exc: - _set_exception(waiter, exc) - return - _set_result(waiter, obj) - if self._in_transaction is not None: - self._in_transaction.append((encoding, cb)) - - def _process_pubsub(self, obj, *, process_waiters=True): - """Processes pubsub messages.""" - kind, *args, data = obj - if kind in (b"subscribe", b"unsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"unsubscribe": - ch = self._pubsub_channels.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind in (b"psubscribe", b"punsubscribe"): - (chan,) = args - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(obj) - if kind == b"punsubscribe": - ch = self._pubsub_patterns.pop(chan, None) - if ch: - ch.close() - self._in_pubsub = data - elif kind == b"message": - (chan,) = args - self._pubsub_channels[chan].put_nowait(data) - elif kind == b"pmessage": - pattern, chan = args - self._pubsub_patterns[pattern].put_nowait((chan, data)) - elif kind == b"pong": - if process_waiters and self._in_pubsub and self._waiters: - self._process_data(data or b"PONG") - else: - logger.warning("Unknown pubsub message received %r", obj) - - @contextmanager - def _buffered(self): - # XXX: we must ensure that no await happens - # as long as we buffer commands. - # Probably we can set some error-raising callback on enter - # and remove it on exit - # if some await happens in between -> throw an error. - # This is creepy solution, 'cause some one might want to await - # on some other source except redis. - # So we must only raise error we someone tries to await - # pending aioredis future - # One of solutions is to return coroutine instead of a future - # in `execute` method. - # In a coroutine we can check if buffering is enabled and raise error. - - # TODO: describe in docs difference in pipeline mode for - # conn.execute vs pipeline.execute() - if self._pipeline_buffer is None: - self._pipeline_buffer = bytearray() + loop.run_until_complete(self.disconnect()) + except Exception: + pass + + @property + def is_connected(self): + return bool(self._reader and self._writer) + + def register_connect_callback(self, callback): + self._connect_callbacks.append(callback) + + def clear_connect_callbacks(self): + self._connect_callbacks = [] + + async def connect(self): + """Connects to the Redis server if not already connected""" + if self.is_connected: + return + try: + await self._connect() + except asyncio.CancelledError: + raise + except (socket.timeout, asyncio.TimeoutError): + raise TimeoutError("Timeout connecting to server") + except OSError as e: + raise ConnectionError(self._error_message(e)) + except Exception as exc: + raise ConnectionError(exc) from exc + + try: + await self.on_connect() + except RedisError: + # clean up after any error in on_connect + await self.disconnect() + raise + + # run any user callbacks. right now the only internal callback + # is for pubsub channel/pattern resubscription + for callback in self._connect_callbacks: + task = callback(self) + if task and inspect.isawaitable(task): + await task + + async def _connect(self): + """Create a TCP socket connection""" + with async_timeout.timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + host=self.host, port=self.port, ssl=self.ssl_context, loop=self._loop + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock is not None: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) try: - yield self - buf = self._pipeline_buffer - self._writer.write(buf) - finally: - self._pipeline_buffer = None - else: - yield self - - def execute(self, command, *args, encoding=_NOTSET): - """Executes redis command and returns Future waiting for the answer. - - Raises: - * TypeError if any of args can not be encoded as bytes. - * ReplyError on redis '-ERR' responses. - * ProtocolError when response can not be decoded meaning connection - is broken. - * ConnectionClosedError when either client or server has closed the - connection. - """ - if self._reader is None or self._reader.at_eof(): - msg = self._close_msg or "Connection closed or corrupted" - raise ConnectionClosedError(msg) - if command is None: - raise TypeError("command must not be None") - if None in args: - raise TypeError("args must not contain None") - command = command.upper().strip() - is_pubsub = command in _PUBSUB_COMMANDS - is_ping = command in ("PING", b"PING") - if self._in_pubsub and not (is_pubsub or is_ping): - raise RedisError("Connection in SUBSCRIBE mode") - elif is_pubsub: - logger.warning("Deprecated. Use `execute_pubsub` method directly") - return self.execute_pubsub(command, *args) - - if command in ("SELECT", b"SELECT"): - cb = partial(self._set_db, args=args) - elif command in ("MULTI", b"MULTI"): - cb = self._start_transaction - elif command in ("EXEC", b"EXEC"): - cb = partial(self._end_transaction, discard=False) - encoding = None - elif command in ("DISCARD", b"DISCARD"): - cb = partial(self._end_transaction, discard=True) - else: - cb = None - if encoding is _NOTSET: - encoding = self._encoding - fut = get_event_loop().create_future() - if self._pipeline_buffer is None: - self._writer.write(encode_command(command, *args)) - else: - encode_command(command, *args, buf=self._pipeline_buffer) - self._waiters.append((fut, encoding, cb)) - return fut + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + # set the socket_timeout now that we're connected + if self.socket_timeout is not None: + sock.settimeout(self.socket_timeout) - def execute_pubsub(self, command, *channels): - """Executes redis (p)subscribe/(p)unsubscribe commands. + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise - Returns asyncio.gather coroutine waiting for all channels/patterns - to receive answers. - """ - command = command.upper().strip() - assert command in _PUBSUB_COMMANDS, ("Pub/Sub command expected", command) - if self._reader is None or self._reader.at_eof(): - raise ConnectionClosedError("Connection closed or corrupted") - if None in set(channels): - raise TypeError("args must not contain None") - if not len(channels): - raise TypeError("No channels/patterns supplied") - is_pattern = len(command) in (10, 12) - mkchannel = partial(Channel, is_pattern=is_pattern) - channels = [ - ch if isinstance(ch, AbcChannel) else mkchannel(ch) for ch in channels - ] - if not all(ch.is_pattern == is_pattern for ch in channels): - raise ValueError(f"Not all channels {channels} match command {command}") - cmd = encode_command(command, *(ch.name for ch in channels)) - res = [] - for ch in channels: - fut = get_event_loop().create_future() - res.append(fut) - cb = partial(self._update_pubsub, ch=ch) - self._waiters.append((fut, None, cb)) - if self._pipeline_buffer is None: - self._writer.write(cmd) + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." else: - self._pipeline_buffer.extend(cmd) - return asyncio.gather(*res) + return ( + f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " + f"{exception.args[0]}." + ) - def close(self): - """Close connection.""" - self._do_close(ConnectionForcedCloseError()) + async def on_connect(self): + """Initialize the connection, authenticate and select a database""" + self._parser.on_connect(self) - def _do_close(self, exc): - if self._closed: + # if username and/or password are set, authenticate + if self.username or self.password: + if self.username: + auth_args = (self.username, self.password or "") + else: + auth_args = (self.password,) + # avoid checking health here -- PING will fail if we try + # to check the health prior to the AUTH + await self.send_command("AUTH", *auth_args, check_health=False) + + try: + auth_response = await self.read_response() + except AuthenticationWrongNumberOfArgsError: + # a username and password were specified but the Redis + # server seems to be < 6.0.0 which expects a single password + # arg. retry auth with just the password. + # https://github.com/andymccurdy/redis-py/issues/1274 + await self.send_command("AUTH", self.password, check_health=False) + auth_response = await self.read_response() + + if str_if_bytes(auth_response) != "OK": + raise AuthenticationError("Invalid Username or Password") + + # if a client_name is given, set it + if self.client_name: + await self.send_command("CLIENT", "SETNAME", self.client_name) + if str_if_bytes(self.read_response()) != "OK": + raise ConnectionError("Error setting client name") + + # if a database is specified, switch to it + if self.db: + await self.send_command("SELECT", self.db) + if str_if_bytes(await self.read_response()) != "OK": + raise ConnectionError("Invalid Database") + + async def disconnect(self): + """Disconnects from the Redis server""" + self._parser.on_disconnect() + if not self.is_connected: return - self._closed = True - self._closing = False - self._writer.transport.close() - self._reader_task.cancel() - self._reader_task = None - self._writer = None + try: + if os.getpid() == self.pid: + self._writer.close() + await self._writer.wait_closed() + except OSError: + pass self._reader = None - self._pipeline_buffer = None + self._writer = None - if exc is not None: - self._close_msg = str(exc) + async def check_health(self): + """Check the health of the connection with a PING/PONG""" + if self.health_check_interval and time.time() > self.next_health_check: + try: + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("Bad response from PING health check") + except (ConnectionError, TimeoutError) as err: + await self.disconnect() + try: + await self.send_command("PING", check_health=False) + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError( + "Bad response from PING health check" + ) from None + except BaseException as err2: + raise err2 from err - while self._waiters: - waiter, *spam = self._waiters.popleft() - logger.debug("Cancelling waiter %r", (waiter, spam)) - if exc is None: - _set_exception(waiter, ConnectionForcedCloseError()) + async def send_packed_command( + self, + command: Union[bytes, str, Iterable[Union[bytes, str]]], + check_health: bool = True, + ): + """Send an already packed command to the Redis server""" + if not self._writer: + await self.connect() + # guard against health check recursion + if check_health: + await self.check_health() + try: + if isinstance(command, str): + command = command.encode() + if isinstance(command, bytes): + command = [command] + self._writer.writelines(command) + await self._writer.drain() + except asyncio.TimeoutError: + await self.disconnect() + raise TimeoutError("Timeout writing to socket") from None + except OSError as e: + await self.disconnect() + if len(e.args) == 1: + errno, errmsg = "UNKNOWN", e.args[0] else: - _set_exception(waiter, exc) - while self._pubsub_channels: - _, ch = self._pubsub_channels.popitem() - logger.debug("Closing pubsub channel %r", ch) - ch.close(exc) - while self._pubsub_patterns: - _, ch = self._pubsub_patterns.popitem() - logger.debug("Closing pubsub pattern %r", ch) - ch.close(exc) + errno = e.args[0] + errmsg = e.args[1] + raise ConnectionError( + f"Error {errno} while writing to socket. {errmsg}." + ) from e + except BaseException: + await self.disconnect() + raise + + async def send_command(self, *args, **kwargs): + """Pack and send a command to the Redis server""" + if not self.is_connected: + await self.connect() + await self.send_packed_command( + self.pack_command(*args), check_health=kwargs.get("check_health", True) + ) + + async def can_read(self, timeout: float = 0): + """Poll the socket to see if there's data that can be read.""" + if not self.is_connected: + await self.connect() + return await self._parser.can_read(timeout) + + async def read_response(self): + """Read the response from a previously sent command""" + try: + with async_timeout.timeout(self.socket_timeout): + response = await self._parser.read_response() + except (asyncio.TimeoutError, asyncio.CancelledError): + await self.disconnect() + raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + except BaseException: + await self.disconnect() + raise + + if self.health_check_interval: + self.next_health_check = time.time() + self.health_check_interval + + if isinstance(response, ResponseError): + raise response from None + return response + + def pack_command(self, *args: EncodableT) -> List[bytes]: + """Pack a series of arguments into the Redis protocol""" + output = [] + # the client might have included 1 or more literal arguments in + # the command name, e.g., 'CONFIG GET'. The Redis server expects these + # arguments to be sent separately, so split the first argument + # manually. These arguments should be bytestrings so that they are + # not encoded. + if isinstance(args[0], str): + args = tuple(args[0].encode().split()) + args[1:] + elif b" " in args[0]: + args = tuple(args[0].split()) + args[1:] + + buff = SYM_EMPTY.join((SYM_STAR, str(len(args)).encode(), SYM_CRLF)) + + buffer_cutoff = self._buffer_cutoff + for arg in map(self.encoder.encode, args): + # to avoid large string mallocs, chunk the command into the + # output list if we're sending large values or memoryviews + arg_length = len(arg) + if ( + len(buff) > buffer_cutoff + or arg_length > buffer_cutoff + or isinstance(arg, memoryview) + ): + buff = SYM_EMPTY.join( + (buff, SYM_DOLLAR, str(arg_length).encode(), SYM_CRLF) + ) + output.append(buff) + output.append(arg) + buff = SYM_CRLF + else: + buff = SYM_EMPTY.join( + ( + buff, + SYM_DOLLAR, + str(arg_length).encode(), + SYM_CRLF, + arg, + SYM_CRLF, + ) + ) + output.append(buff) + return output + + def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes]: + """Pack multiple commands into the Redis protocol""" + output: List[bytes] = [] + pieces: List[bytes] = [] + buffer_length = 0 + buffer_cutoff = self._buffer_cutoff + + for cmd in commands: + for chunk in self.pack_command(*cmd): + chunklen = len(chunk) + if ( + buffer_length > buffer_cutoff + or chunklen > buffer_cutoff + or isinstance(chunk, memoryview) + ): + output.append(SYM_EMPTY.join(pieces)) + buffer_length = 0 + pieces = [] + + if chunklen > buffer_cutoff or isinstance(chunk, memoryview): + output.append(chunk) + else: + pieces.append(chunk) + buffer_length += chunklen + + if pieces: + output.append(SYM_EMPTY.join(pieces)) + return output + + +class SSLConnection(Connection): + def __init__( + self, + ssl_keyfile: str = None, + ssl_certfile: str = None, + ssl_cert_reqs: str = "required", + ssl_ca_certs: str = None, + ssl_check_hostname: bool = False, + **kwargs, + ): + super().__init__(**kwargs) + self.ssl_context = RedisSSLContext( + keyfile=ssl_keyfile, + certfile=ssl_certfile, + cert_reqs=ssl_cert_reqs, + ca_certs=ssl_ca_certs, + check_hostname=ssl_check_hostname, + ) + + @property + def keyfile(self): + return self.ssl_context.keyfile @property - def closed(self): - """True if connection is closed.""" - closed = self._closing or self._closed - if not closed and self._reader and self._reader.at_eof(): - self._closing = closed = True - get_event_loop().call_soon(self._do_close, None) - return closed - - async def wait_closed(self): - """Coroutine waiting until connection is closed.""" - await self._close_state.wait() + def certfile(self): + return self.ssl_context.certfile @property - def db(self): - """Currently selected db index.""" - return self._db + def cert_reqs(self): + return self.ssl_context.cert_reqs @property - def encoding(self): - """Current set codec or None.""" - return self._encoding + def ca_certs(self): + return self.ssl_context.ca_certs @property - def address(self): - """Redis server address, either host-port tuple or str.""" - return self._address - - def select(self, db): - """Change the selected database for the current connection.""" - if not isinstance(db, int): - raise TypeError(f"DB must be of int type, not {db!r}") - if db < 0: - raise ValueError(f"DB must be greater or equal 0, got {db!r}") - fut = self.execute("SELECT", db) - return wait_ok(fut) - - def _set_db(self, ok, args): - assert ok in {b"OK", "OK"}, ("Unexpected result of SELECT", ok) - self._db = args[0] - return ok - - def _start_transaction(self, ok): - assert self._in_transaction is None, ( - "Connection is already in transaction", - self._in_transaction, + def check_hostname(self): + return self.ssl_context.check_hostname + + +class RedisSSLContext: + __slots__ = ( + "keyfile", + "certfile", + "cert_reqs", + "ca_certs", + "context", + "check_hostname", + ) + + def __init__( + self, + keyfile: str = None, + certfile: str = None, + cert_reqs: str = None, + ca_certs: str = None, + check_hostname: bool = False, + ): + self.keyfile = keyfile + self.certfile = certfile + if cert_reqs is None: + self.cert_reqs = ssl.CERT_NONE + elif isinstance(cert_reqs, str): + CERT_REQS = { + "none": ssl.CERT_NONE, + "optional": ssl.CERT_OPTIONAL, + "required": ssl.CERT_REQUIRED, + } + if cert_reqs not in CERT_REQS: + raise RedisError( + "Invalid SSL Certificate Requirements Flag: %s" % cert_reqs + ) + self.cert_reqs = CERT_REQS[cert_reqs] + self.ca_certs = ca_certs + self.check_hostname = check_hostname + self.context = None + + def get(self) -> ssl.SSLContext: + if not self.context: + context = ssl.create_default_context() + context.check_hostname = self.check_hostname + context.verify_mode = self.cert_reqs + if self.certfile and self.keyfile: + context.load_cert_chain(certfile=self.certfile, keyfile=self.keyfile) + if self.ca_certs: + context.load_verify_locations(self.ca_certs) + self.context = context + return self.context + + +class UnixDomainSocketConnection(Connection): + def __init__( + self, + *, + path: str = "", + db: Union[str, int] = 0, + username: str = None, + password: str = None, + socket_timeout: float = None, + encoding: str = "utf-8", + encoding_errors: str = "strict", + decode_responses: bool = False, + retry_on_timeout: bool = False, + parser_class: Type[BaseParser] = DefaultParser, + socket_read_size: int = 65536, + health_check_interval: float = 0.0, + client_name=None, + loop: asyncio.AbstractEventLoop = None, + ): + self.pid = os.getpid() + self.path = path + self.db = db + self.username = username + self.client_name = client_name + self.password = password + self.socket_timeout = socket_timeout + self.retry_on_timeout = retry_on_timeout + self.health_check_interval = health_check_interval + self.next_health_check = 0 + self.encoder = Encoder(encoding, encoding_errors, decode_responses) + self._sock = None + self._parser = parser_class(socket_read_size=socket_read_size) + self._connect_callbacks = [] + self._buffer_cutoff = 6000 + self._loop = loop + + def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: + pieces = [ + ("path", self.path), + ("db", self.db), + ] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + async def _connect(self): + with async_timeout.timeout(self._connect_timeout): + reader, writer = await asyncio.open_unix_connection(path=self.path) + self._reader = reader + self._writer = writer + await self.on_connect() + + def _error_message(self, exception): + # args for socket.error can either be (errno, "message") + # or just "message" + if len(exception.args) == 1: + return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to unix socket: " + f"{self.path}. {exception.args[1]}." + ) + + +FALSE_STRINGS = ("0", "F", "FALSE", "N", "NO") + + +def to_bool(value) -> bool: + if value is None or value == "": + return None + if isinstance(value, str) and value.upper() in FALSE_STRINGS: + return False + return bool(value) + + +URL_QUERY_ARGUMENT_PARSERS = { + "db": int, + "socket_timeout": float, + "socket_connect_timeout": float, + "socket_keepalive": to_bool, + "retry_on_timeout": to_bool, + "max_connections": int, + "health_check_interval": int, + "ssl_check_hostname": to_bool, +} + + +class ConnectKwargs(TypedDict, total=False): + username: str + password: str + connection_class: Type[Connection] + host: str + port: int + db: int + + +def parse_url(url: str) -> ConnectKwargs: + parsed: ParseResult = urlparse(url) + kwargs: ConnectKwargs = {} + + for name, value in parse_qs(parsed.query).items(): + if value and len(value) > 0: + value = unquote(value[0]) + parser = URL_QUERY_ARGUMENT_PARSERS.get(name) + if parser: + try: + kwargs[name] = parser(value) + except (TypeError, ValueError): + raise ValueError("Invalid value for `%s` in connection URL." % name) + else: + kwargs[name] = value + + if parsed.username: + kwargs["username"] = unquote(parsed.username) + if parsed.password: + kwargs["password"] = unquote(parsed.password) + + # We only support redis://, rediss:// and unix:// schemes. + if parsed.scheme == "unix": + if parsed.path: + kwargs["path"] = unquote(parsed.path) + kwargs["connection_class"] = UnixDomainSocketConnection + + elif parsed.scheme in ("redis", "rediss"): + if parsed.hostname: + kwargs["host"] = unquote(parsed.hostname) + if parsed.port: + kwargs["port"] = int(parsed.port) + + # If there's a path argument, use it as the db argument if a + # querystring value wasn't specified + if parsed.path and "db" not in kwargs: + try: + kwargs["db"] = int(unquote(parsed.path).replace("/", "")) + except (AttributeError, ValueError): + pass + + if parsed.scheme == "rediss": + kwargs["connection_class"] = SSLConnection + else: + valid_schemes = "redis://, rediss://, unix://" + raise ValueError( + "Redis URL must specify one of the following " + "schemes (%s)" % valid_schemes ) - self._in_transaction = deque() - self._transaction_error = None - return ok - - def _end_transaction(self, obj, discard): - assert self._in_transaction is not None, ( - "Connection is not in transaction", - obj, + + return kwargs + + +_CP = TypeVar("_CP") + + +class ConnectionPool: + """ + Create a connection pool. ``If max_connections`` is set, then this + object raises :py:class:`~redis.ConnectionError` when the pool's + limit is reached. + + By default, TCP connections are created unless ``connection_class`` + is specified. Use :py:class:`~redis.UnixDomainSocketConnection` for + unix sockets. + + Any additional keyword arguments are passed to the constructor of + ``connection_class``. + """ + + @classmethod + def from_url(cls: Type[_CP], url: str, **kwargs) -> _CP: + """ + Return a connection pool configured from the given URL. + + For example:: + + redis://[[username]:[password]]@localhost:6379/0 + rediss://[[username]:[password]]@localhost:6379/0 + unix://[[username]:[password]]@/path/to/socket.sock?db=0 + + Three URL schemes are supported: + + - `redis://` creates a TCP socket connection. See more at: + + - `rediss://` creates a SSL wrapped TCP socket connection. See more at: + + - ``unix://``: creates a Unix Domain Socket connection. + + The username, password, hostname, path and all querystring values + are passed through urllib.parse.unquote in order to replace any + percent-encoded values with their corresponding characters. + + There are several ways to specify a database number. The first value + found will be used: + 1. A ``db`` querystring option, e.g. redis://localhost?db=0 + 2. If using the redis:// or rediss:// schemes, the path argument + of the url, e.g. redis://localhost/0 + 3. A ``db`` keyword argument to this function. + + If none of these options are specified, the default db=0 is used. + + All querystring options are cast to their appropriate Python types. + Boolean arguments can be specified with string values "True"/"False" + or "Yes"/"No". Values that cannot be properly cast cause a + ``ValueError`` to be raised. Once parsed, the querystring arguments + and keyword arguments are passed to the ``ConnectionPool``'s + class initializer. In the case of conflicting arguments, querystring + arguments always win. + """ + url_options = parse_url(url) + kwargs.update(url_options) + return cls(**kwargs) + + def __init__( + self, + connection_class: Type[Connection] = Connection, + max_connections: int = None, + **connection_kwargs, + ): + max_connections = max_connections or 2 ** 31 + if not isinstance(max_connections, int) or max_connections < 0: + raise ValueError('"max_connections" must be a positive integer') + + self.connection_class = connection_class + self.connection_kwargs = connection_kwargs + self.max_connections = max_connections + + # a lock to protect the critical section in _checkpid(). + # this lock is acquired when the process id changes, such as + # after a fork. during this time, multiple threads in the child + # process could attempt to acquire this lock. the first thread + # to acquire the lock will reset the data structures and lock + # object of this pool. subsequent threads acquiring this lock + # will notice the first thread already did the work and simply + # release the lock. + self._fork_lock = threading.Lock() + self._lock: asyncio.Lock + self._created_connections: int + self._available_connections: List[Connection] + self._in_use_connections: Set[Connection] + self.reset() + self.loop = self.connection_kwargs.get("loop") + self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"<{self.connection_class(**self.connection_kwargs)!r}>" ) - self._transaction_error = None - recall, self._in_transaction = self._in_transaction, None - recall.popleft() # ignore first (its _start_transaction) - if discard: - return obj - assert isinstance(obj, list) or (obj is None and not discard), ( - "Unexpected MULTI/EXEC result", - obj, - recall, + + def reset(self): + self._lock = asyncio.Lock() + self._created_connections = 0 + self._available_connections = [] + self._in_use_connections = set() + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() + + def _checkpid(self): + # _checkpid() attempts to keep ConnectionPool fork-safe on modern + # systems. this is called by all ConnectionPool methods that + # manipulate the pool's state such as get_connection() and release(). + # + # _checkpid() determines whether the process has forked by comparing + # the current process id to the process id saved on the ConnectionPool + # instance. if these values are the same, _checkpid() simply returns. + # + # when the process ids differ, _checkpid() assumes that the process + # has forked and that we're now running in the child process. the child + # process cannot use the parent's file descriptors (e.g., sockets). + # therefore, when _checkpid() sees the process id change, it calls + # reset() in order to reinitialize the child's ConnectionPool. this + # will cause the child to make all new connection objects. + # + # _checkpid() is protected by self._fork_lock to ensure that multiple + # threads in the child process do not call reset() multiple times. + # + # there is an extremely small chance this could fail in the following + # scenario: + # 1. process A calls _checkpid() for the first time and acquires + # self._fork_lock. + # 2. while holding self._fork_lock, process A forks (the fork() + # could happen in a different thread owned by process A) + # 3. process B (the forked child process) inherits the + # ConnectionPool's state from the parent. that state includes + # a locked _fork_lock. process B will not be notified when + # process A releases the _fork_lock and will thus never be + # able to acquire the _fork_lock. + # + # to mitigate this possible deadlock, _checkpid() will only wait 5 + # seconds to acquire _fork_lock. if _fork_lock cannot be acquired in + # that time it is assumed that the child is deadlocked and a + # redis.ChildDeadlockedError error is raised. + if self.pid != os.getpid(): + acquired = self._fork_lock.acquire(timeout=5) + if not acquired: + raise ChildDeadlockedError + # reset() the instance for the new process if another thread + # hasn't already done so + try: + if self.pid != os.getpid(): + self.reset() + finally: + self._fork_lock.release() + + async def get_connection(self, command_name, *keys, **options): + """Get a connection from the pool""" + self._checkpid() + async with self._lock: + try: + connection = self._available_connections.pop() + except IndexError: + connection = self.make_connection() + self._in_use_connections.add(connection) + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't + # leak it + await self.release(connection) + raise + + return connection + + def get_encoder(self): + """Return an encoder based on encoding settings""" + kwargs = self.connection_kwargs + return self.encoder_class( + encoding=kwargs.get("encoding", "utf-8"), + encoding_errors=kwargs.get("encoding_errors", "strict"), + decode_responses=kwargs.get("decode_responses", False), ) - # TODO: need to be able to re-try transaction - if obj is None: - err = WatchVariableError("WATCH variable has changed") - obj = [err] * len(recall) - assert len(obj) == len(recall), ( - "Wrong number of result items in mutli-exec", - obj, - recall, + + def make_connection(self): + """Create a new connection""" + if self._created_connections >= self.max_connections: + raise ConnectionError("Too many connections") + self._created_connections += 1 + return self.connection_class(**self.connection_kwargs) + + async def release(self, connection: Connection): + """Releases the connection back to the pool""" + self._checkpid() + async with self._lock: + try: + self._in_use_connections.remove(connection) + except KeyError: + # Gracefully fail when a connection is returned to this pool + # that the pool doesn't actually own + pass + + if self.owns_connection(connection): + self._available_connections.append(connection) + else: + # pool doesn't own this connection. do not add it back + # to the pool and decrement the count so that another + # connection can take its place if needed + self._created_connections -= 1 + await connection.disconnect() + return + + def owns_connection(self, connection: Connection): + return connection.pid == self.pid + + async def disconnect(self, inuse_connections: bool = True): + """ + Disconnects connections in the pool + + If ``inuse_connections`` is True, disconnect connections that are + current in use, potentially by other threads. Otherwise only disconnect + connections that are idle in the pool. + """ + self._checkpid() + async with self._lock: + if inuse_connections: + connections = chain( + self._available_connections, self._in_use_connections + ) + else: + connections = self._available_connections + resp = await asyncio.gather( + *(connection.disconnect() for connection in connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc + + +class BlockingConnectionPool(ConnectionPool): + """ + Thread-safe blocking connection pool:: + + >>> from aioredis.client import Redis + >>> client = Redis(connection_pool=BlockingConnectionPool()) + + It performs the same function as the default + :py:class:`~redis.ConnectionPool` implementation, in that, + it maintains a pool of reusable connections that can be shared by + multiple redis clients (safely across threads if required). + + The difference is that, in the event that a client tries to get a + connection from the pool when all of connections are in use, rather than + raising a :py:class:`~redis.ConnectionError` (as the default + :py:class:`~redis.ConnectionPool` implementation does), it + makes the client wait ("blocks") for a specified number of seconds until + a connection becomes available. + + Use ``max_connections`` to increase / decrease the pool size:: + + >>> pool = BlockingConnectionPool(max_connections=10) + + Use ``timeout`` to tell it either how many seconds to wait for a connection + to become available, or to block forever: + + >>> # Block forever. + >>> pool = BlockingConnectionPool(timeout=None) + + >>> # Raise a ``ConnectionError`` after five seconds if a connection is + >>> # not available. + >>> pool = BlockingConnectionPool(timeout=5) + """ + + def __init__( + self, + max_connections: int = 50, + timeout: Optional[int] = 20, + connection_class: Type[Connection] = Connection, + queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, + **connection_kwargs, + ): + + self.queue_class = queue_class + self.timeout = timeout + self._connections: List[Connection] + super().__init__( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs, ) - res = [] - for o, (encoding, cb) in zip(obj, recall): - if not isinstance(o, RedisError): - try: - if encoding: - o = decode(o, encoding) - if cb: - o = cb(o) - except Exception as err: - res.append(err) - continue - res.append(o) - return res - - def _update_pubsub(self, obj, *, ch): - kind, *pattern, channel, subscriptions = obj - self._in_pubsub, was_in_pubsub = subscriptions, self._in_pubsub - # XXX: the channels/patterns storage should be refactored. - # if code which supposed to read from channel/pattern - # failed (exception in reader or else) than - # the channel object will still reside in memory - # and leak memory (messages will be put in queue). - if kind == b"subscribe" and channel not in self._pubsub_channels: - self._pubsub_channels[channel] = ch - elif kind == b"psubscribe" and channel not in self._pubsub_patterns: - self._pubsub_patterns[channel] = ch - if not was_in_pubsub: - self._process_pubsub(obj, process_waiters=False) - return obj - @property - def in_transaction(self): - """Set to True when MULTI command was issued.""" - return self._in_transaction is not None + def reset(self): + # Create and fill up a thread safe queue with ``None`` values. + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except asyncio.QueueFull: + break - @property - def in_pubsub(self): - """Indicates that connection is in PUB/SUB mode. + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + + # this must be the last operation in this method. while reset() is + # called when holding _fork_lock, other threads in this process + # can call _checkpid() which compares self.pid and os.getpid() without + # holding any lock (for performance reasons). keeping this assignment + # as the last operation ensures that those other threads will also + # notice a pid difference and block waiting for the first thread to + # release _fork_lock. when each of these threads eventually acquire + # _fork_lock, they will notice that another thread already called + # reset() and they will immediately release _fork_lock and continue on. + self.pid = os.getpid() - Provides the number of subscribed channels. + def make_connection(self): + """Make a fresh connection.""" + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + + async def get_connection(self, command_name, *keys, **options): """ - return self._in_pubsub + Get a connection, blocking for ``self.timeout`` until a connection + is available from the pool. - @property - def pubsub_channels(self): - """Returns read-only channels dict.""" - return types.MappingProxyType(self._pubsub_channels) + If the connection returned is ``None`` then creates a new connection. + Because we use a last-in first-out queue, the existing connections + (having been returned to the pool after the initial ``None`` values + were added) will be returned before ``None`` values. This means we only + create new connections when we need to, i.e.: the actual number of + connections will only increase in response to demand. + """ + # Make sure we haven't changed process. + self._checkpid() - @property - def pubsub_patterns(self): - """Returns read-only patterns dict.""" - return types.MappingProxyType(self._pubsub_patterns) - - def auth(self, password): - """Authenticate to server.""" - fut = self.execute("AUTH", password) - return wait_ok(fut) - - def setname(self, name): - """Set the current connection name.""" - fut = self.execute(b"CLIENT", b"SETNAME", name) - return wait_ok(fut) + # Try and get a connection from the pool. If one isn't available within + # self.timeout then raise a ``ConnectionError``. + connection = None + try: + async with async_timeout.timeout(self.timeout): + connection = await self.pool.get() + except (asyncio.QueueEmpty, asyncio.TimeoutError): + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + + try: + # ensure this connection is connected to Redis + await connection.connect() + # connections that the pool provides should be ready to send + # a command. if not, the connection was either returned to the + # pool before all data has been read or the socket has been + # closed. either way, reconnect and verify everything is good. + try: + if await connection.can_read(): + raise ConnectionError("Connection has data") from None + except ConnectionError: + await connection.disconnect() + await connection.connect() + if await connection.can_read(): + raise ConnectionError("Connection not ready") from None + except BaseException: + # release the connection back to the pool so that we don't leak it + await self.release(connection) + raise + + return connection + + async def release(self, connection: Connection): + """Releases the connection back to the pool.""" + # Make sure we haven't changed process. + self._checkpid() + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + await connection.disconnect() + self.pool.put_nowait(None) + return + + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except asyncio.QueueFull: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + + async def disconnect(self, inuse_connections: bool = True): + """Disconnects all connections in the pool.""" + self._checkpid() + async with self._lock: + resp = await asyncio.gather( + *(connection.disconnect() for connection in self._connections), + return_exceptions=True, + ) + exc = next((r for r in resp if isinstance(r, BaseException)), None) + if exc: + raise exc diff --git a/aioredis/errors.py b/aioredis/errors.py deleted file mode 100644 index 6a97216dd..000000000 --- a/aioredis/errors.py +++ /dev/null @@ -1,113 +0,0 @@ -from typing import Optional, Sequence - -__all__ = ( - "RedisError", - "ProtocolError", - "ReplyError", - "MaxClientsError", - "AuthError", - "PipelineError", - "MultiExecError", - "WatchVariableError", - "ChannelClosedError", - "ConnectionClosedError", - "ConnectionForcedCloseError", - "PoolClosedError", - "MasterNotFoundError", - "SlaveNotFoundError", - "ReadOnlyError", -) - - -class RedisError(Exception): - """Base exception class for aioredis exceptions.""" - - -class ProtocolError(RedisError): - """Raised when protocol error occurs.""" - - -class ReplyError(RedisError): - """Raised for redis error replies (-ERR).""" - - MATCH_REPLY = None # type: Optional[Sequence[str]] - - def __new__(cls, msg, *args): - for klass in cls.__subclasses__(): - if msg and klass.MATCH_REPLY and msg.startswith(klass.MATCH_REPLY): - return klass(msg, *args) - return super().__new__(cls, msg, *args) - - -class MaxClientsError(ReplyError): - """Raised for redis server when the maximum number of client has been - reached.""" - - MATCH_REPLY = "ERR max number of clients reached" - - -class AuthError(ReplyError): - """Raised when authentication errors occurs.""" - - MATCH_REPLY = ( - "NOAUTH ", - "ERR invalid password", - "ERR Client sent AUTH, but no password is set", - ) - - -class BusyGroupError(ReplyError): - """Raised if Consumer Group name already exists.""" - - MATCH_REPLY = "BUSYGROUP Consumer Group name already exists" - - -class PipelineError(RedisError): - """Raised if command within pipeline raised error.""" - - def __init__(self, errors): - super().__init__(f"{self.__class__.__name__} errors:", errors) - - -class MultiExecError(PipelineError): - """Raised if command within MULTI/EXEC block caused error.""" - - -class WatchVariableError(MultiExecError): - """Raised if watched variable changed (EXEC returns None).""" - - -class ChannelClosedError(RedisError): - """Raised when Pub/Sub channel is unsubscribed and messages queue is empty.""" - - -class ReadOnlyError(RedisError): - """Raised from slave when read-only mode is enabled""" - - -class MasterNotFoundError(RedisError): - """Raised for sentinel master not found error.""" - - -class SlaveNotFoundError(RedisError): - """Raised for sentinel slave not found error.""" - - -class MasterReplyError(RedisError): - """Raised by sentinel client for master error replies.""" - - -class SlaveReplyError(RedisError): - """Raised by sentinel client for slave error replies.""" - - -class ConnectionClosedError(RedisError): - """Raised if connection to server was closed.""" - - -class ConnectionForcedCloseError(ConnectionClosedError): - """Raised if connection was closed with .close() method.""" - - -class PoolClosedError(RedisError): - """Raised if pool is closed.""" diff --git a/aioredis/exceptions.py b/aioredis/exceptions.py new file mode 100644 index 000000000..4aeba5e68 --- /dev/null +++ b/aioredis/exceptions.py @@ -0,0 +1,92 @@ +"""Core exceptions raised by the Redis client""" +import asyncio +import builtins + + +class RedisError(Exception): + pass + + +class ConnectionError(builtins.ConnectionError, RedisError): + pass + + +class TimeoutError(asyncio.TimeoutError, builtins.TimeoutError, RedisError): + pass + + +class AuthenticationError(ConnectionError): + pass + + +class BusyLoadingError(ConnectionError): + pass + + +class InvalidResponse(RedisError): + pass + + +class ResponseError(RedisError): + pass + + +class DataError(RedisError): + pass + + +class PubSubError(RedisError): + pass + + +class WatchError(RedisError): + pass + + +class NoScriptError(ResponseError): + pass + + +class ExecAbortError(ResponseError): + pass + + +class ReadOnlyError(ResponseError): + pass + + +class NoPermissionError(ResponseError): + pass + + +class ModuleError(ResponseError): + pass + + +class LockError(RedisError, ValueError): + """Errors acquiring or releasing a lock""" + + # NOTE: For backwards compatibility, this class derives from ValueError. + # This was originally chosen to behave like threading.Lock. + pass + + +class LockNotOwnedError(LockError): + """Error trying to extend or release a lock that is (no longer) owned""" + + pass + + +class ChildDeadlockedError(Exception): + """Error indicating that a child process is deadlocked after a fork()""" + + pass + + +class AuthenticationWrongNumberOfArgsError(ResponseError): + """ + An error to indicate that the wrong number of args + were sent to the AUTH command + """ + + pass diff --git a/aioredis/lock.py b/aioredis/lock.py new file mode 100644 index 000000000..8ee023674 --- /dev/null +++ b/aioredis/lock.py @@ -0,0 +1,309 @@ +from __future__ import annotations + +import threading +import time as mod_time +import uuid +from types import SimpleNamespace +from typing import TYPE_CHECKING, Awaitable, NoReturn, Union + +from aioredis.exceptions import LockError, LockNotOwnedError + +if TYPE_CHECKING: + from aioredis import Redis + + +class Lock: + """ + A shared, distributed Lock. Using Redis for locking allows the Lock + to be shared across processes and/or machines. + + It's left to the user to resolve deadlock issues and make sure + multiple clients play nicely together. + """ + + lua_release = None + lua_extend = None + lua_reacquire = None + + # KEYS[1] - lock name + # ARGV[1] - token + # return 1 if the lock was released, otherwise 0 + LUA_RELEASE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('del', KEYS[1]) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - additional milliseconds + # ARGV[3] - "0" if the additional time should be added to the lock's + # existing ttl or "1" if the existing ttl should be replaced + # return 1 if the locks time was extended, otherwise 0 + LUA_EXTEND_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + local expiration = redis.call('pttl', KEYS[1]) + if not expiration then + expiration = 0 + end + if expiration < 0 then + return 0 + end + + local newttl = ARGV[2] + if ARGV[3] == "0" then + newttl = ARGV[2] + expiration + end + redis.call('pexpire', KEYS[1], newttl) + return 1 + """ + + # KEYS[1] - lock name + # ARGV[1] - token + # ARGV[2] - milliseconds + # return 1 if the locks time was reacquired, otherwise 0 + LUA_REACQUIRE_SCRIPT = """ + local token = redis.call('get', KEYS[1]) + if not token or token ~= ARGV[1] then + return 0 + end + redis.call('pexpire', KEYS[1], ARGV[2]) + return 1 + """ + + def __init__( + self, + redis: Redis, + name: str, + timeout: float = None, + sleep: float = 0.1, + blocking: bool = True, + blocking_timeout: float = None, + thread_local: bool = True, + ): + """ + Create a new Lock instance named ``name`` using the Redis client + supplied by ``redis``. + + ``timeout`` indicates a maximum life for the lock. + By default, it will remain locked until release() is called. + ``timeout`` can be specified as a float or integer, both representing + the number of seconds to wait. + + ``sleep`` indicates the amount of time to sleep per loop iteration + when the lock is in blocking mode and another client is currently + holding the lock. + + ``blocking`` indicates whether calling ``acquire`` should block until + the lock has been acquired or to fail immediately, causing ``acquire`` + to return False and the lock not being acquired. Defaults to True. + Note this value can be overridden by passing a ``blocking`` + argument to ``acquire``. + + ``blocking_timeout`` indicates the maximum amount of time in seconds to + spend trying to acquire the lock. A value of ``None`` indicates + continue trying forever. ``blocking_timeout`` can be specified as a + float or integer, both representing the number of seconds to wait. + + ``thread_local`` indicates whether the lock token is placed in + thread-local storage. By default, the token is placed in thread local + storage so that a thread only sees its token, not a token set by + another thread. Consider the following timeline: + + time: 0, thread-1 acquires `my-lock`, with a timeout of 5 seconds. + thread-1 sets the token to "abc" + time: 1, thread-2 blocks trying to acquire `my-lock` using the + Lock instance. + time: 5, thread-1 has not yet completed. redis expires the lock + key. + time: 5, thread-2 acquired `my-lock` now that it's available. + thread-2 sets the token to "xyz" + time: 6, thread-1 finishes its work and calls release(). if the + token is *not* stored in thread local storage, then + thread-1 would see the token value as "xyz" and would be + able to successfully release the thread-2's lock. + + In some use cases it's necessary to disable thread local storage. For + example, if you have code where one thread acquires a lock and passes + that lock instance to a worker thread to release later. If thread + local storage isn't disabled in this case, the worker thread won't see + the token set by the thread that acquired the lock. Our assumption + is that these cases aren't common and as such default to using + thread local storage. + """ + self.redis = redis + self.name = name + self.timeout = timeout + self.sleep = sleep + self.blocking = blocking + self.blocking_timeout = blocking_timeout + self.thread_local = bool(thread_local) + self.local = threading.local() if self.thread_local else SimpleNamespace() + self.local.token = None + self.register_scripts() + + def register_scripts(self): + cls = self.__class__ + client = self.redis + if cls.lua_release is None: + cls.lua_release = client.register_script(cls.LUA_RELEASE_SCRIPT) + if cls.lua_extend is None: + cls.lua_extend = client.register_script(cls.LUA_EXTEND_SCRIPT) + if cls.lua_reacquire is None: + cls.lua_reacquire = client.register_script(cls.LUA_REACQUIRE_SCRIPT) + + async def __aenter__(self): + # force blocking, as otherwise the user would have to check whether + # the lock was actually acquired or not. + if await self.acquire(blocking=True): + return self + raise LockError("Unable to acquire lock within the time specified") + + async def __aexit__(self, exc_type, exc_value, traceback): + await self.release() + + async def acquire( + self, + blocking: bool = None, + blocking_timeout: float = None, + token: Union[str, bytes] = None, + ): + """ + Use Redis to hold a shared, distributed lock named ``name``. + Returns True once the lock is acquired. + + If ``blocking`` is False, always return immediately. If the lock + was acquired, return True, otherwise return False. + + ``blocking_timeout`` specifies the maximum number of seconds to + wait trying to acquire the lock. + + ``token`` specifies the token value to be used. If provided, token + must be a bytes object or a string that can be encoded to a bytes + object with the default encoding. If a token isn't specified, a UUID + will be generated. + """ + sleep = self.sleep + if token is None: + token = uuid.uuid1().hex.encode() + else: + encoder = self.redis.connection_pool.get_encoder() + token = encoder.encode(token) + if blocking is None: + blocking = self.blocking + if blocking_timeout is None: + blocking_timeout = self.blocking_timeout + stop_trying_at = None + if blocking_timeout is not None: + stop_trying_at = mod_time.monotonic() + blocking_timeout + while True: + if await self.do_acquire(token): + self.local.token = token + return True + if not blocking: + return False + next_try_at = mod_time.monotonic() + sleep + if stop_trying_at is not None and next_try_at > stop_trying_at: + return False + mod_time.sleep(sleep) + + async def do_acquire(self, token: Union[str, bytes]) -> bool: + if self.timeout: + # convert to milliseconds + timeout = int(self.timeout * 1000) + else: + timeout = None + if await self.redis.set(self.name, token, nx=True, px=timeout): + return True + return False + + async def locked(self) -> bool: + """ + Returns True if this key is locked by any process, otherwise False. + """ + return await self.redis.get(self.name) is not None + + async def owned(self) -> bool: + """ + Returns True if this key is locked by this lock, otherwise False. + """ + stored_token = await self.redis.get(self.name) + # need to always compare bytes to bytes + # TODO: this can be simplified when the context manager is finished + if stored_token and not isinstance(stored_token, bytes): + encoder = self.redis.connection_pool.get_encoder() + stored_token = encoder.encode(stored_token) + return self.local.token is not None and stored_token == self.local.token + + def release(self) -> Awaitable[NoReturn]: + """Releases the already acquired lock""" + expected_token = self.local.token + if expected_token is None: + raise LockError("Cannot release an unlocked lock") + self.local.token = None + return self.do_release(expected_token) + + async def do_release(self, expected_token: bytes): + if not bool( + await self.lua_release( + keys=[self.name], args=[expected_token], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot release a lock" " that's no longer owned") + + def extend( + self, additional_time: float, replace_ttl: bool = False + ) -> Awaitable[bool]: + """ + Adds more time to an already acquired lock. + + ``additional_time`` can be specified as an integer or a float, both + representing the number of seconds to add. + + ``replace_ttl`` if False (the default), add `additional_time` to + the lock's existing ttl. If True, replace the lock's ttl with + `additional_time`. + """ + if self.local.token is None: + raise LockError("Cannot extend an unlocked lock") + if self.timeout is None: + raise LockError("Cannot extend a lock with no timeout") + return self.do_extend(additional_time, replace_ttl) + + async def do_extend(self, additional_time, replace_ttl) -> bool: + additional_time = int(additional_time * 1000) + if not bool( + await self.lua_extend( + keys=[self.name], + args=[self.local.token, additional_time, replace_ttl and "1" or "0"], + client=self.redis, + ) + ): + raise LockNotOwnedError("Cannot extend a lock that's" " no longer owned") + return True + + def reacquire(self) -> Awaitable[bool]: + """ + Resets a TTL of an already acquired lock back to a timeout value. + """ + if self.local.token is None: + raise LockError("Cannot reacquire an unlocked lock") + if self.timeout is None: + raise LockError("Cannot reacquire a lock with no timeout") + return self.do_reacquire() + + async def do_reacquire(self) -> bool: + timeout = int(self.timeout * 1000) + if not bool( + await self.lua_reacquire( + keys=[self.name], args=[self.local.token, timeout], client=self.redis + ) + ): + raise LockNotOwnedError("Cannot reacquire a lock that's" " no longer owned") + return True diff --git a/aioredis/locks.py b/aioredis/locks.py deleted file mode 100644 index 20d01e807..000000000 --- a/aioredis/locks.py +++ /dev/null @@ -1,78 +0,0 @@ -import asyncio -import sys -from asyncio.locks import Lock as _Lock - -# Fixes an issue with all Python versions that leaves pending waiters -# without being awakened when the first waiter is canceled. -# Code adapted from the PR https://github.com/python/cpython/pull/1031 -# Waiting once it is merged to make a proper condition to relay on -# the stdlib implementation or this one patched - -# Fixes an issue with multiple lock acquire https://bugs.python.org/issue32734 -# Code adapted from the PR https://github.com/python/cpython/pull/5466 - - -class Lock(_Lock): - - if sys.version_info < (3, 6, 5): - - async def acquire(self): - """Acquire a lock. - - This method blocks until the lock is unlocked, then sets it to - locked and returns True. - """ - if not self._locked and ( - self._waiters is None or all(w.cancelled() for w in self._waiters) - ): - self._locked = True - return True - - fut = self._loop.create_future() - self._waiters.append(fut) - - # Finally block should be called before the CancelledError - # handling as we don't want CancelledError to call - # _wake_up_first() and attempt to wake up itself. - try: - try: - await fut - finally: - self._waiters.remove(fut) - except asyncio.CancelledError: - if not self._locked: - self._wake_up_first() - raise - - self._locked = True - return True - - def release(self): - """Release a lock. - - When the lock is locked, reset it to unlocked, and return. - If any other coroutines are blocked waiting for the lock to become - unlocked, allow exactly one of them to proceed. - - When invoked on an unlocked lock, a RuntimeError is raised. - - There is no return value. - """ - if self._locked: - self._locked = False - self._wake_up_first() - else: - raise RuntimeError("Lock is not acquired.") - - def _wake_up_first(self): - """Wake up the first waiter if it isn't done.""" - try: - fut = next(iter(self._waiters or [])) - except StopIteration: - return - - # .done() necessarily means that a waiter will wake up later on and - # either take the lock, or, if it was cancelled and lock wasn't - # taken already, will hit this again and wake up a new waiter. - if not fut.done(): - fut.set_result(True) diff --git a/aioredis/parser.py b/aioredis/parser.py deleted file mode 100644 index d0ca374b0..000000000 --- a/aioredis/parser.py +++ /dev/null @@ -1,177 +0,0 @@ -from typing import Callable, Generator, Iterator, Optional - -from .errors import ProtocolError, ReplyError - -__all__ = [ - "Reader", - "PyReader", -] - - -class PyReader: - """Pure-Python Redis protocol parser that follows hiredis.Reader - interface (except setmaxbuf/getmaxbuf). - """ - - def __init__( - self, - protocolError: Callable = ProtocolError, - replyError: Callable = ReplyError, - encoding: Optional[str] = None, - ): - if not callable(protocolError): - raise TypeError("Expected a callable") - if not callable(replyError): - raise TypeError("Expected a callable") - self._parser = Parser(protocolError, replyError, encoding) - - def feed(self, data, o: int = 0, l: int = -1): # noqa: E741 - """Feed data to parser.""" - if l == -1: # noqa: E741 - l = len(data) - o # noqa: E741 - if o < 0 or l < 0: - raise ValueError("negative input") - if o + l > len(data): - raise ValueError("input is larger than buffer size") - self._parser.buf.extend(data[o : o + l]) - - def gets(self): - """Get parsed value or False otherwise. - - Error replies are return as replyError exceptions (not raised). - Protocol errors are raised. - """ - return self._parser.parse_one() - - def setmaxbuf(self, size: Optional[int]) -> None: - """No-op.""" - pass - - def getmaxbuf(self) -> int: - """No-op.""" - return 0 - - -class Parser: - def __init__( - self, protocolError: Callable, replyError: Callable, encoding: Optional[str] - ): - - self.buf = bytearray() # type: bytearray - self.pos = 0 # type: int - self.protocolError = protocolError # type: Callable - self.replyError = replyError # type: Callable - self.encoding = encoding # type: Optional[str] - self._err = None - self._gen = None # type: Optional[Generator] - - def waitsome(self, size: int) -> Iterator[bool]: - # keep yielding false until at least `size` bytes added to buf. - while len(self.buf) < self.pos + size: - yield False - - def waitany(self) -> Iterator[bool]: - yield from self.waitsome(len(self.buf) + 1) - - def readone(self): - if not self.buf[self.pos : self.pos + 1]: - yield from self.waitany() - val = self.buf[self.pos : self.pos + 1] - self.pos += 1 - return val - - def readline(self, size: Optional[int] = None): - if size is not None: - if len(self.buf) < size + 2 + self.pos: - yield from self.waitsome(size + 2) - offset = self.pos + size - if self.buf[offset : offset + 2] != b"\r\n": - raise self.error("Expected b'\r\n'") - else: - offset = self.buf.find(b"\r\n", self.pos) - while offset < 0: - yield from self.waitany() - offset = self.buf.find(b"\r\n", self.pos) - val = self.buf[self.pos : offset] - self.pos = 0 - del self.buf[: offset + 2] - return val - - def readint(self): - try: - return int((yield from self.readline())) - except ValueError as exc: - raise self.error(exc) - - def error(self, msg): - self._err = self.protocolError(msg) - return self._err - - # TODO: too complex. Clean this up. - def parse(self, is_bulk: bool = False): # noqa: C901 - if self._err is not None: - raise self._err - ctl = yield from self.readone() - if ctl == b"+": - val = yield from self.readline() - if self.encoding is not None: - try: - return val.decode(self.encoding) - except UnicodeDecodeError: - pass - return bytes(val) - elif ctl == b"-": - val = yield from self.readline() - return self.replyError(val.decode("utf-8")) - elif ctl == b":": - return (yield from self.readint()) - elif ctl == b"$": - val = yield from self.readint() - if val == -1: - return None - val = yield from self.readline(val) - if self.encoding: - try: - return val.decode(self.encoding) - except UnicodeDecodeError: - pass - return bytes(val) - elif ctl == b"*": - val = yield from self.readint() - if val == -1: - return None - bulk_array = [] - error = None - for _ in range(val): - try: - bulk_array.append((yield from self.parse(is_bulk=True))) - except LookupError as err: - if error is None: - error = err - if error is not None: - raise error - return bulk_array - else: - raise self.error(f"Invalid first byte: {ctl!r}") - - def parse_one(self): - if self._gen is None: - self._gen = self.parse() - try: - self._gen.send(None) - except StopIteration as exc: - self._gen = None - return exc.value - except Exception: - self._gen = None - raise - else: - return False - - -try: - import hiredis - - Reader = hiredis.Reader -except ImportError: - Reader = PyReader diff --git a/aioredis/pool.py b/aioredis/pool.py deleted file mode 100644 index e24e2763a..000000000 --- a/aioredis/pool.py +++ /dev/null @@ -1,527 +0,0 @@ -import asyncio -import collections -import sys -import types -import warnings - -from .abc import AbcPool -from .connection import _PUBSUB_COMMANDS, create_connection -from .errors import PoolClosedError -from .locks import Lock -from .log import logger -from .util import CloseEvent, parse_url - - -async def create_pool( - address, - *, - db=None, - password=None, - ssl=None, - encoding=None, - minsize=1, - maxsize=10, - parser=None, - loop=None, - create_connection_timeout=None, - pool_cls=None, - connection_cls=None, - name=None -): - # FIXME: rewrite docstring - """Creates Redis Pool. - - By default it creates pool of Redis instances, but it is - also possible to create pool of plain connections by passing - ``lambda conn: conn`` as commands_factory. - - *commands_factory* parameter is deprecated since v0.2.9 - - All arguments are the same as for create_connection. - - Returns RedisPool instance or a pool_cls if it is given. - """ - if pool_cls: - assert issubclass( - pool_cls, AbcPool - ), "pool_class does not meet the AbcPool contract" - cls = pool_cls - else: - cls = ConnectionsPool - if isinstance(address, str): - address, options = parse_url(address) - db = options.setdefault("db", db) - password = options.setdefault("password", password) - encoding = options.setdefault("encoding", encoding) - create_connection_timeout = options.setdefault( - "timeout", create_connection_timeout - ) - if "ssl" in options: - assert options["ssl"] or (not options["ssl"] and not ssl), ( - "Conflicting ssl options are set", - options["ssl"], - ssl, - ) - ssl = ssl or options["ssl"] - # TODO: minsize/maxsize - - pool = cls( - address, - db, - password, - encoding, - minsize=minsize, - maxsize=maxsize, - ssl=ssl, - parser=parser, - create_connection_timeout=create_connection_timeout, - connection_cls=connection_cls, - loop=loop, - name=name, - ) - try: - await pool._fill_free(override_min=False) - except Exception: - pool.close() - await pool.wait_closed() - raise - return pool - - -class ConnectionsPool(AbcPool): - """Redis connections pool.""" - - def __init__( - self, - address, - db=None, - password=None, - encoding=None, - *, - minsize, - maxsize, - ssl=None, - parser=None, - create_connection_timeout=None, - connection_cls=None, - loop=None, - name=None - ): - assert isinstance(minsize, int) and minsize >= 0, ( - "minsize must be int >= 0", - minsize, - type(minsize), - ) - assert maxsize is not None, "Arbitrary pool size is disallowed." - assert isinstance(maxsize, int) and maxsize > 0, ( - "maxsize must be int > 0", - maxsize, - type(maxsize), - ) - assert minsize <= maxsize, ("Invalid pool min/max sizes", minsize, maxsize) - if loop is not None and sys.version_info >= (3, 8): - warnings.warn("The loop argument is deprecated", DeprecationWarning) - self._address = address - self._db = db - self._password = password - self._ssl = ssl - self._encoding = encoding - self._parser_class = parser - self._minsize = minsize - self._create_connection_timeout = create_connection_timeout - self._pool = collections.deque(maxlen=maxsize) - self._used = set() - self._acquiring = 0 - self._cond = asyncio.Condition(lock=Lock()) - self._close_state = CloseEvent(self._do_close) - self._pubsub_conn = None - self._connection_cls = connection_cls - self._name = name - - def __repr__(self): - return "<{} [db:{}, size:[{}:{}], free:{}]>".format( - self.__class__.__name__, self.db, self.minsize, self.maxsize, self.freesize - ) - - @property - def minsize(self): - """Minimum pool size.""" - return self._minsize - - @property - def maxsize(self): - """Maximum pool size.""" - return self._pool.maxlen - - @property - def size(self): - """Current pool size.""" - return self.freesize + len(self._used) + self._acquiring - - @property - def freesize(self): - """Current number of free connections.""" - return len(self._pool) - - @property - def address(self): - return self._address - - async def clear(self): - """Clear pool connections. - - Close and remove all free connections. - """ - async with self._cond: - await self._do_clear() - - async def _do_clear(self): - waiters = [] - while self._pool: - conn = self._pool.popleft() - conn.close() - waiters.append(conn.wait_closed()) - await asyncio.gather(*waiters) - - async def _do_close(self): - async with self._cond: - assert not self._acquiring, self._acquiring - waiters = [] - while self._pool: - conn = self._pool.popleft() - conn.close() - waiters.append(conn.wait_closed()) - for conn in self._used: - conn.close() - waiters.append(conn.wait_closed()) - await asyncio.gather(*waiters) - # TODO: close _pubsub_conn connection - logger.debug("Closed %d connection(s)", len(waiters)) - - def close(self): - """Close all free and in-progress connections and mark pool as closed.""" - if not self._close_state.is_set(): - self._close_state.set() - - @property - def closed(self): - """True if pool is closed.""" - return self._close_state.is_set() - - async def wait_closed(self): - """Wait until pool gets closed.""" - await self._close_state.wait() - - @property - def db(self): - """Currently selected db index.""" - return self._db or 0 - - @property - def encoding(self): - """Current set codec or None.""" - return self._encoding - - def execute(self, command, *args, **kw): - """Executes redis command in a free connection and returns - future waiting for result. - - Picks connection from free pool and send command through - that connection. - If no connection is found, returns coroutine waiting for - free connection to execute command. - """ - conn, address = self.get_connection(command, args) - if conn is not None: - fut = conn.execute(command, *args, **kw) - return self._check_result(fut, command, args, kw) - else: - coro = self._wait_execute(address, command, args, kw) - return self._check_result(coro, command, args, kw) - - def execute_pubsub(self, command, *channels): - """Executes Redis (p)subscribe/(p)unsubscribe commands. - - ConnectionsPool picks separate connection for pub/sub - and uses it until explicitly closed or disconnected - (unsubscribing from all channels/patterns will leave connection - locked for pub/sub use). - - There is no auto-reconnect for this PUB/SUB connection. - - Returns asyncio.gather coroutine waiting for all channels/patterns - to receive answers. - """ - conn, address = self.get_connection(command) - if conn is not None: - return conn.execute_pubsub(command, *channels) - else: - return self._wait_execute_pubsub(address, command, channels, {}) - - def get_connection(self, command, args=()): - """Get free connection from pool. - - Returns connection. - """ - # TODO: find a better way to determine if connection is free - # and not havily used. - command = command.upper().strip() - is_pubsub = command in _PUBSUB_COMMANDS - if is_pubsub and self._pubsub_conn: - if not self._pubsub_conn.closed: - return self._pubsub_conn, self._pubsub_conn.address - self._used.remove(self._pubsub_conn) - self._pubsub_conn = None - for i in range(self.freesize): - conn = self._pool[0] - self._pool.rotate(1) - if conn.closed: # or conn._waiters: (eg: busy connection) - continue - if conn.in_pubsub: - continue - if is_pubsub: - self._pubsub_conn = conn - self._pool.remove(conn) - self._used.add(conn) - return conn, conn.address - return None, self._address # figure out - - def _check_result(self, fut, *data): - """Hook to check result or catch exception (like MovedError). - - This method can be coroutine. - """ - return fut - - async def _wait_execute(self, address, command, args, kw): - """Acquire connection and execute command.""" - conn = await self.acquire(command, args) - try: - return await conn.execute(command, *args, **kw) - finally: - self.release(conn) - - async def _wait_execute_pubsub(self, address, command, args, kw): - if self.closed: - raise PoolClosedError("Pool is closed") - assert self._pubsub_conn is None or self._pubsub_conn.closed, ( - "Expected no or closed connection", - self._pubsub_conn, - ) - async with self._cond: - if self.closed: - raise PoolClosedError("Pool is closed") - if self._pubsub_conn is None or self._pubsub_conn.closed: - conn = await self._create_new_connection(address) - self._pubsub_conn = conn - conn = self._pubsub_conn - return await conn.execute_pubsub(command, *args, **kw) - - async def select(self, db): - """Changes db index for all free connections. - - All previously acquired connections will be closed when released. - """ - res = True - async with self._cond: - for i in range(self.freesize): - res = res and (await self._pool[i].select(db)) - self._db = db - return res - - async def auth(self, password): - self._password = password - async with self._cond: - for i in range(self.freesize): - await self._pool[i].auth(password) - - async def setname(self, name): - """Set the current connection name.""" - self._name = name - async with self._cond: - for i in range(self.freesize): - await self._pool[i].setname(name) - - @property - def in_pubsub(self): - if self._pubsub_conn and not self._pubsub_conn.closed: - return self._pubsub_conn.in_pubsub - return 0 - - @property - def pubsub_channels(self): - if self._pubsub_conn and not self._pubsub_conn.closed: - return self._pubsub_conn.pubsub_channels - return types.MappingProxyType({}) - - @property - def pubsub_patterns(self): - if self._pubsub_conn and not self._pubsub_conn.closed: - return self._pubsub_conn.pubsub_patterns - return types.MappingProxyType({}) - - async def acquire(self, command=None, args=()): - """Acquires a connection from free pool. - - Creates new connection if needed. - """ - if self.closed: - raise PoolClosedError("Pool is closed") - async with self._cond: - if self.closed: - raise PoolClosedError("Pool is closed") - while True: - await self._fill_free(override_min=True) - if self.freesize: - conn = self._pool.popleft() - assert not conn.closed, conn - assert conn not in self._used, (conn, self._used) - self._used.add(conn) - return conn - else: - await self._cond.wait() - - def release(self, conn): - """Returns used connection back into pool. - - When returned connection has db index that differs from one in pool - the connection will be closed and dropped. - When queue of free connections is full the connection will be dropped. - """ - assert conn in self._used, ("Invalid connection, maybe from other pool", conn) - self._used.remove(conn) - if not conn.closed: - if conn.in_transaction: - logger.warning("Connection %r is in transaction, closing it.", conn) - conn.close() - elif conn.in_pubsub: - logger.warning("Connection %r is in subscribe mode, closing it.", conn) - conn.close() - elif conn._waiters: - logger.warning("Connection %r has pending commands, closing it.", conn) - conn.close() - elif conn.db == self.db: - if self.maxsize and self.freesize < self.maxsize: - self._pool.append(conn) - else: - # consider this connection as old and close it. - conn.close() - else: - conn.close() - # FIXME: check event loop is not closed - asyncio.ensure_future(self._wakeup()) - - def _drop_closed(self): - for i in range(self.freesize): - conn = self._pool[0] - if conn.closed: - self._pool.popleft() - else: - self._pool.rotate(-1) - - async def _fill_free(self, *, override_min): - # drop closed connections first - self._drop_closed() - # address = self._address - while self.size < self.minsize: - self._acquiring += 1 - try: - conn = await self._create_new_connection(self._address) - # check the healthy of that connection, if - # something went wrong just trigger the Exception - await conn.execute("ping") - self._pool.append(conn) - finally: - self._acquiring -= 1 - # connection may be closed at yield point - self._drop_closed() - if self.freesize: - return - if override_min: - while not self._pool and self.size < self.maxsize: - self._acquiring += 1 - try: - conn = await self._create_new_connection(self._address) - self._pool.append(conn) - finally: - self._acquiring -= 1 - # connection may be closed at yield point - self._drop_closed() - - def _create_new_connection(self, address): - return create_connection( - address, - db=self._db, - password=self._password, - ssl=self._ssl, - encoding=self._encoding, - parser=self._parser_class, - timeout=self._create_connection_timeout, - connection_cls=self._connection_cls, - name=self._name, - ) - - async def _wakeup(self, closing_conn=None): - async with self._cond: - self._cond.notify() - if closing_conn is not None: - await closing_conn.wait_closed() - - def __enter__(self): - raise RuntimeError("'await' should be used as a context manager expression") - - def __exit__(self, *args): - pass # pragma: nocover - - def __await__(self): - # To make `with await pool` work - conn = yield from self.acquire().__await__() - return _ConnectionContextManager(self, conn) - - def get(self): - """Return async context manager for working with connection. - - async with pool.get() as conn: - await conn.execute('get', 'my-key') - """ - return _AsyncConnectionContextManager(self) - - -class _ConnectionContextManager: - - __slots__ = ("_pool", "_conn") - - def __init__(self, pool, conn): - self._pool = pool - self._conn = conn - - def __enter__(self): - return self._conn - - def __exit__(self, exc_type, exc_value, tb): - try: - self._pool.release(self._conn) - finally: - self._pool = None - self._conn = None - - -class _AsyncConnectionContextManager: - - __slots__ = ("_pool", "_conn") - - def __init__(self, pool): - self._pool = pool - self._conn = None - - async def __aenter__(self): - conn = await self._pool.acquire() - self._conn = conn - return self._conn - - async def __aexit__(self, exc_type, exc_value, tb): - try: - self._pool.release(self._conn) - finally: - self._pool = None - self._conn = None diff --git a/aioredis/pubsub.py b/aioredis/pubsub.py deleted file mode 100644 index 135c2dea0..000000000 --- a/aioredis/pubsub.py +++ /dev/null @@ -1,457 +0,0 @@ -import asyncio -import collections -import json -import sys -import types -import warnings - -from .abc import AbcChannel -from .errors import ChannelClosedError -from .log import logger -from .util import _converters # , _set_result - -__all__ = [ - "Channel", - "EndOfStream", - "Receiver", -] - - -# End of pubsub messages stream marker. -EndOfStream = object() - - -class Channel(AbcChannel): - """Wrapper around asyncio.Queue.""" - - def __init__(self, name, is_pattern, loop=None): - if loop is not None and sys.version_info >= (3, 8): - warnings.warn("The loop argument is deprecated", DeprecationWarning) - self._queue = ClosableQueue() - self._name = _converters[type(name)](name) - self._is_pattern = is_pattern - - def __repr__(self): - return "<{} name:{!r}, is_pattern:{}, qsize:{}>".format( - self.__class__.__name__, self._name, self._is_pattern, self._queue.qsize() - ) - - @property - def name(self): - """Encoded channel name/pattern.""" - return self._name - - @property - def is_pattern(self): - """Set to True if channel is subscribed to pattern.""" - return self._is_pattern - - @property - def is_active(self): - """Returns True until there are messages in channel or - connection is subscribed to it. - - Can be used with ``while``: - - >>> ch = conn.pubsub_channels['chan:1'] - >>> while ch.is_active: - ... msg = await ch.get() # may stuck for a long time - - """ - return not self._queue.exhausted - - async def get(self, *, encoding=None, decoder=None): - """Coroutine that waits for and returns a message. - - :raises aioredis.ChannelClosedError: If channel is unsubscribed - and has no messages. - """ - assert decoder is None or callable(decoder), decoder - if self._queue.exhausted: - raise ChannelClosedError() - msg = await self._queue.get() - if msg is EndOfStream: - # TODO: maybe we need an explicit marker for "end of stream" - # currently, returning None may overlap with - # possible return value from `decoder` - # so the user would have to check `ch.is_active` - # to determine if its EoS or payload - return - if self._is_pattern: - dest_channel, msg = msg - if encoding is not None: - msg = msg.decode(encoding) - if decoder is not None: - msg = decoder(msg) - if self._is_pattern: - return dest_channel, msg - return msg - - async def get_json(self, encoding="utf-8", decoder=json.loads): - """Shortcut to get JSON messages.""" - return await self.get(encoding=encoding, decoder=decoder) - - def iter(self, *, encoding=None, decoder=None): - """Same as get method but its native coroutine. - - Usage example: - - >>> async for msg in ch.iter(): - ... print(msg) - """ - return _IterHelper( - self, is_active=lambda ch: ch.is_active, encoding=encoding, decoder=decoder - ) - - async def wait_message(self): - """Waits for message to become available in channel - or channel is closed (unsubscribed). - - Possible usage: - - >>> while (await ch.wait_message()): - ... msg = await ch.get() - """ - if not self.is_active: - return False - if not self._queue.empty(): - return True - await self._queue.wait() - return self.is_active - - # internal methods - - def put_nowait(self, data): - self._queue.put(data) - - def close(self, exc=None): - """Marks channel as inactive. - - Internal method, will be called from connection - on `unsubscribe` command. - """ - if not self._queue.closed: - self._queue.close() - - -class _IterHelper: - - __slots__ = ("_ch", "_is_active", "_args", "_kw") - - def __init__(self, ch, is_active, *args, **kw): - self._ch = ch - self._is_active = is_active - self._args = args - self._kw = kw - - def __aiter__(self): - return self - - async def __anext__(self): - if not self._is_active(self._ch): - raise StopAsyncIteration - msg = await self._ch.get(*self._args, **self._kw) - if msg is None: - raise StopAsyncIteration - return msg - - -class Receiver: - """Multi-producers, single-consumer Pub/Sub queue. - - Can be used in cases where a single consumer task - must read messages from several different channels - (where pattern subscriptions may not work well - or channels can be added/removed dynamically). - - Example use case: - - >>> from aioredis.pubsub import Receiver - >>> from aioredis.abc import AbcChannel - >>> mpsc = Receiver() - >>> async def reader(mpsc): - ... async for channel, msg in mpsc.iter(): - ... assert isinstance(channel, AbcChannel) - ... print("Got {!r} in channel {!r}".format(msg, channel)) - >>> asyncio.ensure_future(reader(mpsc)) - >>> await redis.subscribe(mpsc.channel('channel:1'), - ... mpsc.channel('channel:3')) - ... mpsc.channel('channel:5')) - >>> await redis.psubscribe(mpsc.pattern('hello*')) - >>> # publishing 'Hello world' into 'hello-channel' - >>> # will print this message: - Got b'Hello world' in channel b'hello-channel' - >>> # when all is done: - >>> await redis.unsubscribe('channel:1', 'channel:3', 'channel:5') - >>> await redis.punsubscribe('hello') - >>> mpsc.stop() - >>> # any message received after stop() will be ignored. - """ - - def __init__(self, loop=None, on_close=None): - assert on_close is None or callable(on_close), ( - "on_close must be None or callable", - on_close, - ) - if loop is not None: - warnings.warn("The loop argument is deprecated", DeprecationWarning) - if on_close is None: - on_close = self.check_stop - self._queue = ClosableQueue() - self._refs = {} - self._on_close = on_close - - def __repr__(self): - return "".format( - self.is_active, len(self._refs), self._queue.qsize() - ) - - def channel(self, name): - """Create a channel. - - Returns ``_Sender`` object implementing - :class:`~aioredis.abc.AbcChannel`. - """ - enc_name = _converters[type(name)](name) - if (enc_name, False) not in self._refs: - ch = _Sender(self, enc_name, is_pattern=False) - self._refs[(enc_name, False)] = ch - return ch - return self._refs[(enc_name, False)] - - def pattern(self, pattern): - """Create a pattern channel. - - Returns ``_Sender`` object implementing - :class:`~aioredis.abc.AbcChannel`. - """ - enc_pattern = _converters[type(pattern)](pattern) - if (enc_pattern, True) not in self._refs: - ch = _Sender(self, enc_pattern, is_pattern=True) - self._refs[(enc_pattern, True)] = ch - return self._refs[(enc_pattern, True)] - - @property - def channels(self): - """Read-only channels dict.""" - return types.MappingProxyType( - {ch.name: ch for ch in self._refs.values() if not ch.is_pattern} - ) - - @property - def patterns(self): - """Read-only patterns dict.""" - return types.MappingProxyType( - {ch.name: ch for ch in self._refs.values() if ch.is_pattern} - ) - - async def get(self, *, encoding=None, decoder=None): - """Wait for and return pub/sub message from one of channels. - - Return value is either: - - * tuple of two elements: channel & message; - - * tuple of three elements: pattern channel, (target channel & message); - - * or None in case Receiver is not active or has just been stopped. - - :raises aioredis.ChannelClosedError: If listener is stopped - and all messages have been received. - """ - # TODO: add note about raised exception and end marker. - # Flow before ClosableQueue: - # - ch.get() -> message - # - ch.close() -> ch.put(None) - # - ch.get() -> None - # - ch.get() -> ChannelClosedError - # Current flow: - # - ch.get() -> message - # - ch.close() -> ch._closed = True - # - ch.get() -> ChannelClosedError - assert decoder is None or callable(decoder), decoder - if self._queue.exhausted: - raise ChannelClosedError() - obj = await self._queue.get() - if obj is EndOfStream: - return - ch, msg = obj - if ch.is_pattern: - dest_ch, msg = msg - if encoding is not None: - msg = msg.decode(encoding) - if decoder is not None: - msg = decoder(msg) - if ch.is_pattern: - return ch, (dest_ch, msg) - return ch, msg - - async def wait_message(self): - """Blocks until new message appear.""" - if not self._queue.empty(): - return True - if self._queue.closed: - return False - await self._queue.wait() - return self.is_active - - @property - def is_active(self): - """Returns True if listener has any active subscription.""" - if self._queue.exhausted: - return False - return any(ch.is_active for ch in self._refs.values()) - - def stop(self): - """Stop receiving messages. - - All new messages after this call will be ignored, - so you must call unsubscribe before stopping this listener. - """ - self._queue.close() - # TODO: discard all senders as they might still be active. - # Channels storage in Connection should be refactored: - # if we drop _Senders here they will still be subscribed - # and will reside in memory although messages will be discarded. - - def iter(self, *, encoding=None, decoder=None): - """Returns async iterator. - - Usage example: - - >>> async for ch, msg in mpsc.iter(): - ... print(ch, msg) - """ - return _IterHelper( - self, - is_active=lambda r: not r._queue.exhausted, - encoding=encoding, - decoder=decoder, - ) - - def check_stop(self, channel, exc=None): - """TBD""" - # NOTE: this is a fast-path implementation, - # if overridden, implementation should use public API: - # - # if self.is_active and not (self.channels or self.patterns): - if not self._refs: - self.stop() - - # internal methods - - def _put_nowait(self, data, *, sender): - if self._queue.closed: - logger.warning( - "Pub/Sub listener message after stop:" " sender: %r, data: %r", - sender, - data, - ) - return - self._queue.put((sender, data)) - - def _close(self, sender, exc=None): - self._refs.pop((sender.name, sender.is_pattern)) - self._on_close(sender, exc=exc) - - -class _Sender(AbcChannel): - """Write-Only Channel. - - Does not allow direct ``.get()`` calls. - """ - - def __init__(self, receiver, name, is_pattern): - self._receiver = receiver - self._name = _converters[type(name)](name) - self._is_pattern = is_pattern - self._closed = False - - def __repr__(self): - return "<{} name:{!r}, is_pattern:{}, receiver:{!r}>".format( - self.__class__.__name__, self._name, self._is_pattern, self._receiver - ) - - @property - def name(self): - """Encoded channel name or pattern.""" - return self._name - - @property - def is_pattern(self): - """Set to True if channel is subscribed to pattern.""" - return self._is_pattern - - @property - def is_active(self): - return not self._closed - - async def get(self, *, encoding=None, decoder=None): - raise RuntimeError("MPSC channel does not allow direct get() calls") - - def put_nowait(self, data): - self._receiver._put_nowait(data, sender=self) - - def close(self, exc=None): - # TODO: close() is exclusive so we can not share same _Sender - # between different connections. - # This needs to be fixed. - if self._closed: - return - self._closed = True - self._receiver._close(self, exc=exc) - - -class ClosableQueue: - def __init__(self): - self._queue = collections.deque() - self._event = asyncio.Event() - self._closed = False - - async def wait(self): - while not (self._queue or self._closed): - await self._event.wait() - return True - - async def get(self): - await self.wait() - assert self._queue or self._closed, ( - "Unexpected queue state", - self._queue, - self._closed, - ) - if not self._queue and self._closed: - return EndOfStream - item = self._queue.popleft() - if not self._queue: - self._event.clear() - return item - - def put(self, item): - if self._closed: - return - self._queue.append(item) - self._event.set() - - def close(self): - """Mark queue as closed and notify all waiters.""" - self._closed = True - self._event.set() - - @property - def closed(self): - return self._closed - - @property - def exhausted(self): - return self._closed and not self._queue - - def empty(self): - return not self._queue - - def qsize(self): - return len(self._queue) - - def __repr__(self): - closed = "closed" if self._closed else "open" - return "".format(closed, len(self._queue)) diff --git a/aioredis/sentinel.py b/aioredis/sentinel.py new file mode 100644 index 000000000..ed916600f --- /dev/null +++ b/aioredis/sentinel.py @@ -0,0 +1,324 @@ +import random +import weakref +from typing import AsyncIterator, Iterable, Mapping, Sequence, Tuple, Type + +from aioredis.client import Redis +from aioredis.connection import Connection, ConnectionPool, EncodableT +from aioredis.exceptions import ( + ConnectionError, + ReadOnlyError, + ResponseError, + TimeoutError, +) +from aioredis.utils import str_if_bytes + + +class MasterNotFoundError(ConnectionError): + pass + + +class SlaveNotFoundError(ConnectionError): + pass + + +class SentinelManagedConnection(Connection): + def __init__(self, **kwargs): + self.connection_pool = kwargs.pop("connection_pool") + super().__init__(**kwargs) + + def __repr__(self): + pool = self.connection_pool + s = f"{self.__class__.__name__}" + + async def connect_to(self, address): + self.host, self.port = address + await super().connect() + if self.connection_pool.check_connection: + await self.send_command("PING") + if str_if_bytes(await self.read_response()) != "PONG": + raise ConnectionError("PING failed") + + async def connect(self): + if self._reader: + return # already connected + if self.connection_pool.is_master: + await self.connect_to(await self.connection_pool.get_master_address()) + else: + async for slave in self.connection_pool.rotate_slaves(): + try: + return await self.connect_to(slave) + except ConnectionError: + continue + raise SlaveNotFoundError # Never be here + + async def read_response(self): + try: + return await super().read_response() + except ReadOnlyError: + if self.connection_pool.is_master: + # When talking to a master, a ReadOnlyError when likely + # indicates that the previous master that we're still connected + # to has been demoted to a slave and there's a new master. + # calling disconnect will force the connection to re-query + # sentinel during the next connect() attempt. + await self.disconnect() + raise ConnectionError("The previous master is now a slave") + raise + + +class SentinelConnectionPool(ConnectionPool): + """ + Sentinel backed connection pool. + + If ``check_connection`` flag is set to True, SentinelManagedConnection + sends a PING command right after establishing the connection. + """ + + def __init__(self, service_name, sentinel_manager, **kwargs): + kwargs["connection_class"] = kwargs.get( + "connection_class", SentinelManagedConnection + ) + self.is_master = kwargs.pop("is_master", True) + self.check_connection = kwargs.pop("check_connection", False) + super().__init__(**kwargs) + self.connection_kwargs["connection_pool"] = weakref.proxy(self) + self.service_name = service_name + self.sentinel_manager = sentinel_manager + self.master_address = None + self.slave_rr_counter = None + + def __repr__(self): + return ( + f"{self.__class__.__name__}" + f"" + ) + + def reset(self): + super().reset() + self.master_address = None + self.slave_rr_counter = None + + def owns_connection(self, connection: SentinelManagedConnection): + check = not self.is_master or ( + self.is_master and self.master_address == (connection.host, connection.port) + ) + return check and super().owns_connection(connection) + + async def get_master_address(self): + master_address = await self.sentinel_manager.discover_master(self.service_name) + if self.is_master: + if self.master_address != master_address: + self.master_address = master_address + # disconnect any idle connections so that they reconnect + # to the new master the next time that they are used. + await self.disconnect(inuse_connections=False) + return master_address + + async def rotate_slaves(self) -> AsyncIterator: + """Round-robin slave balancer""" + slaves = await self.sentinel_manager.discover_slaves(self.service_name) + if slaves: + if self.slave_rr_counter is None: + self.slave_rr_counter = random.randint(0, len(slaves) - 1) + for _ in range(len(slaves)): + self.slave_rr_counter = (self.slave_rr_counter + 1) % len(slaves) + slave = slaves[self.slave_rr_counter] + yield slave + # Fallback to the master connection + try: + yield await self.get_master_address() + except MasterNotFoundError: + pass + raise SlaveNotFoundError(f"No slave found for {self.service_name!r}") + + +class Sentinel: + """ + Redis Sentinel cluster client + + >>> from aioredis.sentinel import Sentinel + >>> sentinel = Sentinel([('localhost', 26379)], socket_timeout=0.1) + >>> master = sentinel.master_for('mymaster', socket_timeout=0.1) + >>> await master.set('foo', 'bar') + >>> slave = sentinel.slave_for('mymaster', socket_timeout=0.1) + >>> await slave.get('foo') + b'bar' + + ``sentinels`` is a list of sentinel nodes. Each node is represented by + a pair (hostname, port). + + ``min_other_sentinels`` defined a minimum number of peers for a sentinel. + When querying a sentinel, if it doesn't meet this threshold, responses + from that sentinel won't be considered valid. + + ``sentinel_kwargs`` is a dictionary of connection arguments used when + connecting to sentinel instances. Any argument that can be passed to + a normal Redis connection can be specified here. If ``sentinel_kwargs`` is + not specified, any socket_timeout and socket_keepalive options specified + in ``connection_kwargs`` will be used. + + ``connection_kwargs`` are keyword arguments that will be used when + establishing a connection to a Redis server. + """ + + def __init__( + self, + sentinels, + min_other_sentinels=0, + sentinel_kwargs=None, + **connection_kwargs, + ): + # if sentinel_kwargs isn't defined, use the socket_* options from + # connection_kwargs + if sentinel_kwargs is None: + sentinel_kwargs = { + k: v for k, v in connection_kwargs.items() if k.startswith("socket_") + } + self.sentinel_kwargs = sentinel_kwargs + + self.sentinels = [ + Redis(host=hostname, port=port, **self.sentinel_kwargs) + for hostname, port in sentinels + ] + self.min_other_sentinels = min_other_sentinels + self.connection_kwargs = connection_kwargs + + def __repr__(self): + sentinel_addresses = [] + for sentinel in self.sentinels: + sentinel_addresses.append( + f"{sentinel.connection_pool.connection_kwargs['host']}:" + f"{sentinel.connection_pool.connection_kwargs['port']}" + ) + return f"{self.__class__.__name__}" + + def check_master_state(self, state: dict, service_name: str) -> bool: + if not state["is_master"] or state["is_sdown"] or state["is_odown"]: + return False + # Check if our sentinel doesn't see other nodes + if state["num-other-sentinels"] < self.min_other_sentinels: + return False + return True + + async def discover_master(self, service_name: str): + """ + Asks sentinel servers for the Redis master's address corresponding + to the service labeled ``service_name``. + + Returns a pair (address, port) or raises MasterNotFoundError if no + master is found. + """ + for sentinel_no, sentinel in enumerate(self.sentinels): + try: + masters = await sentinel.sentinel_masters() + except (ConnectionError, TimeoutError): + continue + state = masters.get(service_name) + if state and self.check_master_state(state, service_name): + # Put this sentinel at the top of the list + self.sentinels[0], self.sentinels[sentinel_no] = ( + sentinel, + self.sentinels[0], + ) + return state["ip"], state["port"] + raise MasterNotFoundError(f"No master found for {service_name!r}") + + def filter_slaves( + self, slaves: Iterable[Mapping] + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Remove slaves that are in an ODOWN or SDOWN state""" + slaves_alive = [] + for slave in slaves: + if slave["is_odown"] or slave["is_sdown"]: + continue + slaves_alive.append((slave["ip"], slave["port"])) + return slaves_alive + + async def discover_slaves( + self, service_name: str + ) -> Sequence[Tuple[EncodableT, EncodableT]]: + """Returns a list of alive slaves for service ``service_name``""" + for sentinel in self.sentinels: + try: + slaves = await sentinel.sentinel_slaves(service_name) + except (ConnectionError, ResponseError, TimeoutError): + continue + slaves = self.filter_slaves(slaves) + if slaves: + return slaves + return [] + + def master_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns a redis client instance for the ``service_name`` master. + + A :py:class:`~redis.sentinel.SentinelConnectionPool` class is + used to retrive the master's address before establishing a new + connection. + + NOTE: If the master's address has changed, any cached connections to + the old master are closed. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to + use. The :py:class:`~redis.sentinel.SentinelConnectionPool` + will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = True + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) + + def slave_for( + self, + service_name: str, + redis_class: Type[Redis] = Redis, + connection_pool_class: Type[SentinelConnectionPool] = SentinelConnectionPool, + **kwargs, + ): + """ + Returns redis client instance for the ``service_name`` slave(s). + + A SentinelConnectionPool class is used to retrive the slave's + address before establishing a new connection. + + By default clients will be a :py:class:`~redis.Redis` instance. + Specify a different class to the ``redis_class`` argument if you + desire something different. + + The ``connection_pool_class`` specifies the connection pool to use. + The SentinelConnectionPool will be used by default. + + All other keyword arguments are merged with any connection_kwargs + passed to this class and passed to the connection pool as keyword + arguments to be used to initialize Redis connections. + """ + kwargs["is_master"] = False + connection_kwargs = dict(self.connection_kwargs) + connection_kwargs.update(kwargs) + return redis_class( + connection_pool=connection_pool_class( + service_name, self, **connection_kwargs + ) + ) diff --git a/aioredis/sentinel/__init__.py b/aioredis/sentinel/__init__.py deleted file mode 100644 index 54b209a1b..000000000 --- a/aioredis/sentinel/__init__.py +++ /dev/null @@ -1,9 +0,0 @@ -from .commands import RedisSentinel, create_sentinel -from .pool import SentinelPool, create_sentinel_pool - -__all__ = [ - "create_sentinel", - "create_sentinel_pool", - "RedisSentinel", - "SentinelPool", -] diff --git a/aioredis/sentinel/commands.py b/aioredis/sentinel/commands.py deleted file mode 100644 index b48b61077..000000000 --- a/aioredis/sentinel/commands.py +++ /dev/null @@ -1,201 +0,0 @@ -import asyncio - -from ..commands import Redis -from ..util import wait_convert, wait_ok -from .pool import create_sentinel_pool - - -async def create_sentinel( - sentinels, - *, - db=None, - password=None, - encoding=None, - minsize=1, - maxsize=10, - ssl=None, - timeout=0.2, - loop=None -): - """Creates Redis Sentinel client. - - `sentinels` is a list of sentinel nodes. - """ - - if loop is None: - loop = asyncio.get_event_loop() - - pool = await create_sentinel_pool( - sentinels, - db=db, - password=password, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - ssl=ssl, - timeout=timeout, - loop=loop, - ) - return RedisSentinel(pool) - - -class RedisSentinel: - """Redis Sentinel client.""" - - def __init__(self, pool): - self._pool = pool - - def close(self): - """Close client connections.""" - self._pool.close() - - async def wait_closed(self): - """Coroutine waiting until underlying connections are closed.""" - await self._pool.wait_closed() - - @property - def closed(self): - """True if connection is closed.""" - return self._pool.closed - - def master_for(self, name): - """Returns Redis client to master Redis server.""" - # TODO: make class configurable - return Redis(self._pool.master_for(name)) - - def slave_for(self, name): - """Returns Redis client to slave Redis server.""" - # TODO: make class configurable - return Redis(self._pool.slave_for(name)) - - def execute(self, command, *args, **kwargs): - """Execute Sentinel command. - - It will be prefixed with SENTINEL automatically. - """ - return self._pool.execute(b"SENTINEL", command, *args, **kwargs) - - async def ping(self): - """Send PING command to Sentinel instance(s).""" - # TODO: add kwargs allowing to pick sentinel to send command to. - return await self._pool.execute(b"PING") - - def master(self, name): - """Returns a dictionary containing the specified masters state.""" - fut = self.execute(b"MASTER", name, encoding="utf-8") - return wait_convert(fut, parse_sentinel_master) - - def master_address(self, name): - """Returns a (host, port) pair for the given ``name``.""" - fut = self.execute(b"get-master-addr-by-name", name, encoding="utf-8") - return wait_convert(fut, parse_address) - - def masters(self): - """Returns a list of dictionaries containing each master's state.""" - fut = self.execute(b"MASTERS", encoding="utf-8") - # TODO: process masters: we can adjust internal state - return wait_convert(fut, parse_sentinel_masters) - - def slaves(self, name): - """Returns a list of slaves for ``name``.""" - fut = self.execute(b"SLAVES", name, encoding="utf-8") - return wait_convert(fut, parse_sentinel_slaves_and_sentinels) - - def sentinels(self, name): - """Returns a list of sentinels for ``name``.""" - fut = self.execute(b"SENTINELS", name, encoding="utf-8") - return wait_convert(fut, parse_sentinel_slaves_and_sentinels) - - def monitor(self, name, ip, port, quorum): - """Add a new master to Sentinel to be monitored.""" - fut = self.execute(b"MONITOR", name, ip, port, quorum) - return wait_ok(fut) - - def remove(self, name): - """Remove a master from Sentinel's monitoring.""" - fut = self.execute(b"REMOVE", name) - return wait_ok(fut) - - def set(self, name, option, value): - """Set Sentinel monitoring parameters for a given master.""" - fut = self.execute(b"SET", name, option, value) - return wait_ok(fut) - - def failover(self, name): - """Force a failover of a named master.""" - fut = self.execute(b"FAILOVER", name) - return wait_ok(fut) - - def check_quorum(self, name): - """ - Check if the current Sentinel configuration is able - to reach the quorum needed to failover a master, - and the majority needed to authorize the failover. - """ - return self.execute(b"CKQUORUM", name) - - -SENTINEL_STATE_TYPES = { - "can-failover-its-master": int, - "config-epoch": int, - "down-after-milliseconds": int, - "failover-timeout": int, - "info-refresh": int, - "last-hello-message": int, - "last-ok-ping-reply": int, - "last-ping-reply": int, - "last-ping-sent": int, - "master-link-down-time": int, - "master-port": int, - "num-other-sentinels": int, - "num-slaves": int, - "o-down-time": int, - "pending-commands": int, - "link-pending-commands": int, - "link-refcount": int, - "parallel-syncs": int, - "port": int, - "quorum": int, - "role-reported-time": int, - "s-down-time": int, - "slave-priority": int, - "slave-repl-offset": int, - "voted-leader-epoch": int, - "flags": lambda s: frozenset(s.split(",")), # TODO: make flags enum? -} - - -def pairs_to_dict_typed(response, type_info): - it = iter(response) - result = {} - for key, value in zip(it, it): - if key in type_info: - try: - value = type_info[key](value) - except (TypeError, ValueError): - # if for some reason the value can't be coerced, just use - # the string value - pass - result[key] = value - return result - - -def parse_sentinel_masters(response): - result = {} - for item in response: - state = pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) - result[state["name"]] = state - return result - - -def parse_sentinel_slaves_and_sentinels(response): - return [pairs_to_dict_typed(item, SENTINEL_STATE_TYPES) for item in response] - - -def parse_sentinel_master(response): - return pairs_to_dict_typed(response, SENTINEL_STATE_TYPES) - - -def parse_address(value): - if value is not None: - return (value[0], int(value[1])) diff --git a/aioredis/sentinel/pool.py b/aioredis/sentinel/pool.py deleted file mode 100644 index f6b4fc0c9..000000000 --- a/aioredis/sentinel/pool.py +++ /dev/null @@ -1,514 +0,0 @@ -import asyncio -import contextlib - -from async_timeout import timeout as async_timeout - -from ..errors import ( - MasterNotFoundError, - MasterReplyError, - PoolClosedError, - RedisError, - SlaveNotFoundError, - SlaveReplyError, -) -from ..log import sentinel_logger -from ..pool import ConnectionsPool, create_pool -from ..pubsub import Receiver -from ..util import CloseEvent - -# Address marker for discovery -_NON_DISCOVERED = object() - -_logger = sentinel_logger.getChild("monitor") - - -async def create_sentinel_pool( - sentinels, - *, - db=None, - password=None, - encoding=None, - minsize=1, - maxsize=10, - ssl=None, - parser=None, - timeout=0.2, - loop=None, -): - """Create SentinelPool.""" - # FIXME: revise default timeout value - assert isinstance(sentinels, (list, tuple)), sentinels - # TODO: deprecation note - # if loop is None: - # loop = asyncio.get_event_loop() - - pool = SentinelPool( - sentinels, - db=db, - password=password, - ssl=ssl, - encoding=encoding, - parser=parser, - minsize=minsize, - maxsize=maxsize, - timeout=timeout, - loop=loop, - ) - await pool.discover() - return pool - - -class SentinelPool: - """Sentinel connections pool. - - Holds connection pools to known and discovered (TBD) Sentinels - as well as services' connections. - """ - - def __init__( - self, - sentinels, - *, - db=None, - password=None, - ssl=None, - encoding=None, - parser=None, - minsize, - maxsize, - timeout, - loop=None, - ): - # TODO: deprecation note - # if loop is None: - # loop = asyncio.get_event_loop() - # TODO: add connection/discover timeouts; - # and what to do if no master is found: - # (raise error or try forever or try until timeout) - - # XXX: _sentinels is unordered - self._sentinels = set(sentinels) - self._timeout = timeout - self._pools = [] # list of sentinel pools - self._masters = {} - self._slaves = {} - self._parser_class = parser - self._redis_db = db - self._redis_password = password - self._redis_ssl = ssl - self._redis_encoding = encoding - self._redis_minsize = minsize - self._redis_maxsize = maxsize - self._close_state = CloseEvent(self._do_close) - self._close_waiter = None - self._monitor = monitor = Receiver() - - async def echo_events(): - try: - while await monitor.wait_message(): - _, (ev, data) = await monitor.get(encoding="utf-8") - ev = ev.decode("utf-8") - _logger.debug("%s: %s", ev, data) - if ev in ("+odown",): - typ, name, *tail = data.split(" ") - if typ == "master": - self._need_rediscover(name) - # TODO: parse messages; - # watch +new-epoch which signals `failover in progres` - # freeze reconnection - # wait / discover new master (find proper way) - # unfreeze reconnection - # - # discover master in default way - # get-master-addr... - # connnect - # role - # etc... - except asyncio.CancelledError: - pass - - self._monitor_task = asyncio.ensure_future(echo_events()) - - @property - def discover_timeout(self): - """Timeout (seconds) for Redis/Sentinel command calls during - master/slave address discovery. - """ - return self._timeout - - def master_for(self, service): - """Returns wrapper to master's pool for requested service.""" - # TODO: make it coroutine and connect minsize connections - if service not in self._masters: - self._masters[service] = ManagedPool( - self, - service, - is_master=True, - db=self._redis_db, - password=self._redis_password, - encoding=self._redis_encoding, - minsize=self._redis_minsize, - maxsize=self._redis_maxsize, - ssl=self._redis_ssl, - parser=self._parser_class, - ) - return self._masters[service] - - def slave_for(self, service): - """Returns wrapper to slave's pool for requested service.""" - # TODO: make it coroutine and connect minsize connections - if service not in self._slaves: - self._slaves[service] = ManagedPool( - self, - service, - is_master=False, - db=self._redis_db, - password=self._redis_password, - encoding=self._redis_encoding, - minsize=self._redis_minsize, - maxsize=self._redis_maxsize, - ssl=self._redis_ssl, - parser=self._parser_class, - ) - return self._slaves[service] - - def execute(self, command, *args, **kwargs): - """Execute sentinel command.""" - # TODO: choose pool - # kwargs can be used to control which sentinel to use - if self.closed: - raise PoolClosedError("Sentinel pool is closed") - for pool in self._pools: - return pool.execute(command, *args, **kwargs) - # how to handle errors and pick other pool? - # is the only way to make it coroutine? - - @property - def closed(self): - """True if pool is closed or closing.""" - return self._close_state.is_set() - - def close(self): - """Close all controlled connections (both sentinel and redis).""" - if not self._close_state.is_set(): - self._close_state.set() - - async def _do_close(self): - # TODO: lock - tasks = [] - task, self._monitor_task = self._monitor_task, None - task.cancel() - tasks.append(task) - while self._pools: - pool = self._pools.pop(0) - pool.close() - tasks.append(pool.wait_closed()) - while self._masters: - _, pool = self._masters.popitem() - pool.close() - tasks.append(pool.wait_closed()) - while self._slaves: - _, pool = self._slaves.popitem() - pool.close() - tasks.append(pool.wait_closed()) - await asyncio.gather(*tasks) - - async def wait_closed(self): - """Wait until pool gets closed.""" - await self._close_state.wait() - - async def discover(self, timeout=None): # TODO: better name? - """Discover sentinels and all monitored services within given timeout. - - If no sentinels discovered within timeout: TimeoutError is raised. - If some sentinels were discovered but not all — it is ok. - If not all monitored services (masters/slaves) discovered - (or connections established) — it is ok. - TBD: what if some sentinels/services unreachable; - """ - # TODO: check not closed - # TODO: discovery must be done with some customizable timeout. - if timeout is None: - timeout = self.discover_timeout - pools = [] - tasks = [ - self._connect_sentinel(addr, timeout, pools) for addr in self._sentinels - ] - await asyncio.gather(*tasks, return_exceptions=True) - - if not pools: - raise Exception("Could not connect to any sentinel") - pools, self._pools[:] = self._pools[:], pools - # TODO: close current connections - for pool in pools: - pool.close() - await pool.wait_closed() - - # TODO: discover peer sentinels - for pool in self._pools: - await pool.execute_pubsub(b"psubscribe", self._monitor.pattern("*")) - - async def _connect_sentinel(self, address, timeout, pools): - """Try to connect to specified Sentinel returning either - connections pool or exception. - """ - try: - with async_timeout(timeout): - pool = await create_pool( - address, - minsize=1, - maxsize=2, - parser=self._parser_class, - ) - pools.append(pool) - return pool - except asyncio.TimeoutError as err: - sentinel_logger.debug( - "Failed to connect to Sentinel(%r) within %ss timeout", address, timeout - ) - return err - except Exception as err: - sentinel_logger.debug("Error connecting to Sentinel(%r): %r", address, err) - return err - - async def discover_master(self, service, timeout): - """Perform Master discovery for specified service.""" - # TODO: get lock - idle_timeout = timeout - # FIXME: single timeout used 4 times; - # meaning discovery can take up to: - # 3 * timeout * (sentinels count) - # - # having one global timeout also can leed to - # a problem when not all sentinels are checked. - - # use a copy, cause pools can change - pools = self._pools[:] - for sentinel in pools: - try: - with async_timeout(timeout): - address = await self._get_masters_address(sentinel, service) - - pool = self._masters[service] - with async_timeout(timeout), contextlib.ExitStack() as stack: - conn = await pool._create_new_connection(address) - stack.callback(conn.close) - await self._verify_service_role(conn, "master") - stack.pop_all() - - return conn - except asyncio.CancelledError: - # we must correctly handle CancelledError(s): - # application may be stopped or function can be cancelled - # by outer timeout, so we must stop the look up. - raise - except asyncio.TimeoutError: - continue - except DiscoverError as err: - sentinel_logger.debug( - "DiscoverError(%r, %s): %r", sentinel, service, err - ) - await asyncio.sleep(idle_timeout) - continue - except RedisError as err: - raise MasterReplyError(f"Service {service} error", err) - except Exception: - # TODO: clear (drop) connections to schedule reconnect - await asyncio.sleep(idle_timeout) - continue - # Otherwise - raise MasterNotFoundError(f"No master found for {service}") - - async def discover_slave(self, service, timeout, **kwargs): - """Perform Slave discovery for specified service.""" - # TODO: use kwargs to change how slaves are picked up - # (eg: round-robin, priority, random, etc) - idle_timeout = timeout - pools = self._pools[:] - for sentinel in pools: - try: - with async_timeout(timeout): - address = await self._get_slave_address( - sentinel, service - ) # add **kwargs - pool = self._slaves[service] - with async_timeout(timeout), contextlib.ExitStack() as stack: - conn = await pool._create_new_connection(address) - stack.callback(conn.close) - await self._verify_service_role(conn, "slave") - stack.pop_all() - return conn - except asyncio.CancelledError: - raise - except asyncio.TimeoutError: - continue - except DiscoverError: - await asyncio.sleep(idle_timeout) - continue - except RedisError as err: - raise SlaveReplyError(f"Service {service} error", err) - except Exception: - await asyncio.sleep(idle_timeout) - continue - raise SlaveNotFoundError(f"No slave found for {service}") - - async def _get_masters_address(self, sentinel, service): - # NOTE: we don't use `get-master-addr-by-name` - # as it can provide stale data so we repeat - # after redis-py and check service flags. - state = await sentinel.execute( - b"sentinel", b"master", service, encoding="utf-8" - ) - if not state: - raise UnknownService() - state = make_dict(state) - address = state["ip"], int(state["port"]) - flags = set(state["flags"].split(",")) - if {"s_down", "o_down", "disconnected"} & flags: - raise BadState(state) - return address - - async def _get_slave_address(self, sentinel, service): - # Find and return single slave address - slaves = await sentinel.execute( - b"sentinel", b"slaves", service, encoding="utf-8" - ) - if not slaves: - raise UnknownService() - for state in map(make_dict, slaves): - address = state["ip"], int(state["port"]) - flags = set(state["flags"].split(",")) - if {"s_down", "o_down", "disconnected"} & flags: - continue - return address - raise BadState() # XXX: only last state - - async def _verify_service_role(self, conn, role): - res = await conn.execute(b"role", encoding="utf-8") - if res[0] != role: - raise RoleMismatch(res) - - def _need_rediscover(self, service): - sentinel_logger.debug("Must redisover service %s", service) - pool = self._masters.get(service) - if pool: - pool.need_rediscover() - pool = self._slaves.get(service) - if pool: - pool.need_rediscover() - - -class ManagedPool(ConnectionsPool): - def __init__( - self, - sentinel, - service, - is_master, - db=None, - password=None, - encoding=None, - parser=None, - *, - minsize, - maxsize, - ssl=None, - loop=None, - ): - super().__init__( - _NON_DISCOVERED, - db=db, - password=password, - encoding=encoding, - minsize=minsize, - maxsize=maxsize, - ssl=ssl, - parser=parser, - loop=loop, - ) - assert self._address is _NON_DISCOVERED - self._sentinel = sentinel - self._service = service - self._is_master = is_master - # self._discover_timeout = .2 - - @property - def address(self): - if self._address is _NON_DISCOVERED: - return - return self._address - - def get_connection(self, command, args=()): - if self._address is _NON_DISCOVERED: - return None, _NON_DISCOVERED - return super().get_connection(command, args) - - async def _create_new_connection(self, address): - if address is _NON_DISCOVERED: - # Perform service discovery. - # Returns Connection or raises error if no service can be found. - await self._do_clear() # make `clear` blocking - - if self._is_master: - conn = await self._sentinel.discover_master( - self._service, timeout=self._sentinel.discover_timeout - ) - else: - conn = await self._sentinel.discover_slave( - self._service, timeout=self._sentinel.discover_timeout - ) - self._address = conn.address - sentinel_logger.debug( - "Discoverred new address %r for %s", conn.address, self._service - ) - return conn - return await super()._create_new_connection(address) - - def _drop_closed(self): - diff = len(self._pool) - super()._drop_closed() - diff -= len(self._pool) - if diff: - # closed connections were in pool: - # * reset address; - # * notify sentinel pool - sentinel_logger.debug( - "Dropped %d closed connnection(s); must rediscover", diff - ) - self._sentinel._need_rediscover(self._service) - - async def acquire(self, command=None, args=()): - if self._address is _NON_DISCOVERED: - await self.clear() - return await super().acquire(command, args) - - def release(self, conn): - was_closed = conn.closed - super().release(conn) - # if connection was closed while used and not by release() - if was_closed: - sentinel_logger.debug("Released closed connection; must rediscover") - self._sentinel._need_rediscover(self._service) - - def need_rediscover(self): - self._address = _NON_DISCOVERED - - -def make_dict(plain_list): - it = iter(plain_list) - return dict(zip(it, it)) - - -class DiscoverError(Exception): - """Internal errors for masters/slaves discovery.""" - - -class BadState(DiscoverError): - """Bad master's / slave's state read from sentinel.""" - - -class UnknownService(DiscoverError): - """Service is not monitored by specific sentinel.""" - - -class RoleMismatch(DiscoverError): - """Service reported to have other Role.""" diff --git a/aioredis/stream.py b/aioredis/stream.py deleted file mode 100644 index 5670f0fea..000000000 --- a/aioredis/stream.py +++ /dev/null @@ -1,108 +0,0 @@ -import asyncio -import sys -import warnings - -from .util import get_event_loop - -__all__ = [ - "open_connection", - "open_unix_connection", - "StreamReader", -] - - -async def open_connection( - host=None, port=None, *, limit, loop=None, parser=None, **kwds -): - # XXX: parser is not used (yet) - if loop is not None and sys.version_info >= (3, 8): - warnings.warn("The loop argument is deprecated", DeprecationWarning) - reader = StreamReader(limit=limit) - protocol = asyncio.StreamReaderProtocol(reader) - transport, _ = await get_event_loop().create_connection( - lambda: protocol, host, port, **kwds - ) - writer = asyncio.StreamWriter(transport, protocol, reader, loop=get_event_loop()) - return reader, writer - - -async def open_unix_connection(address, *, limit, loop=None, parser=None, **kwds): - # XXX: parser is not used (yet) - if loop is not None and sys.version_info >= (3, 8): - warnings.warn("The loop argument is deprecated", DeprecationWarning) - reader = StreamReader(limit=limit) - protocol = asyncio.StreamReaderProtocol(reader) - transport, _ = await get_event_loop().create_unix_connection( - lambda: protocol, address, **kwds - ) - writer = asyncio.StreamWriter(transport, protocol, reader, loop=get_event_loop()) - return reader, writer - - -class StreamReader(asyncio.StreamReader): - """ - Override the official StreamReader to address the - following issue: http://bugs.python.org/issue30861 - - Also it leverages to get rid of the dobule buffer and - get rid of one coroutine step. Data flows from the buffer - to the Redis parser directly. - """ - - _parser = None - - def set_parser(self, parser): - self._parser = parser - if self._buffer: - self._parser.feed(self._buffer) - del self._buffer[:] - - def feed_data(self, data): - assert not self._eof, "feed_data after feed_eof" - - if not data: - return - if self._parser is None: - # XXX: hopefully it's only a small error message - self._buffer.extend(data) - return - self._parser.feed(data) - self._wakeup_waiter() - - # TODO: implement pause the read. Its needed - # expose the len of the buffer from hiredis - # to make it possible. - - async def readobj(self): - """ - Return a parsed Redis object or an exception - when something wrong happened. - """ - assert self._parser is not None, "set_parser must be called" - while True: - obj = self._parser.gets() - - if obj is not False: - # TODO: implement resume the read - - # Return any valid object and the Nil->None - # case. When its False there is nothing there - # to be parsed and we have to wait for more data. - return obj - - if self._exception: - raise self._exception - - if self._eof: - break - - await self._wait_for_data("readobj") - # NOTE: after break we return None which must be handled as b'' - - async def _read_not_allowed(self, *args, **kwargs): - raise RuntimeError("Use readobj") - - read = _read_not_allowed - readline = _read_not_allowed - readuntil = _read_not_allowed - readexactly = _read_not_allowed diff --git a/aioredis/util.py b/aioredis/util.py deleted file mode 100644 index 7a5887e5e..000000000 --- a/aioredis/util.py +++ /dev/null @@ -1,247 +0,0 @@ -import asyncio -import sys -from urllib.parse import parse_qsl, urlparse - -from .log import logger - -_NOTSET = object() - -IS_PY38 = sys.version_info >= (3, 8) - -# NOTE: never put here anything else; -# just this basic types -_converters = { - bytes: lambda val: val, - bytearray: lambda val: val, - str: lambda val: val.encode(), - int: lambda val: b"%d" % val, - float: lambda val: b"%r" % val, -} - - -def encode_command(*args, buf=None): - """Encodes arguments into redis bulk-strings array. - - Raises TypeError if any of args not of bytearray, bytes, float, int, or str - type. - """ - if buf is None: - buf = bytearray() - buf.extend(b"*%d\r\n" % len(args)) - - try: - for arg in args: - barg = _converters[type(arg)](arg) - buf.extend(b"$%d\r\n%s\r\n" % (len(barg), barg)) - except KeyError: - raise TypeError( - "Argument {!r} expected to be of bytearray, bytes," - " float, int, or str type".format(arg) - ) - return buf - - -def decode(obj, encoding): - if isinstance(obj, bytes): - return obj.decode(encoding) - elif isinstance(obj, list): - return [decode(o, encoding) for o in obj] - return obj - - -async def wait_ok(fut): - res = await fut - if res in (b"QUEUED", "QUEUED"): - return res - return res in (b"OK", "OK") - - -async def wait_convert(fut, type_, **kwargs): - result = await fut - if result in (b"QUEUED", "QUEUED"): - return result - return type_(result, **kwargs) - - -async def wait_make_dict(fut): - res = await fut - if res in (b"QUEUED", "QUEUED"): - return res - it = iter(res) - return dict(zip(it, it)) - - -class coerced_keys_dict(dict): - def __getitem__(self, other): - if not isinstance(other, bytes): - other = _converters[type(other)](other) - return dict.__getitem__(self, other) - - def __contains__(self, other): - if not isinstance(other, bytes): - other = _converters[type(other)](other) - return dict.__contains__(self, other) - - -class _ScanIter: - - __slots__ = ("_scan", "_cur", "_ret") - - def __init__(self, scan): - self._scan = scan - self._cur = b"0" - self._ret = [] - - def __aiter__(self): - return self - - async def __anext__(self): - while not self._ret and self._cur: - self._cur, self._ret = await self._scan(self._cur) - if not self._cur and not self._ret: - raise StopAsyncIteration - else: - ret = self._ret.pop(0) - return ret - - -def _set_result(fut, result, *info): - if fut.done(): - logger.debug("Waiter future is already done %r %r", fut, info) - assert fut.cancelled(), ("waiting future is in wrong state", fut, result, info) - else: - fut.set_result(result) - - -def _set_exception(fut, exception): - if fut.done(): - logger.debug("Waiter future is already done %r", fut) - assert fut.cancelled(), ("waiting future is in wrong state", fut, exception) - else: - fut.set_exception(exception) - - -def parse_url(url): - """Parse Redis connection URI. - - Parse according to IANA specs: - * https://www.iana.org/assignments/uri-schemes/prov/redis - * https://www.iana.org/assignments/uri-schemes/prov/rediss - - Also more rules applied: - - * empty scheme is treated as unix socket path no further parsing is done. - - * 'unix://' scheme is treated as unix socket path and parsed. - - * Multiple query parameter values and blank values are considered error. - - * DB number specified as path and as query parameter is considered error. - - * Password specified in userinfo and as query parameter is - considered error. - """ - r = urlparse(url) - - assert r.scheme in ("", "redis", "rediss", "unix"), ( - "Unsupported URI scheme", - r.scheme, - ) - if r.scheme == "": - return url, {} - query = {} - for p, v in parse_qsl(r.query, keep_blank_values=True): - assert p not in query, ("Multiple parameters are not allowed", p, v) - assert v, ("Empty parameters are not allowed", p, v) - query[p] = v - - if r.scheme == "unix": - assert r.path, ("Empty path is not allowed", url) - assert not r.netloc, ("Netlocation is not allowed for unix scheme", r.netloc) - return r.path, _parse_uri_options(query, "", r.password) - - address = (r.hostname or "localhost", int(r.port or 6379)) - path = r.path - if path.startswith("/"): - path = r.path[1:] - options = _parse_uri_options(query, path, r.password) - if r.scheme == "rediss": - options["ssl"] = True - return address, options - - -def _parse_uri_options(params, path, password): - def parse_db_num(val): - if not val: - return - assert val.isdecimal(), ("Invalid decimal integer", val) - assert val == "0" or not val.startswith("0"), ( - "Expected integer without leading zeroes", - val, - ) - return int(val) - - options = {} - - db1 = parse_db_num(path) - db2 = parse_db_num(params.get("db")) - assert db1 is None or db2 is None, ( - "Single DB value expected, got path and query", - db1, - db2, - ) - if db1 is not None: - options["db"] = db1 - elif db2 is not None: - options["db"] = db2 - - password2 = params.get("password") - assert ( - not password or not password2 - ), "Single password value is expected, got in net location and query" - if password: - options["password"] = password - elif password2: - options["password"] = password2 - - if "encoding" in params: - options["encoding"] = params["encoding"] - if "ssl" in params: - assert params["ssl"] in ("true", "false"), ( - "Expected 'ssl' param to be 'true' or 'false' only", - params["ssl"], - ) - options["ssl"] = params["ssl"] == "true" - - if "timeout" in params: - options["timeout"] = float(params["timeout"]) - return options - - -class CloseEvent: - def __init__(self, on_close): - self._close_init = asyncio.Event() - self._close_done = asyncio.Event() - self._on_close = on_close - - async def wait(self): - await self._close_init.wait() - await self._close_done.wait() - - def is_set(self): - return self._close_done.is_set() or self._close_init.is_set() - - def set(self): - if self._close_init.is_set(): - return - - task = asyncio.ensure_future(self._on_close()) - task.add_done_callback(self._cleanup) - self._close_init.set() - - def _cleanup(self, task): - self._on_close = None - self._close_done.set() - - -get_event_loop = getattr(asyncio, "get_running_loop", asyncio.get_event_loop) diff --git a/aioredis/utils.py b/aioredis/utils.py new file mode 100644 index 000000000..4f37a1abe --- /dev/null +++ b/aioredis/utils.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from contextlib import asynccontextmanager +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from aioredis import Redis + from aioredis.client import Pipeline + + +try: + import hiredis # noqa + + HIREDIS_AVAILABLE = True +except ImportError: + HIREDIS_AVAILABLE = False + + +def from_url(url, **kwargs): + """ + Returns an active Redis client generated from the given database URL. + + Will attempt to extract the database id from the path url fragment, if + none is provided. + """ + from aioredis.client import Redis + + return Redis.from_url(url, **kwargs) + + +@asynccontextmanager +async def pipeline(redis_obj: Redis) -> Pipeline: + p = redis_obj.pipeline() + yield p + await p.execute() + + +def str_if_bytes(value): + return ( + value.decode("utf-8", errors="replace") if isinstance(value, bytes) else value + ) + + +def safe_str(value): + return str(str_if_bytes(value)) diff --git a/examples/blocking.py b/examples/blocking.py index d6f8992ff..6ea08c397 100644 --- a/examples/blocking.py +++ b/examples/blocking.py @@ -5,7 +5,7 @@ async def blocking_commands(): # Redis client bound to pool of connections (auto-reconnecting). - redis = await aioredis.create_redis_pool("redis://localhost") + redis = aioredis.Redis.from_url("redis://localhost") async def get_message(): # Redis blocking commands block the connection they are on @@ -26,8 +26,7 @@ async def get_message(): print(future.result()) # gracefully closing underlying connection - redis.close() - await redis.wait_closed() + await redis.close() if __name__ == "__main__": diff --git a/examples/commands.py b/examples/commands.py index 515771ea9..fe89eefbc 100644 --- a/examples/commands.py +++ b/examples/commands.py @@ -5,26 +5,24 @@ async def main(): # Redis client bound to single connection (no auto reconnection). - redis = await aioredis.create_redis("redis://localhost") + redis = aioredis.Redis(host="localhost", single_connection_client=True) await redis.set("my-key", "value") val = await redis.get("my-key") print(val) # gracefully closing underlying connection - redis.close() - await redis.wait_closed() + await redis.close() async def redis_pool(): # Redis client bound to pool of connections (auto-reconnecting). - redis = await aioredis.create_redis_pool("redis://localhost") + redis = aioredis.Redis.from_url("redis://localhost") await redis.set("my-key", "value") val = await redis.get("my-key") print(val) # gracefully closing underlying connection - redis.close() - await redis.wait_closed() + await redis.close() if __name__ == "__main__": diff --git a/examples/connection.py b/examples/connection.py index 9ce7a8646..f3173cdcb 100644 --- a/examples/connection.py +++ b/examples/connection.py @@ -4,22 +4,20 @@ async def main(): - conn = await aioredis.create_connection("redis://localhost", encoding="utf-8") + conn = aioredis.Redis.from_url( + "redis://localhost", encoding="utf-8", decode_responses=True + ) - ok = await conn.execute("set", "my-key", "some value") - assert ok == "OK", ok + ok = await conn.execute_command("set", "my-key", "some value") + assert ok is True - str_value = await conn.execute("get", "my-key") - raw_value = await conn.execute("get", "my-key", encoding=None) + str_value = await conn.execute_command("get", "my-key") assert str_value == "some value" - assert raw_value == b"some value" print("str value:", str_value) - print("raw value:", raw_value) # optionally close connection - conn.close() - await conn.wait_closed() + await conn.close() if __name__ == "__main__": diff --git a/examples/getting_started/00_connect.py b/examples/getting_started/00_connect.py index 3aae91f18..6a9554fff 100644 --- a/examples/getting_started/00_connect.py +++ b/examples/getting_started/00_connect.py @@ -4,13 +4,12 @@ async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") + redis = await aioredis.Redis.from_url("redis://localhost") await redis.set("my-key", "value") value = await redis.get("my-key", encoding="utf-8") print(value) - redis.close() - await redis.wait_closed() + await redis.close() asyncio.run(main()) diff --git a/examples/getting_started/01_decoding.py b/examples/getting_started/01_decoding.py index 9ab179282..54ca3ad98 100644 --- a/examples/getting_started/01_decoding.py +++ b/examples/getting_started/01_decoding.py @@ -4,7 +4,7 @@ async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") + redis = await aioredis.Redis.from_url("redis://localhost") await redis.set("key", "string-value") bin_value = await redis.get("key") assert bin_value == b"string-value" @@ -12,8 +12,7 @@ async def main(): str_value = await redis.get("key", encoding="utf-8") assert str_value == "string-value" - redis.close() - await redis.wait_closed() + await redis.close() asyncio.run(main()) diff --git a/examples/getting_started/02_decoding.py b/examples/getting_started/02_decoding.py index 02c6fda23..54df5c49a 100644 --- a/examples/getting_started/02_decoding.py +++ b/examples/getting_started/02_decoding.py @@ -4,7 +4,7 @@ async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") + redis = await aioredis.Redis.from_url("redis://localhost") await redis.hmset_dict("hash", key1="value1", key2="value2", key3=123) @@ -15,8 +15,7 @@ async def main(): "key3": "123", # note that Redis returns int as string } - redis.close() - await redis.wait_closed() + await redis.close() asyncio.run(main()) diff --git a/examples/getting_started/03_multiexec.py b/examples/getting_started/03_multiexec.py index b09520859..8737557b5 100644 --- a/examples/getting_started/03_multiexec.py +++ b/examples/getting_started/03_multiexec.py @@ -4,12 +4,9 @@ async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") - - tr = redis.multi_exec() - tr.set("key1", "value1") - tr.set("key2", "value2") - ok1, ok2 = await tr.execute() + redis = await aioredis.Redis.from_url("redis://localhost") + async with redis.pipeline(transaction=True) as pipe: + ok1, ok2 = await (pipe.set("key1", "value1").set("key2", "value2").execute()) assert ok1 assert ok2 diff --git a/examples/getting_started/04_pubsub.py b/examples/getting_started/04_pubsub.py index b9ced7f1d..9cccdd0ec 100644 --- a/examples/getting_started/04_pubsub.py +++ b/examples/getting_started/04_pubsub.py @@ -1,27 +1,38 @@ import asyncio +import async_timeout + import aioredis +STOPWORD = "STOP" -async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") - ch1, ch2 = await redis.subscribe("channel:1", "channel:2") - assert isinstance(ch1, aioredis.Channel) - assert isinstance(ch2, aioredis.Channel) +async def reader(channel: aioredis.client.PubSub): + while True: + try: + async with async_timeout.timeout(1): + message = await channel.get_message(ignore_subscribe_messages=True) + if message is not None: + print(f"(Reader) Message Received: {message}") + if message["data"] == STOPWORD: + print("(Reader) STOP") + break + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + pass - async def reader(channel): - async for message in channel.iter(): - print("Got message:", message) - asyncio.get_running_loop().create_task(reader(ch1)) - asyncio.get_running_loop().create_task(reader(ch2)) +async def main(): + redis = aioredis.Redis.from_url("redis://localhost") + pubsub = redis.pubsub() + await pubsub.subscribe("channel:1", "channel:2") + + asyncio.create_task(reader(pubsub)) await redis.publish("channel:1", "Hello") await redis.publish("channel:2", "World") - - redis.close() - await redis.wait_closed() + await redis.publish("channel:1", STOPWORD) + await redis.close() asyncio.run(main()) diff --git a/examples/getting_started/05_pubsub.py b/examples/getting_started/05_pubsub.py index 315593841..cf9a95e76 100644 --- a/examples/getting_started/05_pubsub.py +++ b/examples/getting_started/05_pubsub.py @@ -1,25 +1,38 @@ import asyncio +import async_timeout + import aioredis +STOPWORD = "STOP" -async def main(): - redis = await aioredis.create_redis_pool("redis://localhost") - (ch,) = await redis.psubscribe("channel:*") - assert isinstance(ch, aioredis.Channel) +async def reader(channel: aioredis.client.PubSub): + while True: + try: + async with async_timeout.timeout(1): + message = await channel.get_message(ignore_subscribe_messages=True) + if message is not None: + print(f"(Reader) Message Received: {message}") + if message["data"] == STOPWORD: + print("(Reader) STOP") + break + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + pass - async def reader(channel): - async for ch, message in channel.iter(): - print("Got message in channel:", ch, ":", message) - asyncio.get_running_loop().create_task(reader(ch)) +async def main(): + redis = aioredis.Redis.from_url("redis://localhost") + pubsub = redis.pubsub() + await pubsub.psubscribe("channel:*") + + asyncio.create_task(reader(pubsub)) await redis.publish("channel:1", "Hello") await redis.publish("channel:2", "World") - - redis.close() - await redis.wait_closed() + await redis.publish("channel:1", STOPWORD) + await redis.close() asyncio.run(main()) diff --git a/examples/getting_started/06_sentinel.py b/examples/getting_started/06_sentinel.py index 0d03be155..0039dc682 100644 --- a/examples/getting_started/06_sentinel.py +++ b/examples/getting_started/06_sentinel.py @@ -1,10 +1,10 @@ import asyncio -import aioredis +import aioredis.sentinel async def main(): - sentinel = await aioredis.create_sentinel( + sentinel = aioredis.sentinel.Sentinel( ["redis://localhost:26379", "redis://sentinel2:26379"] ) redis = sentinel.master_for("mymaster") diff --git a/examples/pipeline.py b/examples/pipeline.py index 6b9f60134..7d60ebf66 100644 --- a/examples/pipeline.py +++ b/examples/pipeline.py @@ -4,7 +4,7 @@ async def main(): - redis = await aioredis.create_redis("redis://localhost") + redis = aioredis.Redis.from_url("redis://localhost") # No pipelining; async def wait_each_command(): @@ -13,7 +13,7 @@ async def wait_each_command(): return val, cnt # Sending multiple commands and then gathering results - async def pipelined(): + async def concurrent(): fut1 = redis.get("foo") # issue command and return future fut2 = redis.incr("bar") # issue command and return future # block until results are available @@ -23,22 +23,26 @@ async def pipelined(): # Explicit pipeline async def explicit_pipeline(): pipe = redis.pipeline() - fut1 = pipe.get("foo") - fut2 = pipe.incr("bar") + pipe.get("foo").incr("bar") result = await pipe.execute() - val, cnt = await asyncio.gather(fut1, fut2) - assert result == [val, cnt] - return val, cnt + return result + + async def context_pipeline(): + async with redis.pipeline() as pipe: + pipe.get("foo").incr("bar") + result = await pipe.execute() + return result res = await wait_each_command() print(res) - res = await pipelined() + res = await concurrent() print(res) res = await explicit_pipeline() print(res) + res = await context_pipeline() + print(res) - redis.close() - await redis.wait_closed() + await redis.close() if __name__ == "__main__": diff --git a/examples/pool.py b/examples/pool.py index b8ce1e8a5..1313b1ee7 100644 --- a/examples/pool.py +++ b/examples/pool.py @@ -4,13 +4,12 @@ async def main(): - pool = await aioredis.create_pool("redis://localhost", minsize=5, maxsize=10) - with await pool as conn: # low-level redis connection - await conn.execute("set", "my-key", "value") - val = await conn.execute("get", "my-key") + redis = aioredis.Redis.from_url("redis://localhost", max_connections=10) + async with redis as r: + await r.execute_command("set", "my-key", "value") + val = await r.execute_command("get", "my-key") print("raw value:", val) - pool.close() - await pool.wait_closed() # closing all open connections + await redis.close() if __name__ == "__main__": diff --git a/examples/pool_pubsub.py b/examples/pool_pubsub.py index a359761d7..91408d39c 100644 --- a/examples/pool_pubsub.py +++ b/examples/pool_pubsub.py @@ -1,68 +1,67 @@ import asyncio +import async_timeout + import aioredis STOPWORD = "STOP" async def pubsub(): - pool = await aioredis.create_pool("redis://localhost", minsize=5, maxsize=10) - - async def reader(channel): - while await channel.wait_message(): - msg = await channel.get(encoding="utf-8") - # ... process message ... - print(f"message in {channel.name}: {msg}") - - if msg == STOPWORD: - return - - with await pool as conn: - await conn.execute_pubsub("subscribe", "channel:1") - channel = conn.pubsub_channels["channel:1"] - await reader(channel) # wait for reader to complete - await conn.execute_pubsub("unsubscribe", "channel:1") - - # Explicit connection usage - conn = await pool.acquire() - try: - await conn.execute_pubsub("subscribe", "channel:1") - channel = conn.pubsub_channels["channel:1"] - await reader(channel) # wait for reader to complete - await conn.execute_pubsub("unsubscribe", "channel:1") - finally: - pool.release(conn) - - pool.close() - await pool.wait_closed() # closing all open connections - - -def main(): - loop = asyncio.get_event_loop() - tsk = asyncio.ensure_future(pubsub(), loop=loop) + redis = aioredis.Redis.from_url( + "redis://localhost", max_connections=10, decode_responses=True + ) + psub = redis.pubsub() + + async def reader(channel: aioredis.client.PubSub): + while True: + try: + async with async_timeout.timeout(1): + message = await channel.get_message(ignore_subscribe_messages=True) + if message is not None: + print(f"(Reader) Message Received: {message}") + if message["data"] == STOPWORD: + print("(Reader) STOP") + break + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + pass + + async with psub as p: + await p.subscribe("channel:1") + await reader(p) # wait for reader to complete + await p.unsubscribe("channel:1") + + # closing all open connections + await psub.close() + await redis.close() + + +async def main(): + tsk = asyncio.create_task(pubsub()) async def publish(): - pub = await aioredis.create_redis("redis://localhost") + pub = aioredis.Redis.from_url("redis://localhost", decode_responses=True) while not tsk.done(): # wait for clients to subscribe while True: - subs = await pub.pubsub_numsub("channel:1") - if subs[b"channel:1"] == 1: + subs = dict(await pub.pubsub_numsub("channel:1")) + if subs["channel:1"] == 1: break - await asyncio.sleep(0, loop=loop) + await asyncio.sleep(0) # publish some messages for msg in ["one", "two", "three"]: + print(f"(Publisher) Publishing Message: {msg}") await pub.publish("channel:1", msg) # send stop word await pub.publish("channel:1", STOPWORD) - pub.close() - await pub.wait_closed() + await pub.close() - loop.run_until_complete(asyncio.gather(publish(), tsk, loop=loop)) + await publish() if __name__ == "__main__": import os if "redis_version:2.6" not in os.environ.get("REDIS_VERSION", ""): - main() + asyncio.run(main()) diff --git a/examples/pubsub.py b/examples/pubsub.py deleted file mode 100644 index ef0629438..000000000 --- a/examples/pubsub.py +++ /dev/null @@ -1,30 +0,0 @@ -import asyncio - -import aioredis - - -async def reader(ch): - while await ch.wait_message(): - msg = await ch.get_json() - print("Got Message:", msg) - - -async def main(): - pub = await aioredis.create_redis("redis://localhost") - sub = await aioredis.create_redis("redis://localhost") - res = await sub.subscribe("chan:1") - ch1 = res[0] - - tsk = asyncio.ensure_future(reader(ch1)) - - res = await pub.publish_json("chan:1", ["Hello", "world"]) - assert res == 1 - - await sub.unsubscribe("chan:1") - await tsk - sub.close() - pub.close() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/examples/pubsub2.py b/examples/pubsub2.py deleted file mode 100644 index c7e8697d3..000000000 --- a/examples/pubsub2.py +++ /dev/null @@ -1,55 +0,0 @@ -import asyncio - -import aioredis - - -async def pubsub(): - sub = await aioredis.create_redis("redis://localhost") - - ch1, ch2 = await sub.subscribe("channel:1", "channel:2") - assert isinstance(ch1, aioredis.Channel) - assert isinstance(ch2, aioredis.Channel) - - async def async_reader(channel): - while await channel.wait_message(): - msg = await channel.get(encoding="utf-8") - # ... process message ... - print(f"message in {channel.name}: {msg}") - - tsk1 = asyncio.ensure_future(async_reader(ch1)) - - # Or alternatively: - - async def async_reader2(channel): - while True: - msg = await channel.get(encoding="utf-8") - if msg is None: - break - # ... process message ... - print(f"message in {channel.name}: {msg}") - - tsk2 = asyncio.ensure_future(async_reader2(ch2)) - - # Publish messages and terminate - pub = await aioredis.create_redis("redis://localhost") - while True: - channels = await pub.pubsub_channels("channel:*") - if len(channels) == 2: - break - - for msg in ("Hello", ",", "world!"): - for ch in ("channel:1", "channel:2"): - await pub.publish(ch, msg) - await asyncio.sleep(0.1) - pub.close() - sub.close() - await pub.wait_closed() - await sub.wait_closed() - await asyncio.gather(tsk1, tsk2) - - -if __name__ == "__main__": - import os - - if "redis_version:2.6" not in os.environ.get("REDIS_VERSION", ""): - asyncio.run(pubsub()) diff --git a/examples/scan.py b/examples/scan.py index 548dc235c..67c73eb13 100644 --- a/examples/scan.py +++ b/examples/scan.py @@ -5,16 +5,15 @@ async def main(): """Scan command example.""" - redis = await aioredis.create_redis("redis://localhost") + redis = aioredis.Redis.from_url("redis://localhost") - await redis.mset("key:1", "value1", "key:2", "value2") + await redis.mset({"key:1": "value1", "key:2": "value2"}) cur = b"0" # set initial cursor to 0 while cur: cur, keys = await redis.scan(cur, match="key:*") print("Iteration results:", keys) - redis.close() - await redis.wait_closed() + await redis.close() if __name__ == "__main__": diff --git a/examples/sentinel.py b/examples/sentinel.py index 68adc3909..e23b35c5a 100644 --- a/examples/sentinel.py +++ b/examples/sentinel.py @@ -1,18 +1,16 @@ import asyncio -import aioredis +import aioredis.sentinel async def main(): - sentinel_client = await aioredis.create_sentinel([("localhost", 26379)]) + sentinel_client = aioredis.sentinel.Sentinel([("localhost", 26379)]) - master_redis = sentinel_client.master_for("mymaster") - info = await master_redis.role() + master_redis: aioredis.Redis = sentinel_client.master_for("mymaster") + info = await master_redis.sentinel_master("mymaster") print("Master role:", info) - assert info.role == "master" - sentinel_client.close() - await sentinel_client.wait_closed() + await sentinel_client.close() if __name__ == "__main__": diff --git a/examples/transaction.py b/examples/transaction.py index 4d335ed7d..439babf88 100644 --- a/examples/transaction.py +++ b/examples/transaction.py @@ -4,18 +4,13 @@ async def main(): - redis = await aioredis.create_redis("redis://localhost") + redis = aioredis.Redis.from_url("redis://localhost") await redis.delete("foo", "bar") - tr = redis.multi_exec() - fut1 = tr.incr("foo") - fut2 = tr.incr("bar") - res = await tr.execute() - res2 = await asyncio.gather(fut1, fut2) + async with redis.pipeline(transaction=True) as pipe: + res = await (pipe.incr("foo").incr("bar").execute()) print(res) - assert res == res2 - redis.close() - await redis.wait_closed() + await redis.close() if __name__ == "__main__": diff --git a/examples/transaction2.py b/examples/transaction2.py deleted file mode 100644 index aed07f803..000000000 --- a/examples/transaction2.py +++ /dev/null @@ -1,23 +0,0 @@ -import asyncio - -import aioredis - - -async def main(): - redis = await aioredis.create_redis("redis://localhost") - - async def transaction(): - tr = redis.multi_exec() - future1 = tr.set("foo", "123") - future2 = tr.set("bar", "321") - result = await tr.execute() - assert result == await asyncio.gather(future1, future2) - return result - - await transaction() - redis.close() - await redis.wait_closed() - - -if __name__ == "__main__": - asyncio.run(main()) diff --git a/setup.py b/setup.py index 1ebd55abb..208108d3f 100644 --- a/setup.py +++ b/setup.py @@ -48,7 +48,11 @@ def read_version(): url="https://github.com/aio-libs/aioredis", license="MIT", packages=find_packages(exclude=["tests"]), - install_requires=["async-timeout", 'hiredis; implementation_name=="cpython"'], + install_requires=[ + "async-timeout", + 'hiredis>=1.0; implementation_name=="cpython"', + "typing-extensions", + ], python_requires=">=3.6", include_package_data=True, ) diff --git a/tests/coerced_keys_dict_test.py b/tests/coerced_keys_dict_test.py deleted file mode 100644 index 5c64cbc1b..000000000 --- a/tests/coerced_keys_dict_test.py +++ /dev/null @@ -1,55 +0,0 @@ -import pytest - -from aioredis.util import coerced_keys_dict - - -def test_simple(): - d = coerced_keys_dict() - assert d == {} - - d = coerced_keys_dict({b"a": "b", b"c": "d"}) - assert "a" in d - assert b"a" in d - assert "c" in d - assert b"c" in d - assert d == {b"a": "b", b"c": "d"} - - -def test_invalid_init(): - d = coerced_keys_dict({"foo": "bar"}) - assert d == {"foo": "bar"} - - assert "foo" not in d - assert b"foo" not in d - with pytest.raises(KeyError): - d["foo"] - with pytest.raises(KeyError): - d[b"foo"] - - d = coerced_keys_dict() - d.update({"foo": "bar"}) - assert d == {"foo": "bar"} - - assert "foo" not in d - assert b"foo" not in d - with pytest.raises(KeyError): - d["foo"] - with pytest.raises(KeyError): - d[b"foo"] - - -def test_valid_init(): - d = coerced_keys_dict({b"foo": "bar"}) - assert d == {b"foo": "bar"} - assert "foo" in d - assert b"foo" in d - assert d["foo"] == "bar" - assert d[b"foo"] == "bar" - - d = coerced_keys_dict() - d.update({b"foo": "bar"}) - assert d == {b"foo": "bar"} - assert "foo" in d - assert b"foo" in d - assert d["foo"] == "bar" - assert d[b"foo"] == "bar" diff --git a/tests/conftest.py b/tests/conftest.py index e0b1292c6..09e8d7c8d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,213 +1,26 @@ import argparse import asyncio -import atexit -import contextlib -import os -import socket -import ssl -import subprocess -import sys -import tempfile -import time -from collections import namedtuple -from urllib.parse import urlencode, urlunparse +import random +from distutils.version import StrictVersion +from typing import Type +from unittest.mock import AsyncMock +from urllib.parse import urlparse import pytest import aioredis -import aioredis.sentinel +from aioredis.client import Monitor +from aioredis.connection import parse_url -TCPAddress = namedtuple("TCPAddress", "host port") +# redis 6 release candidates report a version number of 5.9.x. Use this +# constant for skip_if decorators as a placeholder until 6.0.0 is officially +# released +REDIS_6_VERSION = "5.9.0" -RedisServer = namedtuple("RedisServer", "name tcp_address unixsocket version password") -SentinelServer = namedtuple( - "SentinelServer", "name tcp_address unixsocket version masters" -) +REDIS_INFO = {} +default_redis_url = "redis://localhost:6379/9" -# Public fixtures - - -@pytest.fixture(scope="session") -def event_loop(): - """Creates new event loop.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - -def _unused_tcp_port(): - """Find an unused localhost TCP port from 1024-65535 and return it.""" - with contextlib.closing(socket.socket()) as sock: - sock.bind(("127.0.0.1", 0)) - return sock.getsockname()[1] - - -@pytest.fixture(scope="session") -def tcp_port_factory(): - """A factory function, producing different unused TCP ports.""" - produced = set() - - def factory(): - """Return an unused port.""" - port = _unused_tcp_port() - - while port in produced: - port = _unused_tcp_port() - - produced.add(port) - - return port - - return factory - - -@pytest.fixture -def create_connection(_closable): - """Wrapper around aioredis.create_connection.""" - - async def f(*args, **kw): - conn = await aioredis.create_connection(*args, **kw) - _closable(conn) - return conn - - return f - - -@pytest.fixture( - params=[aioredis.create_redis, aioredis.create_redis_pool], ids=["single", "pool"] -) -def create_redis(_closable, request): - """Wrapper around aioredis.create_redis.""" - factory = request.param - - async def f(*args, **kw): - redis = await factory(*args, **kw) - _closable(redis) - return redis - - return f - - -@pytest.fixture -def create_pool(_closable): - """Wrapper around aioredis.create_pool.""" - - async def f(*args, **kw): - redis = await aioredis.create_pool(*args, **kw) - _closable(redis) - return redis - - return f - - -@pytest.fixture -def create_sentinel(_closable): - """Helper instantiating RedisSentinel client.""" - - async def f(*args, **kw): - # make it fail fast on slow CIs (if timeout argument is omitted) - kw.setdefault("timeout", 0.001) - client = await aioredis.sentinel.create_sentinel(*args, **kw) - _closable(client) - return client - - return f - - -@pytest.fixture -async def pool(create_pool, server): - """Returns RedisPool instance.""" - return await create_pool(server.tcp_address) - - -@pytest.fixture -async def redis(create_redis, server): - """Returns Redis client instance.""" - redis = await create_redis(server.tcp_address) - await redis.flushall() - yield redis - - -@pytest.fixture -async def redis_sentinel(create_sentinel, sentinel): - """Returns Redis Sentinel client instance.""" - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=2) - assert await redis_sentinel.ping() == b"PONG" - return redis_sentinel - - -@pytest.fixture -def _closable(event_loop): - conns = [] - - async def close(): - waiters = [] - while conns: - conn = conns.pop(0) - conn.close() - waiters.append(conn.wait_closed()) - if waiters: - await asyncio.gather(*waiters) - - try: - yield conns.append - finally: - event_loop.run_until_complete(close()) - - -@pytest.fixture(scope="session") -def server(start_server): - """Starts redis-server instance.""" - return start_server("A") - - -@pytest.fixture(scope="session") -def serverB(start_server): - """Starts redis-server instance.""" - return start_server("B") - - -@pytest.fixture(scope="session") -def sentinel(start_sentinel, request, start_server): - """Starts redis-sentinel instance with one master -- masterA.""" - # Adding main+replica for normal (no failover) tests: - main_no_fail = start_server("main-no-fail") - start_server("replica-no-fail", slaveof=main_no_fail) - # Adding master+slave for failover test; - mainA = start_server("mainA") - start_server("replicaA", slaveof=mainA) - return start_sentinel("main", mainA, main_no_fail) - - -@pytest.fixture(params=["path", "query"]) -def server_tcp_url(server, request): - def make(**kwargs): - netloc = "{0.host}:{0.port}".format(server.tcp_address) - path = "" - if request.param == "path": - if "password" in kwargs: - netloc = ":{0}@{1.host}:{1.port}".format( - kwargs.pop("password"), server.tcp_address - ) - if "db" in kwargs: - path = "/{}".format(kwargs.pop("db")) - query = urlencode(kwargs) - return urlunparse(("redis", netloc, path, "", query, "")) - - return make - - -@pytest.fixture -def server_unix_url(server): - def make(**kwargs): - query = urlencode(kwargs) - return urlunparse(("unix", "", server.unixsocket, "", query, "")) - - return make - - -# Internal stuff # # Taken from python3.9 class BooleanOptionalAction(argparse.Action): @@ -256,406 +69,211 @@ def format_usage(self): def pytest_addoption(parser): parser.addoption( - "--redis-server", - default=[], - action="append", - help="Path to redis-server executable," " defaults to `%(default)s`", - ) - parser.addoption( - "--ssl-cafile", - default="tests/ssl/cafile.crt", - help="Path to testing SSL CA file", - ) - parser.addoption( - "--ssl-dhparam", - default="tests/ssl/dhparam.pem", - help="Path to testing SSL DH params file", - ) - parser.addoption( - "--ssl-cert", default="tests/ssl/cert.pem", help="Path to testing SSL CERT file" + "--redis-url", + default=default_redis_url, + action="store", + help="Redis connection string, defaults to `%(default)s`", ) parser.addoption( "--uvloop", action=BooleanOptionalAction, help="Run tests with uvloop" ) -def _read_server_version(redis_bin): - args = [redis_bin, "--version"] - with subprocess.Popen(args, stdout=subprocess.PIPE) as proc: - version = proc.stdout.readline().decode("utf-8") - for part in version.split(): - if part.startswith("v="): - break - else: - raise RuntimeError(f"No version info can be found in {version}") - return tuple(map(int, part[2:].split("."))) +async def _get_info(redis_url): + client = aioredis.Redis.from_url(redis_url) + info = await client.info() + await client.connection_pool.disconnect() + return info -@contextlib.contextmanager -def config_writer(path): - with open(path, "wt") as f: +def pytest_sessionstart(session): + use_uvloop = session.config.getoption("--uvloop") - def write(*args): - print(*args, file=f) + if use_uvloop: + try: + import uvloop - yield write + uvloop.install() + except ImportError as e: + raise RuntimeError( + "Can not import uvloop, make sure it is installed" + ) from e + redis_url = session.config.getoption("--redis-url") + info = asyncio.get_event_loop().run_until_complete(_get_info(redis_url)) + version = info["redis_version"] + arch_bits = info["arch_bits"] + REDIS_INFO["version"] = version + REDIS_INFO["arch_bits"] = arch_bits -REDIS_SERVERS = [] -VERSIONS = {} +def skip_if_server_version_lt(min_version): + redis_version = REDIS_INFO["version"] + check = StrictVersion(redis_version) < StrictVersion(min_version) + return pytest.mark.skipif(check, reason=f"Redis version required >= {min_version}") -def format_version(srv): - return "redis_v{}".format(".".join(map(str, VERSIONS[srv]))) +def skip_if_server_version_gte(min_version): + redis_version = REDIS_INFO["version"] + check = StrictVersion(redis_version) >= StrictVersion(min_version) + return pytest.mark.skipif(check, reason=f"Redis version required < {min_version}") -@pytest.fixture(scope="session") -def start_server(_proc, request, tcp_port_factory, server_bin): - """Starts Redis server instance. - - Caches instances by name. - ``name`` param -- instance alias - ``config_lines`` -- optional list of config directives to put in config - (if no config_lines passed -- no config will be generated, - for backward compatibility). - """ - version = _read_server_version(server_bin) - verbose = request.config.getoption("-v") > 3 - - servers = {} - - def timeout(t): - end = time.time() + t - while time.time() <= end: - yield True - raise RuntimeError("Redis startup timeout expired") - - def maker(name, config_lines=None, *, slaveof=None, password=None): - print("Start REDIS", name) - assert slaveof is None or isinstance(slaveof, RedisServer), slaveof - if name in servers: - return servers[name] - - port = tcp_port_factory() - tcp_address = TCPAddress("localhost", port) - if sys.platform == "win32": - unixsocket = None - else: - unixsocket = f"/tmp/aioredis.{port}.sock" - dumpfile = f"dump-{port}.rdb" - data_dir = tempfile.gettempdir() - dumpfile_path = os.path.join(data_dir, dumpfile) - stdout_file = os.path.join(data_dir, f"aioredis.{port}.stdout") - tmp_files = [dumpfile_path, stdout_file] - if config_lines: - config = os.path.join(data_dir, f"aioredis.{port}.conf") - with config_writer(config) as write: - write("daemonize no") - write('save ""') - write("dir ", data_dir) - write("dbfilename", dumpfile) - write("port", port) - if unixsocket: - write("unixsocket", unixsocket) - tmp_files.append(unixsocket) - if password: - write(f'requirepass "{password}"') - write("# extra config") - for line in config_lines: - write(line) - if slaveof is not None: - write( - "slaveof {0.tcp_address.host} {0.tcp_address.port}".format( - slaveof - ) - ) - if password: - write(f'masterauth "{password}"') - args = [config] - tmp_files.append(config) - else: - args = [ - "--daemonize", - "no", - "--save", - '""', - "--dir", - data_dir, - "--dbfilename", - dumpfile, - "--port", - str(port), - ] - if unixsocket: - args += [ - "--unixsocket", - unixsocket, - ] - if password: - args += [f'--requirepass "{password}"'] - if slaveof is not None: - args += [ - "--slaveof", - str(slaveof.tcp_address.host), - str(slaveof.tcp_address.port), - ] - if password: - args += [f'--masterauth "{password}"'] - f = open(stdout_file, "w") - atexit.register(f.close) - proc = _proc( - server_bin, - *args, - stdout=f, - stderr=subprocess.STDOUT, - _clear_tmp_files=tmp_files, - ) - with open(stdout_file) as f: - for _ in timeout(10): - assert proc.poll() is None, ("Process terminated", proc.returncode) - log = f.readline() - if log and verbose: - print(name, ":", log, end="") - if "The server is now ready to accept connections " in log: - break - if slaveof is not None: - for _ in timeout(10): - log = f.readline() - if log and verbose: - print(name, ":", log, end="") - if "sync: Finished with success" in log: - break - info = RedisServer(name, tcp_address, unixsocket, version, password) - servers.setdefault(name, info) - print("Ready REDIS", name) - return info - - return maker +def skip_unless_arch_bits(arch_bits): + return pytest.mark.skipif( + REDIS_INFO["arch_bits"] != arch_bits, + reason=f"server is not {arch_bits}-bit", + ) -@pytest.fixture(scope="session") -def start_sentinel(_proc, request, tcp_port_factory, server_bin): - """Starts Redis Sentinel instances.""" - version = _read_server_version(server_bin) - verbose = request.config.getoption("-v") > 3 - - sentinels = {} - - def timeout(t): - end = time.time() + t - while time.time() <= end: - yield True - raise RuntimeError("Redis startup timeout expired") - - def maker( - name, - *masters, - quorum=1, - noslaves=False, - down_after_milliseconds=3000, - failover_timeout=1000, - ): - key = (name,) + masters - if key in sentinels: - return sentinels[key] - port = tcp_port_factory() - tcp_address = TCPAddress("localhost", port) - data_dir = tempfile.gettempdir() - config = os.path.join(data_dir, f"aioredis-sentinel.{port}.conf") - stdout_file = os.path.join(data_dir, f"aioredis-sentinel.{port}.stdout") - tmp_files = [config, stdout_file] - if sys.platform == "win32": - unixsocket = None - else: - unixsocket = os.path.join(data_dir, f"aioredis-sentinel.{port}.sock") - tmp_files.append(unixsocket) - - with config_writer(config) as write: - write("daemonize no") - write('save ""') - write("port", port) - if unixsocket: - write("unixsocket", unixsocket) - write("loglevel debug") - for master in masters: - write( - "sentinel monitor", - master.name, - "127.0.0.1", - master.tcp_address.port, - quorum, - ) - write( - "sentinel down-after-milliseconds", - master.name, - down_after_milliseconds, - ) - write("sentinel failover-timeout", master.name, failover_timeout) - write("sentinel auth-pass", master.name, master.password) - - f = open(stdout_file, "w") - atexit.register(f.close) - proc = _proc( - server_bin, - config, - "--sentinel", - stdout=f, - stderr=subprocess.STDOUT, - _clear_tmp_files=tmp_files, - ) - # XXX: wait sentinel see all masters and slaves; - all_masters = {m.name for m in masters} - if noslaves: - all_slaves = {} - else: - all_slaves = {m.name for m in masters} - with open(stdout_file) as f: - for _ in timeout(30): - assert proc.poll() is None, ("Process terminated", proc.returncode) - log = f.readline() - if log and verbose: - print(name, ":", log, end="") - for m in masters: - if f"# +monitor master {m.name}" in log: - all_masters.discard(m.name) - if "* +slave slave" in log and f"@ {m.name}" in log: - all_slaves.discard(m.name) - if not all_masters and not all_slaves: - break +async def _get_client( + cls: Type[aioredis.Redis], + request, + event_loop: asyncio.AbstractEventLoop, + single_connection_client: bool = True, + flushdb: bool = True, + **kwargs, +) -> aioredis.Redis: + """ + Helper for fixtures or tests that need a Redis client + + Uses the "--redis-url" command line argument for connection info. Unlike + ConnectionPool.from_url, keyword arguments to this function override + values specified in the URL. + """ + redis_url = request.config.getoption("--redis-url") + url_options = parse_url(redis_url) + url_options.update(kwargs) + pool = aioredis.ConnectionPool(**url_options) + client: aioredis.Redis = cls(connection_pool=pool) + if single_connection_client: + client = client.client() + await client.initialize() + if request: + + def teardown(): + async def ateardown(): + if flushdb: + try: + await client.flushdb() + except aioredis.ConnectionError: + # handle cases where a test disconnected a client + # just manually retry the flushdb + await client.flushdb() + await client.close() + await client.connection_pool.disconnect() + + if event_loop.is_running(): + event_loop.create_task(ateardown()) else: - raise RuntimeError("Could not start Sentinel") + event_loop.run_until_complete(ateardown()) - masters = {m.name: m for m in masters} - info = SentinelServer(name, tcp_address, unixsocket, version, masters) - sentinels.setdefault(key, info) - return info + request.addfinalizer(teardown) + return client - return maker +@pytest.fixture() +async def r(request, event_loop): + async with (await _get_client(aioredis.Redis, request, event_loop)) as client: + yield client -@pytest.fixture(scope="session") -def ssl_proxy(_proc, request, tcp_port_factory): - by_port = {} - - cafile = os.path.abspath(request.config.getoption("--ssl-cafile")) - certfile = os.path.abspath(request.config.getoption("--ssl-cert")) - dhfile = os.path.abspath(request.config.getoption("--ssl-dhparam")) - assert os.path.exists( - cafile - ), "Missing SSL CA file, run `make certificate` to generate new one" - assert os.path.exists( - certfile - ), "Missing SSL CERT file, run `make certificate` to generate new one" - assert os.path.exists( - dhfile - ), "Missing SSL DH params, run `make certificate` to generate new one" - - ssl_ctx = ssl.create_default_context(cafile=cafile) - ssl_ctx.check_hostname = False - ssl_ctx.verify_mode = ssl.CERT_NONE - ssl_ctx.load_dh_params(dhfile) - - def sockat(unsecure_port): - if unsecure_port in by_port: - return by_port[unsecure_port] - - secure_port = tcp_port_factory() - _proc( - "/usr/bin/socat", - "openssl-listen:{port}," - "dhparam={param}," - "cert={cert},verify=0,fork".format( - port=secure_port, param=dhfile, cert=certfile - ), - f"tcp-connect:localhost:{unsecure_port}", - ) - time.sleep(1) # XXX - by_port[unsecure_port] = secure_port, ssl_ctx - return secure_port, ssl_ctx - - return sockat - - -@pytest.yield_fixture(scope="session") -def _proc(): - processes = [] - tmp_files = set() - - def run(*commandline, _clear_tmp_files=(), **kwargs): - proc = subprocess.Popen(commandline, **kwargs) - processes.append(proc) - tmp_files.update(_clear_tmp_files) - return proc - - try: - yield run - finally: - while processes: - proc = processes.pop(0) - proc.terminate() - proc.wait() - for path in tmp_files: - try: - os.remove(path) - except OSError: - pass - - -def pytest_collection_modifyitems(session, config, items): - skip_by_version = [] - for item in items[:]: - marker = item.get_closest_marker("redis_version") - if marker is not None: - try: - version = VERSIONS[item.callspec.getparam("server_bin")] - except (KeyError, ValueError, AttributeError): - # TODO: throw noisy warning - continue - if version < marker.kwargs["version"]: - skip_by_version.append(item) - item.add_marker(pytest.mark.skip(reason=marker.kwargs["reason"])) - if "ssl_proxy" in item.fixturenames: - item.add_marker( - pytest.mark.skipif( - "not os.path.exists('/usr/bin/socat')", - reason="socat package required (apt-get install socat)", - ) - ) - if len(items) != len(skip_by_version): - for i in skip_by_version: - items.remove(i) - - -def pytest_configure(config): - bins = config.getoption("--redis-server")[:] - cmd = "which redis-server" - if not bins: - with os.popen(cmd) as pipe: - path = pipe.read().rstrip() - assert path, ( - "There is no redis-server on your computer." " Please install it first" - ) - REDIS_SERVERS[:] = [path] - else: - REDIS_SERVERS[:] = bins - VERSIONS.update({srv: _read_server_version(srv) for srv in REDIS_SERVERS}) - assert VERSIONS, ("Expected to detect redis versions", REDIS_SERVERS) +@pytest.fixture() +async def r2(request, event_loop): + """A second client for tests that need multiple""" + async with (await _get_client(aioredis.Redis, request, event_loop)) as client: + yield client - class DynamicFixturePlugin: - @pytest.fixture(scope="session", params=REDIS_SERVERS, ids=format_version) - def server_bin(self, request): - """Common for start_server and start_sentinel - server bin path parameter. - """ - return request.param - config.pluginmanager.register(DynamicFixturePlugin(), "server-bin-fixture") +def _gen_cluster_mock_resp(r, response): + connection = AsyncMock() + connection.read_response.return_value = response + r.connection = connection + return r - if config.getoption("--uvloop"): - try: - import uvloop - except ImportError: - raise RuntimeError("Can not import uvloop, make sure it is installed") - asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) + +@pytest.fixture() +async def mock_cluster_resp_ok(request, event_loop, **kwargs): + r = await _get_client(aioredis.Redis, request, event_loop, **kwargs) + return _gen_cluster_mock_resp(r, "OK") + + +@pytest.fixture() +async def mock_cluster_resp_int(request, event_loop, **kwargs): + r = await _get_client(aioredis.Redis, request, event_loop, **kwargs) + return _gen_cluster_mock_resp(r, "2") + + +@pytest.fixture() +async def mock_cluster_resp_info(request, event_loop, **kwargs): + r = await _get_client(aioredis.Redis, request, event_loop, **kwargs) + response = ( + "cluster_state:ok\r\ncluster_slots_assigned:16384\r\n" + "cluster_slots_ok:16384\r\ncluster_slots_pfail:0\r\n" + "cluster_slots_fail:0\r\ncluster_known_nodes:7\r\n" + "cluster_size:3\r\ncluster_current_epoch:7\r\n" + "cluster_my_epoch:2\r\ncluster_stats_messages_sent:170262\r\n" + "cluster_stats_messages_received:105653\r\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture() +async def mock_cluster_resp_nodes(request, event_loop, **kwargs): + r = await _get_client(aioredis.Redis, request, event_loop, **kwargs) + response = ( + "c8253bae761cb1ecb2b61857d85dfe455a0fec8b 172.17.0.7:7006 " + "slave aa90da731f673a99617dfe930306549a09f83a6b 0 " + "1447836263059 5 connected\n" + "9bd595fe4821a0e8d6b99d70faa660638a7612b3 172.17.0.7:7008 " + "master - 0 1447836264065 0 connected\n" + "aa90da731f673a99617dfe930306549a09f83a6b 172.17.0.7:7003 " + "myself,master - 0 0 2 connected 5461-10922\n" + "1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836262556 3 connected\n" + "4ad9a12e63e8f0207025eeba2354bcf4c85e5b22 172.17.0.7:7005 " + "master - 0 1447836262555 7 connected 0-5460\n" + "19efe5a631f3296fdf21a5441680f893e8cc96ec 172.17.0.7:7004 " + "master - 0 1447836263562 3 connected 10923-16383\n" + "fbb23ed8cfa23f17eaf27ff7d0c410492a1093d6 172.17.0.7:7002 " + "master,fail - 1447829446956 1447829444948 1 disconnected\n" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture() +async def mock_cluster_resp_slaves(request, event_loop, **kwargs): + r = await _get_client(aioredis.Redis, request, event_loop, **kwargs) + response = ( + "['1df047e5a594f945d82fc140be97a1452bcbf93e 172.17.0.7:7007 " + "slave 19efe5a631f3296fdf21a5441680f893e8cc96ec 0 " + "1447836789290 3 connected']" + ) + return _gen_cluster_mock_resp(r, response) + + +@pytest.fixture(scope="session") +def master_host(request): + url = request.config.getoption("--redis-url") + parts = urlparse(url) + yield parts.hostname + + +async def wait_for_command(client: aioredis.Redis, monitor: Monitor, command: str): + # issue a command with a key name that's local to this process. + # if we find a command with our key before the command we're waiting + # for, something went wrong + redis_version = REDIS_INFO["version"] + if StrictVersion(redis_version) >= StrictVersion("5.0.0"): + id_str = str(await client.client_id()) + else: + id_str = "%08x" % random.randrange(2 ** 32) + key = "__REDIS-PY-%s__" % id_str + await client.get(key) + while True: + monitor_response = await monitor.next_command() + if command in monitor_response["command"]: + return monitor_response + if key in monitor_response["command"]: + return None diff --git a/tests/connection_commands_test.py b/tests/connection_commands_test.py deleted file mode 100644 index 0bad1424e..000000000 --- a/tests/connection_commands_test.py +++ /dev/null @@ -1,114 +0,0 @@ -import asyncio - -import pytest - -from aioredis import ConnectionClosedError, Redis, ReplyError -from aioredis.pool import ConnectionsPool -from tests.testutils import redis_version - - -@pytest.mark.asyncio -async def test_repr(create_redis, server): - redis = await create_redis(server.tcp_address, db=1) - assert repr(redis) in { - ">", - ">", - } - - redis = await create_redis(server.tcp_address, db=0) - assert repr(redis) in { - ">", - ">", - } - - -@pytest.mark.asyncio -async def test_auth(redis): - expected_message = "ERR Client sent AUTH, but no password is set" - with pytest.raises(ReplyError, match=expected_message): - await redis.auth("") - - -@pytest.mark.asyncio -async def test_echo(redis): - resp = await redis.echo("ECHO") - assert resp == b"ECHO" - - with pytest.raises(TypeError): - await redis.echo(None) - - -@pytest.mark.asyncio -async def test_ping(redis): - assert await redis.ping() == b"PONG" - - -@pytest.mark.asyncio -async def test_quit(redis): - expected = (ConnectionClosedError, ConnectionError) - try: - assert b"OK" == await redis.quit() - except expected: - pass - - if not isinstance(redis.connection, ConnectionsPool): - # reader task may not yet been cancelled and _do_close not called - # so the ConnectionClosedError may be raised (or ConnectionError) - with pytest.raises(expected): - try: - await redis.ping() - except asyncio.CancelledError: - assert False, "Cancelled error must not be raised" - - # wait one loop iteration until it get surely closed - await asyncio.sleep(0) - assert redis.connection.closed - - with pytest.raises(ConnectionClosedError): - await redis.ping() - - -@pytest.mark.asyncio -async def test_select(redis): - assert redis.db == 0 - - resp = await redis.select(1) - assert resp is True - assert redis.db == 1 - assert redis.connection.db == 1 - - -@pytest.mark.asyncio -async def test_encoding(create_redis, server): - redis = await create_redis(server.tcp_address, db=1, encoding="utf-8") - assert redis.encoding == "utf-8" - - -@pytest.mark.asyncio -async def test_yield_from_backwards_compatibility(create_redis, server): - redis = await create_redis(server.tcp_address) - - assert isinstance(redis, Redis) - # TODO: there should not be warning - # with pytest.warns(UserWarning): - with await redis as client: - assert isinstance(client, Redis) - assert client is not redis - assert await client.ping() - - -@redis_version(4, 0, 0, reason="SWAPDB is available since redis>=4.0.0") -@pytest.mark.asyncio -async def test_swapdb(create_redis, start_server): - server = start_server("swapdb_1") - cli1 = await create_redis(server.tcp_address, db=0) - cli2 = await create_redis(server.tcp_address, db=1) - - await cli1.flushall() - assert await cli1.set("key", "val") is True - assert await cli1.exists("key") - assert not await cli2.exists("key") - - assert await cli1.swapdb(0, 1) is True - assert not await cli1.exists("key") - assert await cli2.exists("key") diff --git a/tests/connection_test.py b/tests/connection_test.py deleted file mode 100644 index 09155ace7..000000000 --- a/tests/connection_test.py +++ /dev/null @@ -1,547 +0,0 @@ -import asyncio -import sys -from unittest import mock -from unittest.mock import patch - -import pytest - -from aioredis import ( - Channel, - ConnectionClosedError, - MaxClientsError, - ProtocolError, - RedisConnection, - RedisError, - ReplyError, -) -from tests.testutils import delay_exc, redis_version, select_opener - - -@pytest.mark.asyncio -async def test_connect_tcp(request, create_connection, server): - conn = await create_connection(server.tcp_address) - assert conn.db == 0 - assert isinstance(conn.address, tuple) - assert conn.address[0] in ("127.0.0.1", "::1") - assert conn.address[1] == server.tcp_address.port - assert str(conn) == "" - - conn = await create_connection(["localhost", server.tcp_address.port]) - assert conn.db == 0 - assert isinstance(conn.address, tuple) - assert conn.address[0] in ("127.0.0.1", "::1") - assert conn.address[1] == server.tcp_address.port - assert str(conn) == "" - - -@pytest.mark.asyncio -async def test_connect_inject_connection_cls(request, create_connection, server): - class MyConnection(RedisConnection): - pass - - conn = await create_connection(server.tcp_address, connection_cls=MyConnection) - - assert isinstance(conn, MyConnection) - - -@pytest.mark.asyncio -async def test_connect_inject_connection_cls_invalid(create_connection, server): - with pytest.raises(AssertionError): - await create_connection(server.tcp_address, connection_cls=type) - - -@pytest.mark.asyncio -async def test_tcp_connect_timeout(create_connection, server): - target, address = select_opener("tcp", server) - with patch("aioredis.connection.open_connection", delay_exc(target, secs=0.2)): - with pytest.raises(asyncio.TimeoutError): - await create_connection(address, timeout=0.1) - - -@pytest.mark.asyncio -async def test_connect_tcp_invalid_timeout(request, create_connection, server): - with pytest.raises(ValueError): - await create_connection(server.tcp_address, timeout=0) - - -@pytest.mark.skipif(sys.platform == "win32", reason="No unixsocket on Windows") -@pytest.mark.asyncio -async def test_connect_unixsocket(create_connection, server): - conn = await create_connection(server.unixsocket, db=0) - assert conn.db == 0 - assert conn.address == server.unixsocket - assert str(conn) == "" - - -@pytest.mark.skipif(sys.platform == "win32", reason="No unixsocket on Windows") -@pytest.mark.asyncio -async def test_connect_unixsocket_timeout(create_connection, server): - target, address = select_opener("unix", server) - with patch("aioredis.connection.open_unix_connection", delay_exc(target, secs=0.2)): - with pytest.raises(asyncio.TimeoutError): - await create_connection(address, timeout=0.1) - - -@redis_version(2, 8, 0, reason="maxclients config setting") -@pytest.mark.asyncio -async def test_connect_maxclients(create_connection, start_server): - server = start_server("server-maxclients") - conn = await create_connection(server.tcp_address) - await conn.execute(b"CONFIG", b"SET", "maxclients", 1) - - errors = (MaxClientsError, ConnectionClosedError, ConnectionError) - with pytest.raises(errors): - conn2 = await create_connection(server.tcp_address) - await conn2.execute("ping") - - -@pytest.mark.asyncio -async def test_select_db(create_connection, server): - address = server.tcp_address - conn = await create_connection(address) - assert conn.db == 0 - - with pytest.raises(ValueError): - await create_connection(address, db=-1) - with pytest.raises(TypeError): - await create_connection(address, db=1.0) - with pytest.raises(TypeError): - await create_connection(address, db="bad value") - with pytest.raises(TypeError): - conn = await create_connection(address, db=None) - await conn.select(None) - with pytest.raises(ReplyError): - await create_connection(address, db=100000) - - await conn.select(1) - assert conn.db == 1 - await conn.select(2) - assert conn.db == 2 - await conn.execute("select", 0) - assert conn.db == 0 - await conn.execute(b"select", 1) - assert conn.db == 1 - - -@pytest.mark.asyncio -async def test_protocol_error(create_connection, server): - conn = await create_connection(server.tcp_address) - - reader = conn._reader - - with pytest.raises(ProtocolError): - reader.feed_data(b"not good redis protocol response") - await conn.select(1) - - assert len(conn._waiters) == 0 - - -def test_close_connection__tcp(create_connection, event_loop, server): - conn = event_loop.run_until_complete(create_connection(server.tcp_address)) - conn.close() - with pytest.raises(ConnectionClosedError): - event_loop.run_until_complete(conn.select(1)) - - conn = event_loop.run_until_complete(create_connection(server.tcp_address)) - conn.close() - fut = None - with pytest.raises(ConnectionClosedError): - fut = conn.select(1) - assert fut is None - - conn = event_loop.run_until_complete(create_connection(server.tcp_address)) - conn.close() - with pytest.raises(ConnectionClosedError): - conn.execute_pubsub("subscribe", "channel:1") - - -@pytest.mark.skipif(sys.platform == "win32", reason="No unixsocket on Windows") -@pytest.mark.asyncio -async def test_close_connection__socket(create_connection, server): - conn = await create_connection(server.unixsocket) - conn.close() - with pytest.raises(ConnectionClosedError): - await conn.select(1) - - conn = await create_connection(server.unixsocket) - conn.close() - with pytest.raises(ConnectionClosedError): - await conn.execute_pubsub("subscribe", "channel:1") - - -@pytest.mark.asyncio -async def test_closed_connection_with_none_reader(create_connection, server): - address = server.tcp_address - conn = await create_connection(address) - stored_reader = conn._reader - conn._reader = None - with pytest.raises(ConnectionClosedError): - await conn.execute("blpop", "test", 0) - conn._reader = stored_reader - conn.close() - - conn = await create_connection(address) - stored_reader = conn._reader - conn._reader = None - with pytest.raises(ConnectionClosedError): - await conn.execute_pubsub("subscribe", "channel:1") - conn._reader = stored_reader - conn.close() - - -@pytest.mark.asyncio -async def test_wait_closed(create_connection, server): - address = server.tcp_address - conn = await create_connection(address) - reader_task = conn._reader_task - conn.close() - assert not reader_task.done() - await conn.wait_closed() - assert reader_task.done() - - -@pytest.mark.asyncio -async def test_cancel_wait_closed(create_connection, event_loop, server): - # Regression test: Don't throw error if wait_closed() is cancelled. - address = server.tcp_address - conn = await create_connection(address) - reader_task = conn._reader_task - conn.close() - task = asyncio.ensure_future(conn.wait_closed()) - - # Make sure the task is cancelled - # after it has been started by the loop. - event_loop.call_soon(task.cancel) - - await conn.wait_closed() - assert reader_task.done() - - -@pytest.mark.asyncio -async def test_auth(create_connection, server): - conn = await create_connection(server.tcp_address) - - res = await conn.execute("CONFIG", "SET", "requirepass", "pass") - assert res == b"OK" - - conn2 = await create_connection(server.tcp_address) - - with pytest.raises(ReplyError): - await conn2.select(1) - - res = await conn2.auth("pass") - assert res is True - res = await conn2.select(1) - assert res is True - - conn3 = await create_connection(server.tcp_address, password="pass") - - res = await conn3.select(1) - assert res is True - - res = await conn2.execute("CONFIG", "SET", "requirepass", "") - assert res == b"OK" - - -@pytest.mark.asyncio -async def test_decoding(create_connection, server): - conn = await create_connection(server.tcp_address, encoding="utf-8") - assert conn.encoding == "utf-8" - res = await conn.execute("set", "{prefix}:key1", "value") - assert res == "OK" - res = await conn.execute("get", "{prefix}:key1") - assert res == "value" - - res = await conn.execute("set", "{prefix}:key1", b"bin-value") - assert res == "OK" - res = await conn.execute("get", "{prefix}:key1") - assert res == "bin-value" - - res = await conn.execute("get", "{prefix}:key1", encoding="ascii") - assert res == "bin-value" - res = await conn.execute("get", "{prefix}:key1", encoding=None) - assert res == b"bin-value" - - with pytest.raises(UnicodeDecodeError): - await conn.execute("set", "{prefix}:key1", "значение") - await conn.execute("get", "{prefix}:key1", encoding="ascii") - - conn2 = await create_connection(server.tcp_address) - res = await conn2.execute("get", "{prefix}:key1", encoding="utf-8") - assert res == "значение" - - -@pytest.mark.asyncio -async def test_execute_exceptions(create_connection, server): - conn = await create_connection(server.tcp_address) - with pytest.raises(TypeError): - await conn.execute(None) - with pytest.raises(TypeError): - await conn.execute("ECHO", None) - with pytest.raises(TypeError): - await conn.execute("GET", ("a", "b")) - assert len(conn._waiters) == 0 - - -@pytest.mark.asyncio -async def test_subscribe_unsubscribe(create_connection, server): - conn = await create_connection(server.tcp_address) - - assert conn.in_pubsub == 0 - - res = await conn.execute("subscribe", "chan:1") - assert res == [[b"subscribe", b"chan:1", 1]] - - assert conn.in_pubsub == 1 - - res = await conn.execute("unsubscribe", "chan:1") - assert res == [[b"unsubscribe", b"chan:1", 0]] - assert conn.in_pubsub == 0 - - res = await conn.execute("subscribe", "chan:1", "chan:2") - assert res == [ - [b"subscribe", b"chan:1", 1], - [b"subscribe", b"chan:2", 2], - ] - assert conn.in_pubsub == 2 - - res = await conn.execute("unsubscribe", "non-existent") - assert res == [[b"unsubscribe", b"non-existent", 2]] - assert conn.in_pubsub == 2 - - res = await conn.execute("unsubscribe", "chan:1") - assert res == [[b"unsubscribe", b"chan:1", 1]] - assert conn.in_pubsub == 1 - - -@pytest.mark.asyncio -async def test_psubscribe_punsubscribe(create_connection, server): - conn = await create_connection(server.tcp_address) - res = await conn.execute("psubscribe", "chan:*") - assert res == [[b"psubscribe", b"chan:*", 1]] - assert conn.in_pubsub == 1 - - -@pytest.mark.asyncio -async def test_bad_command_in_pubsub(create_connection, server): - conn = await create_connection(server.tcp_address) - - res = await conn.execute("subscribe", "chan:1") - assert res == [[b"subscribe", b"chan:1", 1]] - - msg = "Connection in SUBSCRIBE mode" - with pytest.raises(RedisError, match=msg): - await conn.execute("select", 1) - with pytest.raises(RedisError, match=msg): - conn.execute("get") - - -@pytest.mark.asyncio -async def test_pubsub_messages(create_connection, server): - sub = await create_connection(server.tcp_address) - pub = await create_connection(server.tcp_address) - res = await sub.execute("subscribe", "chan:1") - assert res == [[b"subscribe", b"chan:1", 1]] - - assert b"chan:1" in sub.pubsub_channels - chan = sub.pubsub_channels[b"chan:1"] - assert str(chan) == "" - assert chan.name == b"chan:1" - assert chan.is_active is True - - res = await pub.execute("publish", "chan:1", "Hello!") - assert res == 1 - msg = await chan.get() - assert msg == b"Hello!" - - res = await sub.execute("psubscribe", "chan:*") - assert res == [[b"psubscribe", b"chan:*", 2]] - assert b"chan:*" in sub.pubsub_patterns - chan2 = sub.pubsub_patterns[b"chan:*"] - assert chan2.name == b"chan:*" - assert chan2.is_active is True - - res = await pub.execute("publish", "chan:1", "Hello!") - assert res == 2 - - msg = await chan.get() - assert msg == b"Hello!" - dest_chan, msg = await chan2.get() - assert dest_chan == b"chan:1" - assert msg == b"Hello!" - - -@pytest.mark.asyncio -async def test_multiple_subscribe_unsubscribe(create_connection, server): - sub = await create_connection(server.tcp_address) - - res = await sub.execute_pubsub("subscribe", "chan:1") - ch = sub.pubsub_channels["chan:1"] - assert res == [[b"subscribe", b"chan:1", 1]] - res = await sub.execute_pubsub("subscribe", b"chan:1") - assert res == [[b"subscribe", b"chan:1", 1]] - assert ch is sub.pubsub_channels["chan:1"] - res = await sub.execute_pubsub("subscribe", ch) - assert res == [[b"subscribe", b"chan:1", 1]] - assert ch is sub.pubsub_channels["chan:1"] - - res = await sub.execute_pubsub("unsubscribe", "chan:1") - assert res == [[b"unsubscribe", b"chan:1", 0]] - res = await sub.execute_pubsub("unsubscribe", "chan:1") - assert res == [[b"unsubscribe", b"chan:1", 0]] - - res = await sub.execute_pubsub("psubscribe", "chan:*") - assert res == [[b"psubscribe", b"chan:*", 1]] - res = await sub.execute_pubsub("psubscribe", "chan:*") - assert res == [[b"psubscribe", b"chan:*", 1]] - - res = await sub.execute_pubsub("punsubscribe", "chan:*") - assert res == [[b"punsubscribe", b"chan:*", 0]] - res = await sub.execute_pubsub("punsubscribe", "chan:*") - assert res == [[b"punsubscribe", b"chan:*", 0]] - - -@pytest.mark.asyncio -async def test_execute_pubsub_errors(create_connection, server): - sub = await create_connection(server.tcp_address) - - with pytest.raises(TypeError): - sub.execute_pubsub("subscribe", "chan:1", None) - with pytest.raises(TypeError): - sub.execute_pubsub("subscribe") - with pytest.raises(ValueError): - sub.execute_pubsub("subscribe", Channel("chan:1", is_pattern=True)) - with pytest.raises(ValueError): - sub.execute_pubsub("unsubscribe", Channel("chan:1", is_pattern=True)) - with pytest.raises(ValueError): - sub.execute_pubsub("psubscribe", Channel("chan:1", is_pattern=False)) - with pytest.raises(ValueError): - sub.execute_pubsub("punsubscribe", Channel("chan:1", is_pattern=False)) - - -@pytest.mark.asyncio -async def test_multi_exec(create_connection, server): - conn = await create_connection(server.tcp_address) - - ok = await conn.execute("set", "foo", "bar") - assert ok == b"OK" - - ok = await conn.execute("MULTI") - assert ok == b"OK" - queued = await conn.execute("getset", "foo", "baz") - assert queued == b"QUEUED" - res = await conn.execute("EXEC") - assert res == [b"bar"] - - ok = await conn.execute("MULTI") - assert ok == b"OK" - queued = await conn.execute("getset", "foo", "baz") - assert queued == b"QUEUED" - res = await conn.execute("DISCARD") - assert res == b"OK" - - -@pytest.mark.asyncio -async def test_multi_exec__enc(create_connection, server): - conn = await create_connection(server.tcp_address, encoding="utf-8") - - ok = await conn.execute("set", "foo", "bar") - assert ok == "OK" - - ok = await conn.execute("MULTI") - assert ok == "OK" - queued = await conn.execute("getset", "foo", "baz") - assert queued == "QUEUED" - res = await conn.execute("EXEC") - assert res == ["bar"] - - ok = await conn.execute("MULTI") - assert ok == "OK" - queued = await conn.execute("getset", "foo", "baz") - assert queued == "QUEUED" - res = await conn.execute("DISCARD") - assert res == "OK" - - -@pytest.mark.asyncio -async def test_connection_parser_argument(create_connection, server): - klass = mock.MagicMock() - klass.return_value = reader = mock.Mock() - conn = await create_connection(server.tcp_address, parser=klass) - - assert klass.mock_calls == [ - mock.call(protocolError=ProtocolError, replyError=ReplyError), - ] - - response = [False] - - def feed_gets(data, **kwargs): - response[0] = data - - reader.gets.side_effect = lambda *args, **kwargs: response[0] - reader.feed.side_effect = feed_gets - assert b"+PONG\r\n" == await conn.execute("ping") - - -@pytest.mark.asyncio -async def test_connection_idle_close(create_connection, start_server): - server = start_server("idle") - conn = await create_connection(server.tcp_address) - ok = await conn.execute("config", "set", "timeout", 1) - assert ok == b"OK" - - await asyncio.sleep(6) - - with pytest.raises(ConnectionClosedError): - assert await conn.execute("ping") is None - - -@pytest.mark.parametrize( - "kwargs", - [ - {}, - {"db": 1}, - {"encoding": "utf-8"}, - ], - ids=repr, -) -@pytest.mark.asyncio -async def test_create_connection__tcp_url(create_connection, server_tcp_url, kwargs): - url = server_tcp_url(**kwargs) - db = kwargs.get("db", 0) - enc = kwargs.get("encoding", None) - conn = await create_connection(url) - pong = b"PONG" if not enc else b"PONG".decode(enc) - assert await conn.execute("ping") == pong - assert conn.db == db - assert conn.encoding == enc - - -@pytest.mark.skipif('sys.platform == "win32"', reason="No unix sockets on Windows") -@pytest.mark.parametrize( - "kwargs", - [ - {}, - {"db": 1}, - {"encoding": "utf-8"}, - ], - ids=repr, -) -@pytest.mark.asyncio -async def test_create_connection__unix_url(create_connection, server_unix_url, kwargs): - url = server_unix_url(**kwargs) - db = kwargs.get("db", 0) - enc = kwargs.get("encoding", None) - conn = await create_connection(url) - pong = b"PONG" if not enc else b"PONG".decode(enc) - assert await conn.execute("ping") == pong - assert conn.db == db - assert conn.encoding == enc - - -@pytest.mark.asyncio -async def test_connect_setname(request, create_connection, server): - name = "test" - conn = await create_connection(server.tcp_address, name=name) - res = await conn.execute(b"CLIENT", b"GETNAME") - assert res == bytes(name, "utf-8") diff --git a/tests/encode_command_test.py b/tests/encode_command_test.py deleted file mode 100644 index fb7292967..000000000 --- a/tests/encode_command_test.py +++ /dev/null @@ -1,63 +0,0 @@ -import pytest - -from aioredis.util import encode_command - - -def test_encode_bytes(): - res = encode_command(b"Hello") - assert res == b"*1\r\n$5\r\nHello\r\n" - - res = encode_command(b"Hello", b"World") - assert res == b"*2\r\n$5\r\nHello\r\n$5\r\nWorld\r\n" - - res = encode_command(b"\0") - assert res == b"*1\r\n$1\r\n\0\r\n" - - res = encode_command(bytearray(b"Hello\r\n")) - assert res == b"*1\r\n$7\r\nHello\r\n\r\n" - - -def test_encode_bytearray(): - res = encode_command(bytearray(b"Hello")) - assert res == b"*1\r\n$5\r\nHello\r\n" - - res = encode_command(bytearray(b"Hello"), bytearray(b"world")) - assert res == b"*2\r\n$5\r\nHello\r\n$5\r\nworld\r\n" - - -def test_encode_str(): - res = encode_command("Hello") - assert res == b"*1\r\n$5\r\nHello\r\n" - - res = encode_command("Hello", "world") - assert res == b"*2\r\n$5\r\nHello\r\n$5\r\nworld\r\n" - - -def test_encode_int(): - res = encode_command(1) - assert res == b"*1\r\n$1\r\n1\r\n" - - res = encode_command(-1) - assert res == b"*1\r\n$2\r\n-1\r\n" - - -def test_encode_float(): - res = encode_command(1.0) - assert res == b"*1\r\n$3\r\n1.0\r\n" - - res = encode_command(-1.0) - assert res == b"*1\r\n$4\r\n-1.0\r\n" - - -def test_encode_empty(): - res = encode_command() - assert res == b"*0\r\n" - - -def test_encode_errors(): - with pytest.raises(TypeError): - encode_command(dict()) - with pytest.raises(TypeError): - encode_command(list()) - with pytest.raises(TypeError): - encode_command(None) diff --git a/tests/errors_test.py b/tests/errors_test.py deleted file mode 100644 index eb00e669b..000000000 --- a/tests/errors_test.py +++ /dev/null @@ -1,16 +0,0 @@ -from aioredis.errors import MaxClientsError, ReplyError - - -def test_return_default_class(): - assert isinstance(ReplyError(None), ReplyError) - - -def test_return_adhoc_class(): - class MyError(ReplyError): - MATCH_REPLY = "my error" - - assert isinstance(ReplyError("my error"), MyError) - - -def test_return_max_clients_error(): - assert isinstance(ReplyError("ERR max number of clients reached"), MaxClientsError) diff --git a/tests/generic_commands_test.py b/tests/generic_commands_test.py deleted file mode 100644 index 8666c19e3..000000000 --- a/tests/generic_commands_test.py +++ /dev/null @@ -1,832 +0,0 @@ -import asyncio -import math -import sys -import time -from unittest import mock - -import pytest - -from aioredis import ReplyError -from tests.testutils import redis_version - - -async def add(redis, key, value): - ok = await redis.connection.execute("set", key, value) - assert ok == b"OK" - - -@pytest.mark.asyncio -async def test_delete(redis): - await add(redis, "my-key", 123) - await add(redis, "other-key", 123) - - res = await redis.delete("my-key", "non-existent-key") - assert res == 1 - - res = await redis.delete("other-key", "other-key") - assert res == 1 - - with pytest.raises(TypeError): - await redis.delete(None) - - with pytest.raises(TypeError): - await redis.delete("my-key", "my-key", None) - - -@pytest.mark.asyncio -async def test_dump(redis): - await add(redis, "my-key", 123) - - data = await redis.dump("my-key") - assert data == mock.ANY - assert isinstance(data, (bytes, bytearray)) - assert len(data) > 0 - - data = await redis.dump("non-existent-key") - assert data is None - - with pytest.raises(TypeError): - await redis.dump(None) - - -@pytest.mark.asyncio -async def test_exists(redis, server): - await add(redis, "my-key", 123) - - res = await redis.exists("my-key") - assert isinstance(res, int) - assert res == 1 - - res = await redis.exists("non-existent-key") - assert isinstance(res, int) - assert res == 0 - - with pytest.raises(TypeError): - await redis.exists(None) - if server.version < (3, 0, 3): - with pytest.raises(ReplyError): - await redis.exists("key-1", "key-2") - - -@redis_version(3, 0, 3, reason="Multi-key EXISTS available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_exists_multiple(redis): - await add(redis, "my-key", 123) - - res = await redis.exists("my-key", "other-key") - assert isinstance(res, int) - assert res == 1 - - res = await redis.exists("my-key", "my-key") - assert isinstance(res, int) - assert res == 2 - - res = await redis.exists("foo", "bar") - assert isinstance(res, int) - assert res == 0 - - -@pytest.mark.asyncio -async def test_expire(redis): - await add(redis, "my-key", 132) - - res = await redis.expire("my-key", 10) - assert res is True - - res = await redis.connection.execute("TTL", "my-key") - assert res >= 10 - - await redis.expire("my-key", -1) - res = await redis.exists("my-key") - assert not res - - res = await redis.expire("other-key", 1000) - assert res is False - - await add(redis, "my-key", 1) - res = await redis.expire("my-key", 10.0) - assert res is True - res = await redis.connection.execute("TTL", "my-key") - assert res >= 10 - - with pytest.raises(TypeError): - await redis.expire(None, 123) - with pytest.raises(TypeError): - await redis.expire("my-key", "timeout") - - -@pytest.mark.asyncio -async def test_expireat(redis): - await add(redis, "my-key", 123) - now = math.ceil(time.time()) - - fut1 = redis.expireat("my-key", now + 10) - fut2 = redis.connection.execute("TTL", "my-key") - assert (await fut1) is True - assert (await fut2) >= 10 - - now = time.time() - fut1 = redis.expireat("my-key", now + 10) - fut2 = redis.connection.execute("TTL", "my-key") - assert (await fut1) is True - assert (await fut2) >= 10 - - res = await redis.expireat("my-key", -1) - assert res is True - - res = await redis.exists("my-key") - assert not res - - await add(redis, "my-key", 123) - - res = await redis.expireat("my-key", 0) - assert res is True - - res = await redis.exists("my-key") - assert not res - - await add(redis, "my-key", 123) - with pytest.raises(TypeError): - await redis.expireat(None, 123) - with pytest.raises(TypeError): - await redis.expireat("my-key", "timestamp") - - -@pytest.mark.asyncio -async def test_keys(redis): - res = await redis.keys("*pattern*") - assert res == [] - - await redis.connection.execute("FLUSHDB") - res = await redis.keys("*") - assert res == [] - - await add(redis, "my-key-1", 1) - await add(redis, "my-key-ab", 1) - - res = await redis.keys("my-key-?") - assert res == [b"my-key-1"] - res = await redis.keys("my-key-*") - assert sorted(res) == [b"my-key-1", b"my-key-ab"] - - # test with encoding param - res = await redis.keys("my-key-*", encoding="utf-8") - assert sorted(res) == ["my-key-1", "my-key-ab"] - - with pytest.raises(TypeError): - await redis.keys(None) - - -@pytest.mark.asyncio -async def test_migrate(create_redis, server, serverB): - redisA = await create_redis(server.tcp_address) - redisB = await create_redis(serverB.tcp_address, db=2) - - await add(redisA, "my-key", 123) - - await redisB.delete("my-key") - assert await redisA.exists("my-key") - assert not (await redisB.exists("my-key")) - - ok = await redisA.migrate("localhost", serverB.tcp_address.port, "my-key", 2, 1000) - assert ok is True - assert not (await redisA.exists("my-key")) - assert await redisB.exists("my-key") - - with pytest.raises(TypeError, match="host .* str"): - await redisA.migrate(None, 1234, "key", 1, 23) - with pytest.raises(TypeError, match="args .* None"): - await redisA.migrate("host", "1234", None, 1, 123) - with pytest.raises(TypeError, match="dest_db .* int"): - await redisA.migrate("host", 123, "key", 1.0, 123) - with pytest.raises(TypeError, match="timeout .* int"): - await redisA.migrate("host", "1234", "key", 2, None) - with pytest.raises(ValueError, match="Got empty host"): - await redisA.migrate("", "123", "key", 1, 123) - with pytest.raises(ValueError, match="dest_db .* greater equal 0"): - await redisA.migrate("host", 6379, "key", -1, 1000) - with pytest.raises(ValueError, match="timeout .* greater equal 0"): - await redisA.migrate("host", 6379, "key", 1, -1000) - - -@redis_version(3, 0, 0, reason="Copy/Replace flags available since Redis 3.0") -@pytest.mark.asyncio -async def test_migrate_copy_replace(create_redis, server, serverB): - redisA = await create_redis(server.tcp_address) - redisB = await create_redis(serverB.tcp_address, db=0) - - await add(redisA, "my-key", 123) - await redisB.delete("my-key") - - ok = await redisA.migrate( - "localhost", serverB.tcp_address.port, "my-key", 0, 1000, copy=True - ) - assert ok is True - assert (await redisA.get("my-key")) == b"123" - assert (await redisB.get("my-key")) == b"123" - - assert await redisA.set("my-key", "val") - ok = await redisA.migrate( - "localhost", serverB.tcp_address.port, "my-key", 2, 1000, replace=True - ) - assert (await redisA.get("my-key")) is None - assert await redisB.get("my-key") - - -@redis_version(3, 0, 6, reason="MIGRATE…KEYS available since Redis 3.0.6") -@pytest.mark.skipif( - sys.platform == "win32", reason="Seems to be unavailable in win32 build" -) -@pytest.mark.asyncio -async def test_migrate_keys(create_redis, server, serverB): - redisA = await create_redis(server.tcp_address) - redisB = await create_redis(serverB.tcp_address, db=0) - - await add(redisA, "key1", 123) - await add(redisA, "key2", 123) - await add(redisA, "key3", 123) - await redisB.delete("key1", "key2", "key3") - - ok = await redisA.migrate_keys( - "localhost", - serverB.tcp_address.port, - ("key1", "key2", "key3", "non-existing-key"), - dest_db=0, - timeout=1000, - ) - assert ok is True - - assert (await redisB.get("key1")) == b"123" - assert (await redisB.get("key2")) == b"123" - assert (await redisB.get("key3")) == b"123" - assert (await redisA.get("key1")) is None - assert (await redisA.get("key2")) is None - assert (await redisA.get("key3")) is None - - ok = await redisA.migrate_keys( - "localhost", - serverB.tcp_address.port, - ("key1", "key2", "key3"), - dest_db=0, - timeout=1000, - ) - assert not ok - ok = await redisB.migrate_keys( - "localhost", - server.tcp_address.port, - ("key1", "key2", "key3"), - dest_db=0, - timeout=1000, - copy=True, - ) - assert ok - assert (await redisB.get("key1")) == b"123" - assert (await redisB.get("key2")) == b"123" - assert (await redisB.get("key3")) == b"123" - assert (await redisA.get("key1")) == b"123" - assert (await redisA.get("key2")) == b"123" - assert (await redisA.get("key3")) == b"123" - - assert await redisA.set("key1", "val") - assert await redisA.set("key2", "val") - assert await redisA.set("key3", "val") - ok = await redisA.migrate_keys( - "localhost", - serverB.tcp_address.port, - ("key1", "key2", "key3", "non-existing-key"), - dest_db=0, - timeout=1000, - replace=True, - ) - assert ok is True - - assert (await redisB.get("key1")) == b"val" - assert (await redisB.get("key2")) == b"val" - assert (await redisB.get("key3")) == b"val" - assert (await redisA.get("key1")) is None - assert (await redisA.get("key2")) is None - assert (await redisA.get("key3")) is None - - -@pytest.mark.asyncio -async def test_migrate__exceptions(redis, server, unused_tcp_port): - await add(redis, "my-key", 123) - - assert await redis.exists("my-key") - - with pytest.raises(ReplyError, match="IOERR .* timeout .*"): - assert not ( - await redis.migrate( - "localhost", unused_tcp_port, "my-key", dest_db=30, timeout=10 - ) - ) - - -@redis_version(3, 0, 6, reason="MIGRATE…KEYS available since Redis 3.0.6") -@pytest.mark.skipif( - sys.platform == "win32", reason="Seems to be unavailable in win32 build" -) -@pytest.mark.asyncio -async def test_migrate_keys__errors(redis): - with pytest.raises(TypeError, match="host .* str"): - await redis.migrate_keys(None, 1234, "key", 1, 23) - with pytest.raises(TypeError, match="keys .* list or tuple"): - await redis.migrate_keys("host", "1234", None, 1, 123) - with pytest.raises(TypeError, match="dest_db .* int"): - await redis.migrate_keys("host", 123, ("key",), 1.0, 123) - with pytest.raises(TypeError, match="timeout .* int"): - await redis.migrate_keys("host", "1234", ("key",), 2, None) - with pytest.raises(ValueError, match="Got empty host"): - await redis.migrate_keys("", "123", ("key",), 1, 123) - with pytest.raises(ValueError, match="dest_db .* greater equal 0"): - await redis.migrate_keys("host", 6379, ("key",), -1, 1000) - with pytest.raises(ValueError, match="timeout .* greater equal 0"): - await redis.migrate_keys("host", 6379, ("key",), 1, -1000) - with pytest.raises(ValueError, match="keys .* empty"): - await redis.migrate_keys("host", "1234", (), 2, 123) - - -@pytest.mark.asyncio -async def test_move(redis): - await add(redis, "my-key", 123) - - assert redis.db == 0 - res = await redis.move("my-key", 1) - assert res is True - - with pytest.raises(TypeError): - await redis.move(None, 1) - with pytest.raises(TypeError): - await redis.move("my-key", None) - with pytest.raises(ValueError): - await redis.move("my-key", -1) - with pytest.raises(TypeError): - await redis.move("my-key", "not db") - - -@pytest.mark.asyncio -async def test_object_refcount(redis): - await add(redis, "foo", "bar") - - res = await redis.object_refcount("foo") - assert res == 1 - res = await redis.object_refcount("non-existent-key") - assert res is None - - with pytest.raises(TypeError): - await redis.object_refcount(None) - - -@pytest.mark.asyncio -async def test_object_encoding(redis, server): - await add(redis, "foo", "bar") - - res = await redis.object_encoding("foo") - - if server.version < (3, 0, 0): - assert res == "raw" - else: - assert res == "embstr" - - res = await redis.incr("key") - assert res == 1 - res = await redis.object_encoding("key") - assert res == "int" - res = await redis.object_encoding("non-existent-key") - assert res is None - - with pytest.raises(TypeError): - await redis.object_encoding(None) - - -@redis_version(3, 0, 0, reason="Older Redis version has lower idle time resolution") -@pytest.mark.timeout(20) -@pytest.mark.asyncio -async def test_object_idletime(redis, server): - await add(redis, "foo", "bar") - - res = await redis.object_idletime("foo") - # NOTE: sometimes travis-ci is too slow - assert res >= 0 - - res = 0 - while not res: - res = await redis.object_idletime("foo") - await asyncio.sleep(0.5) - assert res >= 1 - - res = await redis.object_idletime("non-existent-key") - assert res is None - - with pytest.raises(TypeError): - await redis.object_idletime(None) - - -@pytest.mark.asyncio -async def test_persist(redis): - await add(redis, "my-key", 123) - res = await redis.expire("my-key", 10) - assert res is True - - res = await redis.persist("my-key") - assert res is True - - res = await redis.connection.execute("TTL", "my-key") - assert res == -1 - - with pytest.raises(TypeError): - await redis.persist(None) - - -@pytest.mark.asyncio -async def test_pexpire(redis): - await add(redis, "my-key", 123) - res = await redis.pexpire("my-key", 100) - assert res is True - - res = await redis.connection.execute("TTL", "my-key") - assert res == 0 - res = await redis.connection.execute("PTTL", "my-key") - assert res > 0 - - await add(redis, "my-key", 123) - res = await redis.pexpire("my-key", 1) - assert res is True - - # XXX: tests now looks strange to me. - await asyncio.sleep(0.2) - - res = await redis.exists("my-key") - assert not res - - with pytest.raises(TypeError): - await redis.pexpire(None, 0) - with pytest.raises(TypeError): - await redis.pexpire("my-key", 1.0) - - -@pytest.mark.asyncio -async def test_pexpireat(redis): - await add(redis, "my-key", 123) - now = int((await redis.time()) * 1000) - fut1 = redis.pexpireat("my-key", now + 2000) - fut2 = redis.ttl("my-key") - fut3 = redis.pttl("my-key") - assert await fut1 is True - assert await fut2 == 2 - assert 1000 < await fut3 <= 2000 - - with pytest.raises(TypeError): - await redis.pexpireat(None, 1234) - with pytest.raises(TypeError): - await redis.pexpireat("key", "timestamp") - with pytest.raises(TypeError): - await redis.pexpireat("key", 1000.0) - - -@pytest.mark.asyncio -async def test_pttl(redis, server): - await add(redis, "key", "val") - res = await redis.pttl("key") - assert res == -1 - res = await redis.pttl("non-existent-key") - if server.version < (2, 8, 0): - assert res == -1 - else: - assert res == -2 - - await redis.pexpire("key", 500) - res = await redis.pttl("key") - assert 400 < res <= 500 - - with pytest.raises(TypeError): - await redis.pttl(None) - - -@pytest.mark.asyncio -async def test_randomkey(redis): - await add(redis, "key:1", 123) - await add(redis, "key:2", 123) - await add(redis, "key:3", 123) - - res = await redis.randomkey() - assert res in [b"key:1", b"key:2", b"key:3"] - - # test with encoding param - res = await redis.randomkey(encoding="utf-8") - assert res in ["key:1", "key:2", "key:3"] - - await redis.connection.execute("flushdb") - res = await redis.randomkey() - assert res is None - - -@pytest.mark.asyncio -async def test_rename(redis, server): - await add(redis, "foo", "bar") - await redis.delete("bar") - - res = await redis.rename("foo", "bar") - assert res is True - - with pytest.raises(ReplyError, match="ERR no such key"): - await redis.rename("foo", "bar") - with pytest.raises(TypeError): - await redis.rename(None, "bar") - with pytest.raises(TypeError): - await redis.rename("foo", None) - with pytest.raises(ValueError): - await redis.rename("foo", "foo") - - if server.version < (3, 2): - with pytest.raises(ReplyError, match=".* objects are the same"): - await redis.rename("bar", b"bar") - - -@pytest.mark.asyncio -async def test_renamenx(redis, server): - await redis.delete("foo", "bar") - await add(redis, "foo", 123) - - res = await redis.renamenx("foo", "bar") - assert res is True - - await add(redis, "foo", 123) - res = await redis.renamenx("foo", "bar") - assert res is False - - with pytest.raises(ReplyError, match="ERR no such key"): - await redis.renamenx("baz", "foo") - with pytest.raises(TypeError): - await redis.renamenx(None, "foo") - with pytest.raises(TypeError): - await redis.renamenx("foo", None) - with pytest.raises(ValueError): - await redis.renamenx("foo", "foo") - - if server.version < (3, 2): - with pytest.raises(ReplyError, match=".* objects are the same"): - await redis.renamenx("foo", b"foo") - - -@pytest.mark.asyncio -async def test_restore(redis): - ok = await redis.set("key", "value") - assert ok - dump = await redis.dump("key") - assert dump is not None - ok = await redis.delete("key") - assert ok - assert b"OK" == (await redis.restore("key", 0, dump)) - assert (await redis.get("key")) == b"value" - - -@redis_version(2, 8, 0, reason="SCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_scan(redis): - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - key = f"key:scan:{foo_or_bar}:{i}".encode("utf-8") - await add(redis, key, i) - - cursor, values = await redis.scan() - # values should be *>=* just in case some other tests left - # test keys - assert len(values) >= 10 - - cursor, test_values = b"0", [] - while cursor: - cursor, values = await redis.scan(cursor=cursor, match=b"key:scan:foo*") - test_values.extend(values) - assert len(test_values) == 3 - - cursor, test_values = b"0", [] - while cursor: - cursor, values = await redis.scan(cursor=cursor, match=b"key:scan:bar:*") - test_values.extend(values) - assert len(test_values) == 7 - # SCAN family functions do not guarantee that the number of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - cursor = b"0" - test_values = [] - while cursor: - cursor, values = await redis.scan(cursor=cursor, match=b"key:scan:*", count=2) - - test_values.extend(values) - assert len(test_values) == 10 - - -async def zadd(redis, key, value): - ok = await redis.connection.execute("zadd", key, value, value) - assert ok == 1 - - -@redis_version(6, 0, 0, reason="SCAN ... TYPE is available since redis>=6.0.0") -@pytest.mark.asyncio -async def test_scan_type(redis): - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - key = f"key:scan:{foo_or_bar}:{i}".encode("utf-8") - print(key) - await zadd(redis, key, i) - - cursor, test_values = b"0", [] - while cursor: - cursor, values = await redis.scan( - cursor=cursor, match=b"key:scan:bar:*", key_type=b"zset" - ) - test_values.extend(values) - assert len(test_values) == 7 - - -@pytest.mark.asyncio -async def test_sort(redis): - async def _make_list(key, items): - await redis.delete(key) - for i in items: - await redis.rpush(key, i) - - await _make_list("a", "4231") - res = await redis.sort("a") - assert res == [b"1", b"2", b"3", b"4"] - - res = await redis.sort("a", offset=2, count=2) - assert res == [b"3", b"4"] - - res = await redis.sort("a", asc=b"DESC") - assert res == [b"4", b"3", b"2", b"1"] - - await _make_list("a", "dbca") - res = await redis.sort("a", asc=b"DESC", alpha=True, offset=2, count=2) - assert res == [b"b", b"a"] - - await redis.set("key:1", 10) - await redis.set("key:2", 4) - await redis.set("key:3", 7) - await _make_list("a", "321") - - res = await redis.sort("a", by="key:*") - assert res == [b"2", b"3", b"1"] - - res = await redis.sort("a", by="nosort") - assert res == [b"3", b"2", b"1"] - - res = await redis.sort("a", by="key:*", store="sorted_a") - assert res == 3 - res = await redis.lrange("sorted_a", 0, -1) - assert res == [b"2", b"3", b"1"] - - await redis.set("value:1", 20) - await redis.set("value:2", 30) - await redis.set("value:3", 40) - res = await redis.sort("a", "value:*", by="key:*") - assert res == [b"30", b"40", b"20"] - - await redis.hset("data_1", "weight", 30) - await redis.hset("data_2", "weight", 20) - await redis.hset("data_3", "weight", 10) - await redis.hset("hash_1", "field", 20) - await redis.hset("hash_2", "field", 30) - await redis.hset("hash_3", "field", 10) - res = await redis.sort("a", "hash_*->field", by="data_*->weight") - assert res == [b"10", b"30", b"20"] - - -@redis_version(3, 2, 1, reason="TOUCH is available since redis>=3.2.1") -@pytest.mark.timeout(20) -@pytest.mark.asyncio -async def test_touch(redis): - await add(redis, "key", "val") - res = 0 - while not res: - res = await redis.object_idletime("key") - await asyncio.sleep(0.5) - assert res > 0 - assert await redis.touch("key", "key", "key") == 3 - res2 = await redis.object_idletime("key") - assert 0 <= res2 < res - - -@pytest.mark.asyncio -async def test_ttl(redis, server): - await add(redis, "key", "val") - res = await redis.ttl("key") - assert res == -1 - res = await redis.ttl("non-existent-key") - if server.version < (2, 8, 0): - assert res == -1 - else: - assert res == -2 - - await redis.expire("key", 10) - res = await redis.ttl("key") - assert res >= 9 - - with pytest.raises(TypeError): - await redis.ttl(None) - - -@pytest.mark.asyncio -async def test_type(redis): - await add(redis, "key", "val") - res = await redis.type("key") - assert res == b"string" - - await redis.delete("key") - await redis.incr("key") - res = await redis.type("key") - assert res == b"string" - - await redis.delete("key") - await redis.sadd("key", "val") - res = await redis.type("key") - assert res == b"set" - - res = await redis.type("non-existent-key") - assert res == b"none" - - with pytest.raises(TypeError): - await redis.type(None) - - -@redis_version(2, 8, 0, reason="SCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_iscan(redis): - full = set() - foo = set() - bar = set() - for i in range(1, 11): - is_bar = i % 3 - foo_or_bar = "bar" if is_bar else "foo" - key = f"key:scan:{foo_or_bar}:{i}".encode("utf-8") - full.add(key) - if is_bar: - bar.add(key) - else: - foo.add(key) - assert await redis.set(key, i) is True - - async def coro(cmd): - lst = [] - async for i in cmd: - lst.append(i) - return lst - - ret = await coro(redis.iscan()) - assert len(ret) >= 10 - - ret = await coro(redis.iscan(match="key:scan:*")) - assert 10 == len(ret) - assert set(ret) == full - - ret = await coro(redis.iscan(match="key:scan:foo*")) - assert set(ret) == foo - - ret = await coro(redis.iscan(match="key:scan:bar*")) - assert set(ret) == bar - - # SCAN family functions do not guarantee that the number of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - - ret = await coro(redis.iscan(match="key:scan:*", count=2)) - assert 10 == len(ret) - assert set(ret) == full - - -@redis_version(4, 0, 0, reason="UNLINK is available since redis>=4.0.0") -@pytest.mark.asyncio -async def test_unlink(redis): - await add(redis, "my-key", 123) - await add(redis, "other-key", 123) - - res = await redis.unlink("my-key", "non-existent-key") - assert res == 1 - - res = await redis.unlink("other-key", "other-key") - assert res == 1 - - with pytest.raises(TypeError): - await redis.unlink(None) - - with pytest.raises(TypeError): - await redis.unlink("my-key", "my-key", None) - - -@redis_version(3, 0, 0, reason="WAIT is available since redis>=3.0.0") -@pytest.mark.asyncio -async def test_wait(redis): - await add(redis, "key", "val1") - start = await redis.time() - res = await redis.wait(1, 400) - end = await redis.time() - assert res == 0 - assert end - start >= 0.4 - - await add(redis, "key", "val2") - start = await redis.time() - res = await redis.wait(0, 400) - end = await redis.time() - assert res == 0 - assert end - start < 0.4 diff --git a/tests/geo_commands_test.py b/tests/geo_commands_test.py deleted file mode 100644 index 3ee8b4fa0..000000000 --- a/tests/geo_commands_test.py +++ /dev/null @@ -1,489 +0,0 @@ -import pytest - -from aioredis import GeoMember, GeoPoint -from tests.testutils import redis_version - - -@redis_version(3, 2, 0, reason="GEOADD is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geoadd(redis): - res = await redis.geoadd("geodata", 13.361389, 38.115556, "Palermo") - assert res == 1 - - res = await redis.geoadd( - "geodata", 15.087269, 37.502669, "Catania", 12.424315, 37.802105, "Marsala" - ) - assert res == 2 - - -@redis_version(3, 2, 0, reason="GEODIST is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geodist(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.geodist("geodata", "Palermo", "Catania") - assert res == 166274.1516 - - res = await redis.geodist("geodata", "Palermo", "Catania", "km") - assert res == 166.2742 - - -@redis_version(3, 2, 0, reason="GEOHASH is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geohash(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.geohash("geodata", "Palermo", encoding="utf-8") - assert res == ["sqc8b49rny0"] - - res = await redis.geohash("geodata", "Palermo", "Catania", encoding="utf-8") - assert res == ["sqc8b49rny0", "sqdtr74hyu0"] - - -@redis_version(3, 2, 0, reason="GEOPOS is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geopos(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.geopos("geodata", "Palermo") - assert res == [ - GeoPoint(longitude=13.36138933897018433, latitude=38.11555639549629859) - ] - - res = await redis.geopos("geodata", "Catania", "Palermo") - assert res == [ - GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - GeoPoint(longitude=13.36138933897018433, latitude=38.11555639549629859), - ] - - -@redis_version(3, 2, 0, reason="GEO* is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geo_not_exist_members(redis): - res = await redis.geoadd("geodata", 13.361389, 38.115556, "Palermo") - assert res == 1 - - res = await redis.geoadd( - "geodata", 15.087269, 37.502669, "Catania", 12.424315, 37.802105, "Marsala" - ) - assert res == 2 - - res = await redis.geohash("geodata", "NotExistMember") - assert res == [None] - - res = await redis.geodist("geodata", "NotExistMember", "Catania") - assert res is None - - res = await redis.geopos("geodata", "Palermo", "NotExistMember", "Catania") - assert res == [ - GeoPoint(longitude=13.36138933897018433, latitude=38.11555639549629859), - None, - GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ] - - -@redis_version(3, 2, 0, reason="GEORADIUS is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_georadius_validation(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - with pytest.raises(TypeError): - res = await redis.georadius( - "geodata", 15, 37, 200, "km", count=1.3, encoding="utf-8" - ) - with pytest.raises(TypeError): - res = await redis.georadius("geodata", 15, 37, "200", "km", encoding="utf-8") - with pytest.raises(ValueError): - res = await redis.georadius("geodata", 15, 37, 200, "k", encoding="utf-8") - with pytest.raises(ValueError): - res = await redis.georadius( - "geodata", 15, 37, 200, "km", sort="DESV", encoding="utf-8" - ) - - -@redis_version(3, 2, 0, reason="GEORADIUS is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_georadius(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.georadius("geodata", 15, 37, 200, "km", encoding="utf-8") - assert res == ["Palermo", "Catania"] - - res = await redis.georadius("geodata", 15, 37, 200, "km", count=1, encoding="utf-8") - assert res == ["Catania"] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", sort="ASC", encoding="utf-8" - ) - assert res == ["Catania", "Palermo"] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_dist=True, encoding="utf-8" - ) - assert res == [ - GeoMember(member="Palermo", dist=190.4424, coord=None, hash=None), - GeoMember(member="Catania", dist=56.4413, coord=None, hash=None), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_dist=True, with_coord=True, encoding="utf-8" - ) - assert res == [ - GeoMember( - member="Palermo", - dist=190.4424, - hash=None, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member="Catania", - dist=56.4413, - hash=None, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", - 15, - 37, - 200, - "km", - with_dist=True, - with_coord=True, - with_hash=True, - encoding="utf-8", - ) - assert res == [ - GeoMember( - member="Palermo", - dist=190.4424, - hash=3479099956230698, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member="Catania", - dist=56.4413, - hash=3479447370796909, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_coord=True, with_hash=True, encoding="utf-8" - ) - assert res == [ - GeoMember( - member="Palermo", - dist=None, - hash=3479099956230698, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member="Catania", - dist=None, - hash=3479447370796909, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_coord=True, encoding="utf-8" - ) - assert res == [ - GeoMember( - member="Palermo", - dist=None, - hash=None, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member="Catania", - dist=None, - hash=None, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", - 15, - 37, - 200, - "km", - count=1, - sort="DESC", - with_hash=True, - encoding="utf-8", - ) - assert res == [ - GeoMember(member="Palermo", dist=None, hash=3479099956230698, coord=None) - ] - - -@redis_version(3, 2, 0, reason="GEORADIUSBYMEMBER is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_georadiusbymember(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.georadiusbymember( - "geodata", "Palermo", 200, "km", with_dist=True, encoding="utf-8" - ) - assert res == [ - GeoMember(member="Palermo", dist=0.0, coord=None, hash=None), - GeoMember(member="Catania", dist=166.2742, coord=None, hash=None), - ] - res = await redis.georadiusbymember( - "geodata", "Palermo", 200, "km", encoding="utf-8" - ) - assert res == ["Palermo", "Catania"] - - res = await redis.georadiusbymember( - "geodata", - "Palermo", - 200, - "km", - with_dist=True, - with_coord=True, - encoding="utf-8", - ) - assert res == [ - GeoMember( - member="Palermo", - dist=0.0, - hash=None, - coord=GeoPoint(13.361389338970184, 38.1155563954963), - ), - GeoMember( - member="Catania", - dist=166.2742, - hash=None, - coord=GeoPoint(15.087267458438873, 37.50266842333162), - ), - ] - - res = await redis.georadiusbymember( - "geodata", - "Palermo", - 200, - "km", - with_dist=True, - with_coord=True, - with_hash=True, - encoding="utf-8", - ) - assert res == [ - GeoMember( - member="Palermo", - dist=0.0, - hash=3479099956230698, - coord=GeoPoint(13.361389338970184, 38.1155563954963), - ), - GeoMember( - member="Catania", - dist=166.2742, - hash=3479447370796909, - coord=GeoPoint(15.087267458438873, 37.50266842333162), - ), - ] - - -@redis_version(3, 2, 0, reason="GEOHASH is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_geohash_binary(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.geohash("geodata", "Palermo") - assert res == [b"sqc8b49rny0"] - - res = await redis.geohash("geodata", "Palermo", "Catania") - assert res == [b"sqc8b49rny0", b"sqdtr74hyu0"] - - -@redis_version(3, 2, 0, reason="GEORADIUS is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_georadius_binary(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.georadius("geodata", 15, 37, 200, "km") - assert res == [b"Palermo", b"Catania"] - - res = await redis.georadius("geodata", 15, 37, 200, "km", count=1) - assert res == [b"Catania"] - - res = await redis.georadius("geodata", 15, 37, 200, "km", sort="ASC") - assert res == [b"Catania", b"Palermo"] - - res = await redis.georadius("geodata", 15, 37, 200, "km", with_dist=True) - assert res == [ - GeoMember(member=b"Palermo", dist=190.4424, coord=None, hash=None), - GeoMember(member=b"Catania", dist=56.4413, coord=None, hash=None), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_dist=True, with_coord=True - ) - assert res == [ - GeoMember( - member=b"Palermo", - dist=190.4424, - hash=None, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member=b"Catania", - dist=56.4413, - hash=None, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_dist=True, with_coord=True, with_hash=True - ) - assert res == [ - GeoMember( - member=b"Palermo", - dist=190.4424, - hash=3479099956230698, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member=b"Catania", - dist=56.4413, - hash=3479447370796909, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", with_coord=True, with_hash=True - ) - assert res == [ - GeoMember( - member=b"Palermo", - dist=None, - hash=3479099956230698, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member=b"Catania", - dist=None, - hash=3479447370796909, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius("geodata", 15, 37, 200, "km", with_coord=True) - assert res == [ - GeoMember( - member=b"Palermo", - dist=None, - hash=None, - coord=GeoPoint( - longitude=13.36138933897018433, latitude=38.11555639549629859 - ), - ), - GeoMember( - member=b"Catania", - dist=None, - hash=None, - coord=GeoPoint(longitude=15.087267458438873, latitude=37.50266842333162), - ), - ] - - res = await redis.georadius( - "geodata", 15, 37, 200, "km", count=1, sort="DESC", with_hash=True - ) - assert res == [ - GeoMember(member=b"Palermo", dist=None, hash=3479099956230698, coord=None) - ] - - -@redis_version(3, 2, 0, reason="GEORADIUSBYMEMBER is available since redis >= 3.2.0") -@pytest.mark.asyncio -async def test_georadiusbymember_binary(redis): - res = await redis.geoadd( - "geodata", 13.361389, 38.115556, "Palermo", 15.087269, 37.502669, "Catania" - ) - assert res == 2 - - res = await redis.georadiusbymember("geodata", "Palermo", 200, "km", with_dist=True) - assert res == [ - GeoMember(member=b"Palermo", dist=0.0, coord=None, hash=None), - GeoMember(member=b"Catania", dist=166.2742, coord=None, hash=None), - ] - - res = await redis.georadiusbymember( - "geodata", "Palermo", 200, "km", with_dist=True, with_coord=True - ) - assert res == [ - GeoMember( - member=b"Palermo", - dist=0.0, - hash=None, - coord=GeoPoint(13.361389338970184, 38.1155563954963), - ), - GeoMember( - member=b"Catania", - dist=166.2742, - hash=None, - coord=GeoPoint(15.087267458438873, 37.50266842333162), - ), - ] - - res = await redis.georadiusbymember( - "geodata", "Palermo", 200, "km", with_dist=True, with_coord=True, with_hash=True - ) - assert res == [ - GeoMember( - member=b"Palermo", - dist=0.0, - hash=3479099956230698, - coord=GeoPoint(13.361389338970184, 38.1155563954963), - ), - GeoMember( - member=b"Catania", - dist=166.2742, - hash=3479447370796909, - coord=GeoPoint(15.087267458438873, 37.50266842333162), - ), - ] diff --git a/tests/hash_commands_test.py b/tests/hash_commands_test.py deleted file mode 100644 index 96ff550c6..000000000 --- a/tests/hash_commands_test.py +++ /dev/null @@ -1,535 +0,0 @@ -import pytest - -from aioredis import ReplyError -from tests.testutils import redis_version - - -@pytest.mark.asyncio -async def add(redis, key, field, value): - ok = await redis.connection.execute(b"hset", key, field, value) - assert ok == 1 - - -@pytest.mark.asyncio -async def test_hdel(redis): - key, field, value = b"key:hdel", b"bar", b"zap" - await add(redis, key, field, value) - # delete value that exists, expected 1 - result = await redis.hdel(key, field) - assert result == 1 - # delete value that does not exists, expected 0 - result = await redis.hdel(key, field) - assert result == 0 - - with pytest.raises(TypeError): - await redis.hdel(None, field) - - -@pytest.mark.asyncio -async def test_hexists(redis): - key, field, value = b"key:hexists", b"bar", b"zap" - await add(redis, key, field, value) - # check value that exists, expected 1 - result = await redis.hexists(key, field) - assert result == 1 - # check value when, key exists and field does not, expected 0 - result = await redis.hexists(key, b"not:" + field) - assert result == 0 - # check value when, key not exists, expected 0 - result = await redis.hexists(b"not:" + key, field) - assert result == 0 - - with pytest.raises(TypeError): - await redis.hexists(None, field) - - -@pytest.mark.asyncio -async def test_hget(redis): - - key, field, value = b"key:hget", b"bar", b"zap" - await add(redis, key, field, value) - # basic test, fetch value and check in to reference - test_value = await redis.hget(key, field) - assert test_value == value - # fetch value, when field does not exists - test_value = await redis.hget(key, b"not" + field) - assert test_value is None - # fetch value when key does not exists - test_value = await redis.hget(b"not:" + key, b"baz") - assert test_value is None - - # check encoding - test_value = await redis.hget(key, field, encoding="utf-8") - assert test_value == "zap" - - with pytest.raises(TypeError): - await redis.hget(None, field) - - -@pytest.mark.asyncio -async def test_hgetall(redis): - await add(redis, "key:hgetall", "foo", "baz") - await add(redis, "key:hgetall", "bar", "zap") - - test_value = await redis.hgetall("key:hgetall") - assert isinstance(test_value, dict) - assert {b"foo": b"baz", b"bar": b"zap"} == test_value - # try to get all values from key that does not exits - test_value = await redis.hgetall(b"not:key:hgetall") - assert test_value == {} - - # check encoding param - test_value = await redis.hgetall("key:hgetall", encoding="utf-8") - assert {"foo": "baz", "bar": "zap"} == test_value - - with pytest.raises(TypeError): - await redis.hgetall(None) - - -@pytest.mark.asyncio -async def test_hincrby(redis): - key, field, value = b"key:hincrby", b"bar", 1 - await add(redis, key, field, value) - # increment initial value by 2 - result = await redis.hincrby(key, field, 2) - assert result == 3 - - result = await redis.hincrby(key, field, -1) - assert result == 2 - - result = await redis.hincrby(key, field, -100) - assert result == -98 - - result = await redis.hincrby(key, field, -2) - assert result == -100 - - # increment value in case of key or field that does not exists - result = await redis.hincrby(b"not:" + key, field, 2) - assert result == 2 - result = await redis.hincrby(key, b"not:" + field, 2) - assert result == 2 - - with pytest.raises(ReplyError): - await redis.hincrby(key, b"not:" + field, 3.14) - - with pytest.raises(ReplyError): - # initial value is float, try to increment 1 - await add(redis, b"other:" + key, field, 3.14) - await redis.hincrby(b"other:" + key, field, 1) - - with pytest.raises(TypeError): - await redis.hincrby(None, field, 2) - - -@pytest.mark.asyncio -async def test_hincrbyfloat(redis): - key, field, value = b"key:hincrbyfloat", b"bar", 2.71 - await add(redis, key, field, value) - - result = await redis.hincrbyfloat(key, field, 3.14) - assert result == 5.85 - - result = await redis.hincrbyfloat(key, field, -2.71) - assert result == 3.14 - - result = await redis.hincrbyfloat(key, field, -100.1) - assert result == -96.96 - - # increment value in case of key or field that does not exists - result = await redis.hincrbyfloat(b"not:" + key, field, 3.14) - assert result == 3.14 - - result = await redis.hincrbyfloat(key, b"not:" + field, 3.14) - assert result == 3.14 - - with pytest.raises(TypeError): - await redis.hincrbyfloat(None, field, 2) - - -@pytest.mark.asyncio -async def test_hkeys(redis): - key = b"key:hkeys" - field1, field2 = b"foo", b"bar" - value1, value2 = b"baz", b"zap" - await add(redis, key, field1, value1) - await add(redis, key, field2, value2) - - test_value = await redis.hkeys(key) - assert set(test_value) == {field1, field2} - - test_value = await redis.hkeys(b"not:" + key) - assert test_value == [] - - test_value = await redis.hkeys(key, encoding="utf-8") - assert set(test_value) == {"foo", "bar"} - - with pytest.raises(TypeError): - await redis.hkeys(None) - - -@pytest.mark.asyncio -async def test_hlen(redis): - key = b"key:hlen" - field1, field2 = b"foo", b"bar" - value1, value2 = b"baz", b"zap" - await add(redis, key, field1, value1) - await add(redis, key, field2, value2) - - test_value = await redis.hlen(key) - assert test_value == 2 - - test_value = await redis.hlen(b"not:" + key) - assert test_value == 0 - - with pytest.raises(TypeError): - await redis.hlen(None) - - -@pytest.mark.asyncio -async def test_hmget(redis): - key = b"key:hmget" - field1, field2 = b"foo", b"bar" - value1, value2 = b"baz", b"zap" - await add(redis, key, field1, value1) - await add(redis, key, field2, value2) - - test_value = await redis.hmget(key, field1, field2) - assert set(test_value) == {value1, value2} - - test_value = await redis.hmget(key, b"not:" + field1, b"not:" + field2) - assert [None, None] == test_value - - val = await redis.hincrby(key, "numeric") - assert val == 1 - test_value = await redis.hmget(key, field1, field2, "numeric", encoding="utf-8") - assert ["baz", "zap", "1"] == test_value - - with pytest.raises(TypeError): - await redis.hmget(None, field1, field2) - - -@pytest.mark.asyncio -async def test_hmset(redis): - key, field, value = b"key:hmset", b"bar", b"zap" - await add(redis, key, field, value) - - # key and field exists - test_value = await redis.hmset(key, field, b"baz") - assert test_value is True - - result = await redis.hexists(key, field) - assert result == 1 - - # key and field does not exists - test_value = await redis.hmset(b"not:" + key, field, value) - assert test_value is True - result = await redis.hexists(b"not:" + key, field) - assert result == 1 - - # set multiple - pairs = [b"foo", b"baz", b"bar", b"paz"] - test_value = await redis.hmset(key, *pairs) - assert test_value is True - test_value = await redis.hmget(key, b"foo", b"bar") - assert set(test_value) == {b"baz", b"paz"} - - with pytest.raises(TypeError): - await redis.hmset(key, b"foo", b"bar", b"baz") - - with pytest.raises(TypeError): - await redis.hmset(None, *pairs) - - with pytest.raises(TypeError): - await redis.hmset(key, {"foo": "bar"}, {"baz": "bad"}) - - with pytest.raises(TypeError): - await redis.hmset(key) - - -@pytest.mark.asyncio -async def test_hmset_dict(redis): - key = "key:hmset" - - # dict - d1 = {b"foo": b"one dict"} - test_value = await redis.hmset_dict(key, d1) - assert test_value is True - test_value = await redis.hget(key, b"foo") - assert test_value == b"one dict" - - # kwdict - test_value = await redis.hmset_dict(key, foo=b"kw1", bar=b"kw2") - assert test_value is True - test_value = await redis.hmget(key, b"foo", b"bar") - assert set(test_value) == {b"kw1", b"kw2"} - - # dict & kwdict - d1 = {b"foo": b"dict"} - test_value = await redis.hmset_dict(key, d1, foo=b"kw") - assert test_value is True - test_value = await redis.hget(key, b"foo") - assert test_value == b"kw" - - # allow empty dict with kwargs - test_value = await redis.hmset_dict(key, {}, foo="kw") - assert test_value is True - test_value = await redis.hget(key, "foo") - assert test_value == b"kw" - - with pytest.raises(TypeError): - await redis.hmset_dict(key) - - with pytest.raises(ValueError): - await redis.hmset_dict(key, {}) - - with pytest.raises(TypeError): - await redis.hmset_dict(key, ("foo", "pairs")) - - with pytest.raises(TypeError): - await redis.hmset_dict(key, b"foo", "pairs") - - with pytest.raises(TypeError): - await redis.hmset_dict(key, b"foo", "pairs", foo=b"kw1") - - with pytest.raises(TypeError): - await redis.hmset_dict(key, {"a": 1}, {"b": 2}) - - with pytest.raises(TypeError): - await redis.hmset_dict(key, {"a": 1}, {"b": 2}, "c", 3, d=4) - - -@pytest.mark.asyncio -async def test_hset(redis): - key, field, value = b"key:hset", b"bar", b"zap" - test_value = await redis.hset(key, field, value) - assert test_value == 1 - - test_value = await redis.hset(key, field, value) - assert test_value == 0 - - test_value = await redis.hset(b"other:" + key, field, value) - assert test_value == 1 - - result = await redis.hexists(b"other:" + key, field) - assert result == 1 - - with pytest.raises(TypeError): - await redis.hset(None, field, value) - - -@redis_version(4, 0, 0, reason="HSET changed to variadic in redis 4.0.0") -@pytest.mark.asyncio -async def test_hset_multiple(redis): - test_key = b"key:hset_multiple" - - # raises ValueError if field/value pair and mapping both are missing - with pytest.raises(ValueError): - await redis.hset(test_key) - - # raises TypeError if mapping is not a dict - with pytest.raises(TypeError): - await redis.hset(test_key, mapping="not a dict") - - # test multiple fields and string values - mapping = {b"a": b"1", b"b": b"2", b"c": b"test1", b"d": b"test string"} - # insert mapping through hset - test_value = await redis.hset(test_key, mapping=mapping) - # check if 4 values were added - assert test_value == 4 - - for field, value in mapping.items(): - # check if each field exists in the hash - result = await redis.hexists(test_key, field) - assert result == True - - # check value of each field matches the correct value - test_value = await redis.hget(test_key, field) - assert test_value == value - - # test imput of both field/value pair and mapping at same time - test_value = await redis.hset(test_key, "e", "5", mapping={"f": "6"}) - assert test_value == 2 - value1 = await redis.hget(test_key, "e") - value2 = await redis.hget(test_key, "f") - assert value1 == b"5" - assert value2 == b"6" - - -@pytest.mark.asyncio -async def test_hsetnx(redis): - key, field, value = b"key:hsetnx", b"bar", b"zap" - # field does not exists, operation should be successful - test_value = await redis.hsetnx(key, field, value) - assert test_value == 1 - # make sure that value was stored - result = await redis.hget(key, field) - assert result == value - # field exists, operation should not change any value - test_value = await redis.hsetnx(key, field, b"baz") - assert test_value == 0 - # make sure value was not changed - result = await redis.hget(key, field) - assert result == value - - with pytest.raises(TypeError): - await redis.hsetnx(None, field, value) - - -@pytest.mark.asyncio -async def test_hvals(redis): - key = b"key:hvals" - field1, field2 = b"foo", b"bar" - value1, value2 = b"baz", b"zap" - await add(redis, key, field1, value1) - await add(redis, key, field2, value2) - - test_value = await redis.hvals(key) - assert set(test_value) == {value1, value2} - - test_value = await redis.hvals(b"not:" + key) - assert test_value == [] - - test_value = await redis.hvals(key, encoding="utf-8") - assert set(test_value) == {"baz", "zap"} - with pytest.raises(TypeError): - await redis.hvals(None) - - -@redis_version(2, 8, 0, reason="HSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_hscan(redis): - key = b"key:hscan" - # setup initial values 3 "field:foo:*" items and 7 "field:bar:*" items - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - f = f"field:{foo_or_bar}:{i}".encode("utf-8") - v = f"value:{i}".encode("utf-8") - await add(redis, key, f, v) - # fetch 'field:foo:*' items expected tuple with 3 fields and 3 values - cursor, values = await redis.hscan(key, match=b"field:foo:*") - assert len(values) == 3 - assert sorted(values) == [ - (b"field:foo:3", b"value:3"), - (b"field:foo:6", b"value:6"), - (b"field:foo:9", b"value:9"), - ] - # fetch 'field:bar:*' items expected tuple with 7 fields and 7 values - cursor, values = await redis.hscan(key, match=b"field:bar:*") - assert len(values) == 7 - assert sorted(values) == [ - (b"field:bar:1", b"value:1"), - (b"field:bar:10", b"value:10"), - (b"field:bar:2", b"value:2"), - (b"field:bar:4", b"value:4"), - (b"field:bar:5", b"value:5"), - (b"field:bar:7", b"value:7"), - (b"field:bar:8", b"value:8"), - ] - - # SCAN family functions do not guarantee that the number of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - cursor = b"0" - test_values = [] - while cursor: - cursor, values = await redis.hscan(key, cursor, count=1) - test_values.extend(values) - assert len(test_values) == 10 - - with pytest.raises(TypeError): - await redis.hscan(None) - - -@pytest.mark.asyncio -async def test_hgetall_enc(create_redis, server): - redis = await create_redis(server.tcp_address, encoding="utf-8") - TEST_KEY = "my-key-nx" - await redis.hmset(TEST_KEY, "foo", "bar", "baz", "bad") - - tr = redis.multi_exec() - tr.hgetall(TEST_KEY) - res = await tr.execute() - assert res == [{"foo": "bar", "baz": "bad"}] - - -@redis_version(3, 2, 0, reason="HSTRLEN new in redis 3.2.0") -@pytest.mark.asyncio -async def test_hstrlen(redis): - ok = await redis.hset("myhash", "str_field", "some value") - assert ok == 1 - ok = await redis.hincrby("myhash", "uint_field", 1) - assert ok == 1 - - ok = await redis.hincrby("myhash", "int_field", -1) - assert ok == -1 - - l = await redis.hstrlen("myhash", "str_field") - assert l == 10 - l = await redis.hstrlen("myhash", "uint_field") - assert l == 1 - l = await redis.hstrlen("myhash", "int_field") - assert l == 2 - - l = await redis.hstrlen("myhash", "none_field") - assert l == 0 - - l = await redis.hstrlen("none_key", "none_field") - assert l == 0 - - -@redis_version(2, 8, 0, reason="HSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_ihscan(redis): - key = b"key:hscan" - # setup initial values 3 "field:foo:*" items and 7 "field:bar:*" items - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - f = f"field:{foo_or_bar}:{i}".encode("utf-8") - v = f"value:{i}".encode("utf-8") - assert await redis.hset(key, f, v) == 1 - - async def coro(cmd): - lst = [] - async for i in cmd: - lst.append(i) - return lst - - # fetch 'field:foo:*' items expected tuple with 3 fields and 3 values - ret = await coro(redis.ihscan(key, match=b"field:foo:*")) - assert set(ret) == { - (b"field:foo:3", b"value:3"), - (b"field:foo:6", b"value:6"), - (b"field:foo:9", b"value:9"), - } - - # fetch 'field:bar:*' items expected tuple with 7 fields and 7 values - ret = await coro(redis.ihscan(key, match=b"field:bar:*")) - assert set(ret) == { - (b"field:bar:1", b"value:1"), - (b"field:bar:2", b"value:2"), - (b"field:bar:4", b"value:4"), - (b"field:bar:5", b"value:5"), - (b"field:bar:7", b"value:7"), - (b"field:bar:8", b"value:8"), - (b"field:bar:10", b"value:10"), - } - - # SCAN family functions do not guarantee that the number of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - ret = await coro(redis.ihscan(key, count=1)) - assert set(ret) == { - (b"field:foo:3", b"value:3"), - (b"field:foo:6", b"value:6"), - (b"field:foo:9", b"value:9"), - (b"field:bar:1", b"value:1"), - (b"field:bar:2", b"value:2"), - (b"field:bar:4", b"value:4"), - (b"field:bar:5", b"value:5"), - (b"field:bar:7", b"value:7"), - (b"field:bar:8", b"value:8"), - (b"field:bar:10", b"value:10"), - } - - with pytest.raises(TypeError): - await redis.ihscan(None) diff --git a/tests/hyperloglog_commands_test.py b/tests/hyperloglog_commands_test.py deleted file mode 100644 index 78c01629f..000000000 --- a/tests/hyperloglog_commands_test.py +++ /dev/null @@ -1,102 +0,0 @@ -import pytest - -from tests.testutils import redis_version - -pytestmark = redis_version(2, 8, 9, reason="HyperLogLog works only with redis>=2.8.9") - - -@pytest.mark.asyncio -async def test_pfcount(redis): - key = "hll_pfcount" - other_key = "some-other-hll" - - # add initial data, cardinality changed so command returns 1 - is_changed = await redis.pfadd(key, "foo", "bar", "zap") - assert is_changed == 1 - - # add more data, cardinality not changed so command returns 0 - is_changed = await redis.pfadd(key, "zap", "zap", "zap") - assert is_changed == 0 - - # add event more data, cardinality not changed so command returns 0 - is_changed = await redis.pfadd(key, "foo", "bar") - assert is_changed == 0 - - # check cardinality of one key - cardinality = await redis.pfcount(key) - assert cardinality == 3 - - # create new key (variable) for cardinality estimation - is_changed = await redis.pfadd(other_key, 1, 2, 3) - assert is_changed == 1 - - # check cardinality of multiple keys - cardinality = await redis.pfcount(key, other_key) - assert cardinality == 6 - - with pytest.raises(TypeError): - await redis.pfcount(None) - with pytest.raises(TypeError): - await redis.pfcount(key, None) - with pytest.raises(TypeError): - await redis.pfcount(key, key, None) - - -@pytest.mark.asyncio -async def test_pfadd(redis): - key = "hll_pfadd" - values = ["a", "s", "y", "n", "c", "i", "o"] - # add initial data, cardinality changed so command returns 1 - is_changed = await redis.pfadd(key, *values) - assert is_changed == 1 - # add event more data, cardinality not changed so command returns 0 - is_changed = await redis.pfadd(key, "i", "o") - assert is_changed == 0 - - -@pytest.mark.asyncio -async def test_pfadd_wrong_input(redis): - with pytest.raises(TypeError): - await redis.pfadd(None, "value") - - -@pytest.mark.asyncio -async def test_pfmerge(redis): - key = "hll_asyncio" - key_other = "hll_aioredis" - - key_dest = "hll_aio" - - values = ["a", "s", "y", "n", "c", "i", "o"] - values_other = ["a", "i", "o", "r", "e", "d", "i", "s"] - - data_set = set(values + values_other) - cardinality_merged = len(data_set) - - # add initial data, cardinality changed so command returns 1 - await redis.pfadd(key, *values) - await redis.pfadd(key_other, *values_other) - - # check cardinality of one key - cardinality = await redis.pfcount(key) - assert cardinality == len(set(values_other)) - - cardinality_other = await redis.pfcount(key_other) - assert cardinality_other == len(set(values_other)) - - await redis.pfmerge(key_dest, key, key_other) - cardinality_dest = await redis.pfcount(key_dest) - assert cardinality_dest == cardinality_merged - - with pytest.raises(TypeError): - await redis.pfmerge(None, key) - with pytest.raises(TypeError): - await redis.pfmerge(key_dest, None) - with pytest.raises(TypeError): - await redis.pfmerge(key_dest, key, None) - - -@pytest.mark.asyncio -async def test_pfmerge_wrong_input(redis): - with pytest.raises(TypeError): - await redis.pfmerge(None, "value") diff --git a/tests/integration_test.py b/tests/integration_test.py deleted file mode 100644 index 5772f007a..000000000 --- a/tests/integration_test.py +++ /dev/null @@ -1,103 +0,0 @@ -import asyncio - -import pytest - -import aioredis - - -@pytest.fixture -def pool_or_redis(_closable, server): - version = tuple(map(int, aioredis.__version__.split(".")[:2])) - if version >= (1, 0): - factory = aioredis.create_redis_pool - else: - factory = aioredis.create_pool - - async def redis_factory(maxsize): - redis = await factory(server.tcp_address, minsize=1, maxsize=maxsize) - _closable(redis) - return redis - - return redis_factory - - -async def simple_get_set(pool, idx): - """A simple test to make sure Redis(pool) can be used as old Pool(Redis).""" - val = f"val:{idx}" - with await pool as redis: - assert await redis.set("key", val) - await redis.get("key", encoding="utf-8") - - -async def pipeline(pool, val): - val = f"val:{val}" - with await pool as redis: - f1 = redis.set("key", val) - f2 = redis.get("key", encoding="utf-8") - ok, res = await asyncio.gather(f1, f2) - - -async def transaction(pool, val): - val = f"val:{val}" - with await pool as redis: - tr = redis.multi_exec() - tr.set("key", val) - tr.get("key", encoding="utf-8") - ok, res = await tr.execute() - assert ok, ok - assert res == val - - -async def blocking_pop(pool, val): - async def lpush(): - with await pool as redis: - # here v0.3 has bound connection, v1.0 does not; - await asyncio.sleep(0.1) - await redis.lpush("list-key", "val") - - async def blpop(): - with await pool as redis: - # here v0.3 has bound connection, v1.0 does not; - res = await redis.blpop("list-key", timeout=2, encoding="utf-8") - assert res == ["list-key", "val"], res - - await asyncio.gather(blpop(), lpush()) - - -@pytest.mark.parametrize( - "test_case,pool_size", - [ - (simple_get_set, 1), - (pipeline, 1), - (transaction, 1), - pytest.param( - blocking_pop, - 1, - marks=pytest.mark.xfail(reason="blpop gets connection first and blocks"), - ), - (simple_get_set, 10), - (pipeline, 10), - (transaction, 10), - (blocking_pop, 10), - ], - ids=lambda o: getattr(o, "__name__", repr(o)), -) -@pytest.mark.asyncio -async def test_operations(pool_or_redis, test_case, pool_size): - repeat = 100 - redis = await pool_or_redis(pool_size) - done, pending = await asyncio.wait( - [asyncio.ensure_future(test_case(redis, i)) for i in range(repeat)] - ) - - assert not pending - success = 0 - failures = [] - for fut in done: - exc = fut.exception() - if exc is None: - success += 1 - else: - failures.append(exc) - assert repeat == success, failures - assert not failures diff --git a/tests/list_commands_test.py b/tests/list_commands_test.py deleted file mode 100644 index ca8aa0f8c..000000000 --- a/tests/list_commands_test.py +++ /dev/null @@ -1,545 +0,0 @@ -import asyncio - -import pytest - -from aioredis import ReplyError - - -async def push_data_with_sleep(redis, key, *values): - await asyncio.sleep(0.2) - result = await redis.lpush(key, *values) - return result - - -@pytest.mark.asyncio -async def test_blpop(redis): - key1, value1 = b"key:blpop:1", b"blpop:value:1" - key2, value2 = b"key:blpop:2", b"blpop:value:2" - - # setup list - result = await redis.rpush(key1, value1, value2) - assert result == 2 - # make sure that left value poped - test_value = await redis.blpop(key1) - assert test_value == [key1, value1] - # pop remaining value, so list should become empty - test_value = await redis.blpop(key1) - assert test_value == [key1, value2] - - with pytest.raises(TypeError): - await redis.blpop(None) - with pytest.raises(TypeError): - await redis.blpop(key1, None) - with pytest.raises(TypeError): - await redis.blpop(key1, timeout=b"one") - with pytest.raises(ValueError): - await redis.blpop(key2, timeout=-10) - - # test encoding param - await redis.rpush(key2, value1) - test_value = await redis.blpop(key2, encoding="utf-8") - assert test_value == ["key:blpop:2", "blpop:value:1"] - - -@pytest.mark.asyncio -async def test_blpop_blocking_features(redis, create_redis, server): - key1, key2 = b"key:blpop:1", b"key:blpop:2" - value = b"blpop:value:2" - - other_redis = await create_redis(server.tcp_address) - - # create blocking task in separate connection - consumer = other_redis.blpop(key1, key2) - - producer_task = asyncio.ensure_future(push_data_with_sleep(redis, key2, value)) - results = await asyncio.gather(consumer, producer_task) - - assert results[0] == [key2, value] - assert results[1] == 1 - - # wait for data with timeout, list is emtpy, so blpop should - # return None in 1 sec - waiter = redis.blpop(key1, key2, timeout=1) - test_value = await waiter - assert test_value is None - other_redis.close() - - -@pytest.mark.asyncio -async def test_brpop(redis): - key1, value1 = b"key:brpop:1", b"brpop:value:1" - key2, value2 = b"key:brpop:2", b"brpop:value:2" - - # setup list - result = await redis.rpush(key1, value1, value2) - assert result == 2 - # make sure that right value poped - test_value = await redis.brpop(key1) - assert test_value == [key1, value2] - # pop remaining value, so list should become empty - test_value = await redis.brpop(key1) - assert test_value == [key1, value1] - - with pytest.raises(TypeError): - await redis.brpop(None) - with pytest.raises(TypeError): - await redis.brpop(key1, None) - with pytest.raises(TypeError): - await redis.brpop(key1, timeout=b"one") - with pytest.raises(ValueError): - await redis.brpop(key2, timeout=-10) - - # test encoding param - await redis.rpush(key2, value1) - test_value = await redis.brpop(key2, encoding="utf-8") - assert test_value == ["key:brpop:2", "brpop:value:1"] - - -@pytest.mark.asyncio -async def test_brpop_blocking_features(redis, create_redis, server): - key1, key2 = b"key:brpop:1", b"key:brpop:2" - value = b"brpop:value:2" - - other_redis = await create_redis(server.tcp_address) - # create blocking task in separate connection - consumer_task = other_redis.brpop(key1, key2) - - producer_task = asyncio.ensure_future(push_data_with_sleep(redis, key2, value)) - - results = await asyncio.gather(consumer_task, producer_task) - - assert results[0] == [key2, value] - assert results[1] == 1 - - # wait for data with timeout, list is emtpy, so brpop should - # return None in 1 sec - waiter = redis.brpop(key1, key2, timeout=1) - test_value = await waiter - assert test_value is None - - -@pytest.mark.asyncio -async def test_brpoplpush(redis): - key = b"key:brpoplpush:1" - value1, value2 = b"brpoplpush:value:1", b"brpoplpush:value:2" - - destkey = b"destkey:brpoplpush:1" - - # setup list - await redis.rpush(key, value1, value2) - - # move value in into head of new list - result = await redis.brpoplpush(key, destkey) - assert result == value2 - # move last value - result = await redis.brpoplpush(key, destkey) - assert result == value1 - - # make sure that all values stored in new destkey list - test_value = await redis.lrange(destkey, 0, -1) - assert test_value == [value1, value2] - - with pytest.raises(TypeError): - await redis.brpoplpush(None, destkey) - - with pytest.raises(TypeError): - await redis.brpoplpush(key, None) - - with pytest.raises(TypeError): - await redis.brpoplpush(key, destkey, timeout=b"one") - - with pytest.raises(ValueError): - await redis.brpoplpush(key, destkey, timeout=-10) - - # test encoding param - result = await redis.brpoplpush(destkey, key, encoding="utf-8") - assert result == "brpoplpush:value:2" - - -@pytest.mark.asyncio -async def test_brpoplpush_blocking_features(redis, create_redis, server): - source = b"key:brpoplpush:12" - value = b"brpoplpush:value:2" - destkey = b"destkey:brpoplpush:2" - other_redis = await create_redis(server.tcp_address) - # create blocking task - consumer_task = other_redis.brpoplpush(source, destkey) - producer_task = asyncio.ensure_future(push_data_with_sleep(redis, source, value)) - results = await asyncio.gather(consumer_task, producer_task) - assert results[0] == value - assert results[1] == 1 - - # make sure that all values stored in new destkey list - test_value = await redis.lrange(destkey, 0, -1) - assert test_value == [value] - - # wait for data with timeout, list is emtpy, so brpoplpush should - # return None in 1 sec - waiter = redis.brpoplpush(source, destkey, timeout=1) - test_value = await waiter - assert test_value is None - other_redis.close() - - -@pytest.mark.asyncio -async def test_lindex(redis): - key, value = b"key:lindex:1", "value:{}" - # setup list - values = [value.format(i).encode("utf-8") for i in range(0, 10)] - await redis.rpush(key, *values) - # make sure that all indexes are correct - for i in range(0, 10): - test_value = await redis.lindex(key, i) - assert test_value == values[i] - - # get last element - test_value = await redis.lindex(key, -1) - assert test_value == b"value:9" - - # index of element if key does not exists - test_value = await redis.lindex(b"not:" + key, 5) - assert test_value is None - - # test encoding param - await redis.rpush(key, "one", "two") - test_value = await redis.lindex(key, 10, encoding="utf-8") - assert test_value == "one" - test_value = await redis.lindex(key, 11, encoding="utf-8") - assert test_value == "two" - - with pytest.raises(TypeError): - await redis.lindex(None, -1) - - with pytest.raises(TypeError): - await redis.lindex(key, b"one") - - -@pytest.mark.asyncio -async def test_linsert(redis): - key = b"key:linsert:1" - value1, value2, value3, value4 = b"Hello", b"World", b"foo", b"bar" - await redis.rpush(key, value1, value2) - - # insert element before pivot - test_value = await redis.linsert(key, value2, value3, before=True) - assert test_value == 3 - # insert element after pivot - test_value = await redis.linsert(key, value2, value4, before=False) - assert test_value == 4 - - # make sure that values actually inserted in right placed - test_value = await redis.lrange(key, 0, -1) - expected = [value1, value3, value2, value4] - assert test_value == expected - - # try to insert something when pivot value does not exits - test_value = await redis.linsert(key, b"not:pivot", value3, before=True) - assert test_value == -1 - - with pytest.raises(TypeError): - await redis.linsert(None, value1, value3) - - -@pytest.mark.asyncio -async def test_llen(redis): - key = b"key:llen:1" - value1, value2 = b"Hello", b"World" - await redis.rpush(key, value1, value2) - - test_value = await redis.llen(key) - assert test_value == 2 - - test_value = await redis.llen(b"not:" + key) - assert test_value == 0 - - with pytest.raises(TypeError): - await redis.llen(None) - - -@pytest.mark.asyncio -async def test_lpop(redis): - key = b"key:lpop:1" - value1, value2 = b"lpop:value:1", b"lpop:value:2" - - # setup list - result = await redis.rpush(key, value1, value2) - assert result == 2 - # make sure that left value poped - test_value = await redis.lpop(key) - assert test_value == value1 - # pop remaining value, so list should become empty - test_value = await redis.lpop(key) - assert test_value == value2 - # pop from empty list - test_value = await redis.lpop(key) - assert test_value is None - - # test encoding param - await redis.rpush(key, "value") - test_value = await redis.lpop(key, encoding="utf-8") - assert test_value == "value" - - with pytest.raises(TypeError): - await redis.lpop(None) - - -@pytest.mark.asyncio -async def test_lpush(redis): - key = b"key:lpush" - value1, value2 = b"value:1", b"value:2" - - # add multiple values to the list, with key that does not exists - result = await redis.lpush(key, value1, value2) - assert result == 2 - - # make sure that values actually inserted in right placed and order - test_value = await redis.lrange(key, 0, -1) - assert test_value == [value2, value1] - - # test encoding param - test_value = await redis.lrange(key, 0, -1, encoding="utf-8") - assert test_value == ["value:2", "value:1"] - - with pytest.raises(TypeError): - await redis.lpush(None, value1) - - -@pytest.mark.asyncio -async def test_lpushx(redis): - key = b"key:lpushx" - value1, value2 = b"value:1", b"value:2" - - # add multiple values to the list, with key that does not exists - # so value should not be pushed - result = await redis.lpushx(key, value2) - assert result == 0 - # init key with list by using regular lpush - result = await redis.lpush(key, value1) - assert result == 1 - - result = await redis.lpushx(key, value2) - assert result == 2 - - # make sure that values actually inserted in right placed and order - test_value = await redis.lrange(key, 0, -1) - assert test_value == [value2, value1] - - with pytest.raises(TypeError): - await redis.lpushx(None, value1) - - -@pytest.mark.asyncio -async def test_lrange(redis): - key, value = b"key:lrange:1", "value:{}" - values = [value.format(i).encode("utf-8") for i in range(0, 10)] - await redis.rpush(key, *values) - - test_value = await redis.lrange(key, 0, 2) - assert test_value == values[0:3] - - test_value = await redis.lrange(key, 0, -1) - assert test_value == values - - test_value = await redis.lrange(key, -2, -1) - assert test_value == values[-2:] - - # range of elements if key does not exists - test_value = await redis.lrange(b"not:" + key, 0, -1) - assert test_value == [] - - with pytest.raises(TypeError): - await redis.lrange(None, 0, -1) - - with pytest.raises(TypeError): - await redis.lrange(key, b"zero", -1) - - with pytest.raises(TypeError): - await redis.lrange(key, 0, b"one") - - -@pytest.mark.asyncio -async def test_lrem(redis): - key, value = b"key:lrem:1", "value:{}" - values = [value.format(i % 2).encode("utf-8") for i in range(0, 10)] - await redis.rpush(key, *values) - # remove elements from tail to head - test_value = await redis.lrem(key, -4, b"value:0") - assert test_value == 4 - # remove element from head to tail - test_value = await redis.lrem(key, 4, b"value:1") - assert test_value == 4 - - # remove values that not in list - test_value = await redis.lrem(key, 4, b"value:other") - assert test_value == 0 - - # make sure that only two values left in the list - test_value = await redis.lrange(key, 0, -1) - assert test_value == [b"value:0", b"value:1"] - - # remove all instance of value:0 - test_value = await redis.lrem(key, 0, b"value:0") - assert test_value == 1 - - # make sure that only one values left in the list - test_value = await redis.lrange(key, 0, -1) - assert test_value == [b"value:1"] - - with pytest.raises(TypeError): - await redis.lrem(None, 0, b"value:0") - - with pytest.raises(TypeError): - await redis.lrem(key, b"ten", b"value:0") - - -@pytest.mark.asyncio -async def test_lset(redis): - key, value = b"key:lset", "value:{}" - values = [value.format(i).encode("utf-8") for i in range(0, 3)] - await redis.rpush(key, *values) - - await redis.lset(key, 0, b"foo") - await redis.lset(key, -1, b"baz") - await redis.lset(key, -2, b"zap") - - test_value = await redis.lrange(key, 0, -1) - assert test_value == [b"foo", b"zap", b"baz"] - - with pytest.raises(TypeError): - await redis.lset(None, 0, b"value:0") - - with pytest.raises(ReplyError): - await redis.lset(key, 100, b"value:0") - - with pytest.raises(TypeError): - await redis.lset(key, b"one", b"value:0") - - -@pytest.mark.asyncio -async def test_ltrim(redis): - key, value = b"key:ltrim", "value:{}" - values = [value.format(i).encode("utf-8") for i in range(0, 10)] - await redis.rpush(key, *values) - - # trim with negative indexes - await redis.ltrim(key, 0, -5) - test_value = await redis.lrange(key, 0, -1) - assert test_value == values[:-4] - # trim with positive indexes - await redis.ltrim(key, 0, 2) - test_value = await redis.lrange(key, 0, -1) - assert test_value == values[:3] - - # try to trim out of range indexes - res = await redis.ltrim(key, 100, 110) - assert res is True - test_value = await redis.lrange(key, 0, -1) - assert test_value == [] - - with pytest.raises(TypeError): - await redis.ltrim(None, 0, -1) - - with pytest.raises(TypeError): - await redis.ltrim(key, b"zero", -1) - - with pytest.raises(TypeError): - await redis.ltrim(key, 0, b"one") - - -@pytest.mark.asyncio -async def test_rpop(redis): - key = b"key:rpop:1" - value1, value2 = b"rpop:value:1", b"rpop:value:2" - - # setup list - result = await redis.rpush(key, value1, value2) - assert result == 2 - # make sure that left value poped - test_value = await redis.rpop(key) - assert test_value == value2 - # pop remaining value, so list should become empty - test_value = await redis.rpop(key) - assert test_value == value1 - # pop from empty list - test_value = await redis.rpop(key) - assert test_value is None - - # test encoding param - await redis.rpush(key, "value") - test_value = await redis.rpop(key, encoding="utf-8") - assert test_value == "value" - - with pytest.raises(TypeError): - await redis.rpop(None) - - -@pytest.mark.asyncio -async def test_rpoplpush(redis): - key = b"key:rpoplpush:1" - value1, value2 = b"rpoplpush:value:1", b"rpoplpush:value:2" - destkey = b"destkey:rpoplpush:1" - - # setup list - await redis.rpush(key, value1, value2) - - # move value in into head of new list - result = await redis.rpoplpush(key, destkey) - assert result == value2 - # move last value - result = await redis.rpoplpush(key, destkey) - assert result == value1 - - # make sure that all values stored in new destkey list - result = await redis.lrange(destkey, 0, -1) - assert result == [value1, value2] - - # test encoding param - result = await redis.rpoplpush(destkey, key, encoding="utf-8") - assert result == "rpoplpush:value:2" - - with pytest.raises(TypeError): - await redis.rpoplpush(None, destkey) - - with pytest.raises(TypeError): - await redis.rpoplpush(key, None) - - -@pytest.mark.asyncio -async def test_rpush(redis): - key = b"key:rpush" - value1, value2 = b"value:1", b"value:2" - - # add multiple values to the list, with key that does not exists - result = await redis.rpush(key, value1, value2) - assert result == 2 - - # make sure that values actually inserted in right placed and order - test_value = await redis.lrange(key, 0, -1) - assert test_value == [value1, value2] - - with pytest.raises(TypeError): - await redis.rpush(None, value1) - - -@pytest.mark.asyncio -async def test_rpushx(redis): - key = b"key:rpushx" - value1, value2 = b"value:1", b"value:2" - - # add multiple values to the list, with key that does not exists - # so value should not be pushed - result = await redis.rpushx(key, value2) - assert result == 0 - # init key with list by using regular rpush - result = await redis.rpush(key, value1) - assert result == 1 - - result = await redis.rpushx(key, value2) - assert result == 2 - - # make sure that values actually inserted in right placed and order - test_value = await redis.lrange(key, 0, -1) - assert test_value == [value1, value2] - - with pytest.raises(TypeError): - await redis.rpushx(None, value1) diff --git a/tests/locks_test.py b/tests/locks_test.py deleted file mode 100644 index 23e27ff20..000000000 --- a/tests/locks_test.py +++ /dev/null @@ -1,33 +0,0 @@ -import asyncio - -import pytest - -from aioredis.locks import Lock - - -@pytest.mark.asyncio -async def test_finished_waiter_cancelled(): - lock = Lock() - - ta = asyncio.ensure_future(lock.acquire()) - await asyncio.sleep(0) - assert lock.locked() - - tb = asyncio.ensure_future(lock.acquire()) - await asyncio.sleep(0) - assert len(lock._waiters) == 1 - - # Create a second waiter, wake up the first, and cancel it. - # Without the fix, the second was not woken up and the lock - # will never be locked - asyncio.ensure_future(lock.acquire()) - await asyncio.sleep(0) - lock.release() - tb.cancel() - - await asyncio.sleep(0) - assert ta.done() - assert tb.cancelled() - - await asyncio.sleep(0) - assert lock.locked() diff --git a/tests/multi_exec_test.py b/tests/multi_exec_test.py deleted file mode 100644 index 11b11341a..000000000 --- a/tests/multi_exec_test.py +++ /dev/null @@ -1,46 +0,0 @@ -import asyncio -from contextlib import contextmanager -from unittest import mock - -from aioredis.commands import MultiExec, Redis - - -@contextmanager -def nullcontext(result): - yield result - - -def test_global_loop(): - conn = mock.Mock(spec=("execute closed _transaction_error _buffered".split())) - try: - old_loop = asyncio.get_event_loop() - except (AssertionError, RuntimeError): - old_loop = None - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - tr = MultiExec(conn, commands_factory=Redis) - # assert tr._loop is loop - - def make_fut(cmd, *args, **kw): - fut = asyncio.get_event_loop().create_future() - if cmd == "PING": - fut.set_result(b"QUEUED") - elif cmd == "EXEC": - fut.set_result([b"PONG"]) - else: - fut.set_result(b"OK") - return fut - - conn.execute.side_effect = make_fut - conn.closed = False - conn._transaction_error = None - conn._buffered.side_effect = lambda: nullcontext(conn) - - async def go(): - tr.ping() - res = await tr.execute() - assert res == [b"PONG"] - - loop.run_until_complete(go()) - asyncio.set_event_loop(old_loop) diff --git a/tests/parse_url_test.py b/tests/parse_url_test.py deleted file mode 100644 index 61c7cd991..000000000 --- a/tests/parse_url_test.py +++ /dev/null @@ -1,149 +0,0 @@ -import pytest - -from aioredis.util import parse_url - - -@pytest.mark.parametrize( - "url,expected_address,expected_options", - [ - # redis scheme - ("redis://", ("localhost", 6379), {}), - ("redis://localhost:6379", ("localhost", 6379), {}), - ("redis://localhost:6379/", ("localhost", 6379), {}), - ("redis://localhost:6379/0", ("localhost", 6379), {"db": 0}), - ("redis://localhost:6379/1", ("localhost", 6379), {"db": 1}), - ("redis://localhost:6379?db=1", ("localhost", 6379), {"db": 1}), - ("redis://localhost:6379/?db=1", ("localhost", 6379), {"db": 1}), - ("redis://redis-host", ("redis-host", 6379), {}), - ("redis://redis-host", ("redis-host", 6379), {}), - ("redis://host:1234", ("host", 1234), {}), - ("redis://user@localhost", ("localhost", 6379), {}), - ("redis://:secret@localhost", ("localhost", 6379), {"password": "secret"}), - ("redis://user:secret@localhost", ("localhost", 6379), {"password": "secret"}), - ( - "redis://localhost?password=secret", - ("localhost", 6379), - {"password": "secret"}, - ), - ( - "redis://localhost?encoding=utf-8", - ("localhost", 6379), - {"encoding": "utf-8"}, - ), - ("redis://localhost?ssl=true", ("localhost", 6379), {"ssl": True}), - ("redis://localhost?timeout=1.0", ("localhost", 6379), {"timeout": 1.0}), - ("redis://localhost?timeout=10", ("localhost", 6379), {"timeout": 10.0}), - # rediss scheme - ("rediss://", ("localhost", 6379), {"ssl": True}), - ("rediss://localhost:6379", ("localhost", 6379), {"ssl": True}), - ("rediss://localhost:6379/", ("localhost", 6379), {"ssl": True}), - ("rediss://localhost:6379/0", ("localhost", 6379), {"ssl": True, "db": 0}), - ("rediss://localhost:6379/1", ("localhost", 6379), {"ssl": True, "db": 1}), - ("rediss://localhost:6379?db=1", ("localhost", 6379), {"ssl": True, "db": 1}), - ("rediss://localhost:6379/?db=1", ("localhost", 6379), {"ssl": True, "db": 1}), - ("rediss://redis-host", ("redis-host", 6379), {"ssl": True}), - ("rediss://redis-host", ("redis-host", 6379), {"ssl": True}), - ("rediss://host:1234", ("host", 1234), {"ssl": True}), - ("rediss://user@localhost", ("localhost", 6379), {"ssl": True}), - ( - "rediss://:secret@localhost", - ("localhost", 6379), - {"ssl": True, "password": "secret"}, - ), - ( - "rediss://user:secret@localhost", - ("localhost", 6379), - {"ssl": True, "password": "secret"}, - ), - ( - "rediss://localhost?password=secret", - ("localhost", 6379), - {"ssl": True, "password": "secret"}, - ), - ( - "rediss://localhost?encoding=utf-8", - ("localhost", 6379), - {"ssl": True, "encoding": "utf-8"}, - ), - ( - "rediss://localhost?timeout=1.0", - ("localhost", 6379), - {"ssl": True, "timeout": 1.0}, - ), - ( - "rediss://localhost?timeout=10", - ("localhost", 6379), - {"ssl": True, "timeout": 10.0}, - ), - # unix scheme - ("unix:///", "/", {}), - ("unix:///redis.sock?db=12", "/redis.sock", {"db": 12}), - ("unix:///redis.sock?encoding=utf-8", "/redis.sock", {"encoding": "utf-8"}), - ("unix:///redis.sock?ssl=true", "/redis.sock", {"ssl": True}), - ("unix:///redis.sock?timeout=12", "/redis.sock", {"timeout": 12}), - # no scheme - ("/some/path/to/socket", "/some/path/to/socket", {}), - ("/some/path/to/socket?db=1", "/some/path/to/socket?db=1", {}), - ], -) -def test_good_url(url, expected_address, expected_options): - address, options = parse_url(url) - assert address == expected_address - assert options == expected_options - - -@pytest.mark.parametrize( - "url,expected_error", - [ - ("bad-scheme://localhost:6379/", ("Unsupported URI scheme", "bad-scheme")), - ("redis:///?db=1&db=2", ("Multiple parameters are not allowed", "db", "2")), - ("redis:///?db=", ("Empty parameters are not allowed", "db", "")), - ("redis:///?foo=", ("Empty parameters are not allowed", "foo", "")), - ("unix://", ("Empty path is not allowed", "unix://")), - ( - "unix://host:123/", - ("Netlocation is not allowed for unix scheme", "host:123"), - ), - ( - "unix://user:pass@host:123/", - ("Netlocation is not allowed for unix scheme", "user:pass@host:123"), - ), - ( - "unix://user:pass@/", - ("Netlocation is not allowed for unix scheme", "user:pass@"), - ), - ("redis:///01", ("Expected integer without leading zeroes", "01")), - ("rediss:///01", ("Expected integer without leading zeroes", "01")), - ("redis:///?db=01", ("Expected integer without leading zeroes", "01")), - ("rediss:///?db=01", ("Expected integer without leading zeroes", "01")), - ("redis:///1?db=2", ("Single DB value expected, got path and query", 1, 2)), - ("rediss:///1?db=2", ("Single DB value expected, got path and query", 1, 2)), - ( - "redis://:passwd@localhost/?password=passwd", - ("Single password value is expected, got in net location and query"), - ), - ("redis:///?ssl=1", ("Expected 'ssl' param to be 'true' or 'false' only", "1")), - ( - "redis:///?ssl=True", - ("Expected 'ssl' param to be 'true' or 'false' only", "True"), - ), - ], -) -def test_url_assertions(url, expected_error): - with pytest.raises(AssertionError) as exc_info: - parse_url(url) - assert exc_info.value.args == (expected_error,) - - -@pytest.mark.parametrize( - "url", - [ - "redis:///bad-db-num", - "redis:///0/1", - "redis:///?db=bad-num", - "redis:///?db=-1", - ], -) -def test_db_num_assertions(url): - with pytest.raises(AssertionError, match="Invalid decimal integer"): - parse_url(url) diff --git a/tests/pool_test.py b/tests/pool_test.py deleted file mode 100644 index 36ac240ef..000000000 --- a/tests/pool_test.py +++ /dev/null @@ -1,622 +0,0 @@ -import asyncio -import logging -import sys -from contextlib import ExitStack -from unittest.mock import patch - -import async_timeout -import pytest - -from aioredis import ( - ConnectionClosedError, - ConnectionsPool, - MaxClientsError, - PoolClosedError, - ReplyError, -) -from tests.testutils import redis_version - -BPO_34638 = sys.version_info >= (3, 8) - - -def _assert_defaults(pool): - assert isinstance(pool, ConnectionsPool) - assert pool.minsize == 1 - assert pool.maxsize == 10 - assert pool.size == 1 - assert pool.freesize == 1 - assert not pool._close_state.is_set() - - -def test_connect(pool): - _assert_defaults(pool) - - -@pytest.mark.asyncio -async def test_clear(pool): - _assert_defaults(pool) - - await pool.clear() - assert pool.freesize == 0 - - -@pytest.mark.parametrize("minsize", [None, -100, 0.0, 100]) -@pytest.mark.asyncio -async def test_minsize(minsize, create_pool, server): - - with pytest.raises(AssertionError): - await create_pool(server.tcp_address, minsize=minsize, maxsize=10) - - -@pytest.mark.parametrize("maxsize", [None, -100, 0.0, 1]) -@pytest.mark.asyncio -async def test_maxsize(maxsize, create_pool, server): - - with pytest.raises(AssertionError): - await create_pool(server.tcp_address, minsize=2, maxsize=maxsize) - - -@pytest.mark.asyncio -async def test_create_connection_timeout(create_pool, server): - from aioredis.connection import open_connection - - with patch("aioredis.connection.open_connection") as open_conn_mock: - - async def open_conn(*args, **kwargs): - await asyncio.sleep(0.2) - return await open_connection(*args, **kwargs) - - open_conn_mock.side_effect = open_conn - with pytest.raises(asyncio.TimeoutError): - await create_pool(server.tcp_address, create_connection_timeout=0.1) - - -def test_no_yield_from(pool): - with pytest.raises(RuntimeError): - with pool: - pass # pragma: no cover - - -@pytest.mark.asyncio -async def test_simple_command(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=10) - - with (await pool) as conn: - msg = await conn.execute("echo", "hello") - assert msg == b"hello" - assert pool.size == 10 - assert pool.freesize == 9 - assert pool.size == 10 - assert pool.freesize == 10 - - -@pytest.mark.asyncio -async def test_create_new(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1) - assert pool.size == 1 - assert pool.freesize == 1 - - with (await pool): - assert pool.size == 1 - assert pool.freesize == 0 - - with (await pool): - assert pool.size == 2 - assert pool.freesize == 0 - - assert pool.size == 2 - assert pool.freesize == 2 - - -@pytest.mark.asyncio -async def test_create_constraints(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, maxsize=1) - assert pool.size == 1 - assert pool.freesize == 1 - - with (await pool): - assert pool.size == 1 - assert pool.freesize == 0 - - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(pool.acquire(), timeout=0.2) - - -@pytest.mark.asyncio -async def test_create_no_minsize(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=0, maxsize=1) - assert pool.size == 0 - assert pool.freesize == 0 - - with (await pool): - assert pool.size == 1 - assert pool.freesize == 0 - - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(pool.acquire(), timeout=0.2) - assert pool.size == 1 - assert pool.freesize == 1 - - -@pytest.mark.asyncio -async def test_create_pool_cls(create_pool, server): - class MyPool(ConnectionsPool): - pass - - pool = await create_pool(server.tcp_address, pool_cls=MyPool) - - assert isinstance(pool, MyPool) - - -@pytest.mark.asyncio -async def test_create_pool_cls_invalid(create_pool, server): - with pytest.raises(AssertionError): - await create_pool(server.tcp_address, pool_cls=type) - - -@pytest.mark.asyncio -async def test_release_closed(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1) - assert pool.size == 1 - assert pool.freesize == 1 - - with (await pool) as conn: - conn.close() - await conn.wait_closed() - assert pool.size == 0 - assert pool.freesize == 0 - - -@pytest.mark.asyncio -async def test_release_pending(create_pool, server, caplog): - pool = await create_pool(server.tcp_address, minsize=1) - assert pool.size == 1 - assert pool.freesize == 1 - - caplog.clear() - with caplog.at_level("WARNING", "aioredis"): - with (await pool) as conn: - try: - await asyncio.wait_for( - conn.execute(b"blpop", b"somekey:not:exists", b"0"), - 0.05, - ) - except asyncio.TimeoutError: - pass - assert pool.size == 0 - assert pool.freesize == 0 - assert caplog.record_tuples == [ - ( - "aioredis", - logging.WARNING, - "Connection " " has pending commands, closing it.", - ), - ] - - -@pytest.mark.asyncio -async def test_release_bad_connection(create_pool, create_redis, server): - pool = await create_pool(server.tcp_address) - conn = await pool.acquire() - assert conn.address[0] in ("127.0.0.1", "::1") - assert conn.address[1] == server.tcp_address.port - other_conn = await create_redis(server.tcp_address) - with pytest.raises(AssertionError): - pool.release(other_conn) - - pool.release(conn) - other_conn.close() - await other_conn.wait_closed() - - -@pytest.mark.asyncio -async def test_release_pubsub_closed(create_pool, server): - """ - Check `_used` connections list cleanup after pubsub connection died - """ - pool = await create_pool(server.tcp_address, minsize=1, maxsize=2) - assert pool.size == 1 - assert pool.freesize == 1 - assert pool.maxsize == 2 - - await pool.execute("set", "key", "val") - await pool.execute_pubsub("subscribe", "channel:1") - assert pool.size == 1 - - res = await pool.execute("get", "key") - assert res == b"val" - assert pool.size == 2 - - conn, _ = pool.get_connection("subscribe") - assert conn and not conn.closed - assert conn.in_pubsub - conn.close() - await conn.wait_closed() - - await pool.execute_pubsub("subscribe", "channel:1") - - # Here we could get timeout if `_used` list was not cleaned properly - with async_timeout.timeout(5): - res = await pool.execute("get", "key") - assert res == b"val" - - -@pytest.mark.asyncio -async def test_select_db(create_pool, server): - pool = await create_pool(server.tcp_address) - - await pool.select(1) - with (await pool) as conn: - assert conn.db == 1 - - -@pytest.mark.asyncio -async def test_change_db(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, db=0) - assert pool.size == 1 - assert pool.freesize == 1 - - with (await pool) as conn: - await conn.select(1) - assert pool.size == 0 - assert pool.freesize == 0 - - with (await pool): - assert pool.size == 1 - assert pool.freesize == 0 - - await pool.select(1) - assert pool.db == 1 - assert pool.size == 1 - assert pool.freesize == 0 - assert pool.size == 0 - assert pool.freesize == 0 - assert pool.db == 1 - - -@pytest.mark.asyncio -async def test_change_db_errors(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, db=0) - - with pytest.raises(TypeError): - await pool.select(None) - assert pool.db == 0 - - with (await pool): - pass - assert pool.size == 1 - assert pool.freesize == 1 - - with pytest.raises(TypeError): - await pool.select(None) - assert pool.db == 0 - with pytest.raises(ValueError): - await pool.select(-1) - assert pool.db == 0 - with pytest.raises(ReplyError): - await pool.select(100000) - assert pool.db == 0 - - -@pytest.mark.xfail(reason="Need to refactor this test") -@pytest.mark.asyncio -async def test_select_and_create(create_pool, server): - # trying to model situation when select and acquire - # called simultaneously - # but acquire freezes on _wait_select and - # then continues with proper db - - # TODO: refactor this test as there's no _wait_select any more. - with async_timeout.timeout(10): - pool = await create_pool( - server.tcp_address, - minsize=1, - db=0, - ) - db = 0 - while True: - db = (db + 1) & 1 - _, conn = await asyncio.gather(pool.select(db), pool.acquire()) - assert pool.db == db - pool.release(conn) - if conn.db == db: - break - # await asyncio.wait_for(test(), 3, loop=loop) - - -@pytest.mark.asyncio -async def test_response_decoding(create_pool, server): - pool = await create_pool(server.tcp_address, encoding="utf-8") - - assert pool.encoding == "utf-8" - with (await pool) as conn: - await conn.execute("set", "key", "value") - with (await pool) as conn: - res = await conn.execute("get", "key") - assert res == "value" - - -@pytest.mark.asyncio -async def test_hgetall_response_decoding(create_pool, server): - pool = await create_pool(server.tcp_address, encoding="utf-8") - - assert pool.encoding == "utf-8" - with (await pool) as conn: - await conn.execute("del", "key1") - await conn.execute("hmset", "key1", "foo", "bar") - await conn.execute("hmset", "key1", "baz", "zap") - with (await pool) as conn: - res = await conn.execute("hgetall", "key1") - assert res == ["foo", "bar", "baz", "zap"] - - -@pytest.mark.asyncio -async def test_crappy_multiexec(create_pool, server): - pool = await create_pool(server.tcp_address, encoding="utf-8", minsize=1, maxsize=1) - - with (await pool) as conn: - await conn.execute("set", "abc", "def") - await conn.execute("multi") - await conn.execute("set", "abc", "fgh") - assert conn.closed is True - with (await pool) as conn: - value = await conn.execute("get", "abc") - assert value == "def" - - -@pytest.mark.asyncio -async def test_pool_size_growth(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, maxsize=1) - - done = set() - tasks = [] - - async def task1(i): - with (await pool): - assert pool.size <= pool.maxsize - assert pool.freesize == 0 - await asyncio.sleep(0.2) - done.add(i) - - async def task2(): - with (await pool): - assert pool.size <= pool.maxsize - assert pool.freesize >= 0 - assert done == {0, 1} - - for _ in range(2): - tasks.append(asyncio.ensure_future(task1(_))) - tasks.append(asyncio.ensure_future(task2())) - await asyncio.gather(*tasks) - - -@pytest.mark.asyncio -async def test_pool_with_closed_connections(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, maxsize=2) - assert 1 == pool.freesize - conn1 = pool._pool[0] - conn1.close() - assert conn1.closed is True - assert 1 == pool.freesize - with (await pool) as conn2: - assert conn2.closed is False - assert conn1 is not conn2 - - -@pytest.mark.asyncio -async def test_pool_close(create_pool, server): - pool = await create_pool(server.tcp_address) - - assert pool.closed is False - - with (await pool) as conn: - assert (await conn.execute("ping")) == b"PONG" - - pool.close() - await pool.wait_closed() - assert pool.closed is True - - with pytest.raises(PoolClosedError): - with (await pool) as conn: - assert (await conn.execute("ping")) == b"PONG" - - -@pytest.mark.asyncio -async def test_pool_close__used(create_pool, server): - pool = await create_pool(server.tcp_address) - - assert pool.closed is False - - with (await pool) as conn: - pool.close() - await pool.wait_closed() - assert pool.closed is True - - with pytest.raises(ConnectionClosedError): - await conn.execute("ping") - - -@redis_version(2, 8, 0, reason="maxclients config setting") -@pytest.mark.asyncio -async def test_pool_check_closed_when_exception( - create_pool, create_redis, start_server, caplog -): - server = start_server("server-small") - redis = await create_redis(server.tcp_address) - await redis.config_set("maxclients", 2) - - errors = (MaxClientsError, ConnectionClosedError, ConnectionError) - caplog.clear() - with caplog.at_level("DEBUG", "aioredis"): - with pytest.raises(errors): - await create_pool(address=tuple(server.tcp_address), minsize=3) - - assert len(caplog.record_tuples) >= 3 - connect_msg = "Creating tcp connection to ('localhost', {})".format( - server.tcp_address.port - ) - assert caplog.record_tuples[:2] == [ - ("aioredis", logging.DEBUG, connect_msg), - ("aioredis", logging.DEBUG, connect_msg), - ] - assert caplog.record_tuples[-1] == ( - "aioredis", - logging.DEBUG, - "Closed 1 connection(s)", - ) - - -@pytest.mark.asyncio -async def test_pool_get_connection(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, maxsize=2) - res = await pool.execute("set", "key", "val") - assert res == b"OK" - - res = await pool.execute_pubsub("subscribe", "channel:1") - assert res == [[b"subscribe", b"channel:1", 1]] - - res = await pool.execute("getset", "key", "value") - assert res == b"val" - - res = await pool.execute_pubsub("subscribe", "channel:2") - assert res == [[b"subscribe", b"channel:2", 2]] - - res = await pool.execute("get", "key") - assert res == b"value" - - -@pytest.mark.asyncio -async def test_pool_get_connection_with_pipelining(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=1, maxsize=2) - fut1 = pool.execute("set", "key", "val") - fut2 = pool.execute_pubsub("subscribe", "channel:1") - fut3 = pool.execute("getset", "key", "next") - fut4 = pool.execute_pubsub("subscribe", "channel:2") - fut5 = pool.execute("get", "key") - res = await fut1 - assert res == b"OK" - res = await fut2 - assert res == [[b"subscribe", b"channel:1", 1]] - res = await fut3 - assert res == b"val" - res = await fut4 - assert res == [[b"subscribe", b"channel:2", 2]] - res = await fut5 - assert res == b"next" - - -@pytest.mark.skipif(sys.platform == "win32", reason="flaky on windows") -@pytest.mark.asyncio -async def test_pool_idle_close(create_pool, start_server): - server = start_server("idle") - conn = await create_pool(server.tcp_address, minsize=2) - ok = await conn.execute("config", "set", "timeout", 1) - assert ok == b"OK" - closed = [] - while len(closed) < 2: - await asyncio.sleep(0.5) - closed = [c for c in conn._pool if c._closed] - assert closed == [*conn._pool] - # On CI this test fails from time to time. - # It is possible to pick 'unclosed' connection and send command, - # however on the same loop iteration it gets closed and exception is raised - assert (await conn.execute("ping")) == b"PONG" - - -@pytest.mark.asyncio -async def test_await(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=10) - - with (await pool) as conn: - msg = await conn.execute("echo", "hello") - assert msg == b"hello" - - -@pytest.mark.asyncio -async def test_async_with(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=10) - - async with pool.get() as conn: - msg = await conn.execute("echo", "hello") - assert msg == b"hello" - - -@pytest.mark.asyncio -async def test_pool__drop_closed(create_pool, server): - pool = await create_pool(server.tcp_address, minsize=3, maxsize=3) - assert pool.size == 3 - assert pool.freesize == 3 - assert not pool._pool[0].closed - assert not pool._pool[1].closed - assert not pool._pool[2].closed - - pool._pool[1].close() - pool._pool[2].close() - await pool._pool[1].wait_closed() - await pool._pool[2].wait_closed() - - assert not pool._pool[0].closed - assert pool._pool[1].closed - assert pool._pool[2].closed - - assert pool.size == 3 - assert pool.freesize == 3 - - pool._drop_closed() - assert pool.freesize == 1 - assert pool.size == 1 - - -@pytest.mark.asyncio -async def test_multiple_connection_acquire(create_pool, server): - # see https://bugs.python.org/issue32734 for explanation - - pool = await create_pool(server.tcp_address, minsize=10, maxsize=10) - - with ExitStack() as stack: - fill_free_event = asyncio.Event() - - async def fill_free_se(override_min): - await asyncio.sleep(0) - await fill_free_event.wait() - - mocked_fill_free = stack.enter_context( - patch.object(pool, "_fill_free", side_effect=fill_free_se) - ) - - conn_fut1 = asyncio.ensure_future(pool.acquire()) - conn_fut2 = asyncio.ensure_future(pool.acquire()) - conn_fut3 = asyncio.ensure_future(pool.acquire()) - conn_fut4 = asyncio.ensure_future(pool.acquire()) - - # acquire multiple Condition._lock - await asyncio.sleep(0) - conn_fut1.cancel() - conn_fut2.cancel() - fill_free_event.set() - - assert mocked_fill_free.call_count == 1 - - with pytest.raises(asyncio.CancelledError): - await conn_fut1 - with pytest.raises(asyncio.CancelledError): - await conn_fut2 - - await conn_fut3 - await conn_fut4 - - -@pytest.mark.asyncio -async def test_client_name(create_pool, server): - name = "test" - pool = await create_pool(server.tcp_address, name=name) - - with (await pool) as conn: - res = await conn.execute(b"CLIENT", b"GETNAME") - assert res == bytes(name, "utf-8") - - name = "test2" - await pool.setname(name) - with (await pool) as conn: - res = await conn.execute(b"CLIENT", b"GETNAME") - assert res == bytes(name, "utf-8") diff --git a/tests/pubsub_commands_test.py b/tests/pubsub_commands_test.py deleted file mode 100644 index da66b02da..000000000 --- a/tests/pubsub_commands_test.py +++ /dev/null @@ -1,356 +0,0 @@ -import asyncio - -import pytest - -import aioredis -from tests.testutils import redis_version - - -async def _reader(channel, output, waiter, conn): - await conn.execute("subscribe", channel) - ch = conn.pubsub_channels[channel] - waiter.set_result(conn) - while await ch.wait_message(): - msg = await ch.get() - await output.put(msg) - - -@pytest.mark.asyncio -async def test_publish(create_connection, redis, server, event_loop): - out = asyncio.Queue() - fut = event_loop.create_future() - conn = await create_connection(server.tcp_address) - sub = asyncio.ensure_future(_reader("chan:1", out, fut, conn)) - - await fut - await redis.publish("chan:1", "Hello") - msg = await out.get() - assert msg == b"Hello" - - sub.cancel() - - -@pytest.mark.asyncio -async def test_publish_json(create_connection, redis, server, event_loop): - out = asyncio.Queue() - fut = event_loop.create_future() - conn = await create_connection(server.tcp_address) - sub = asyncio.ensure_future(_reader("chan:1", out, fut, conn)) - - await fut - - res = await redis.publish_json("chan:1", {"Hello": "world"}) - assert res == 1 # receivers - - msg = await out.get() - assert msg == b'{"Hello": "world"}' - sub.cancel() - - -@pytest.mark.asyncio -async def test_subscribe(redis): - res = await redis.subscribe("chan:1", "chan:2") - assert redis.in_pubsub == 2 - - ch1 = redis.channels["chan:1"] - ch2 = redis.channels["chan:2"] - - assert res == [ch1, ch2] - assert ch1.is_pattern is False - assert ch2.is_pattern is False - - res = await redis.unsubscribe("chan:1", "chan:2") - assert res == [[b"unsubscribe", b"chan:1", 1], [b"unsubscribe", b"chan:2", 0]] - - -@pytest.mark.parametrize( - "create_redis", - [ - pytest.param(aioredis.create_redis_pool, id="pool"), - ], -) -@pytest.mark.asyncio -async def test_subscribe_empty_pool(create_redis, server, _closable): - redis = await create_redis(server.tcp_address) - _closable(redis) - await redis.connection.clear() - - res = await redis.subscribe("chan:1", "chan:2") - assert redis.in_pubsub == 2 - - ch1 = redis.channels["chan:1"] - ch2 = redis.channels["chan:2"] - - assert res == [ch1, ch2] - assert ch1.is_pattern is False - assert ch2.is_pattern is False - - res = await redis.unsubscribe("chan:1", "chan:2") - assert res == [[b"unsubscribe", b"chan:1", 1], [b"unsubscribe", b"chan:2", 0]] - - -@pytest.mark.asyncio -async def test_psubscribe(redis, create_redis, server): - sub = redis - res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 - - pat1 = sub.patterns["patt:*"] - pat2 = sub.patterns["chan:*"] - assert res == [pat1, pat2] - - pub = await create_redis(server.tcp_address) - await pub.publish_json("chan:123", {"Hello": "World"}) - res = await pat2.get_json() - assert res == (b"chan:123", {"Hello": "World"}) - - res = await sub.punsubscribe("patt:*", "patt:*", "chan:*") - assert res == [ - [b"punsubscribe", b"patt:*", 1], - [b"punsubscribe", b"patt:*", 1], - [b"punsubscribe", b"chan:*", 0], - ] - - -@pytest.mark.parametrize( - "create_redis", - [ - pytest.param(aioredis.create_redis_pool, id="pool"), - ], -) -@pytest.mark.asyncio -async def test_psubscribe_empty_pool(create_redis, server, _closable): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - _closable(sub) - _closable(pub) - await sub.connection.clear() - res = await sub.psubscribe("patt:*", "chan:*") - assert sub.in_pubsub == 2 - - pat1 = sub.patterns["patt:*"] - pat2 = sub.patterns["chan:*"] - assert res == [pat1, pat2] - - await pub.publish_json("chan:123", {"Hello": "World"}) - res = await pat2.get_json() - assert res == (b"chan:123", {"Hello": "World"}) - - res = await sub.punsubscribe("patt:*", "patt:*", "chan:*") - assert res == [ - [b"punsubscribe", b"patt:*", 1], - [b"punsubscribe", b"patt:*", 1], - [b"punsubscribe", b"chan:*", 0], - ] - - -@redis_version(2, 8, 0, reason="PUBSUB CHANNELS is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_pubsub_channels(create_redis, server): - redis = await create_redis(server.tcp_address) - res = await redis.pubsub_channels() - assert res == [] - - res = await redis.pubsub_channels("chan:*") - assert res == [] - - sub = await create_redis(server.tcp_address) - await sub.subscribe("chan:1") - - res = await redis.pubsub_channels() - assert res == [b"chan:1"] - - res = await redis.pubsub_channels("ch*") - assert res == [b"chan:1"] - - await sub.unsubscribe("chan:1") - await sub.psubscribe("chan:*") - - res = await redis.pubsub_channels() - assert res == [] - - -@redis_version(2, 8, 0, reason="PUBSUB NUMSUB is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_pubsub_numsub(create_redis, server): - redis = await create_redis(server.tcp_address) - res = await redis.pubsub_numsub() - assert res == {} - - res = await redis.pubsub_numsub("chan:1") - assert res == {b"chan:1": 0} - - sub = await create_redis(server.tcp_address) - await sub.subscribe("chan:1") - - res = await redis.pubsub_numsub() - assert res == {} - - res = await redis.pubsub_numsub("chan:1") - assert res == {b"chan:1": 1} - - res = await redis.pubsub_numsub("chan:2") - assert res == {b"chan:2": 0} - - res = await redis.pubsub_numsub("chan:1", "chan:2") - assert res == {b"chan:1": 1, b"chan:2": 0} - - await sub.unsubscribe("chan:1") - await sub.psubscribe("chan:*") - - res = await redis.pubsub_numsub() - assert res == {} - - -@redis_version(2, 8, 0, reason="PUBSUB NUMPAT is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_pubsub_numpat(create_redis, server, redis): - sub = await create_redis(server.tcp_address) - - res = await redis.pubsub_numpat() - assert res == 0 - - await sub.subscribe("chan:1") - res = await redis.pubsub_numpat() - assert res == 0 - - await sub.psubscribe("chan:*") - res = await redis.pubsub_numpat() - assert res == 1 - - -@pytest.mark.asyncio -async def test_close_pubsub_channels(redis): - (ch,) = await redis.subscribe("chan:1") - - async def waiter(ch): - assert not await ch.wait_message() - - tsk = asyncio.ensure_future(waiter(ch)) - redis.close() - await redis.wait_closed() - await tsk - - -@pytest.mark.asyncio -async def test_close_pubsub_patterns(redis): - (ch,) = await redis.psubscribe("chan:*") - - async def waiter(ch): - assert not await ch.wait_message() - - tsk = asyncio.ensure_future(waiter(ch)) - redis.close() - await redis.wait_closed() - await tsk - - -@pytest.mark.asyncio -async def test_close_cancelled_pubsub_channel(redis): - (ch,) = await redis.subscribe("chan:1") - - async def waiter(ch): - with pytest.raises(asyncio.CancelledError): - await ch.wait_message() - - tsk = asyncio.ensure_future(waiter(ch)) - await asyncio.sleep(0) - tsk.cancel() - - -@pytest.mark.asyncio -async def test_channel_get_after_close(create_redis, event_loop, server): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - (ch,) = await sub.subscribe("chan:1") - - await pub.publish("chan:1", "message") - assert await ch.get() == b"message" - event_loop.call_soon(sub.close) - assert await ch.get() is None - with pytest.raises(aioredis.ChannelClosedError): - assert await ch.get() - - -@pytest.mark.asyncio -async def test_subscribe_concurrency(create_redis, server): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - - async def subscribe(*args): - return await sub.subscribe(*args) - - async def publish(*args): - await asyncio.sleep(0) - return await pub.publish(*args) - - res = await asyncio.gather( - subscribe("channel:0"), - publish("channel:0", "Hello"), - subscribe("channel:1"), - ) - (ch1,), subs, (ch2,) = res - - assert ch1.name == b"channel:0" - assert subs == 1 - assert ch2.name == b"channel:1" - - -@redis_version(3, 2, 0, reason="PUBSUB PING is available since redis>=3.2.0") -@pytest.mark.asyncio -async def test_pubsub_ping(redis): - await redis.subscribe("chan:1", "chan:2") - - res = await redis.ping() - assert res == b"PONG" - res = await redis.ping("Hello") - assert res == b"Hello" - res = await redis.ping("Hello", encoding="utf-8") - assert res == "Hello" - - await redis.unsubscribe("chan:1", "chan:2") - - -@pytest.mark.asyncio -async def test_pubsub_channel_iter(create_redis, server): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - - (ch,) = await sub.subscribe("chan:1") - - async def coro(ch): - lst = [] - async for msg in ch.iter(): - lst.append(msg) - return lst - - tsk = asyncio.ensure_future(coro(ch)) - await pub.publish_json("chan:1", {"Hello": "World"}) - await pub.publish_json("chan:1", ["message"]) - await asyncio.sleep(0.1) - ch.close() - assert await tsk == [b'{"Hello": "World"}', b'["message"]'] - - -@redis_version(2, 8, 12, reason="extended `client kill` format required") -@pytest.mark.asyncio -async def test_pubsub_disconnection_notification(create_redis, server): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - - async def coro(ch): - lst = [] - async for msg in ch.iter(): - assert ch.is_active - lst.append(msg) - return lst - - (ch,) = await sub.subscribe("chan:1") - tsk = asyncio.ensure_future(coro(ch)) - assert ch.is_active - await pub.publish_json("chan:1", {"Hello": "World"}) - assert ch.is_active - assert await pub.execute("client", "kill", "type", "pubsub") >= 1 - assert await pub.publish_json("chan:1", ["message"]) == 0 - assert await tsk == [b'{"Hello": "World"}'] - assert not ch.is_active diff --git a/tests/pubsub_receiver_test.py b/tests/pubsub_receiver_test.py deleted file mode 100644 index efb65f584..000000000 --- a/tests/pubsub_receiver_test.py +++ /dev/null @@ -1,374 +0,0 @@ -import asyncio -import json -import logging -import sys -from unittest import mock - -import pytest - -from aioredis import ChannelClosedError -from aioredis.abc import AbcChannel -from aioredis.pubsub import Receiver, _Sender - - -def test_listener_channel(): - mpsc = Receiver() - assert not mpsc.is_active - - ch_a = mpsc.channel("channel:1") - assert isinstance(ch_a, AbcChannel) - assert mpsc.is_active - - ch_b = mpsc.channel("channel:1") - assert ch_a is ch_b - assert ch_a.name == ch_b.name - assert ch_a.is_pattern == ch_b.is_pattern - assert mpsc.is_active - - # remember id; drop refs to objects and create new one; - ch_a.close() - assert not ch_a.is_active - - assert not mpsc.is_active - ch = mpsc.channel("channel:1") - assert ch is not ch_a - - assert dict(mpsc.channels) == {b"channel:1": ch} - assert dict(mpsc.patterns) == {} - - -def test_listener_pattern(): - mpsc = Receiver() - assert not mpsc.is_active - - ch_a = mpsc.pattern("*") - assert isinstance(ch_a, AbcChannel) - assert mpsc.is_active - - ch_b = mpsc.pattern("*") - assert ch_a is ch_b - assert ch_a.name == ch_b.name - assert ch_a.is_pattern == ch_b.is_pattern - assert mpsc.is_active - - # remember id; drop refs to objects and create new one; - ch_a.close() - assert not ch_a.is_active - - assert not mpsc.is_active - ch = mpsc.pattern("*") - assert ch is not ch_a - - assert dict(mpsc.channels) == {} - assert dict(mpsc.patterns) == {b"*": ch} - - -@pytest.mark.asyncio -async def test_sender(): - receiver = mock.Mock() - - sender = _Sender(receiver, "name", is_pattern=False) - assert isinstance(sender, AbcChannel) - assert sender.name == b"name" - assert sender.is_pattern is False - assert sender.is_active is True - - with pytest.raises(RuntimeError): - await sender.get() - assert receiver.mock_calls == [] - - sender.put_nowait(b"some data") - assert receiver.mock_calls == [ - mock.call._put_nowait(b"some data", sender=sender), - ] - - -def test_sender_close(): - receiver = mock.Mock() - sender = _Sender(receiver, "name", is_pattern=False) - sender.close() - assert receiver.mock_calls == [mock.call._close(sender, exc=None)] - sender.close() - assert receiver.mock_calls == [mock.call._close(sender, exc=None)] - receiver.reset_mock() - assert receiver.mock_calls == [] - sender.close() - assert receiver.mock_calls == [] - - -@pytest.mark.asyncio -async def test_subscriptions(create_connection, server): - sub = await create_connection(server.tcp_address) - pub = await create_connection(server.tcp_address) - - mpsc = Receiver() - await sub.execute_pubsub( - "subscribe", mpsc.channel("channel:1"), mpsc.channel("channel:3") - ) - res = await pub.execute("publish", "channel:3", "Hello world") - assert res == 1 - res = await pub.execute("publish", "channel:1", "Hello world") - assert res == 1 - assert mpsc.is_active - - ch, msg = await mpsc.get() - assert ch.name == b"channel:3" - assert not ch.is_pattern - assert msg == b"Hello world" - - ch, msg = await mpsc.get() - assert ch.name == b"channel:1" - assert not ch.is_pattern - assert msg == b"Hello world" - - -@pytest.mark.asyncio -async def test_unsubscribe(create_connection, server): - sub = await create_connection(server.tcp_address) - pub = await create_connection(server.tcp_address) - - mpsc = Receiver() - await sub.execute_pubsub( - "subscribe", mpsc.channel("channel:1"), mpsc.channel("channel:3") - ) - res = await pub.execute("publish", "channel:3", "Hello world") - assert res == 1 - res = await pub.execute("publish", "channel:1", "Hello world") - assert res == 1 - assert mpsc.is_active - - assert (await mpsc.wait_message()) is True - ch, msg = await mpsc.get() - assert ch.name == b"channel:3" - assert not ch.is_pattern - assert msg == b"Hello world" - - assert (await mpsc.wait_message()) is True - ch, msg = await mpsc.get() - assert ch.name == b"channel:1" - assert not ch.is_pattern - assert msg == b"Hello world" - - await sub.execute_pubsub("unsubscribe", "channel:1") - assert mpsc.is_active - - res = await pub.execute("publish", "channel:3", "message") - assert res == 1 - assert (await mpsc.wait_message()) is True - ch, msg = await mpsc.get() - assert ch.name == b"channel:3" - assert not ch.is_pattern - assert msg == b"message" - - waiter = asyncio.ensure_future(mpsc.get()) - await sub.execute_pubsub("unsubscribe", "channel:3") - assert not mpsc.is_active - assert await waiter is None - - -@pytest.mark.asyncio -async def test_stopped(create_connection, server, caplog): - sub = await create_connection(server.tcp_address) - pub = await create_connection(server.tcp_address) - - mpsc = Receiver() - await sub.execute_pubsub("subscribe", mpsc.channel("channel:1")) - assert mpsc.is_active - mpsc.stop() - - caplog.clear() - with caplog.at_level("DEBUG", "aioredis"): - await pub.execute("publish", "channel:1", b"Hello") - await asyncio.sleep(0) - - assert len(caplog.record_tuples) == 1 - # Receiver must have 1 EndOfStream message - message = ( - "Pub/Sub listener message after stop: " - "sender: <_Sender name:b'channel:1', is_pattern:False, receiver:" - ">, data: b'Hello'" - ) - assert caplog.record_tuples == [ - ("aioredis", logging.WARNING, message), - ] - - # assert (await mpsc.get()) is None - with pytest.raises(ChannelClosedError): - await mpsc.get() - res = await mpsc.wait_message() - assert res is False - - -@pytest.mark.asyncio -async def test_wait_message(create_connection, server): - sub = await create_connection(server.tcp_address) - pub = await create_connection(server.tcp_address) - - mpsc = Receiver() - await sub.execute_pubsub("subscribe", mpsc.channel("channel:1")) - fut = asyncio.ensure_future(mpsc.wait_message()) - assert not fut.done() - await asyncio.sleep(0) - assert not fut.done() - - await pub.execute("publish", "channel:1", "hello") - await asyncio.sleep(0) # read in connection - await asyncio.sleep(0) # call Future.set_result - assert fut.done() - res = await fut - assert res is True - - -@pytest.mark.asyncio -async def test_decode_message(): - mpsc = Receiver() - ch = mpsc.channel("channel:1") - ch.put_nowait(b"Some data") - - res = await mpsc.get(encoding="utf-8") - assert isinstance(res[0], _Sender) - assert res[1] == "Some data" - - ch.put_nowait('{"hello": "world"}') - res = await mpsc.get(decoder=json.loads) - assert isinstance(res[0], _Sender) - assert res[1] == {"hello": "world"} - - ch.put_nowait(b'{"hello": "world"}') - res = await mpsc.get(encoding="utf-8", decoder=json.loads) - assert isinstance(res[0], _Sender) - assert res[1] == {"hello": "world"} - - -@pytest.mark.skipif( - sys.version_info >= (3, 6), reason="json.loads accept bytes since Python 3.6" -) -@pytest.mark.asyncio -async def test_decode_message_error(): - mpsc = Receiver() - ch = mpsc.channel("channel:1") - - ch.put_nowait(b'{"hello": "world"}') - unexpected = (mock.ANY, {"hello": "world"}) - with pytest.raises(TypeError): - assert (await mpsc.get(decoder=json.loads)) == unexpected - - ch = mpsc.pattern("*") - ch.put_nowait((b"channel", b'{"hello": "world"}')) - unexpected = (mock.ANY, b"channel", {"hello": "world"}) - with pytest.raises(TypeError): - assert (await mpsc.get(decoder=json.loads)) == unexpected - - -@pytest.mark.asyncio -async def test_decode_message_for_pattern(): - mpsc = Receiver() - ch = mpsc.pattern("*") - ch.put_nowait((b"channel", b"Some data")) - - res = await mpsc.get(encoding="utf-8") - assert isinstance(res[0], _Sender) - assert res[1] == (b"channel", "Some data") - - ch.put_nowait((b"channel", '{"hello": "world"}')) - res = await mpsc.get(decoder=json.loads) - assert isinstance(res[0], _Sender) - assert res[1] == (b"channel", {"hello": "world"}) - - ch.put_nowait((b"channel", b'{"hello": "world"}')) - res = await mpsc.get(encoding="utf-8", decoder=json.loads) - assert isinstance(res[0], _Sender) - assert res[1] == (b"channel", {"hello": "world"}) - - -@pytest.mark.asyncio -async def test_pubsub_receiver_iter(create_redis, server, event_loop): - sub = await create_redis(server.tcp_address) - pub = await create_redis(server.tcp_address) - - mpsc = Receiver() - - async def coro(mpsc): - lst = [] - async for msg in mpsc.iter(): - lst.append(msg) - return lst - - tsk = asyncio.ensure_future(coro(mpsc)) - (snd1,) = await sub.subscribe(mpsc.channel("chan:1")) - (snd2,) = await sub.subscribe(mpsc.channel("chan:2")) - (snd3,) = await sub.psubscribe(mpsc.pattern("chan:*")) - - subscribers = await pub.publish_json("chan:1", {"Hello": "World"}) - assert subscribers > 1 - subscribers = await pub.publish_json("chan:2", ["message"]) - assert subscribers > 1 - event_loop.call_later(0, mpsc.stop) - await asyncio.sleep(0.01) - assert await tsk == [ - (snd1, b'{"Hello": "World"}'), - (snd3, (b"chan:1", b'{"Hello": "World"}')), - (snd2, b'["message"]'), - (snd3, (b"chan:2", b'["message"]')), - ] - assert not mpsc.is_active - - -@pytest.mark.timeout(5) -@pytest.mark.asyncio -async def test_pubsub_receiver_call_stop_with_empty_queue( - create_redis, server, event_loop -): - sub = await create_redis(server.tcp_address) - - mpsc = Receiver() - - # FIXME: currently at least one subscriber is needed - (snd1,) = await sub.subscribe(mpsc.channel("chan:1")) - - now = event_loop.time() - event_loop.call_later(0.5, mpsc.stop) - async for i in mpsc.iter(): # (flake8 bug with async for) - assert False, "StopAsyncIteration not raised" - dt = event_loop.time() - now - assert dt <= 1.5 - assert not mpsc.is_active - - -@pytest.mark.asyncio -async def test_pubsub_receiver_stop_on_disconnect(create_redis, server): - pub = await create_redis(server.tcp_address) - sub = await create_redis(server.tcp_address) - sub_name = "sub-{:X}".format(id(sub)) - await sub.client_setname(sub_name) - for sub_info in await pub.client_list(): - if sub_info.name == sub_name: - break - assert sub_info.name == sub_name - - mpsc = Receiver() - await sub.subscribe(mpsc.channel("channel:1")) - await sub.subscribe(mpsc.channel("channel:2")) - await sub.psubscribe(mpsc.pattern("channel:*")) - - q = asyncio.Queue() - EOF = object() - - async def reader(): - async for ch, msg in mpsc.iter(encoding="utf-8"): - await q.put((ch.name, msg)) - await q.put(EOF) - - tsk = asyncio.ensure_future(reader()) - await pub.publish_json("channel:1", ["hello"]) - await pub.publish_json("channel:2", ["hello"]) - # receive all messages - assert await q.get() == (b"channel:1", '["hello"]') - assert await q.get() == (b"channel:*", (b"channel:1", '["hello"]')) - assert await q.get() == (b"channel:2", '["hello"]') - assert await q.get() == (b"channel:*", (b"channel:2", '["hello"]')) - - # XXX: need to implement `client kill` - assert await pub.execute("client", "kill", sub_info.addr) in (b"OK", 1) - await asyncio.wait_for(tsk, timeout=1) - assert await q.get() is EOF diff --git a/tests/pyreader_test.py b/tests/pyreader_test.py deleted file mode 100644 index bbebee678..000000000 --- a/tests/pyreader_test.py +++ /dev/null @@ -1,285 +0,0 @@ -import pytest - -from aioredis.errors import AuthError, MaxClientsError, ProtocolError, ReplyError -from aioredis.parser import PyReader - - -@pytest.fixture -def reader(): - return PyReader() - - -def test_nothing(reader): - assert reader.gets() is False - - -def test_error_when_feeding_non_string(reader): - with pytest.raises(TypeError): - reader.feed(1) - - -@pytest.mark.parametrize( - "data", - [ - b"x", - b"$5\r\nHello world", - b":None\r\n", - b":1.2\r\n", - b":1,2\r\n", - ], - ids=[ - "Bad control char", - "Invalid bulk length", - "Invalid int - none", - "Invalid int - dot", - "Invalid int - comma", - ], -) -def test_protocol_error(reader, data): - reader.feed(data) - with pytest.raises(ProtocolError): - reader.gets() - # not functional any more - with pytest.raises(ProtocolError): - reader.gets() - - -class CustomExc(Exception): - pass - - -@pytest.mark.parametrize( - "exc,arg", - [ - (RuntimeError, RuntimeError), - (CustomExc, lambda e: CustomExc(e)), - ], - ids=["RuntimeError", "callable"], -) -def test_protocol_error_with_custom_class(exc, arg): - reader = PyReader(protocolError=arg) - reader.feed(b"x") - with pytest.raises(exc): - reader.gets() - - -@pytest.mark.parametrize( - "init", - [ - dict(protocolError="wrong"), - dict(replyError="wrong"), - ], - ids=["wrong protocolError", "wrong replyError"], -) -def test_fail_with_wrong_error_class(init): - with pytest.raises(TypeError): - PyReader(**init) - - -def test_error_string(reader): - reader.feed(b"-error\r\n") - error = reader.gets() - - assert isinstance(error, ReplyError) - assert error.args == ("error",) - - -@pytest.mark.parametrize( - "error_kind,data", - [ - (AuthError, b"-NOAUTH auth required\r\n"), - (AuthError, b"-ERR invalid password\r\n"), - (MaxClientsError, b"-ERR max number of clients reached\r\n"), - ], -) -def test_error_construction(reader, error_kind, data): - reader.feed(data) - error = reader.gets() - assert isinstance(error, ReplyError) - assert isinstance(error, error_kind) - - -@pytest.mark.parametrize( - "exc,arg", - [ - (RuntimeError, RuntimeError), - (CustomExc, lambda e: CustomExc(e)), - ], - ids=["RuntimeError", "callable"], -) -def test_error_string_with_custom_class(exc, arg): - reader = PyReader(replyError=arg) - reader.feed(b"-error\r\n") - error = reader.gets() - - assert isinstance(error, exc) - assert error.args == ("error",) - - -def test_errors_in_nested_multi_bulk(reader): - reader.feed(b"*2\r\n-err0\r\n-err1\r\n") - - for r, error in zip(("err0", "err1"), reader.gets()): - assert isinstance(error, ReplyError) - assert error.args == (r,) - - -def test_integer(reader): - value = 2 ** 63 - 1 # Largest 64-bit signed integer - reader.feed((":%d\r\n" % value).encode("ascii")) - assert reader.gets() == value - - -def test_status_string(reader): - reader.feed(b"+ok\r\n") - assert reader.gets() == b"ok" - - -@pytest.mark.parametrize( - "data,expected", - [ - (b"$0\r\n\r\n", b""), - (b"$-1\r\n", None), - (b"$5\r\nhello\r\n", b"hello"), - ], - ids=["Empty", "null", "hello"], -) -def test_bulk_string(reader, data, expected): - reader.feed(data) - assert reader.gets() == expected - - -def test_bulk_string_without_encoding(reader): - snowman = b"\xe2\x98\x83" - reader.feed(b"$3\r\n" + snowman + b"\r\n") - assert reader.gets() == snowman - - -@pytest.mark.parametrize( - "encoding,expected", - [ - ("utf-8", b"\xe2\x98\x83".decode("utf-8")), - ("utf-32", b"\xe2\x98\x83"), - ], - ids=["utf-8", "utf-32"], -) -def test_bulk_string_with_encoding(encoding, expected): - snowman = b"\xe2\x98\x83" - reader = PyReader(encoding=encoding) - reader.feed(b"$3\r\n" + snowman + b"\r\n") - assert reader.gets() == expected - - -def test_bulk_string_with_invalid_encoding(): - reader = PyReader(encoding="unknown") - reader.feed(b"$5\r\nhello\r\n") - with pytest.raises(LookupError): - reader.gets() - - -def test_bulk_string_wait_buffer(reader): - reader.feed(b"$5\r\nH") - assert not reader.gets() - reader.feed(b"ello") - assert not reader.gets() - reader.feed(b"\r\n") - assert reader.gets() == b"Hello" - - -@pytest.mark.parametrize( - "data,expected", - [ - (b"*-1\r\n", None), - (b"*0\r\n", []), - (b"*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n", [b"hello", b"world"]), - ], - ids=["Null", "Empty list", "hello world"], -) -def test_null_multi_bulk(reader, data, expected): - reader.feed(data) - assert reader.gets() == expected - - -@pytest.mark.parametrize( - "data", - [ - (b"*2\r\n$5\r\nhello\r\n", b":1"), - (b"*2\r\n:1\r\n*1\r\n", b"+hello"), - (b"*2\r\n+hello\r\n+world",), - (b"*2\r\n*1\r\n+hello\r\n*1\r\n+world",), - ], - ids=["First in bulk", "Error in nested", "Multiple errors", "Multiple nested"], -) -def test_multi_bulk_with_invalid_encoding_and_partial_reply(data): - reader = PyReader(encoding="unknown") - for chunk in data: - reader.feed(chunk) - assert reader.gets() is False - reader.feed(b"\r\n") - with pytest.raises(LookupError): - reader.gets() - - reader.feed(b":1\r\n") - assert reader.gets() == 1 - - -def test_nested_multi_bulk(reader): - reader.feed(b"*2\r\n*2\r\n$5\r\nhello\r\n$5\r\nworld\r\n$1\r\n!\r\n") - assert reader.gets() == [[b"hello", b"world"], b"!"] - - -def test_nested_multi_bulk_depth(reader): - reader.feed(b"*1\r\n*1\r\n*1\r\n*1\r\n$1\r\n!\r\n") - assert reader.gets() == [[[[b"!"]]]] - - -@pytest.mark.parametrize( - "encoding,expected", - [ - ("utf-8", b"\xe2\x98\x83".decode("utf-8")), - ("utf-32", b"\xe2\x98\x83"), - ], - ids=["utf-8", "utf-32"], -) -def test_simple_string_with_encoding(encoding, expected): - snowman = b"\xe2\x98\x83" - reader = PyReader(encoding=encoding) - reader.feed(b"+" + snowman + b"\r\n") - assert reader.gets() == expected - - -def test_invalid_offset(reader): - data = b"+ok\r\n" - with pytest.raises(ValueError): - reader.feed(data, 6) - - -def test_invalid_length(reader): - data = b"+ok\r\n" - with pytest.raises(ValueError): - reader.feed(data, 0, 6) - - -def test_ok_offset(reader): - data = b"blah+ok\r\n" - reader.feed(data, 4) - assert reader.gets() == b"ok" - - -def test_ok_length(reader): - data = b"blah+ok\r\n" - reader.feed(data, 4, len(data) - 4) - assert reader.gets() == b"ok" - - -@pytest.mark.xfail() -def test_maxbuf(reader): - defaultmaxbuf = reader.getmaxbuf() - reader.setmaxbuf(0) - assert 0 == reader.getmaxbuf() - reader.setmaxbuf(10000) - assert 10000 == reader.getmaxbuf() - reader.setmaxbuf(None) - assert defaultmaxbuf == reader.getmaxbuf() - with pytest.raises(ValueError): - reader.setmaxbuf(-4) diff --git a/tests/scripting_commands_test.py b/tests/scripting_commands_test.py deleted file mode 100644 index 263ee83ed..000000000 --- a/tests/scripting_commands_test.py +++ /dev/null @@ -1,126 +0,0 @@ -import asyncio - -import pytest - -from aioredis import ReplyError - - -@pytest.mark.asyncio -async def test_eval(redis): - await redis.delete("key:eval", "value:eval") - - script = "return 42" - res = await redis.eval(script) - assert res == 42 - - key, value = b"key:eval", b"value:eval" - script = """ - if redis.call('setnx', KEYS[1], ARGV[1]) == 1 - then - return 'foo' - else - return 'bar' - end - """ - res = await redis.eval(script, keys=[key], args=[value]) - assert res == b"foo" - res = await redis.eval(script, keys=[key], args=[value]) - assert res == b"bar" - - script = "return 42" - with pytest.raises(TypeError): - await redis.eval(script, keys="not:list") - - with pytest.raises(TypeError): - await redis.eval(script, keys=["valid", None]) - with pytest.raises(TypeError): - await redis.eval(script, args=["valid", None]) - with pytest.raises(TypeError): - await redis.eval(None) - - -@pytest.mark.asyncio -async def test_evalsha(redis): - script = b"return 42" - sha_hash = await redis.script_load(script) - assert len(sha_hash) == 40 - res = await redis.evalsha(sha_hash) - assert res == 42 - - key, arg1, arg2 = b"key:evalsha", b"1", b"2" - script = "return {KEYS[1], ARGV[1], ARGV[2]}" - sha_hash = await redis.script_load(script) - res = await redis.evalsha(sha_hash, [key], [arg1, arg2]) - assert res == [key, arg1, arg2] - - with pytest.raises(ReplyError): - await redis.evalsha(b"wrong sha hash") - with pytest.raises(TypeError): - await redis.evalsha(sha_hash, keys=["valid", None]) - with pytest.raises(TypeError): - await redis.evalsha(sha_hash, args=["valid", None]) - with pytest.raises(TypeError): - await redis.evalsha(None) - - -@pytest.mark.asyncio -async def test_script_exists(redis): - sha_hash1 = await redis.script_load(b"return 1") - sha_hash2 = await redis.script_load(b"return 2") - assert len(sha_hash1) == 40 - assert len(sha_hash2) == 40 - - res = await redis.script_exists(sha_hash1, sha_hash1) - assert res == [1, 1] - - no_sha = b"ffffffffffffffffffffffffffffffffffffffff" - res = await redis.script_exists(no_sha) - assert res == [0] - - with pytest.raises(TypeError): - await redis.script_exists(None) - with pytest.raises(TypeError): - await redis.script_exists("123", None) - - -@pytest.mark.asyncio -async def test_script_flush(redis): - sha_hash1 = await redis.script_load(b"return 1") - assert len(sha_hash1) == 40 - res = await redis.script_exists(sha_hash1) - assert res == [1] - res = await redis.script_flush() - assert res is True - res = await redis.script_exists(sha_hash1) - assert res == [0] - - -@pytest.mark.asyncio -async def test_script_load(redis): - sha_hash1 = await redis.script_load(b"return 1") - sha_hash2 = await redis.script_load(b"return 2") - assert len(sha_hash1) == 40 - assert len(sha_hash2) == 40 - res = await redis.script_exists(sha_hash1, sha_hash1) - assert res == [1, 1] - - -@pytest.mark.asyncio -async def test_script_kill(create_redis, server, redis): - script = "while (1) do redis.call('TIME') end" - - other_redis = await create_redis(server.tcp_address) - - ok = await redis.set("key1", "value") - assert ok is True - - fut = other_redis.eval(script, keys=["non-existent-key"], args=[10]) - await asyncio.sleep(0.1) - resp = await redis.script_kill() - assert resp is True - - with pytest.raises(ReplyError): - await fut - - with pytest.raises(ReplyError): - await redis.script_kill() diff --git a/tests/sentinel_commands_test.py b/tests/sentinel_commands_test.py deleted file mode 100644 index 8ad5e0bf6..000000000 --- a/tests/sentinel_commands_test.py +++ /dev/null @@ -1,305 +0,0 @@ -import asyncio -import logging -import sys - -import pytest - -from aioredis import PoolClosedError, RedisError, ReplyError -from aioredis.abc import AbcPool -from aioredis.errors import MasterReplyError -from aioredis.sentinel.commands import RedisSentinel -from tests.testutils import redis_version - -pytestmark = redis_version(2, 8, 12, reason="Sentinel v2 required") -if sys.platform == "win32": - pytestmark = pytest.mark.skip(reason="unstable on windows") - -BPO_30399 = sys.version_info >= (3, 7, 0, "alpha", 3) - - -@pytest.mark.asyncio -async def test_client_close(redis_sentinel): - assert isinstance(redis_sentinel, RedisSentinel) - assert not redis_sentinel.closed - - redis_sentinel.close() - assert redis_sentinel.closed - with pytest.raises(PoolClosedError): - assert (await redis_sentinel.ping()) != b"PONG" - - await redis_sentinel.wait_closed() - - -@pytest.mark.asyncio -async def test_ping(redis_sentinel): - assert b"PONG" == (await redis_sentinel.ping()) - - -@pytest.mark.asyncio -async def test_master_info(redis_sentinel, sentinel): - info = await redis_sentinel.master("main-no-fail") - assert isinstance(info, dict) - assert info["name"] == "main-no-fail" - assert "slave" not in info["flags"] - assert "s_down" not in info["flags"] - assert "o_down" not in info["flags"] - assert "sentinel" not in info["flags"] - assert "disconnected" not in info["flags"] - assert "master" in info["flags"] - - for key in [ - "num-other-sentinels", - "flags", - "quorum", - "ip", - "failover-timeout", - "runid", - "info-refresh", - "config-epoch", - "parallel-syncs", - "role-reported-time", - "last-ok-ping-reply", - "last-ping-reply", - "last-ping-sent", - "name", - "down-after-milliseconds", - "num-slaves", - "port", - "role-reported", - ]: - assert key in info - if sentinel.version < (3, 2, 0): - assert "pending-commands" in info - else: - assert "link-pending-commands" in info - assert "link-refcount" in info - - -@pytest.mark.asyncio -async def test_master__auth(create_sentinel, start_sentinel, start_server): - main = start_server("main_1", password="123") - start_server("replica_1", slaveof=main, password="123") - - sentinel = start_sentinel("auth_sentinel_1", main) - client1 = await create_sentinel([sentinel.tcp_address], password="123", timeout=1) - - client2 = await create_sentinel([sentinel.tcp_address], password="111", timeout=1) - - client3 = await create_sentinel([sentinel.tcp_address], timeout=1) - - m1 = client1.master_for(main.name) - await m1.set("mykey", "myval") - - with pytest.raises(MasterReplyError) as exc_info: - m2 = client2.master_for(main.name) - await m2.set("mykey", "myval") - if BPO_30399: - expected = "('Service main_1 error', AuthError('ERR invalid password'))" - else: - expected = "('Service main_1 error', AuthError('ERR invalid password',))" - assert str(exc_info.value) == expected - - with pytest.raises(MasterReplyError): - m3 = client3.master_for(main.name) - await m3.set("mykey", "myval") - - -@pytest.mark.asyncio -async def test_master__no_auth(create_sentinel, sentinel): - client = await create_sentinel([sentinel.tcp_address], password="123", timeout=1) - - main = client.master_for("mainA") - with pytest.raises(MasterReplyError): - await main.set("mykey", "myval") - - -@pytest.mark.asyncio -async def test_master__unknown(redis_sentinel): - with pytest.raises(ReplyError): - await redis_sentinel.master("unknown-main") - - -@pytest.mark.asyncio -async def test_master_address(redis_sentinel, sentinel): - _, port = await redis_sentinel.master_address("main-no-fail") - assert port == sentinel.masters["main-no-fail"].tcp_address.port - - -@pytest.mark.asyncio -async def test_master_address__unknown(redis_sentinel): - res = await redis_sentinel.master_address("unknown-main") - assert res is None - - -@pytest.mark.asyncio -async def test_masters(redis_sentinel): - masters = await redis_sentinel.masters() - assert isinstance(masters, dict) - assert len(masters) >= 1, "At least on masters expected" - assert "main-no-fail" in masters - assert isinstance(masters["main-no-fail"], dict) - - -@pytest.mark.asyncio -async def test_slave_info(sentinel, redis_sentinel): - info = await redis_sentinel.slaves("main-no-fail") - assert len(info) == 1 - info = info[0] - assert isinstance(info, dict) - assert "master" not in info["flags"] - assert "s_down" not in info["flags"] - assert "o_down" not in info["flags"] - assert "sentinel" not in info["flags"] - # assert 'disconnected' not in info['flags'] - assert "slave" in info["flags"] - - keys_set = { - "flags", - "master-host", - "master-link-down-time", - "master-link-status", - "master-port", - "name", - "slave-priority", - "ip", - "runid", - "info-refresh", - "role-reported-time", - "last-ok-ping-reply", - "last-ping-reply", - "last-ping-sent", - "down-after-milliseconds", - "port", - "role-reported", - } - if sentinel.version < (3, 2, 0): - keys_set.add("pending-commands") - else: - keys_set.add("link-pending-commands") - keys_set.add("link-refcount") - - missing = keys_set - set(info) - assert not missing - - -@pytest.mark.asyncio -async def test_slave__unknown(redis_sentinel): - with pytest.raises(ReplyError): - await redis_sentinel.slaves("unknown-main") - - -@pytest.mark.asyncio -async def test_sentinels_empty(redis_sentinel): - res = await redis_sentinel.sentinels("main-no-fail") - assert res == [] - - with pytest.raises(ReplyError): - await redis_sentinel.sentinels("unknown-main") - - -@pytest.mark.timeout(30) -@pytest.mark.asyncio -async def test_sentinels__exist(create_sentinel, start_sentinel, start_server): - m1 = start_server("main-two-sentinels") - s1 = start_sentinel("peer-sentinel-1", m1, quorum=2, noslaves=True) - s2 = start_sentinel("peer-sentinel-2", m1, quorum=2, noslaves=True) - - redis_sentinel = await create_sentinel([s1.tcp_address, s2.tcp_address], timeout=1) - - while True: - info = await redis_sentinel.master("main-two-sentinels") - if info["num-other-sentinels"] > 0: - break - await asyncio.sleep(0.2) - info = await redis_sentinel.sentinels("main-two-sentinels") - assert len(info) == 1 - assert "sentinel" in info[0]["flags"] - assert info[0]["port"] in (s1.tcp_address.port, s2.tcp_address.port) - - -@pytest.mark.asyncio -async def test_ckquorum(redis_sentinel): - assert await redis_sentinel.check_quorum("main-no-fail") - - # change quorum - - assert await redis_sentinel.set("main-no-fail", "quorum", 2) - - with pytest.raises(RedisError): - await redis_sentinel.check_quorum("main-no-fail") - - assert await redis_sentinel.set("main-no-fail", "quorum", 1) - assert await redis_sentinel.check_quorum("main-no-fail") - - -@pytest.mark.asyncio -async def test_set_option(redis_sentinel): - assert await redis_sentinel.set("main-no-fail", "quorum", 10) - main = await redis_sentinel.master("main-no-fail") - assert main["quorum"] == 10 - - assert await redis_sentinel.set("main-no-fail", "quorum", 1) - main = await redis_sentinel.master("main-no-fail") - assert main["quorum"] == 1 - - with pytest.raises(ReplyError): - await redis_sentinel.set("mainA", "foo", "bar") - - -@pytest.mark.asyncio -async def test_sentinel_role(sentinel, create_redis): - redis = await create_redis(sentinel.tcp_address) - info = await redis.role() - assert info.role == "sentinel" - assert isinstance(info.masters, list) - assert "main-no-fail" in info.masters - - -@pytest.mark.timeout(30) -@pytest.mark.asyncio -async def test_remove(redis_sentinel, start_server): - m1 = start_server("main-to-remove") - ok = await redis_sentinel.monitor(m1.name, "127.0.0.1", m1.tcp_address.port, 1) - assert ok - - ok = await redis_sentinel.remove(m1.name) - assert ok - - with pytest.raises(ReplyError): - await redis_sentinel.remove("unknown-main") - - -@pytest.mark.timeout(30) -@pytest.mark.asyncio -async def test_monitor(redis_sentinel, start_server, unused_tcp_port): - m1 = start_server("main-to-monitor") - ok = await redis_sentinel.monitor(m1.name, "127.0.0.1", m1.tcp_address.port, 1) - assert ok - - _, port = await redis_sentinel.master_address("main-to-monitor") - assert port == m1.tcp_address.port - - -@pytest.mark.timeout(5) -@pytest.mark.asyncio -async def test_sentinel_master_pool_size(sentinel, create_sentinel, caplog): - redis_s = await create_sentinel( - [sentinel.tcp_address], timeout=1, minsize=10, maxsize=10 - ) - main = redis_s.master_for("main-no-fail") - assert isinstance(main.connection, AbcPool) - assert main.connection.size == 0 - - caplog.clear() - with caplog.at_level("DEBUG", "aioredis.sentinel"): - assert await main.ping() - assert len(caplog.record_tuples) == 1 - assert caplog.record_tuples == [ - ( - "aioredis.sentinel", - logging.DEBUG, - f"Discoverred new address {main.address} for main-no-fail", - ), - ] - assert main.connection.size == 10 - assert main.connection.freesize == 10 diff --git a/tests/sentinel_failover_test.py b/tests/sentinel_failover_test.py deleted file mode 100644 index a6f86e5d7..000000000 --- a/tests/sentinel_failover_test.py +++ /dev/null @@ -1,217 +0,0 @@ -import asyncio -import sys - -import pytest - -from aioredis import ReadOnlyError, SlaveNotFoundError -from tests.testutils import redis_version - -pytestmark = redis_version(2, 8, 12, reason="Sentinel v2 required") -if sys.platform == "win32": - pytestmark = pytest.mark.skip(reason="unstable on windows") - - -@pytest.mark.timeout(40) -@pytest.mark.asyncio -async def test_auto_failover( - start_sentinel, start_server, create_sentinel, create_connection -): - server1 = start_server("master-failover", ["slave-read-only yes"]) - start_server("slave-failover1", ["slave-read-only yes"], slaveof=server1) - start_server("slave-failover2", ["slave-read-only yes"], slaveof=server1) - - sentinel1 = start_sentinel( - "sentinel-failover1", - server1, - quorum=2, - down_after_milliseconds=300, - failover_timeout=1000, - ) - sentinel2 = start_sentinel( - "sentinel-failover2", - server1, - quorum=2, - down_after_milliseconds=300, - failover_timeout=1000, - ) - # Wait a bit for sentinels to sync - await asyncio.sleep(3) - - sp = await create_sentinel( - [sentinel1.tcp_address, sentinel2.tcp_address], timeout=1 - ) - - _, old_port = await sp.master_address(server1.name) - # ignoring host - assert old_port == server1.tcp_address.port - master = sp.master_for(server1.name) - res = await master.role() - assert res.role == "master" - assert master.address is not None - assert master.address[1] == old_port - - # wait failover - conn = await create_connection(server1.tcp_address) - await conn.execute("debug", "sleep", 5) - - # _, new_port = await sp.master_address(server1.name) - # assert new_port != old_port - # assert new_port == server2.tcp_address.port - assert await master.set("key", "val") - assert master.address is not None - assert master.address[1] != old_port - - -@pytest.mark.asyncio -async def test_sentinel_normal(sentinel, create_sentinel): - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=1) - redis = redis_sentinel.master_for("mainA") - - info = await redis.role() - assert info.role == "master" - - key, field, value = b"key:hset", b"bar", b"zap" - exists = await redis.hexists(key, field) - if exists: - ret = await redis.hdel(key, field) - assert ret != 1 - - ret = await redis.hset(key, field, value) - assert ret == 1 - ret = await redis.hset(key, field, value) - assert ret == 0 - - -@pytest.mark.xfail(reason="same sentinel; single master;") -@pytest.mark.asyncio -async def test_sentinel_slave(sentinel, create_sentinel): - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=1) - redis = redis_sentinel.slave_for("mainA") - - info = await redis.role() - assert info.role == "slave" - - key, field, value = b"key:hset", b"bar", b"zap" - # redis = await get_slave_connection() - exists = await redis.hexists(key, field) - if exists: - with pytest.raises(ReadOnlyError): - await redis.hdel(key, field) - - with pytest.raises(ReadOnlyError): - await redis.hset(key, field, value) - - -@pytest.mark.xfail(reason="Need proper sentinel configuration") -@pytest.mark.asyncio -async def test_sentinel_slave_fail(sentinel, create_sentinel): - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=1) - - key, field, value = b"key:hset", b"bar", b"zap" - - redis = redis_sentinel.slave_for("mainA") - exists = await redis.hexists(key, field) - if exists: - with pytest.raises(ReadOnlyError): - await redis.hdel(key, field) - - with pytest.raises(ReadOnlyError): - await redis.hset(key, field, value) - - ret = await redis_sentinel.failover("mainA") - assert ret is True - await asyncio.sleep(2) - - with pytest.raises(ReadOnlyError): - await redis.hset(key, field, value) - - ret = await redis_sentinel.failover("mainA") - assert ret is True - await asyncio.sleep(2) - while True: - try: - await asyncio.sleep(1) - await redis.hset(key, field, value) - except SlaveNotFoundError: - continue - except ReadOnlyError: - break - - -@pytest.mark.xfail(reason="Need proper sentinel configuration") -@pytest.mark.asyncio -async def test_sentinel_normal_fail(sentinel, create_sentinel): - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=1) - - key, field, value = b"key:hset", b"bar", b"zap" - redis = redis_sentinel.master_for("mainA") - exists = await redis.hexists(key, field) - if exists: - ret = await redis.hdel(key, field) - assert ret == 1 - - ret = await redis.hset(key, field, value) - assert ret == 1 - ret = await redis_sentinel.failover("mainA") - assert ret is True - await asyncio.sleep(2) - ret = await redis.hset(key, field, value) - assert ret == 0 - ret = await redis_sentinel.failover("mainA") - assert ret is True - await asyncio.sleep(2) - redis = redis_sentinel.slave_for("mainA") - while True: - try: - await redis.hset(key, field, value) - await asyncio.sleep(1) - # redis = await get_slave_connection() - except ReadOnlyError: - break - - -@pytest.mark.timeout(30) -@pytest.mark.asyncio -async def test_failover_command(start_server, start_sentinel, create_sentinel): - server = start_server("master-failover-cmd", ["slave-read-only yes"]) - start_server("slave-failover-cmd", ["slave-read-only yes"], slaveof=server) - - sentinel = start_sentinel( - "sentinel-failover-cmd", - server, - quorum=1, - down_after_milliseconds=300, - failover_timeout=1000, - ) - - name = "master-failover-cmd" - redis_sentinel = await create_sentinel([sentinel.tcp_address], timeout=1) - # Wait a bit for sentinels to sync - await asyncio.sleep(3) - - orig_master = await redis_sentinel.master_address(name) - assert await redis_sentinel.failover(name) is True - await asyncio.sleep(2) - - new_master = await redis_sentinel.master_address(name) - assert orig_master != new_master - - ret = await redis_sentinel.failover(name) - assert ret is True - await asyncio.sleep(2) - - new_master = await redis_sentinel.master_address(name) - assert orig_master == new_master - - # This part takes almost 10 seconds (waiting for '+convert-to-slave'). - # Disabled for time being. - - # redis = redis_sentinel.slave_for(name) - # while True: - # try: - # await asyncio.sleep(.2) - # await redis.set('foo', 'bar') - # except SlaveNotFoundError: - # pass - # except ReadOnlyError: - # break diff --git a/tests/server_commands_test.py b/tests/server_commands_test.py deleted file mode 100644 index 06e46c4cb..000000000 --- a/tests/server_commands_test.py +++ /dev/null @@ -1,314 +0,0 @@ -import sys -import time -from unittest import mock - -import pytest - -from aioredis import ReplyError -from tests.testutils import redis_version - - -@pytest.mark.asyncio -async def test_client_list(redis, server, request): - name = request.node.callspec.id - assert await redis.client_setname(name) - res = await redis.client_list() - assert isinstance(res, list) - res = [dict(i._asdict()) for i in res] - expected = { - "addr": mock.ANY, - "fd": mock.ANY, - "age": mock.ANY, - "idle": mock.ANY, - "flags": "N", - "db": "0", - "sub": "0", - "psub": "0", - "multi": "-1", - "qbuf": "0", - "qbuf_free": mock.ANY, - "obl": "0", - "oll": "0", - "omem": "0", - "events": "r", - "cmd": "client", - "name": name, - } - if server.version >= (2, 8, 12): - expected["id"] = mock.ANY - if server.version >= (5,): - expected["qbuf"] = "26" - assert expected in res - - -@pytest.mark.skipif(sys.platform == "win32", reason="No unixsocket on Windows") -@pytest.mark.asyncio -async def test_client_list__unixsocket(create_redis, server, request): - redis = await create_redis(server.unixsocket) - name = request.node.callspec.id - assert await redis.client_setname(name) - res = await redis.client_list() - info = [dict(i._asdict()) for i in res] - expected = { - "addr": f"{server.unixsocket}:0", - "fd": mock.ANY, - "age": mock.ANY, - "idle": mock.ANY, - "flags": "U", # Conneted via unix socket - "db": "0", - "sub": "0", - "psub": "0", - "multi": "-1", - "qbuf": "0", - "qbuf_free": mock.ANY, - "obl": "0", - "oll": "0", - "omem": "0", - "events": "r", - "cmd": "client", - "name": name, - } - if server.version >= (2, 8, 12): - expected["id"] = mock.ANY - if server.version >= (5,): - expected["qbuf"] = "26" - assert expected in info - - -@redis_version(2, 9, 50, reason="CLIENT PAUSE is available since redis >= 2.9.50") -@pytest.mark.asyncio -async def test_client_pause(redis): - tr = redis.pipeline() - tr.time() - tr.client_pause(100) - tr.time() - t1, ok, t2 = await tr.execute() - assert ok - assert t2 - t1 >= 0.1 - - with pytest.raises(TypeError): - await redis.client_pause(2.0) - with pytest.raises(ValueError): - await redis.client_pause(-1) - - -@pytest.mark.asyncio -async def test_client_getname(redis): - res = await redis.client_getname() - assert res is None - ok = await redis.client_setname("TestClient") - assert ok is True - - res = await redis.client_getname() - assert res == b"TestClient" - res = await redis.client_getname(encoding="utf-8") - assert res == "TestClient" - - -@redis_version(2, 8, 13, reason="available since Redis 2.8.13") -@pytest.mark.asyncio -async def test_command(redis): - res = await redis.command() - assert isinstance(res, list) - assert len(res) > 0 - - -@redis_version(2, 8, 13, reason="available since Redis 2.8.13") -@pytest.mark.asyncio -async def test_command_count(redis): - res = await redis.command_count() - assert res > 0 - - -@redis_version(3, 0, 0, reason="available since Redis 3.0.0") -@pytest.mark.asyncio -async def test_command_getkeys(redis): - res = await redis.command_getkeys("get", "key") - assert res == ["key"] - res = await redis.command_getkeys("get", "key", encoding=None) - assert res == [b"key"] - res = await redis.command_getkeys("mset", "k1", "v1", "k2", "v2") - assert res == ["k1", "k2"] - res = await redis.command_getkeys("mset", "k1", "v1", "k2") - assert res == ["k1", "k2"] - - with pytest.raises(ReplyError): - assert await redis.command_getkeys("get") - with pytest.raises(TypeError): - assert not (await redis.command_getkeys(None)) - - -@redis_version(2, 8, 13, reason="available since Redis 2.8.13") -@pytest.mark.asyncio -async def test_command_info(redis): - res = await redis.command_info("get") - assert res == [ - ["get", 2, ["readonly", "fast"], 1, 1, 1], - ] - - res = await redis.command_info("unknown-command") - assert res == [None] - res = await redis.command_info("unknown-command", "unknown-commnad") - assert res == [None, None] - - -@pytest.mark.asyncio -async def test_config_get(redis, server): - res = await redis.config_get("port") - assert res == {"port": str(server.tcp_address.port)} - - res = await redis.config_get() - assert len(res) > 0 - - res = await redis.config_get("unknown_parameter") - assert res == {} - - with pytest.raises(TypeError): - await redis.config_get(b"port") - - -@pytest.mark.asyncio -async def test_config_rewrite(redis): - with pytest.raises(ReplyError): - await redis.config_rewrite() - - -@pytest.mark.asyncio -async def test_config_set(redis): - cur_value = await redis.config_get("slave-read-only") - res = await redis.config_set("slave-read-only", "no") - assert res is True - res = await redis.config_set("slave-read-only", cur_value["slave-read-only"]) - assert res is True - - with pytest.raises(ReplyError, match="Unsupported CONFIG parameter"): - await redis.config_set("databases", 100) - with pytest.raises(TypeError): - await redis.config_set(100, "databases") - - -# @pytest.mark.skip("Not implemented") -# def test_config_resetstat(): -# pass - - -@pytest.mark.asyncio -async def test_debug_object(redis): - with pytest.raises(ReplyError): - assert (await redis.debug_object("key")) is None - - ok = await redis.set("key", "value") - assert ok - res = await redis.debug_object("key") - assert res is not None - - -@pytest.mark.asyncio -async def test_debug_sleep(redis): - t1 = await redis.time() - ok = await redis.debug_sleep(0.2) - assert ok - t2 = await redis.time() - assert t2 - t1 >= 0.2 - - -@pytest.mark.asyncio -async def test_dbsize(redis): - res = await redis.dbsize() - assert res == 0 - - await redis.set("key", "value") - - res = await redis.dbsize() - assert res > 0 - - await redis.flushdb() - res = await redis.dbsize() - assert res == 0 - await redis.set("key", "value") - res = await redis.dbsize() - assert res == 1 - - -@pytest.mark.asyncio -async def test_info(redis): - res = await redis.info() - assert isinstance(res, dict) - - res = await redis.info("all") - assert isinstance(res, dict) - - with pytest.raises(ValueError): - await redis.info("") - - -@pytest.mark.asyncio -async def test_lastsave(redis): - res = await redis.lastsave() - assert res > 0 - - -@redis_version(2, 8, 12, reason="ROLE is available since redis>=2.8.12") -@pytest.mark.asyncio -async def test_role(redis): - res = await redis.role() - assert dict(res._asdict()) == { - "role": "master", - "replication_offset": mock.ANY, - "slaves": [], - } - - -@pytest.mark.asyncio -async def test_save(redis): - res = await redis.dbsize() - assert res == 0 - t1 = await redis.lastsave() - ok = await redis.save() - assert ok - t2 = await redis.lastsave() - assert t2 >= t1 - - -@pytest.mark.parametrize( - "encoding", - [ - pytest.param(None, id="no decoding"), - pytest.param("utf-8", id="with decoding"), - ], -) -@pytest.mark.asyncio -async def test_time(create_redis, server, encoding): - redis = await create_redis(server.tcp_address, encoding="utf-8") - now = time.time() - res = await redis.time() - assert isinstance(res, float) - assert res == pytest.approx(now, abs=10) - - -@pytest.mark.asyncio -async def test_slowlog_len(redis): - res = await redis.slowlog_len() - assert res >= 0 - - -@pytest.mark.asyncio -async def test_slowlog_get(redis): - res = await redis.slowlog_get() - assert isinstance(res, list) - assert len(res) >= 0 - - res = await redis.slowlog_get(2) - assert isinstance(res, list) - assert 0 <= len(res) <= 2 - - with pytest.raises(TypeError): - assert not (await redis.slowlog_get(1.2)) - with pytest.raises(TypeError): - assert not (await redis.slowlog_get("1")) - - -@pytest.mark.asyncio -async def test_slowlog_reset(redis): - ok = await redis.slowlog_reset() - assert ok is True diff --git a/tests/set_commands_test.py b/tests/set_commands_test.py deleted file mode 100644 index f95059296..000000000 --- a/tests/set_commands_test.py +++ /dev/null @@ -1,517 +0,0 @@ -import pytest - -from aioredis import ReplyError -from tests.testutils import redis_version - - -async def add(redis, key, members): - ok = await redis.connection.execute(b"sadd", key, members) - assert ok == 1 - - -@pytest.mark.asyncio -async def test_sadd(redis): - key, member = b"key:sadd", b"hello" - # add member to the set, expected result: 1 - test_result = await redis.sadd(key, member) - assert test_result == 1 - - # add other value, expected result: 1 - test_result = await redis.sadd(key, b"world") - assert test_result == 1 - - # add existing member to the set, expected result: 0 - test_result = await redis.sadd(key, member) - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.sadd(None, 10) - - -@pytest.mark.asyncio -async def test_scard(redis): - key, member = b"key:scard", b"hello" - - # check that our set is empty one - empty_size = await redis.scard(key) - assert empty_size == 0 - - # add more members to the set and check, set size on every step - for i in range(1, 11): - incr = str(i).encode("utf-8") - await add(redis, key, member + incr) - current_size = await redis.scard(key) - assert current_size == i - - with pytest.raises(TypeError): - await redis.scard(None) - - -@pytest.mark.asyncio -async def test_sdiff(redis): - key1 = b"key:sdiff:1" - key2 = b"key:sdiff:2" - key3 = b"key:sdiff:3" - - members1 = (b"a", b"b", b"c", b"d") - members2 = (b"c",) - members3 = (b"a", b"c", b"e") - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - await redis.sadd(key3, *members3) - - # test multiple keys - test_result = await redis.sdiff(key1, key2, key3) - assert set(test_result) == {b"b", b"d"} - - # test single key - test_result = await redis.sdiff(key2) - assert set(test_result) == {b"c"} - - with pytest.raises(TypeError): - await redis.sdiff(None) - with pytest.raises(TypeError): - await redis.sdiff(key1, None) - - -@pytest.mark.asyncio -async def test_sdiffstore(redis): - key1 = b"key:sdiffstore:1" - key2 = b"key:sdiffstore:2" - destkey = b"key:sdiffstore:destkey" - members1 = (b"a", b"b", b"c") - members2 = (b"c", b"d", b"e") - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - - # test basic use case, expected: since diff contains only two members - test_result = await redis.sdiffstore(destkey, key1, key2) - assert test_result == 2 - - # make sure that destkey contains 2 members - test_result = await redis.scard(destkey) - assert test_result == 2 - - # try sdiffstore in case none of sets exists - test_result = await redis.sdiffstore( - b"not:" + destkey, b"not:" + key1, b"not:" + key2 - ) - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.sdiffstore(None, key1) - with pytest.raises(TypeError): - await redis.sdiffstore(destkey, None) - with pytest.raises(TypeError): - await redis.sdiffstore(destkey, key1, None) - - -@pytest.mark.asyncio -async def test_sinter(redis): - key1 = b"key:sinter:1" - key2 = b"key:sinter:2" - key3 = b"key:sinter:3" - - members1 = (b"a", b"b", b"c", b"d") - members2 = (b"c",) - members3 = (b"a", b"c", b"e") - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - await redis.sadd(key3, *members3) - - # test multiple keys - test_result = await redis.sinter(key1, key2, key3) - assert set(test_result) == {b"c"} - - # test single key - test_result = await redis.sinter(key2) - assert set(test_result) == {b"c"} - - with pytest.raises(TypeError): - await redis.sinter(None) - with pytest.raises(TypeError): - await redis.sinter(key1, None) - - -@pytest.mark.asyncio -async def test_sinterstore(redis): - key1 = b"key:sinterstore:1" - key2 = b"key:sinterstore:2" - destkey = b"key:sinterstore:destkey" - members1 = (b"a", b"b", b"c") - members2 = (b"c", b"d", b"e") - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - - # test basic use case, expected: since inter contains only one member - test_result = await redis.sinterstore(destkey, key1, key2) - assert test_result == 1 - - # make sure that destkey contains only one member - test_result = await redis.scard(destkey) - assert test_result == 1 - - # try sinterstore in case none of sets exists - test_result = await redis.sinterstore( - b"not:" + destkey, b"not:" + key1, b"not:" + key2 - ) - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.sinterstore(None, key1) - with pytest.raises(TypeError): - await redis.sinterstore(destkey, None) - with pytest.raises(TypeError): - await redis.sinterstore(destkey, key1, None) - - -@pytest.mark.asyncio -async def test_sismember(redis): - key, member = b"key:sismember", b"hello" - # add member to the set, expected result: 1 - test_result = await redis.sadd(key, member) - assert test_result == 1 - - # test that value in set - test_result = await redis.sismember(key, member) - assert test_result == 1 - # test that value not in set - test_result = await redis.sismember(key, b"world") - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.sismember(None, b"world") - - -@pytest.mark.asyncio -async def test_smembers(redis): - key = b"key:smembers" - member1 = b"hello" - member2 = b"world" - - await redis.sadd(key, member1) - await redis.sadd(key, member2) - - # test not empty set - test_result = await redis.smembers(key) - assert set(test_result) == {member1, member2} - - # test empty set - test_result = await redis.smembers(b"not:" + key) - assert test_result == [] - - # test encoding param - test_result = await redis.smembers(key, encoding="utf-8") - assert set(test_result) == {"hello", "world"} - - with pytest.raises(TypeError): - await redis.smembers(None) - - -@pytest.mark.asyncio -async def test_smove(redis): - key1 = b"key:smove:1" - key2 = b"key:smove:2" - member1 = b"one" - member2 = b"two" - member3 = b"three" - await redis.sadd(key1, member1, member2) - await redis.sadd(key2, member3) - # move member2 to second set - test_result = await redis.smove(key1, key2, member2) - assert test_result == 1 - # check first set, member should be removed - test_result = await redis.smembers(key1) - assert test_result == [member1] - # check second set, member should be added - test_result = await redis.smembers(key2) - assert set(test_result) == {member2, member3} - - # move to empty set - test_result = await redis.smove(key1, b"not:" + key2, member1) - assert test_result == 1 - - # move from empty set (set with under key1 is empty now - test_result = await redis.smove(key1, b"not:" + key2, member1) - assert test_result == 0 - - # move from set that does not exists to set tha does not exists too - test_result = await redis.smove(b"not:" + key1, b"other:not:" + key2, member1) - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.smove(None, key1, member1) - with pytest.raises(TypeError): - await redis.smove(key1, None, member1) - - -@pytest.mark.asyncio -async def test_spop(redis): - key = b"key:spop:1" - members = b"one", b"two", b"three" - await redis.sadd(key, *members) - - for _ in members: - test_result = await redis.spop(key) - assert test_result in members - - # test with encoding - members = "four", "five", "six" - await redis.sadd(key, *members) - - for _ in members: - test_result = await redis.spop(key, encoding="utf-8") - assert test_result in members - - # make sure set is empty, after all values poped - test_result = await redis.smembers(key) - assert test_result == [] - - # try to pop data from empty set - test_result = await redis.spop(b"not:" + key) - assert test_result is None - - with pytest.raises(TypeError): - await redis.spop(None) - - -@redis_version( - 3, 2, 0, reason="The count argument in SPOP is available since redis>=3.2.0" -) -@pytest.mark.asyncio -async def test_spop_count(redis): - key = b"key:spop:1" - members1 = b"one", b"two", b"three" - await redis.sadd(key, *members1) - - # fetch 3 random members - test_result1 = await redis.spop(key, 3) - assert len(test_result1) == 3 - assert set(test_result1).issubset(members1) is True - - members2 = "four", "five", "six" - await redis.sadd(key, *members2) - - # test with encoding, fetch 3 random members - test_result2 = await redis.spop(key, 3, encoding="utf-8") - assert len(test_result2) == 3 - assert set(test_result2).issubset(members2) is True - - # try to pop data from empty set - test_result = await redis.spop(b"not:" + key, 2) - assert len(test_result) == 0 - - # test with negative counter - with pytest.raises(ReplyError): - await redis.spop(key, -2) - - # test with counter is zero - test_result3 = await redis.spop(key, 0) - assert len(test_result3) == 0 - - -@pytest.mark.asyncio -async def test_srandmember(redis): - key = b"key:srandmember:1" - members = b"one", b"two", b"three", b"four", b"five", b"six", b"seven" - await redis.sadd(key, *members) - - for _ in members: - test_result = await redis.srandmember(key) - assert test_result in members - - # test with encoding - test_result = await redis.srandmember(key, encoding="utf-8") - strings = {"one", "two", "three", "four", "five", "six", "seven"} - assert test_result in strings - - # make sure set contains all values, and nothing missing - test_result = await redis.smembers(key) - assert set(test_result) == set(members) - - # fetch 4 elements for the first time, as result 4 distinct values - test_result1 = await redis.srandmember(key, 4) - assert len(test_result1) == 4 - assert set(test_result1).issubset(members) is True - - # test negative count, same element may be returned multiple times - test_result2 = await redis.srandmember(key, -10) - assert len(test_result2) == 10 - assert set(test_result2).issubset(members) is True - assert len(set(test_result2)) <= len(members) - - # pull member from empty set - test_result = await redis.srandmember(b"not" + key) - assert test_result is None - - with pytest.raises(TypeError): - await redis.srandmember(None) - - -@pytest.mark.asyncio -async def test_srem(redis): - key = b"key:srem:1" - members = b"one", b"two", b"three", b"four", b"five", b"six", b"seven" - await redis.sadd(key, *members) - - # remove one element from set - test_result = await redis.srem(key, members[-1]) - assert test_result == 1 - - # remove not existing element - test_result = await redis.srem(key, b"foo") - assert test_result == 0 - - # remove not existing element from not existing set - test_result = await redis.srem(b"not:" + key, b"foo") - assert test_result == 0 - - # remove multiple elements from set - test_result = await redis.srem(key, *members[:-1]) - assert test_result == 6 - with pytest.raises(TypeError): - await redis.srem(None, members) - - -@pytest.mark.asyncio -async def test_sunion(redis): - key1 = b"key:sunion:1" - key2 = b"key:sunion:2" - key3 = b"key:sunion:3" - - members1 = [b"a", b"b", b"c", b"d"] - members2 = [b"c"] - members3 = [b"a", b"c", b"e"] - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - await redis.sadd(key3, *members3) - - # test multiple keys - test_result = await redis.sunion(key1, key2, key3) - assert set(test_result) == set(members1 + members2 + members3) - - # test single key - test_result = await redis.sunion(key2) - assert set(test_result) == {b"c"} - - with pytest.raises(TypeError): - await redis.sunion(None) - with pytest.raises(TypeError): - await redis.sunion(key1, None) - - -@pytest.mark.asyncio -async def test_sunionstore(redis): - key1 = b"key:sunionstore:1" - key2 = b"key:sunionstore:2" - destkey = b"key:sunionstore:destkey" - members1 = (b"a", b"b", b"c") - members2 = (b"c", b"d", b"e") - - await redis.sadd(key1, *members1) - await redis.sadd(key2, *members2) - - # test basic use case - test_result = await redis.sunionstore(destkey, key1, key2) - assert test_result == 5 - - # make sure that destkey contains 5 members - test_result = await redis.scard(destkey) - assert test_result == 5 - - # try sunionstore in case none of sets exists - test_result = await redis.sunionstore( - b"not:" + destkey, b"not:" + key1, b"not:" + key2 - ) - assert test_result == 0 - - with pytest.raises(TypeError): - await redis.sunionstore(None, key1) - with pytest.raises(TypeError): - await redis.sunionstore(destkey, None) - with pytest.raises(TypeError): - await redis.sunionstore(destkey, key1, None) - - -@redis_version(2, 8, 0, reason="SSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_sscan(redis): - key = b"key:sscan" - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - member = f"member:{foo_or_bar}:{i}".encode("utf-8") - await add(redis, key, member) - - cursor, values = await redis.sscan(key, match=b"member:foo:*") - assert len(values) == 3 - - cursor, values = await redis.sscan(key, match=b"member:bar:*") - assert len(values) == 7 - - # SCAN family functions do not guarantee that the number (count) of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - cursor = b"0" - test_values = [] - while cursor: - cursor, values = await redis.sscan(key, cursor, count=2) - test_values.extend(values) - assert len(test_values) == 10 - - with pytest.raises(TypeError): - await redis.sscan(None) - - -@redis_version(2, 8, 0, reason="SSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_isscan(redis): - key = b"key:sscan" - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - member = f"member:{foo_or_bar}:{i}".encode("utf-8") - assert await redis.sadd(key, member) == 1 - - async def coro(cmd): - lst = [] - async for i in cmd: - lst.append(i) - return lst - - ret = await coro(redis.isscan(key, match=b"member:foo:*")) - assert set(ret) == {b"member:foo:3", b"member:foo:6", b"member:foo:9"} - - ret = await coro(redis.isscan(key, match=b"member:bar:*")) - assert set(ret) == { - b"member:bar:1", - b"member:bar:2", - b"member:bar:4", - b"member:bar:5", - b"member:bar:7", - b"member:bar:8", - b"member:bar:10", - } - - # SCAN family functions do not guarantee that the number (count) of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - ret = await coro(redis.isscan(key, count=2)) - assert set(ret) == { - b"member:foo:3", - b"member:foo:6", - b"member:foo:9", - b"member:bar:1", - b"member:bar:2", - b"member:bar:4", - b"member:bar:5", - b"member:bar:7", - b"member:bar:8", - b"member:bar:10", - } - - with pytest.raises(TypeError): - await redis.isscan(None) diff --git a/tests/sorted_set_commands_test.py b/tests/sorted_set_commands_test.py deleted file mode 100644 index 162c6448f..000000000 --- a/tests/sorted_set_commands_test.py +++ /dev/null @@ -1,842 +0,0 @@ -import itertools - -import pytest - -from tests.testutils import redis_version - - -@redis_version(5, 0, 0, reason="BZPOPMAX is available since redis>=5.0.0") -@pytest.mark.asyncio -async def test_bzpopmax(redis): - key1 = b"key:zpopmax:1" - key2 = b"key:zpopmax:2" - - pairs = [(0, b"a"), (5, b"c"), (2, b"d"), (8, b"e"), (9, b"f"), (3, b"g")] - await redis.zadd(key1, *pairs[0]) - await redis.zadd(key2, *itertools.chain.from_iterable(pairs)) - - res = await redis.bzpopmax(key1, timeout=0) - assert res == [key1, b"a", b"0"] - res = await redis.bzpopmax(key1, key2, timeout=0) - assert res == [key2, b"f", b"9"] - - with pytest.raises(TypeError): - await redis.bzpopmax(key1, timeout=b"one") - with pytest.raises(ValueError): - await redis.bzpopmax(key2, timeout=-10) - - -@redis_version(5, 0, 0, reason="BZPOPMIN is available since redis>=5.0.0") -@pytest.mark.asyncio -async def test_bzpopmin(redis): - key1 = b"key:zpopmin:1" - key2 = b"key:zpopmin:2" - - pairs = [(0, b"a"), (5, b"c"), (2, b"d"), (8, b"e"), (9, b"f"), (3, b"g")] - await redis.zadd(key1, *pairs[0]) - await redis.zadd(key2, *itertools.chain.from_iterable(pairs)) - - res = await redis.bzpopmin(key1, timeout=0) - assert res == [key1, b"a", b"0"] - res = await redis.bzpopmin(key1, key2, timeout=0) - assert res == [key2, b"a", b"0"] - - with pytest.raises(TypeError): - await redis.bzpopmin(key1, timeout=b"one") - with pytest.raises(ValueError): - await redis.bzpopmin(key2, timeout=-10) - - -@pytest.mark.asyncio -async def test_zadd(redis): - key = b"key:zadd" - res = await redis.zadd(key, 1, b"one") - assert res == 1 - res = await redis.zadd(key, 1, b"one") - assert res == 0 - res = await redis.zadd(key, 1, b"uno") - assert res == 1 - res = await redis.zadd(key, 2.5, b"two") - assert res == 1 - res = await redis.zadd(key, 3, b"three", 4, b"four") - assert res == 2 - - res = await redis.zrange(key, 0, -1, withscores=False) - assert res == [b"one", b"uno", b"two", b"three", b"four"] - - with pytest.raises(TypeError): - await redis.zadd(None, 1, b"one") - with pytest.raises(TypeError): - await redis.zadd(key, b"two", b"one") - with pytest.raises(TypeError): - await redis.zadd(key, 3, b"three", 4) - with pytest.raises(TypeError): - await redis.zadd(key, 3, b"three", "four", 4) - - -@redis_version( - 3, - 0, - 2, - reason="ZADD options is available since redis>=3.0.2", -) -@pytest.mark.asyncio -async def test_zadd_options(redis): - key = b"key:zaddopt" - - res = await redis.zadd(key, 0, b"one") - assert res == 1 - - res = await redis.zadd( - key, - 1, - b"one", - 2, - b"two", - exist=redis.ZSET_IF_EXIST, - ) - assert res == 0 - - res = await redis.zscore(key, b"one") - assert res == 1 - - res = await redis.zscore(key, b"two") - assert res is None - - res = await redis.zadd( - key, - 1, - b"one", - 2, - b"two", - exist=redis.ZSET_IF_NOT_EXIST, - ) - assert res == 1 - - res = await redis.zscore(key, b"one") - assert res == 1 - - res = await redis.zscore(key, b"two") - assert res == 2 - - res = await redis.zrange(key, 0, -1, withscores=False) - assert res == [b"one", b"two"] - - res = await redis.zadd(key, 1, b"two", changed=True) - assert res == 1 - - res = await redis.zadd(key, 1, b"two", incr=True) - assert int(res) == 2 - - with pytest.raises(ValueError): - await redis.zadd(key, 1, b"one", 2, b"two", incr=True) - - -@pytest.mark.asyncio -async def test_zcard(redis): - key = b"key:zcard" - pairs = [1, b"one", 2, b"two", 3, b"three"] - res = await redis.zadd(key, *pairs) - assert res == 3 - res = await redis.zcard(key) - assert res == 3 - res = await redis.zadd(key, 1, b"ein") - assert res == 1 - res = await redis.zcard(key) - assert res == 4 - - with pytest.raises(TypeError): - await redis.zcard(None) - - -@pytest.mark.asyncio -async def test_zcount(redis): - key = b"key:zcount" - pairs = [1, b"one", 1, b"uno", 2.5, b"two", 3, b"three", 7, b"seven"] - res = await redis.zadd(key, *pairs) - assert res == 5 - - res_zcount = await redis.zcount(key) - res_zcard = await redis.zcard(key) - assert res_zcount == res_zcard - - res = await redis.zcount(key, 1, 3) - assert res == 4 - res = await redis.zcount(key, 3, 10) - assert res == 2 - res = await redis.zcount(key, 100, 200) - assert res == 0 - - res = await redis.zcount(key, 1, 3, exclude=redis.ZSET_EXCLUDE_BOTH) - assert res == 1 - res = await redis.zcount(key, 1, 3, exclude=redis.ZSET_EXCLUDE_MIN) - assert res == 2 - res = await redis.zcount(key, 1, 3, exclude=redis.ZSET_EXCLUDE_MAX) - assert res == 3 - res = await redis.zcount(key, 1, exclude=redis.ZSET_EXCLUDE_MAX) - assert res == 5 - res = await redis.zcount(key, float("-inf"), 3, exclude=redis.ZSET_EXCLUDE_MIN) - assert res == 4 - - with pytest.raises(TypeError): - await redis.zcount(None) - with pytest.raises(TypeError): - await redis.zcount(key, "one", 2) - with pytest.raises(TypeError): - await redis.zcount(key, 1.1, b"two") - with pytest.raises(ValueError): - await redis.zcount(key, 10, 1) - - -@pytest.mark.asyncio -async def test_zincrby(redis): - key = b"key:zincrby" - pairs = [1, b"one", 1, b"uno", 2.5, b"two", 3, b"three"] - res = await redis.zadd(key, *pairs) - res = await redis.zincrby(key, 1, b"one") - assert res == 2 - res = await redis.zincrby(key, -5, b"uno") - assert res == -4 - res = await redis.zincrby(key, 3.14, b"two") - assert abs(res - 5.64) <= 0.00001 - res = await redis.zincrby(key, -3.14, b"three") - assert abs(res - -0.14) <= 0.00001 - - with pytest.raises(TypeError): - await redis.zincrby(None, 5, "one") - with pytest.raises(TypeError): - await redis.zincrby(key, "one", 5) - - -@pytest.mark.asyncio -async def test_zinterstore(redis): - zset1 = [2, "one", 2, "two"] - zset2 = [3, "one", 3, "three"] - - await redis.zadd("zset1", *zset1) - await redis.zadd("zset2", *zset2) - - res = await redis.zinterstore("zout", "zset1", "zset2") - assert res == 1 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 5)] - - res = await redis.zinterstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_SUM - ) - assert res == 1 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 5)] - - res = await redis.zinterstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_MIN - ) - assert res == 1 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 2)] - - res = await redis.zinterstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_MAX - ) - assert res == 1 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 3)] - - # weights - - with pytest.raises(AssertionError): - await redis.zinterstore("zout", "zset1", "zset2", with_weights=True) - - res = await redis.zinterstore("zout", ("zset1", 2), ("zset2", 2), with_weights=True) - assert res == 1 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 10)] - - -@redis_version(2, 8, 9, reason="ZLEXCOUNT is available since redis>=2.8.9") -@pytest.mark.asyncio -async def test_zlexcount(redis): - key = b"key:zlexcount" - pairs = [0, b"a", 0, b"b", 0, b"c", 0, b"d", 0, b"e"] - res = await redis.zadd(key, *pairs) - assert res == 5 - res = await redis.zlexcount(key) - assert res == 5 - res = await redis.zlexcount(key, min=b"-", max=b"[e") - assert res == 5 - res = await redis.zlexcount(key, min=b"(a", max=b"(e") - assert res == 3 - - -@pytest.mark.parametrize("encoding", [None, "utf-8"]) -@pytest.mark.asyncio -async def test_zrange(redis, encoding): - key = b"key:zrange" - scores = [1, 1, 2.5, 3, 7] - if encoding: - members = ["one", "uno", "two", "three", "seven"] - else: - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - rev_pairs = list(zip(members, scores)) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zrange(key, 0, -1, withscores=False, encoding=encoding) - assert res == members - res = await redis.zrange(key, 0, -1, withscores=True, encoding=encoding) - assert res == rev_pairs - res = await redis.zrange(key, -2, -1, withscores=False, encoding=encoding) - assert res == members[-2:] - res = await redis.zrange(key, 1, 2, withscores=False, encoding=encoding) - assert res == members[1:3] - - with pytest.raises(TypeError): - await redis.zrange(None, 1, b"one") - with pytest.raises(TypeError): - await redis.zrange(key, b"first", -1) - with pytest.raises(TypeError): - await redis.zrange(key, 0, "last") - - -@redis_version(2, 8, 9, reason="ZRANGEBYLEX is available since redis>=2.8.9") -@pytest.mark.asyncio -async def test_zrangebylex(redis): - key = b"key:zrangebylex" - scores = [0] * 5 - members = [b"a", b"b", b"c", b"d", b"e"] - strings = [x.decode("utf-8") for x in members] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - res = await redis.zrangebylex(key) - assert res == members - res = await redis.zrangebylex(key, encoding="utf-8") - assert res == strings - res = await redis.zrangebylex(key, min=b"-", max=b"[d") - assert res == members[:-1] - res = await redis.zrangebylex(key, min=b"(a", max=b"(e") - assert res == members[1:-1] - res = await redis.zrangebylex(key, min=b"[x", max=b"[z") - assert res == [] - res = await redis.zrangebylex(key, min=b"[e", max=b"[a") - assert res == [] - res = await redis.zrangebylex(key, offset=1, count=2) - assert res == members[1:3] - with pytest.raises(TypeError): - await redis.zrangebylex(None, b"[a", b"[e") - with pytest.raises(TypeError): - await redis.zrangebylex(key, None, b"[e") - with pytest.raises(TypeError): - await redis.zrangebylex(key, b"[a", None) - with pytest.raises(TypeError): - await redis.zrangebylex(key, b"a", b"e", offset=1) - with pytest.raises(TypeError): - await redis.zrangebylex(key, b"a", b"e", count=1) - with pytest.raises(TypeError): - await redis.zrangebylex(key, b"a", b"e", offset="one", count=1) - with pytest.raises(TypeError): - await redis.zrangebylex(key, b"a", b"e", offset=1, count="one") - - -@pytest.mark.asyncio -async def test_zrank(redis): - key = b"key:zrank" - scores = [1, 1, 2.5, 3, 7] - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - for i, m in enumerate(members): - res = await redis.zrank(key, m) - assert res == i - - res = await redis.zrank(key, b"not:exists") - assert res is None - - with pytest.raises(TypeError): - await redis.zrank(None, b"one") - - -@pytest.mark.parametrize("encoding", [None, "utf-8"]) -@pytest.mark.asyncio -async def test_zrangebyscore(redis, encoding): - key = b"key:zrangebyscore" - scores = [1, 1, 2.5, 3, 7] - if encoding: - members = ["one", "uno", "two", "three", "seven"] - else: - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - rev_pairs = list(zip(members, scores)) - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zrangebyscore(key, 1, 7, withscores=False, encoding=encoding) - assert res == members - res = await redis.zrangebyscore( - key, 1, 7, withscores=False, exclude=redis.ZSET_EXCLUDE_BOTH, encoding=encoding - ) - assert res == members[2:-1] - res = await redis.zrangebyscore(key, 1, 7, withscores=True, encoding=encoding) - assert res == rev_pairs - - res = await redis.zrangebyscore(key, 1, 10, offset=2, count=2, encoding=encoding) - assert res == members[2:4] - - with pytest.raises(TypeError): - await redis.zrangebyscore(None, 1, 7) - with pytest.raises(TypeError): - await redis.zrangebyscore(key, 10, b"e") - with pytest.raises(TypeError): - await redis.zrangebyscore(key, b"a", 20) - with pytest.raises(TypeError): - await redis.zrangebyscore(key, 1, 7, offset=1) - with pytest.raises(TypeError): - await redis.zrangebyscore(key, 1, 7, count=1) - with pytest.raises(TypeError): - await redis.zrangebyscore(key, 1, 7, offset="one", count=1) - with pytest.raises(TypeError): - await redis.zrangebyscore(key, 1, 7, offset=1, count="one") - - -@pytest.mark.asyncio -async def test_zrem(redis): - key = b"key:zrem" - scores = [1, 1, 2.5, 3, 7] - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zrem(key, b"uno", b"one") - assert res == 2 - - res = await redis.zrange(key, 0, -1) - assert res == members[2:] - - res = await redis.zrem(key, b"not:exists") - assert res == 0 - - res = await redis.zrem(b"not:" + key, b"not:exists") - assert res == 0 - - with pytest.raises(TypeError): - await redis.zrem(None, b"one") - - -@redis_version(2, 8, 9, reason="ZREMRANGEBYLEX is available since redis>=2.8.9") -@pytest.mark.asyncio -async def test_zremrangebylex(redis): - key = b"key:zremrangebylex" - members = [ - b"aaaa", - b"b", - b"c", - b"d", - b"e", - b"foo", - b"zap", - b"zip", - b"ALPHA", - b"alpha", - ] - scores = [0] * len(members) - - pairs = list(itertools.chain(*zip(scores, members))) - res = await redis.zadd(key, *pairs) - assert res == 10 - - res = await redis.zremrangebylex(key, b"[alpha", b"[omega") - assert res == 6 - res = await redis.zrange(key, 0, -1) - assert res == [b"ALPHA", b"aaaa", b"zap", b"zip"] - - res = await redis.zremrangebylex(key, b"(zap", b"(zip") - assert res == 0 - - res = await redis.zrange(key, 0, -1) - assert res == [b"ALPHA", b"aaaa", b"zap", b"zip"] - - res = await redis.zremrangebylex(key) - assert res == 4 - res = await redis.zrange(key, 0, -1) - assert res == [] - - with pytest.raises(TypeError): - await redis.zremrangebylex(None, b"a", b"e") - with pytest.raises(TypeError): - await redis.zremrangebylex(key, None, b"e") - with pytest.raises(TypeError): - await redis.zremrangebylex(key, b"a", None) - - -@pytest.mark.asyncio -async def test_zremrangebyrank(redis): - key = b"key:zremrangebyrank" - scores = [0, 1, 2, 3, 4, 5] - members = [b"zero", b"one", b"two", b"three", b"four", b"five"] - pairs = list(itertools.chain(*zip(scores, members))) - res = await redis.zadd(key, *pairs) - assert res == 6 - - res = await redis.zremrangebyrank(key, 0, 1) - assert res == 2 - res = await redis.zrange(key, 0, -1) - assert res == members[2:] - - res = await redis.zremrangebyrank(key, -2, -1) - assert res == 2 - res = await redis.zrange(key, 0, -1) - assert res == members[2:-2] - - with pytest.raises(TypeError): - await redis.zremrangebyrank(None, 1, 2) - with pytest.raises(TypeError): - await redis.zremrangebyrank(key, b"first", -1) - with pytest.raises(TypeError): - await redis.zremrangebyrank(key, 0, "last") - - -@pytest.mark.asyncio -async def test_zremrangebyscore(redis): - key = b"key:zremrangebyscore" - scores = [1, 1, 2.5, 3, 7] - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zremrangebyscore(key, 3, 7.5, exclude=redis.ZSET_EXCLUDE_MIN) - assert res == 1 - res = await redis.zrange(key, 0, -1) - assert res == members[:-1] - - res = await redis.zremrangebyscore(key, 1, 3, exclude=redis.ZSET_EXCLUDE_BOTH) - assert res == 1 - res = await redis.zrange(key, 0, -1) - assert res == [b"one", b"uno", b"three"] - - res = await redis.zremrangebyscore(key) - assert res == 3 - res = await redis.zrange(key, 0, -1) - assert res == [] - - with pytest.raises(TypeError): - await redis.zremrangebyscore(None, 1, 2) - with pytest.raises(TypeError): - await redis.zremrangebyscore(key, b"first", -1) - with pytest.raises(TypeError): - await redis.zremrangebyscore(key, 0, "last") - - -@pytest.mark.parametrize("encoding", [None, "utf-8"]) -@pytest.mark.asyncio -async def test_zrevrange(redis, encoding): - key = b"key:zrevrange" - scores = [1, 1, 2.5, 3, 7] - if encoding: - members = ["one", "uno", "two", "three", "seven"] - else: - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - rev_pairs = list(zip(members, scores)) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zrevrange(key, 0, -1, withscores=False, encoding=encoding) - assert res == members[::-1] - res = await redis.zrevrange(key, 0, -1, withscores=True, encoding=encoding) - assert res == rev_pairs[::-1] - res = await redis.zrevrange(key, -2, -1, withscores=False, encoding=encoding) - assert res == members[1::-1] - res = await redis.zrevrange(key, 1, 2, withscores=False, encoding=encoding) - assert res == members[3:1:-1] - - with pytest.raises(TypeError): - await redis.zrevrange(None, 1, b"one") - with pytest.raises(TypeError): - await redis.zrevrange(key, b"first", -1) - with pytest.raises(TypeError): - await redis.zrevrange(key, 0, "last") - - -@pytest.mark.asyncio -async def test_zrevrank(redis): - key = b"key:zrevrank" - scores = [1, 1, 2.5, 3, 7] - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - for i, m in enumerate(members): - res = await redis.zrevrank(key, m) - assert res == len(members) - i - 1 - - res = await redis.zrevrank(key, b"not:exists") - assert res is None - - with pytest.raises(TypeError): - await redis.zrevrank(None, b"one") - - -@pytest.mark.asyncio -async def test_zscore(redis): - key = b"key:zscore" - scores = [1, 1, 2.5, 3, 7] - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - - for s, m in zip(scores, members): - res = await redis.zscore(key, m) - assert res == s - with pytest.raises(TypeError): - await redis.zscore(None, b"one") - # Check None on undefined members - res = await redis.zscore(key, "undefined") - assert res is None - - -@pytest.mark.asyncio -async def test_zunionstore(redis): - zset1 = [2, "one", 2, "two"] - zset2 = [3, "one", 3, "three"] - - await redis.zadd("zset1", *zset1) - await redis.zadd("zset2", *zset2) - - res = await redis.zunionstore("zout", "zset1", "zset2") - assert res == 3 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"two", 2), (b"three", 3), (b"one", 5)] - - res = await redis.zunionstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_SUM - ) - assert res == 3 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"two", 2), (b"three", 3), (b"one", 5)] - - res = await redis.zunionstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_MIN - ) - assert res == 3 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"one", 2), (b"two", 2), (b"three", 3)] - - res = await redis.zunionstore( - "zout", "zset1", "zset2", aggregate=redis.ZSET_AGGREGATE_MAX - ) - assert res == 3 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"two", 2), (b"one", 3), (b"three", 3)] - - # weights - - with pytest.raises(AssertionError): - await redis.zunionstore("zout", "zset1", "zset2", with_weights=True) - - res = await redis.zunionstore("zout", ("zset1", 2), ("zset2", 2), with_weights=True) - assert res == 3 - res = await redis.zrange("zout", withscores=True) - assert res == [(b"two", 4), (b"three", 6), (b"one", 10)] - - -@pytest.mark.parametrize("encoding", [None, "utf-8"]) -@pytest.mark.asyncio -async def test_zrevrangebyscore(redis, encoding): - key = b"key:zrevrangebyscore" - scores = [1, 1, 2.5, 3, 7] - if encoding: - members = ["one", "uno", "two", "three", "seven"] - else: - members = [b"one", b"uno", b"two", b"three", b"seven"] - pairs = list(itertools.chain(*zip(scores, members))) - rev_pairs = list(zip(members[::-1], scores[::-1])) - res = await redis.zadd(key, *pairs) - assert res == 5 - - res = await redis.zrevrangebyscore(key, 7, 1, withscores=False, encoding=encoding) - assert res == members[::-1] - res = await redis.zrevrangebyscore( - key, 7, 1, withscores=False, exclude=redis.ZSET_EXCLUDE_BOTH, encoding=encoding - ) - assert res == members[-2:1:-1] - res = await redis.zrevrangebyscore(key, 7, 1, withscores=True, encoding=encoding) - assert res == rev_pairs - - res = await redis.zrevrangebyscore(key, 10, 1, offset=2, count=2, encoding=encoding) - assert res == members[-3:-5:-1] - - with pytest.raises(TypeError): - await redis.zrevrangebyscore(None, 1, 7) - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, 10, b"e") - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, b"a", 20) - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, 1, 7, offset=1) - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, 1, 7, count=1) - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, 1, 7, offset="one", count=1) - with pytest.raises(TypeError): - await redis.zrevrangebyscore(key, 1, 7, offset=1, count="one") - - -@redis_version(2, 8, 9, reason="ZREVRANGEBYLEX is available since redis>=2.8.9") -@pytest.mark.asyncio -async def test_zrevrangebylex(redis): - key = b"key:zrevrangebylex" - scores = [0] * 5 - members = [b"a", b"b", b"c", b"d", b"e"] - strings = [x.decode("utf-8") for x in members] - rev_members = members[::-1] - rev_strings = strings[::-1] - pairs = list(itertools.chain(*zip(scores, members))) - - res = await redis.zadd(key, *pairs) - assert res == 5 - res = await redis.zrevrangebylex(key) - assert res == rev_members - res = await redis.zrevrangebylex(key, encoding="utf-8") - assert res == rev_strings - res = await redis.zrevrangebylex(key, min=b"-", max=b"[d") - assert res == rev_members[1:] - res = await redis.zrevrangebylex(key, min=b"(a", max=b"(e") - assert res == rev_members[1:-1] - res = await redis.zrevrangebylex(key, min=b"[x", max=b"[z") - assert res == [] - res = await redis.zrevrangebylex(key, min=b"[e", max=b"[a") - assert res == [] - res = await redis.zrevrangebylex(key, offset=1, count=2) - assert res == rev_members[1:3] - with pytest.raises(TypeError): - await redis.zrevrangebylex(None, b"a", b"e") - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, None, b"e") - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, b"a", None) - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, b"a", b"e", offset=1) - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, b"a", b"e", count=1) - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, b"a", b"e", offset="one", count=1) - with pytest.raises(TypeError): - await redis.zrevrangebylex(key, b"a", b"e", offset=1, count="one") - - -@redis_version(2, 8, 0, reason="ZSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_zscan(redis): - key = b"key:zscan" - scores, members = [], [] - - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - members.append(f"zmem:{foo_or_bar}:{i}".encode("utf-8")) - scores.append(i) - pairs = list(itertools.chain(*zip(scores, members))) - rev_pairs = set(zip(members, scores)) - await redis.zadd(key, *pairs) - - cursor, values = await redis.zscan(key, match=b"zmem:foo:*") - assert len(values) == 3 - - cursor, values = await redis.zscan(key, match=b"zmem:bar:*") - assert len(values) == 7 - - # SCAN family functions do not guarantee that the number (count) of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - cursor = b"0" - test_values = set() - while cursor: - cursor, values = await redis.zscan(key, cursor, count=2) - test_values.update(values) - assert test_values == rev_pairs - - with pytest.raises(TypeError): - await redis.zscan(None) - - -@redis_version(2, 8, 0, reason="ZSCAN is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_izscan(redis): - key = b"key:zscan" - scores, members = [], [] - - for i in range(1, 11): - foo_or_bar = "bar" if i % 3 else "foo" - members.append(f"zmem:{foo_or_bar}:{i}".encode("utf-8")) - scores.append(i) - pairs = list(itertools.chain(*zip(scores, members))) - await redis.zadd(key, *pairs) - vals = set(zip(members, scores)) - - async def coro(cmd): - res = set() - async for key, score in cmd: - res.add((key, score)) - return res - - ret = await coro(redis.izscan(key)) - assert set(ret) == set(vals) - - ret = await coro(redis.izscan(key, match=b"zmem:foo:*")) - assert set(ret) == {v for v in vals if b"foo" in v[0]} - - ret = await coro(redis.izscan(key, match=b"zmem:bar:*")) - assert set(ret) == {v for v in vals if b"bar" in v[0]} - - # SCAN family functions do not guarantee that the number (count) of - # elements returned per call are in a given range. So here - # just dummy test, that *count* argument does not break something - - ret = await coro(redis.izscan(key, count=2)) - assert set(ret) == set(vals) - - with pytest.raises(TypeError): - await redis.izscan(None) - - -@redis_version(5, 0, 0, reason="ZPOPMAX is available since redis>=5.0.0") -@pytest.mark.asyncio -async def test_zpopmax(redis): - key = b"key:zpopmax" - - pairs = [(0, b"a"), (5, b"c"), (2, b"d"), (8, b"e"), (9, b"f"), (3, b"g")] - await redis.zadd(key, *itertools.chain.from_iterable(pairs)) - - assert await redis.zpopmax(key) == [b"f", b"9"] - assert await redis.zpopmax(key, 3) == [b"e", b"8", b"c", b"5", b"g", b"3"] - - with pytest.raises(TypeError): - await redis.zpopmax(key, b"b") - - -@redis_version(5, 0, 0, reason="ZPOPMIN is available since redis>=5.0.0") -@pytest.mark.asyncio -async def test_zpopmin(redis): - key = b"key:zpopmin" - - pairs = [(0, b"a"), (5, b"c"), (2, b"d"), (8, b"e"), (9, b"f"), (3, b"g")] - await redis.zadd(key, *itertools.chain.from_iterable(pairs)) - - assert await redis.zpopmin(key) == [b"a", b"0"] - assert await redis.zpopmin(key, 3) == [b"d", b"2", b"g", b"3", b"c", b"5"] - - with pytest.raises(TypeError): - await redis.zpopmin(key, b"b") diff --git a/tests/ssl/Makefile b/tests/ssl/Makefile deleted file mode 100644 index 97c7879df..000000000 --- a/tests/ssl/Makefile +++ /dev/null @@ -1,15 +0,0 @@ - -all: dhparam.pem cert.pem cafile.crt - -dhparam.pem: - openssl dhparam -out $@ 2048 -cafile.crt: - openssl req -newkey rsa:2048 -keyout private.key \ - -batch -nodes -x509 -out $@ -cert.pem: cafile.crt - cat private.key $^ > $@ - -clean: - rm private.key cert.pem cafile.crt dhparam.pem - -.PHONY: all diff --git a/tests/ssl/cafile.crt b/tests/ssl/cafile.crt deleted file mode 100644 index 64b8a35b0..000000000 --- a/tests/ssl/cafile.crt +++ /dev/null @@ -1,21 +0,0 @@ ------BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIJAOZQ4fBuw3mhMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTcwNTAzMTAzMzE1WhcNMTcwNjAyMTAzMzE1WjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEAwTO5eaHSSw294oB/uiSxm8mHkGxD+Md+fKjwDIJVmCvIE7hAQ73JbyxL -TBDRsKlOenUN2sxyhswsAl8OGYlX5lnuhohWzBVvo8OFYN4D3T5k810LrNVelpLv -AJ0YUxS2wWL/iKz/bbH1eQ4irJU1D3t5ie0k0tZ+vocf2flLeHiqdDguUtCs4UQk -bMO9Pc/sZ1ydyTFdFRpifVTb7EtvwQ+BS1XwNG6PH0DJrVaCLf8YnSGPK7OKz/Qa -VEzkLfsleOvTrr61aHIM4TyEmrrfohT9sAz6UjQjQwukzfqArMHhk5P+hDsd8rn7 -NqONiXpypyjVfdUevSCVRRhny5nK5wIDAQABo1AwTjAdBgNVHQ4EFgQUaRQN5nt+ -PwlxGaDHTiZh+n7zGHkwHwYDVR0jBBgwFoAUaRQN5nt+PwlxGaDHTiZh+n7zGHkw -DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAmbPALAkVB2aWCj2rKgK8 -27OPJrds7ioTFkCTE++cb5id84zkM/fZQeak+8bF4Fg63Lo9EUkg+ZVPClyIXfWa -7GMI6eOHiQCaFTbKuh4RSVG7llRYdJe2XR5a5A5UXqqy0j05UNe4mA39lH1AHpvj -zlw5Q2ttUdzDARrIK7gdvqJxpcviY5U9op/MRt1CiblOv+dBKpOjUF+gkInw6SIT -Nq4zlCGUK6feBXwGuyWlOkIglhn7xPMWFNnhyAiDzQ3Au1dN0DWJ4ODAhNMHFyHl -gPNWfqu4uz5LbWaoPeb/NYw7KPPabMj6nKMAer1FGlYQr4YELZwImpvpQasU+6ES -zw== ------END CERTIFICATE----- diff --git a/tests/ssl/cert.pem b/tests/ssl/cert.pem deleted file mode 100644 index 452a7942f..000000000 --- a/tests/ssl/cert.pem +++ /dev/null @@ -1,49 +0,0 @@ ------BEGIN PRIVATE KEY----- -MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDBM7l5odJLDb3i -gH+6JLGbyYeQbEP4x358qPAMglWYK8gTuEBDvclvLEtMENGwqU56dQ3azHKGzCwC -Xw4ZiVfmWe6GiFbMFW+jw4Vg3gPdPmTzXQus1V6Wku8AnRhTFLbBYv+IrP9tsfV5 -DiKslTUPe3mJ7STS1n6+hx/Z+Ut4eKp0OC5S0KzhRCRsw709z+xnXJ3JMV0VGmJ9 -VNvsS2/BD4FLVfA0bo8fQMmtVoIt/xidIY8rs4rP9BpUTOQt+yV469OuvrVocgzh -PISaut+iFP2wDPpSNCNDC6TN+oCsweGTk/6EOx3yufs2o42JenKnKNV91R69IJVF -GGfLmcrnAgMBAAECggEAdFVW61x7KeI+YjKJtmX95BZ3YIkwbI6DJUD8OiwIqjZC -pU0etSuELUL4m/bMrJllverZytOsampqXYsrDElc+kFQlQVnbj/CF1PV7jwBC6lU -VA3Ex+86o2QaMb5mNTCV0uBvibbRnxW2/4t5aX32kzoANwLqV2H8s3DdxvyLayAA -cI0y26jYn791Jg9KDqynIS+6oyLObUDt+T4gQcgNVPwqBcrm1tFlC6GVcOwiKVx8 -PTjrT4DeTOKStxAnd1sMkLGOEy/fMe3t4vx57TQ7f9ugMlT557B4JiRj3UxP6KUN -Y6jYbwIzU8h/o1NQMQLNwB7UAsE+JuJuMS3ljhYe4QKBgQDzD6efkpz2DIMEOqOb -1P0HhmgZI562+7iTDNA+fkRxOBSkLLlYQB+D3vciKmmblYxSVFpx4XLlWfxgnO4M -OMXcU0xWaBntTup9ztvGWgsIg0oVre5Kc7yk0hanrLIwG99JUDCEa5XHELwyIjwp -GmBhzVGosOj8XX8yeDuZjp6sfwKBgQDLfJvZadvMDbLwVgKZimkCqAEku81B7C7d -S5DkX//3xMUNS1DbdmcS2Pe2n5oQgaXe6Wut+qB23vb2adRFvQMgdeiR3JckmKV+ -BHDgpnHlPGf2W716/lsOCgOFK3QOkke3xhD4sIIjKp+vZC2OGyMn8Id4xNgYgxlU -Caxn693NmQKBgDqzyTQM4MO7+diHoQP9yK6Mk4+evrJK0SUSryiorjb56GJOOuVJ -d1MOAnnJ7H+a+qzXmpBudqVVulJLFGL3QzIXHBSyR7C8on2H/CRHkuqXaskZnLd2 -hFT6OGZ+mvprgN3f9BfHNAFD1W/2PLlgmW6Fe/dV4q8wlYvG5f4MJ95ZAoGBAIbA -CtirsEDtZWygGHKi5K08oh04PFGGXPZwnw+Mvw3NgQwvrujV/KXPhiKqiDScFkKK -YqNAj9iICBTfuhFAfHyXeB53bKNwbk8IE9PAhTXfjZzn3Y9ANv4SBYi/YMhxeAqr -n/t/r34oMLzN4xjywZImRx/jgpKg5jnvefsLh8MpAoGAbxvvJJIOo9JtG8LzQmoU -pAu7Xvem6keJ9o2Z1ZnySo2Ky2FlyW6mXkd3B9uaNlnlMDUhfSYxa9czQHZKHosk -gFLpLVFliVDOoSTquKCUyh/egnWbTnVPhB9VKGoZOnpxiuWWXGTWNjctpRRzze9y -dRodUMhcBh6+m28Vyn4g24g= ------END PRIVATE KEY----- ------BEGIN CERTIFICATE----- -MIIDXTCCAkWgAwIBAgIJAOZQ4fBuw3mhMA0GCSqGSIb3DQEBCwUAMEUxCzAJBgNV -BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX -aWRnaXRzIFB0eSBMdGQwHhcNMTcwNTAzMTAzMzE1WhcNMTcwNjAyMTAzMzE1WjBF -MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 -ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIB -CgKCAQEAwTO5eaHSSw294oB/uiSxm8mHkGxD+Md+fKjwDIJVmCvIE7hAQ73JbyxL -TBDRsKlOenUN2sxyhswsAl8OGYlX5lnuhohWzBVvo8OFYN4D3T5k810LrNVelpLv -AJ0YUxS2wWL/iKz/bbH1eQ4irJU1D3t5ie0k0tZ+vocf2flLeHiqdDguUtCs4UQk -bMO9Pc/sZ1ydyTFdFRpifVTb7EtvwQ+BS1XwNG6PH0DJrVaCLf8YnSGPK7OKz/Qa -VEzkLfsleOvTrr61aHIM4TyEmrrfohT9sAz6UjQjQwukzfqArMHhk5P+hDsd8rn7 -NqONiXpypyjVfdUevSCVRRhny5nK5wIDAQABo1AwTjAdBgNVHQ4EFgQUaRQN5nt+ -PwlxGaDHTiZh+n7zGHkwHwYDVR0jBBgwFoAUaRQN5nt+PwlxGaDHTiZh+n7zGHkw -DAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAmbPALAkVB2aWCj2rKgK8 -27OPJrds7ioTFkCTE++cb5id84zkM/fZQeak+8bF4Fg63Lo9EUkg+ZVPClyIXfWa -7GMI6eOHiQCaFTbKuh4RSVG7llRYdJe2XR5a5A5UXqqy0j05UNe4mA39lH1AHpvj -zlw5Q2ttUdzDARrIK7gdvqJxpcviY5U9op/MRt1CiblOv+dBKpOjUF+gkInw6SIT -Nq4zlCGUK6feBXwGuyWlOkIglhn7xPMWFNnhyAiDzQ3Au1dN0DWJ4ODAhNMHFyHl -gPNWfqu4uz5LbWaoPeb/NYw7KPPabMj6nKMAer1FGlYQr4YELZwImpvpQasU+6ES -zw== ------END CERTIFICATE----- diff --git a/tests/ssl/dhparam.pem b/tests/ssl/dhparam.pem deleted file mode 100644 index c4236900e..000000000 --- a/tests/ssl/dhparam.pem +++ /dev/null @@ -1,8 +0,0 @@ ------BEGIN DH PARAMETERS----- -MIIBCAKCAQEA29izdn3rdcj9JmNqoy2hyzJXugcbvnDoR58kLHJEc0Lcn2otTij+ -n9AvYhejuNya/8olpL3hPH5oWfYo6HfZZlzyl/m6erlqJuzbGj9JHJ+4l8fybnT+ -rBMlijgffOANglcJBZJMQ9LC53BJsuwHj7l7anYSjWHJm/13mUA19gDRI3/OCmyn -L6qOq8Iso5Xg1BBZxberUq3cVqcsZWkGTuQCTZFel63M4Li37aGwcBnucRPE1eE7 -ZIFWOUEclpgSsyF9Uk2BUFymHhtwurkASoGgJ+NgqU3+6eJYyaKlOUVUgtyNusl9 -G+BZZknn0G+R1Kn+RRCU2pp2go9Dv6SdCwIBAg== ------END DH PARAMETERS----- diff --git a/tests/ssl_test.py b/tests/ssl_test.py deleted file mode 100644 index 59d9859d3..000000000 --- a/tests/ssl_test.py +++ /dev/null @@ -1,29 +0,0 @@ -import pytest - - -@pytest.mark.asyncio -async def test_ssl_connection(create_connection, server, ssl_proxy): - ssl_port, ssl_ctx = ssl_proxy(server.tcp_address.port) - - conn = await create_connection(("localhost", ssl_port), ssl=ssl_ctx) - res = await conn.execute("ping") - assert res == b"PONG" - - -@pytest.mark.asyncio -async def test_ssl_redis(create_redis, server, ssl_proxy): - ssl_port, ssl_ctx = ssl_proxy(server.tcp_address.port) - - redis = await create_redis(("localhost", ssl_port), ssl=ssl_ctx) - res = await redis.ping() - assert res == b"PONG" - - -@pytest.mark.asyncio -async def test_ssl_pool(create_pool, server, ssl_proxy): - ssl_port, ssl_ctx = ssl_proxy(server.tcp_address.port) - - pool = await create_pool(("localhost", ssl_port), ssl=ssl_ctx) - with (await pool) as conn: - res = await conn.execute("PING") - assert res == b"PONG" diff --git a/tests/stream_commands_test.py b/tests/stream_commands_test.py deleted file mode 100644 index 4894beb6f..000000000 --- a/tests/stream_commands_test.py +++ /dev/null @@ -1,619 +0,0 @@ -import asyncio -from collections import OrderedDict -from unittest import mock - -import pytest - -from aioredis.commands.streams import parse_messages -from aioredis.errors import BusyGroupError -from tests.testutils import redis_version - -pytestmark = redis_version(5, 0, 0, reason="Streams only available since Redis 5.0.0") - - -async def add_message_with_sleep(redis, stream, fields): - await asyncio.sleep(0.2) - result = await redis.xadd(stream, fields) - return result - - -@pytest.mark.asyncio -async def test_xadd(redis, server_bin): - fields = OrderedDict( - ( - (b"field1", b"value1"), - (b"field2", b"value2"), - ) - ) - message_id = await redis.xadd("test_stream", fields) - - # Check the result is in the expected format (i.e: 1507400517949-0) - assert b"-" in message_id - timestamp, sequence = message_id.split(b"-") - assert timestamp.isdigit() - assert sequence.isdigit() - - # Read it back - messages = await redis.xrange("test_stream") - assert len(messages) == 1 - message = messages[0] - assert message[0] == message_id - assert message[1] == OrderedDict([(b"field1", b"value1"), (b"field2", b"value2")]) - - -@pytest.mark.asyncio -async def test_xadd_maxlen_exact(redis, server_bin): - message_id1 = await redis.xadd("test_stream", {"f1": "v1"}) - - # Ensure the millisecond-based message ID increments - await asyncio.sleep(0.001) - message_id2 = await redis.xadd("test_stream", {"f2": "v2"}) - await asyncio.sleep(0.001) - message_id3 = await redis.xadd( - "test_stream", {"f3": "v3"}, max_len=2, exact_len=True - ) - - # Read it back - messages = await redis.xrange("test_stream") - assert len(messages) == 2 - - message2 = messages[0] - message3 = messages[1] - - # The first message should no longer exist, just messages - # 2 and 3 remain - assert message2[0] == message_id2 - assert message2[1] == OrderedDict([(b"f2", b"v2")]) - - assert message3[0] == message_id3 - assert message3[1] == OrderedDict([(b"f3", b"v3")]) - - -@pytest.mark.asyncio -async def test_xadd_manual_message_ids(redis, server_bin): - await redis.xadd("test_stream", {"f1": "v1"}, message_id="1515958771000-0") - await redis.xadd("test_stream", {"f1": "v1"}, message_id="1515958771000-1") - await redis.xadd("test_stream", {"f1": "v1"}, message_id="1515958772000-0") - - messages = await redis.xrange("test_stream") - message_ids = [message_id for message_id, _ in messages] - assert message_ids == [b"1515958771000-0", b"1515958771000-1", b"1515958772000-0"] - - -@pytest.mark.asyncio -async def test_xadd_maxlen_inexact(redis, server_bin): - await redis.xadd("test_stream", {"f1": "v1"}) - # Ensure the millisecond-based message ID increments - await asyncio.sleep(0.001) - await redis.xadd("test_stream", {"f2": "v2"}) - await asyncio.sleep(0.001) - await redis.xadd("test_stream", {"f3": "v3"}, max_len=2, exact_len=False) - - # Read it back - messages = await redis.xrange("test_stream") - # Redis will not have removed the whole node yet - assert len(messages) == 3 - - # Check the stream is eventually truncated - for x in range(0, 1000): - await redis.xadd("test_stream", {"f": "v"}, max_len=2) - - messages = await redis.xrange("test_stream") - assert len(messages) < 1000 - - -@pytest.mark.asyncio -async def test_xrange(redis, server_bin): - stream = "test_stream" - fields = OrderedDict( - ( - (b"field1", b"value1"), - (b"field2", b"value2"), - ) - ) - message_id1 = await redis.xadd(stream, fields) - message_id2 = await redis.xadd(stream, fields) - message_id3 = await redis.xadd(stream, fields) - - # Test no parameters - messages = await redis.xrange(stream) - assert len(messages) == 3 - message = messages[0] - assert message[0] == message_id1 - assert message[1] == OrderedDict([(b"field1", b"value1"), (b"field2", b"value2")]) - - # Test start - messages = await redis.xrange(stream, start=message_id2) - assert len(messages) == 2 - - messages = await redis.xrange(stream, start="9900000000000-0") - assert len(messages) == 0 - - # Test stop - messages = await redis.xrange(stream, stop="0000000000000-0") - assert len(messages) == 0 - - messages = await redis.xrange(stream, stop=message_id2) - assert len(messages) == 2 - - messages = await redis.xrange(stream, stop="9900000000000-0") - assert len(messages) == 3 - - # Test start & stop - messages = await redis.xrange(stream, start=message_id1, stop=message_id2) - assert len(messages) == 2 - - messages = await redis.xrange( - stream, start="0000000000000-0", stop="9900000000000-0" - ) - assert len(messages) == 3 - - # Test count - messages = await redis.xrange(stream, count=2) - assert len(messages) == 2 - - -@pytest.mark.asyncio -async def test_xrevrange(redis, server_bin): - stream = "test_stream" - fields = OrderedDict( - ( - (b"field1", b"value1"), - (b"field2", b"value2"), - ) - ) - message_id1 = await redis.xadd(stream, fields) - message_id2 = await redis.xadd(stream, fields) - message_id3 = await redis.xadd(stream, fields) - - # Test no parameters - messages = await redis.xrevrange(stream) - assert len(messages) == 3 - message = messages[0] - assert message[0] == message_id3 - assert message[1] == OrderedDict([(b"field1", b"value1"), (b"field2", b"value2")]) - - # Test start - messages = await redis.xrevrange(stream, start=message_id2) - assert len(messages) == 2 - - messages = await redis.xrevrange(stream, start="9900000000000-0") - assert len(messages) == 3 - - # Test stop - messages = await redis.xrevrange(stream, stop="0000000000000-0") - assert len(messages) == 3 - - messages = await redis.xrevrange(stream, stop=message_id2) - assert len(messages) == 2 - - messages = await redis.xrevrange(stream, stop="9900000000000-0") - assert len(messages) == 0 - - # Test start & stop - messages = await redis.xrevrange(stream, start=message_id2, stop=message_id1) - assert len(messages) == 2 - - messages = await redis.xrevrange( - stream, start="9900000000000-0", stop="0000000000000-0" - ) - assert len(messages) == 3 - - # Test count - messages = await redis.xrevrange(stream, count=2) - assert len(messages) == 2 - - -@pytest.mark.asyncio -async def test_xread_selection(redis, server_bin): - """Test use of counts and starting IDs""" - stream = "test_stream" - fields = OrderedDict( - ( - (b"field1", b"value1"), - (b"field2", b"value2"), - ) - ) - message_id1 = await redis.xadd(stream, fields) - message_id2 = await redis.xadd(stream, fields) - message_id3 = await redis.xadd(stream, fields) - - messages = await redis.xread([stream], timeout=1, latest_ids=["0000000000000-0"]) - assert len(messages) == 3 - - messages = await redis.xread([stream], timeout=1, latest_ids=[message_id1]) - assert len(messages) == 2 - - messages = await redis.xread([stream], timeout=1, latest_ids=[message_id3]) - assert len(messages) == 0 - - messages = await redis.xread( - [stream], timeout=1, latest_ids=["0000000000000-0"], count=2 - ) - assert len(messages) == 2 - - -@pytest.mark.asyncio -async def test_xread_blocking(redis, create_redis, server, server_bin): - """Test the blocking read features""" - fields = OrderedDict( - ( - (b"field1", b"value1"), - (b"field2", b"value2"), - ) - ) - other_redis = await create_redis(server.tcp_address) - - # create blocking task in separate connection - consumer = other_redis.xread(["test_stream"], timeout=1000) - - producer_task = asyncio.Task(add_message_with_sleep(redis, "test_stream", fields)) - results = await asyncio.gather(consumer, producer_task) - - received_messages, sent_message_id = results - assert len(received_messages) == 1 - assert sent_message_id - - received_stream, received_message_id, received_fields = received_messages[0] - - assert received_stream == b"test_stream" - assert sent_message_id == received_message_id - assert fields == received_fields - - # Test that we get nothing back from an empty stream - results = await redis.xread(["another_stream"], timeout=100) - assert results == [] - - other_redis.close() - - -@pytest.mark.asyncio -async def test_xgroup_create(redis, server_bin): - # Also tests xinfo_groups() - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - info = await redis.xinfo_groups("test_stream") - assert info == [ - { - b"name": b"test_group", - b"last-delivered-id": mock.ANY, - b"pending": 0, - b"consumers": 0, - } - ] - - -@pytest.mark.asyncio -async def test_xgroup_create_mkstream(redis, server_bin): - await redis.xgroup_create("test_stream", "test_group", mkstream=True) - info = await redis.xinfo_groups("test_stream") - assert info == [ - { - b"name": b"test_group", - b"last-delivered-id": mock.ANY, - b"pending": 0, - b"consumers": 0, - } - ] - - -@pytest.mark.asyncio -async def test_xgroup_create_already_exists(redis, server_bin): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - with pytest.raises(BusyGroupError): - await redis.xgroup_create("test_stream", "test_group") - - -@pytest.mark.asyncio -async def test_xgroup_setid(redis, server_bin): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - await redis.xgroup_setid("test_stream", "test_group", "$") - - -@pytest.mark.asyncio -async def test_xgroup_destroy(redis, server_bin): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - await redis.xgroup_destroy("test_stream", "test_group") - info = await redis.xinfo_groups("test_stream") - assert not info - - -@pytest.mark.asyncio -async def test_xread_group(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - - # read all pending messages - messages = await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - assert len(messages) == 1 - stream, message_id, fields = messages[0] - assert stream == b"test_stream" - assert message_id - assert fields == {b"a": b"1"} - - -@pytest.mark.asyncio -async def test_xread_group_with_no_ack(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - - # read all pending messages - messages = await redis.xread_group( - "test_group", - "test_consumer", - ["test_stream"], - timeout=1000, - latest_ids=[">"], - no_ack=True, - ) - assert len(messages) == 1 - stream, message_id, fields = messages[0] - assert stream == b"test_stream" - assert message_id - assert fields == {b"a": b"1"} - - -@pytest.mark.asyncio -async def test_xack_and_xpending(redis): - # Test a full xread -> xack cycle, using xpending to check the status - message_id = await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - - # Nothing pending as we haven't claimed anything yet - pending_count, min_id, max_id, count = await redis.xpending( - "test_stream", "test_group" - ) - assert pending_count == 0 - - # Read the message - await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - - # It is now pending - pending_count, min_id, max_id, pel = await redis.xpending( - "test_stream", "test_group" - ) - assert pending_count == 1 - assert min_id == message_id - assert max_id == message_id - assert pel == [[b"test_consumer", b"1"]] - - # Acknowledge the message - await redis.xack("test_stream", "test_group", message_id) - - # It is no longer pending - pending_count, min_id, max_id, pel = await redis.xpending( - "test_stream", "test_group" - ) - assert pending_count == 0 - - -@pytest.mark.asyncio -async def test_xpending_get_messages(redis): - # Like test_xack_and_xpending(), but using the start/end xpending() - # params to get the messages - message_id = await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - await asyncio.sleep(0.05) - - # It is now pending - response = await redis.xpending("test_stream", "test_group", "-", "+", 10) - assert len(response) == 1 - ( - message_id, - consumer_name, - milliseconds_since_last_delivery, - num_deliveries, - ) = response[0] - - assert message_id - assert consumer_name == b"test_consumer" - assert milliseconds_since_last_delivery >= 50 - assert num_deliveries == 1 - - -@pytest.mark.asyncio -async def test_xpending_start_of_zero(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - # Doesn't raise a value error - await redis.xpending("test_stream", "test_group", 0, "+", 10) - - -@pytest.mark.asyncio -async def test_xclaim_simple(redis): - # Put a message in a pending state then reclaim it is XCLAIM - message_id = await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - - # Message is now pending - pending_count, min_id, max_id, pel = await redis.xpending( - "test_stream", "test_group" - ) - assert pending_count == 1 - assert pel == [[b"test_consumer", b"1"]] - - # Now claim it for another consumer - result = await redis.xclaim( - "test_stream", "test_group", "new_consumer", min_idle_time=0, id=message_id - ) - assert result - claimed_message_id, fields = result[0] - assert claimed_message_id == message_id - assert fields == {b"a": b"1"} - - # Ok, no see how things look - pending_count, min_id, max_id, pel = await redis.xpending( - "test_stream", "test_group" - ) - assert pending_count == 1 - assert pel == [[b"new_consumer", b"1"]] - - -@pytest.mark.asyncio -async def test_xclaim_min_idle_time_includes_messages(redis): - message_id = await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - - # Message is now pending. Wait 100ms - await asyncio.sleep(0.1) - - # Now reclaim any messages which have been idle for > 50ms - result = await redis.xclaim( - "test_stream", "test_group", "new_consumer", min_idle_time=50, id=message_id - ) - assert result - - -@pytest.mark.asyncio -async def test_xclaim_min_idle_time_excludes_messages(redis): - message_id = await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group", latest_id="0") - await redis.xread_group( - "test_group", "test_consumer", ["test_stream"], timeout=1000, latest_ids=[">"] - ) - # Message is now pending. Wait no time at all - - # Now reclaim any messages which have been idle for > 50ms - result = await redis.xclaim( - "test_stream", "test_group", "new_consumer", min_idle_time=50, id=message_id - ) - # Nothing to claim - assert not result - - -@pytest.mark.asyncio -async def test_xgroup_delconsumer(redis, create_redis, server): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - - # Note that consumers are only created once they read a message, - # not when they first connect. So make sure we consume from ID 0 - # so we get the messages we just XADDed (above) - await redis.xread_group( - "test_group", "test_consumer", streams=["test_stream"], latest_ids=[0] - ) - - response = await redis.xgroup_delconsumer( - "test_stream", "test_group", "test_consumer" - ) - assert response == 0 - info = await redis.xinfo_consumers("test_stream", "test_group") - assert not info - - -@pytest.mark.asyncio -async def test_xdel_stream(redis): - message_id = await redis.xadd("test_stream", {"a": 1}) - response = await redis.xdel("test_stream", id=message_id) - assert response >= 0 - - -@pytest.mark.asyncio -async def test_xdel_multiple_ids_stream(redis): - message_id_1 = await redis.xadd("test_stream", {"a": 1}) - message_id_2 = await redis.xadd("test_stream", {"a": 2}) - message_id_3 = await redis.xadd("test_stream", {"a": 3}) - response = await redis.xdel("test_stream", message_id_1, message_id_2, message_id_3) - assert response >= 0 - - -@pytest.mark.asyncio -async def test_xtrim_stream(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xadd("test_stream", {"b": 1}) - await redis.xadd("test_stream", {"c": 1}) - response = await redis.xtrim("test_stream", max_len=1, exact_len=False) - assert response >= 0 - - -@pytest.mark.asyncio -async def test_xlen_stream(redis): - await redis.xadd("test_stream", {"a": 1}) - response = await redis.xlen("test_stream") - assert response >= 0 - - -@pytest.mark.asyncio -async def test_xinfo_consumers(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - - # Note that consumers are only created once they read a message, - # not when they first connect. So make sure we consume from ID 0 - # so we get the messages we just XADDed (above) - await redis.xread_group( - "test_group", "test_consumer", streams=["test_stream"], latest_ids=[0] - ) - - info = await redis.xinfo_consumers("test_stream", "test_group") - assert info - assert isinstance(info[0], dict) - - -@pytest.mark.asyncio -async def test_xinfo_stream(redis): - await redis.xadd("test_stream", {"a": 1}) - await redis.xgroup_create("test_stream", "test_group") - - # Note that consumers are only created once they read a message, - # not when they first connect. So make sure we consume from ID 0 - # so we get the messages we just XADDed (above) - await redis.xread_group( - "test_group", "test_consumer", streams=["test_stream"], latest_ids=[0] - ) - - info = await redis.xinfo_stream("test_stream") - assert info - assert isinstance(info, dict) - - info = await redis.xinfo("test_stream") - assert info - assert isinstance(info, dict) - - -@pytest.mark.asyncio -async def test_xinfo_help(redis): - info = await redis.xinfo_help() - assert info - - -@pytest.mark.parametrize("param", [0.1, "1"]) -@pytest.mark.asyncio -async def test_xread_param_types(redis, param): - with pytest.raises(TypeError): - await redis.xread(["system_event_stream"], timeout=param, latest_ids=[0]) - - -def test_parse_messages_ok(): - message = [(b"123", [b"f1", b"v1", b"f2", b"v2"])] - assert parse_messages(message) == [(b"123", {b"f1": b"v1", b"f2": b"v2"})] - - -def test_parse_messages_null_fields(): - # Redis can sometimes respond with a fields value of 'null', - # so ensure we handle that sensibly - message = [(b"123", None)] - assert parse_messages(message) == [(b"123", OrderedDict())] - - -def test_parse_messages_null_message(): - # Redis can sometimes respond with a fields value of 'null', - # so ensure we handle that sensibly - message = [None] - assert parse_messages(message) == [(None, OrderedDict())] diff --git a/tests/stream_test.py b/tests/stream_test.py deleted file mode 100644 index 0153367c1..000000000 --- a/tests/stream_test.py +++ /dev/null @@ -1,46 +0,0 @@ -import pytest - -from aioredis.errors import ProtocolError, ReplyError -from aioredis.parser import PyReader -from aioredis.stream import StreamReader - - -@pytest.fixture -def reader(event_loop): - reader = StreamReader(loop=event_loop) - reader.set_parser(PyReader(protocolError=ProtocolError, replyError=ReplyError)) - return reader - - -@pytest.mark.asyncio -async def test_feed_and_parse(reader): - reader.feed_data(b"+PONG\r\n") - assert (await reader.readobj()) == b"PONG" - - -@pytest.mark.asyncio -async def test_buffer_available_after_RST(reader): - reader.feed_data(b"+PONG\r\n") - reader.set_exception(Exception()) - assert (await reader.readobj()) == b"PONG" - with pytest.raises(Exception): - await reader.readobj() - - -def test_feed_with_eof(reader): - reader.feed_eof() - with pytest.raises(AssertionError): - reader.feed_data(b"+PONG\r\n") - - -def test_feed_no_data(reader): - assert not reader.feed_data(None) - - -@pytest.mark.parametrize( - "read_method", ["read", "readline", "readuntil", "readexactly"] -) -@pytest.mark.asyncio -async def test_read_flavors_not_supported(reader, read_method): - with pytest.raises(RuntimeError): - await getattr(reader, read_method)() diff --git a/tests/string_commands_test.py b/tests/string_commands_test.py deleted file mode 100644 index d68f0f5a8..000000000 --- a/tests/string_commands_test.py +++ /dev/null @@ -1,707 +0,0 @@ -import asyncio - -import pytest - -from aioredis import ReplyError -from tests.testutils import redis_version - - -async def add(redis, key, value): - ok = await redis.set(key, value) - assert ok is True - - -@pytest.mark.asyncio -async def test_append(redis): - len_ = await redis.append("my-key", "Hello") - assert len_ == 5 - len_ = await redis.append("my-key", ", world!") - assert len_ == 13 - - val = await redis.connection.execute("GET", "my-key") - assert val == b"Hello, world!" - - with pytest.raises(TypeError): - await redis.append(None, "value") - with pytest.raises(TypeError): - await redis.append("none-key", None) - - -@pytest.mark.asyncio -async def test_bitcount(redis): - await add(redis, "my-key", b"\x00\x10\x01") - - ret = await redis.bitcount("my-key") - assert ret == 2 - ret = await redis.bitcount("my-key", 0, 0) - assert ret == 0 - ret = await redis.bitcount("my-key", 1, 1) - assert ret == 1 - ret = await redis.bitcount("my-key", 2, 2) - assert ret == 1 - ret = await redis.bitcount("my-key", 0, 1) - assert ret == 1 - ret = await redis.bitcount("my-key", 0, 2) - assert ret == 2 - ret = await redis.bitcount("my-key", 1, 2) - assert ret == 2 - ret = await redis.bitcount("my-key", 2, 3) - assert ret == 1 - ret = await redis.bitcount("my-key", 0, -1) - assert ret == 2 - - with pytest.raises(TypeError): - await redis.bitcount(None, 2, 2) - with pytest.raises(TypeError): - await redis.bitcount("my-key", None, 2) - with pytest.raises(TypeError): - await redis.bitcount("my-key", 2, None) - - -@pytest.mark.asyncio -async def test_bitop_and(redis): - key1, value1 = b"key:bitop:and:1", 5 - key2, value2 = b"key:bitop:and:2", 7 - - await add(redis, key1, value1) - await add(redis, key2, value2) - - destkey = b"key:bitop:dest" - - await redis.bitop_and(destkey, key1, key2) - test_value = await redis.get(destkey) - assert test_value == b"5" - - with pytest.raises(TypeError): - await redis.bitop_and(None, key1, key2) - with pytest.raises(TypeError): - await redis.bitop_and(destkey, None) - with pytest.raises(TypeError): - await redis.bitop_and(destkey, key1, None) - - -@pytest.mark.asyncio -async def test_bitop_or(redis): - key1, value1 = b"key:bitop:or:1", 5 - key2, value2 = b"key:bitop:or:2", 7 - - await add(redis, key1, value1) - await add(redis, key2, value2) - - destkey = b"key:bitop:dest" - - await redis.bitop_or(destkey, key1, key2) - test_value = await redis.get(destkey) - assert test_value == b"7" - - with pytest.raises(TypeError): - await redis.bitop_or(None, key1, key2) - with pytest.raises(TypeError): - await redis.bitop_or(destkey, None) - with pytest.raises(TypeError): - await redis.bitop_or(destkey, key1, None) - - -@pytest.mark.asyncio -async def test_bitop_xor(redis): - key1, value1 = b"key:bitop:xor:1", 5 - key2, value2 = b"key:bitop:xor:2", 7 - - await add(redis, key1, value1) - await add(redis, key2, value2) - - destkey = b"key:bitop:dest" - - await redis.bitop_xor(destkey, key1, key2) - test_value = await redis.get(destkey) - assert test_value == b"\x02" - - with pytest.raises(TypeError): - await redis.bitop_xor(None, key1, key2) - with pytest.raises(TypeError): - await redis.bitop_xor(destkey, None) - with pytest.raises(TypeError): - await redis.bitop_xor(destkey, key1, None) - - -@pytest.mark.asyncio -async def test_bitop_not(redis): - key1, value1 = b"key:bitop:not:1", 5 - await add(redis, key1, value1) - - destkey = b"key:bitop:dest" - - await redis.bitop_not(destkey, key1) - res = await redis.get(destkey) - assert res == b"\xca" - - with pytest.raises(TypeError): - await redis.bitop_not(None, key1) - with pytest.raises(TypeError): - await redis.bitop_not(destkey, None) - - -@redis_version(2, 8, 0, reason="BITPOS is available since redis>=2.8.0") -@pytest.mark.asyncio -async def test_bitpos(redis): - key, value = b"key:bitop", b"\xff\xf0\x00" - await add(redis, key, value) - test_value = await redis.bitpos(key, 0, end=3) - assert test_value == 12 - - test_value = await redis.bitpos(key, 0, 2, 3) - assert test_value == 16 - - key, value = b"key:bitop", b"\x00\xff\xf0" - await add(redis, key, value) - test_value = await redis.bitpos(key, 1, 0) - assert test_value == 8 - - test_value = await redis.bitpos(key, 1, 1) - assert test_value == 8 - - key, value = b"key:bitop", b"\x00\x00\x00" - await add(redis, key, value) - test_value = await redis.bitpos(key, 1, 0) - assert test_value == -1 - - test_value = await redis.bitpos(b"not:" + key, 1) - assert test_value == -1 - - with pytest.raises(TypeError): - test_value = await redis.bitpos(None, 1) - - with pytest.raises(ValueError): - test_value = await redis.bitpos(key, 7) - - -@pytest.mark.asyncio -async def test_decr(redis): - await redis.delete("key") - - res = await redis.decr("key") - assert res == -1 - res = await redis.decr("key") - assert res == -2 - - with pytest.raises(ReplyError): - await add(redis, "key", "val") - await redis.decr("key") - with pytest.raises(ReplyError): - await add(redis, "key", 1.0) - await redis.decr("key") - with pytest.raises(TypeError): - await redis.decr(None) - - -@pytest.mark.asyncio -async def test_decrby(redis): - await redis.delete("key") - - res = await redis.decrby("key", 1) - assert res == -1 - res = await redis.decrby("key", 10) - assert res == -11 - res = await redis.decrby("key", -1) - assert res == -10 - - with pytest.raises(ReplyError): - await add(redis, "key", "val") - await redis.decrby("key", 1) - with pytest.raises(ReplyError): - await add(redis, "key", 1.0) - await redis.decrby("key", 1) - with pytest.raises(TypeError): - await redis.decrby(None, 1) - with pytest.raises(TypeError): - await redis.decrby("key", None) - - -@pytest.mark.asyncio -async def test_get(redis): - await add(redis, "my-key", "value") - ret = await redis.get("my-key") - assert ret == b"value" - - await add(redis, "my-key", 123) - ret = await redis.get("my-key") - assert ret == b"123" - - ret = await redis.get("bad-key") - assert ret is None - - with pytest.raises(TypeError): - await redis.get(None) - - -@pytest.mark.asyncio -async def test_getbit(redis): - key, value = b"key:getbit", 10 - await add(redis, key, value) - - result = await redis.setbit(key, 7, 1) - assert result == 1 - - test_value = await redis.getbit(key, 0) - assert test_value == 0 - - test_value = await redis.getbit(key, 7) - assert test_value == 1 - - test_value = await redis.getbit(b"not:" + key, 7) - assert test_value == 0 - - test_value = await redis.getbit(key, 100) - assert test_value == 0 - - with pytest.raises(TypeError): - await redis.getbit(None, 0) - with pytest.raises(TypeError): - await redis.getbit(key, b"one") - with pytest.raises(ValueError): - await redis.getbit(key, -7) - - -@pytest.mark.asyncio -async def test_getrange(redis): - key, value = b"key:getrange", b"This is a string" - await add(redis, key, value) - - test_value = await redis.getrange(key, 0, 3) - assert test_value == b"This" - - test_value = await redis.getrange(key, -3, -1) - assert test_value == b"ing" - - test_value = await redis.getrange(key, 0, -1) - assert test_value == b"This is a string" - test_value = await redis.getrange(key, 0, -1, encoding="utf-8") - assert test_value == "This is a string" - - test_value = await redis.getrange(key, 10, 100) - assert test_value == b"string" - test_value = await redis.getrange(key, 10, 100, encoding="utf-8") - assert test_value == "string" - - test_value = await redis.getrange(key, 50, 100) - assert test_value == b"" - - with pytest.raises(TypeError): - await redis.getrange(None, 0, 3) - with pytest.raises(TypeError): - await redis.getrange(key, b"one", 3) - with pytest.raises(TypeError): - await redis.getrange(key, 0, b"seven") - - -@pytest.mark.asyncio -async def test_getset(redis): - key, value = b"key:getset", b"hello" - await add(redis, key, value) - - test_value = await redis.getset(key, b"asyncio") - assert test_value == b"hello" - - test_value = await redis.get(key) - assert test_value == b"asyncio" - - test_value = await redis.getset(key, "world", encoding="utf-8") - assert test_value == "asyncio" - - test_value = await redis.getset(b"not:" + key, b"asyncio") - assert test_value is None - - test_value = await redis.get(b"not:" + key) - assert test_value == b"asyncio" - - with pytest.raises(TypeError): - await redis.getset(None, b"asyncio") - - -@pytest.mark.asyncio -async def test_incr(redis): - await redis.delete("key") - - res = await redis.incr("key") - assert res == 1 - res = await redis.incr("key") - assert res == 2 - - with pytest.raises(ReplyError): - await add(redis, "key", "val") - await redis.incr("key") - with pytest.raises(ReplyError): - await add(redis, "key", 1.0) - await redis.incr("key") - with pytest.raises(TypeError): - await redis.incr(None) - - -@pytest.mark.asyncio -async def test_incrby(redis): - await redis.delete("key") - - res = await redis.incrby("key", 1) - assert res == 1 - res = await redis.incrby("key", 10) - assert res == 11 - res = await redis.incrby("key", -1) - assert res == 10 - - with pytest.raises(ReplyError): - await add(redis, "key", "val") - await redis.incrby("key", 1) - with pytest.raises(ReplyError): - await add(redis, "key", 1.0) - await redis.incrby("key", 1) - with pytest.raises(TypeError): - await redis.incrby(None, 1) - with pytest.raises(TypeError): - await redis.incrby("key", None) - - -@pytest.mark.asyncio -async def test_incrbyfloat(redis): - await redis.delete("key") - - res = await redis.incrbyfloat("key", 1.0) - assert res == 1.0 - res = await redis.incrbyfloat("key", 10.5) - assert res == 11.5 - res = await redis.incrbyfloat("key", -1.0) - assert res == 10.5 - await add(redis, "key", 2) - res = await redis.incrbyfloat("key", 0.5) - assert res == 2.5 - - with pytest.raises(ReplyError): - await add(redis, "key", "val") - await redis.incrbyfloat("key", 1.0) - with pytest.raises(TypeError): - await redis.incrbyfloat(None, 1.0) - with pytest.raises(TypeError): - await redis.incrbyfloat("key", None) - with pytest.raises(TypeError): - await redis.incrbyfloat("key", 1) - with pytest.raises(TypeError): - await redis.incrbyfloat("key", "1.0") - - -@pytest.mark.asyncio -async def test_mget(redis): - key1, value1 = b"foo", b"bar" - key2, value2 = b"baz", b"bzz" - await add(redis, key1, value1) - await add(redis, key2, value2) - - res = await redis.mget("key") - assert res == [None] - res = await redis.mget("key", "key") - assert res == [None, None] - - res = await redis.mget(key1, key2) - assert res == [value1, value2] - - # test encoding param - res = await redis.mget(key1, key2, encoding="utf-8") - assert res == ["bar", "bzz"] - - with pytest.raises(TypeError): - await redis.mget(None, key2) - with pytest.raises(TypeError): - await redis.mget(key1, None) - - -@pytest.mark.asyncio -async def test_mset(redis): - key1, value1 = b"key:mset:1", b"hello" - key2, value2 = b"key:mset:2", b"world" - - await redis.mset(key1, value1, key2, value2) - - test_value = await redis.mget(key1, key2) - assert test_value == [value1, value2] - - await redis.mset(b"other:" + key1, b"other:" + value1) - test_value = await redis.get(b"other:" + key1) - assert test_value == b"other:" + value1 - - with pytest.raises(TypeError): - await redis.mset(None, value1) - with pytest.raises(TypeError): - await redis.mset(key1, value1, key1) - - -@pytest.mark.asyncio -async def test_mset_with_dict(redis): - array = [str(n) for n in range(10)] - _dict = dict.fromkeys( - array, - "default value", - ) - - await redis.mset(_dict) - - test_values = await redis.mget(*_dict.keys()) - assert test_values == [str.encode(val) for val in _dict.values()] - - with pytest.raises(TypeError): - await redis.mset( - "param", - ) - - -@pytest.mark.asyncio -async def test_msetnx(redis): - key1, value1 = b"key:msetnx:1", b"Hello" - key2, value2 = b"key:msetnx:2", b"there" - key3, value3 = b"key:msetnx:3", b"world" - - res = await redis.msetnx(key1, value1, key2, value2) - assert res == 1 - res = await redis.mget(key1, key2) - assert res == [value1, value2] - res = await redis.msetnx(key2, value2, key3, value3) - assert res == 0 - res = await redis.mget(key1, key2, key3) - assert res == [value1, value2, None] - - with pytest.raises(TypeError): - await redis.msetnx(None, value1) - with pytest.raises(TypeError): - await redis.msetnx(key1, value1, key2) - - -@pytest.mark.asyncio -async def test_psetex(redis): - key, value = b"key:psetex:1", b"Hello" - # test expiration in milliseconds - tr = redis.multi_exec() - fut1 = tr.psetex(key, 10, value) - fut2 = tr.get(key) - await tr.execute() - await fut1 - test_value = await fut2 - assert test_value == value - - await asyncio.sleep(0.050) - test_value = await redis.get(key) - assert test_value is None - - with pytest.raises(TypeError): - await redis.psetex(None, 10, value) - with pytest.raises(TypeError): - await redis.psetex(key, 7.5, value) - - -@pytest.mark.asyncio -async def test_set(redis): - ok = await redis.set("my-key", "value") - assert ok is True - - ok = await redis.set(b"my-key", b"value") - assert ok is True - - ok = await redis.set(bytearray(b"my-key"), bytearray(b"value")) - assert ok is True - - with pytest.raises(TypeError): - await redis.set(None, "value") - - -@pytest.mark.asyncio -async def test_set_expire(redis): - key, value = b"key:set:expire", b"foo" - # test expiration in milliseconds - tr = redis.multi_exec() - fut1 = tr.set(key, value, pexpire=10) - fut2 = tr.get(key) - await tr.execute() - await fut1 - result_1 = await fut2 - assert result_1 == value - await asyncio.sleep(0.050) - result_2 = await redis.get(key) - assert result_2 is None - - # same thing but timeout in seconds - tr = redis.multi_exec() - fut1 = tr.set(key, value, expire=1) - fut2 = tr.get(key) - await tr.execute() - await fut1 - result_3 = await fut2 - assert result_3 == value - await asyncio.sleep(1.050) - result_4 = await redis.get(key) - assert result_4 is None - - -@pytest.mark.asyncio -async def test_set_only_if_not_exists(redis): - key, value = b"key:set:only_if_not_exists", b"foo" - await redis.set(key, value, exist=redis.SET_IF_NOT_EXIST) - result_1 = await redis.get(key) - assert result_1 == value - - # new values not set cos, values exists - await redis.set(key, "foo2", exist=redis.SET_IF_NOT_EXIST) - result_2 = await redis.get(key) - # nothing changed result is same "foo" - assert result_2 == value - - -@pytest.mark.asyncio -async def test_set_only_if_exists(redis): - key, value = b"key:set:only_if_exists", b"only_if_exists:foo" - # ensure that such key does not exits, and value not sets - await redis.delete(key) - await redis.set(key, value, exist=redis.SET_IF_EXIST) - result_1 = await redis.get(key) - assert result_1 is None - - # ensure key exits, and value updates - await redis.set(key, value) - await redis.set(key, b"foo", exist=redis.SET_IF_EXIST) - result_2 = await redis.get(key) - assert result_2 == b"foo" - - -@pytest.mark.asyncio -async def test_set_wrong_input(redis): - key, value = b"key:set:", b"foo" - - with pytest.raises(TypeError): - await redis.set(None, value) - with pytest.raises(TypeError): - await redis.set(key, value, expire=7.8) - with pytest.raises(TypeError): - await redis.set(key, value, pexpire=7.8) - - -@pytest.mark.asyncio -async def test_setbit(redis): - key = b"key:setbit" - result = await redis.setbit(key, 7, 1) - assert result == 0 - test_value = await redis.getbit(key, 7) - assert test_value == 1 - - with pytest.raises(TypeError): - await redis.setbit(None, 7, 1) - with pytest.raises(TypeError): - await redis.setbit(key, 7.5, 1) - with pytest.raises(ValueError): - await redis.setbit(key, -1, 1) - with pytest.raises(ValueError): - await redis.setbit(key, 1, 7) - - -@pytest.mark.asyncio -async def test_setex(redis): - key, value = b"key:setex:1", b"Hello" - tr = redis.multi_exec() - fut1 = tr.setex(key, 1, value) - fut2 = tr.get(key) - await tr.execute() - await fut1 - test_value = await fut2 - assert test_value == value - await asyncio.sleep(1.050) - test_value = await redis.get(key) - assert test_value is None - - tr = redis.multi_exec() - fut1 = tr.setex(key, 0.1, value) - fut2 = tr.get(key) - await tr.execute() - await fut1 - test_value = await fut2 - assert test_value == value - await asyncio.sleep(0.50) - test_value = await redis.get(key) - assert test_value is None - - with pytest.raises(TypeError): - await redis.setex(None, 1, value) - with pytest.raises(TypeError): - await redis.setex(key, b"one", value) - - -@pytest.mark.asyncio -async def test_setnx(redis): - key, value = b"key:setnx:1", b"Hello" - # set fresh new value - test_value = await redis.setnx(key, value) - # 1 means value has been set - assert test_value == 1 - # fetch installed value just to be sure - test_value = await redis.get(key) - assert test_value == value - # try to set new value on same key - test_value = await redis.setnx(key, b"other:" + value) - # 0 means value has not been set - assert test_value == 0 - # make sure that value was not changed - test_value = await redis.get(key) - assert test_value == value - - with pytest.raises(TypeError): - await redis.setnx(None, value) - - -@pytest.mark.asyncio -async def test_setrange(redis): - key, value = b"key:setrange", b"Hello World" - await add(redis, key, value) - test_value = await redis.setrange(key, 6, b"Redis") - assert test_value == 11 - test_value = await redis.get(key) - assert test_value == b"Hello Redis" - - test_value = await redis.setrange(b"not:" + key, 6, b"Redis") - assert test_value == 11 - test_value = await redis.get(b"not:" + key) - assert test_value == b"\x00\x00\x00\x00\x00\x00Redis" - - with pytest.raises(TypeError): - await redis.setrange(None, 6, b"Redis") - with pytest.raises(TypeError): - await redis.setrange(key, 0.7, b"Redis") - with pytest.raises(ValueError): - await redis.setrange(key, -1, b"Redis") - - -@pytest.mark.asyncio -async def test_strlen(redis): - key, value = b"key:strlen", b"asyncio" - await add(redis, key, value) - test_value = await redis.strlen(key) - assert test_value == len(value) - - test_value = await redis.strlen(b"not:" + key) - assert test_value == 0 - - with pytest.raises(TypeError): - await redis.strlen(None) - - -@pytest.mark.asyncio -async def test_cancel_hang(redis): - exists_coro = redis.execute("EXISTS", b"key:test1") - exists_coro.cancel() - exists_check = await redis.exists(b"key:test2") - assert not exists_check - - -@pytest.mark.asyncio -async def test_set_enc(create_redis, server): - redis = await create_redis(server.tcp_address, encoding="utf-8") - TEST_KEY = "my-key" - ok = await redis.set(TEST_KEY, "value") - assert ok is True - - with pytest.raises(TypeError): - await redis.set(None, "value") - - await redis.delete(TEST_KEY) diff --git a/tests/task_cancellation_test.py b/tests/task_cancellation_test.py deleted file mode 100644 index 871b44fa5..000000000 --- a/tests/task_cancellation_test.py +++ /dev/null @@ -1,20 +0,0 @@ -import asyncio - -import pytest - - -@pytest.mark.asyncio -async def test_future_cancellation(create_connection, event_loop, server): - conn = await create_connection(server.tcp_address) - - ts = event_loop.time() - fut = conn.execute("BLPOP", "some-list", 5) - with pytest.raises(asyncio.TimeoutError): - await asyncio.wait_for(fut, 1) - assert fut.cancelled() - - # NOTE: Connection becomes available only after timeout expires - await conn.execute("TIME") - dt = int(event_loop.time() - ts) - assert dt in {4, 5, 6} - # self.assertAlmostEqual(dt, 5.0, delta=1) # this fails too often diff --git a/tests/test_commands.py b/tests/test_commands.py new file mode 100644 index 000000000..a5ab258e9 --- /dev/null +++ b/tests/test_commands.py @@ -0,0 +1,3048 @@ +import binascii +import datetime +import re +import time +from string import ascii_letters + +import pytest + +import aioredis +from aioredis import exceptions +from aioredis.client import parse_info + +from .conftest import ( + REDIS_6_VERSION, + _get_client, + skip_if_server_version_gte, + skip_if_server_version_lt, + skip_unless_arch_bits, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture() +async def slowlog(r: aioredis.Redis, event_loop): + current_config = await r.config_get() + old_slower_than_value = current_config["slowlog-log-slower-than"] + old_max_legnth_value = current_config["slowlog-max-len"] + + await r.config_set("slowlog-log-slower-than", 0) + await r.config_set("slowlog-max-len", 128) + + yield + + await r.config_set("slowlog-log-slower-than", old_slower_than_value) + await r.config_set("slowlog-max-len", old_max_legnth_value) + + +async def redis_server_time(client: aioredis.Redis): + seconds, milliseconds = await client.time() + timestamp = float(f"{seconds}.{milliseconds}") + return datetime.datetime.fromtimestamp(timestamp) + + +async def get_stream_message(client: aioredis.Redis, stream: str, message_id: str): + """Fetch a stream message and format it as a (message_id, fields) pair""" + response = await client.xrange(stream, min=message_id, max=message_id) + assert len(response) == 1 + return response[0] + + +# RESPONSE CALLBACKS +class TestResponseCallbacks: + """Tests for the response callback system""" + + async def test_response_callbacks(self, r: aioredis.Redis): + assert r.response_callbacks == aioredis.Redis.RESPONSE_CALLBACKS + assert id(r.response_callbacks) != id(aioredis.Redis.RESPONSE_CALLBACKS) + r.set_response_callback("GET", lambda x: "static") + await r.set("a", "foo") + assert await r.get("a") == "static" + + async def test_case_insensitive_command_names(self, r: aioredis.Redis): + assert r.response_callbacks["del"] == r.response_callbacks["DEL"] + + +class TestRedisCommands: + async def test_command_on_invalid_key_type(self, r: aioredis.Redis): + await r.lpush("a", "1") + with pytest.raises(aioredis.ResponseError): + await r.get("a") + + # SERVER INFORMATION + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_cat_no_category(self, r: aioredis.Redis): + categories = await r.acl_cat() + assert isinstance(categories, list) + assert "read" in categories + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_cat_with_category(self, r: aioredis.Redis): + commands = await r.acl_cat("read") + assert isinstance(commands, list) + assert "get" in commands + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_deluser(self, r: aioredis.Redis, request, event_loop): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + assert await r.acl_deluser(username) == 0 + assert await r.acl_setuser(username, enabled=False, reset=True) + assert await r.acl_deluser(username) == 1 + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_genpass(self, r: aioredis.Redis): + password = await r.acl_genpass() + assert isinstance(password, str) + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_getuser_setuser(self, r: aioredis.Redis, request, event_loop): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + # test enabled=False + assert await r.acl_setuser(username, enabled=False, reset=True) + assert await r.acl_getuser(username) == { + "categories": ["-@all"], + "commands": [], + "enabled": False, + "flags": ["off"], + "keys": [], + "passwords": [], + } + + # test nopass=True + assert await r.acl_setuser(username, enabled=True, reset=True, nopass=True) + assert await r.acl_getuser(username) == { + "categories": ["-@all"], + "commands": [], + "enabled": True, + "flags": ["on", "nopass"], + "keys": [], + "passwords": [], + } + + # test all args + assert await r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=["+pass1", "+pass2"], + categories=["+set", "+@hash", "-geo"], + commands=["+get", "+mget", "-hset"], + keys=["cache:*", "objects:*"], + ) + acl = await r.acl_getuser(username) + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["commands"]) == {"+get", "+mget", "-hset"} + assert acl["enabled"] is True + assert acl["flags"] == ["on"] + assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert len(acl["passwords"]) == 2 + + # test reset=False keeps existing ACL and applies new ACL on top + assert await r.acl_setuser( + username, + enabled=True, + reset=True, + passwords=["+pass1"], + categories=["+@set"], + commands=["+get"], + keys=["cache:*"], + ) + assert await r.acl_setuser( + username, + enabled=True, + passwords=["+pass2"], + categories=["+@hash"], + commands=["+mget"], + keys=["objects:*"], + ) + acl = await r.acl_getuser(username) + assert set(acl["categories"]) == {"-@all", "+@set", "+@hash"} + assert set(acl["commands"]) == {"+get", "+mget"} + assert acl["enabled"] is True + assert acl["flags"] == ["on"] + assert set(acl["keys"]) == {b"cache:*", b"objects:*"} + assert len(acl["passwords"]) == 2 + + # test removal of passwords + assert await r.acl_setuser( + username, enabled=True, reset=True, passwords=["+pass1", "+pass2"] + ) + assert len((await r.acl_getuser(username))["passwords"]) == 2 + assert await r.acl_setuser(username, enabled=True, passwords=["-pass2"]) + assert len((await r.acl_getuser(username))["passwords"]) == 1 + + # Resets and tests that hashed passwords are set properly. + hashed_password = ( + "5e884898da28047151d0e56f8dc629" "2773603d0d6aabbdd62a11ef721d1542d8" + ) + assert await r.acl_setuser( + username, enabled=True, reset=True, hashed_passwords=["+" + hashed_password] + ) + acl = await r.acl_getuser(username) + assert acl["passwords"] == [hashed_password] + + # test removal of hashed passwords + assert await r.acl_setuser( + username, + enabled=True, + reset=True, + hashed_passwords=["+" + hashed_password], + passwords=["+pass1"], + ) + assert len((await r.acl_getuser(username))["passwords"]) == 2 + assert await r.acl_setuser( + username, enabled=True, hashed_passwords=["-" + hashed_password] + ) + assert len((await r.acl_getuser(username))["passwords"]) == 1 + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_list(self, r: aioredis.Redis, request, event_loop): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + assert await r.acl_setuser(username, enabled=False, reset=True) + users = await r.acl_list() + assert "user %s off -@all" % username in users + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_log(self, r: aioredis.Redis, request, event_loop): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + await r.acl_setuser( + username, + enabled=True, + reset=True, + commands=["+get", "+set", "+select"], + keys=["cache:*"], + nopass=True, + ) + await r.acl_log_reset() + + user_client = await _get_client( + aioredis.Redis, request, event_loop, flushdb=False, username=username + ) + + # Valid operation and key + assert await user_client.set("cache:0", 1) + assert await user_client.get("cache:0") == b"1" + + # Invalid key + with pytest.raises(exceptions.NoPermissionError): + await user_client.get("violated_cache:0") + + # Invalid operation + with pytest.raises(exceptions.NoPermissionError): + await user_client.hset("cache:0", "hkey", "hval") + + assert isinstance(await r.acl_log(), list) + assert len(await r.acl_log()) == 2 + assert len(await r.acl_log(count=1)) == 1 + assert isinstance((await r.acl_log())[0], dict) + assert "client-info" in (await r.acl_log(count=1))[0] + assert await r.acl_log_reset() + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_setuser_categories_without_prefix_fails( + self, r: aioredis.Redis, request, event_loop + ): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + with pytest.raises(exceptions.DataError): + await r.acl_setuser(username, categories=["list"]) + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_setuser_commands_without_prefix_fails( + self, r: aioredis.Redis, request, event_loop + ): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + with pytest.raises(exceptions.DataError): + await r.acl_setuser(username, commands=["get"]) + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_setuser_add_passwords_and_nopass_fails( + self, r: aioredis.Redis, request, event_loop + ): + username = "redis-py-user" + + def teardown(): + coro = r.acl_deluser(username) + if event_loop.is_running(): + event_loop.create_task(coro) + else: + event_loop.run_until_complete(coro) + + request.addfinalizer(teardown) + + with pytest.raises(exceptions.DataError): + await r.acl_setuser(username, passwords="+mypass", nopass=True) + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_users(self, r: aioredis.Redis): + users = await r.acl_users() + assert isinstance(users, list) + assert len(users) > 0 + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_acl_whoami(self, r: aioredis.Redis): + username = await r.acl_whoami() + assert isinstance(username, str) + + async def test_client_list(self, r: aioredis.Redis): + clients = await r.client_list() + assert isinstance(clients[0], dict) + assert "addr" in clients[0] + + @skip_if_server_version_lt("5.0.0") + async def test_client_list_type(self, r: aioredis.Redis): + with pytest.raises(exceptions.RedisError): + await r.client_list(_type="not a client type") + for client_type in ["normal", "master", "replica", "pubsub"]: + clients = await r.client_list(_type=client_type) + assert isinstance(clients, list) + + @skip_if_server_version_lt("5.0.0") + async def test_client_id(self, r: aioredis.Redis): + assert await r.client_id() > 0 + + @skip_if_server_version_lt("5.0.0") + async def test_client_unblock(self, r: aioredis.Redis): + myid = await r.client_id() + assert not await r.client_unblock(myid) + assert not await r.client_unblock(myid, error=True) + assert not await r.client_unblock(myid, error=False) + + @skip_if_server_version_lt("2.6.9") + async def test_client_getname(self, r: aioredis.Redis): + assert await r.client_getname() is None + + @skip_if_server_version_lt("2.6.9") + async def test_client_setname(self, r: aioredis.Redis): + assert await r.client_setname("redis_py_test") + assert await r.client_getname() == "redis_py_test" + + @skip_if_server_version_lt("2.6.9") + async def test_client_kill(self, r: aioredis.Redis, r2): + await r.client_setname("redis-py-c1") + await r2.client_setname("redis-py-c2") + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 2 + + clients_by_name = {client.get("name"): client for client in clients} + + client_addr = clients_by_name["redis-py-c2"].get("addr") + assert await r.client_kill(client_addr) is True + + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 1 + assert clients[0].get("name") == "redis-py-c1" + + @skip_if_server_version_lt("2.8.12") + async def test_client_kill_filter_invalid_params(self, r: aioredis.Redis): + # empty + with pytest.raises(exceptions.DataError): + await r.client_kill_filter() + + # invalid skipme + with pytest.raises(exceptions.DataError): + await r.client_kill_filter(skipme="yeah") + + # invalid type + with pytest.raises(exceptions.DataError): + await r.client_kill_filter(_type="caster") + + @skip_if_server_version_lt("2.8.12") + async def test_client_kill_filter_by_id(self, r: aioredis.Redis, r2): + await r.client_setname("redis-py-c1") + await r2.client_setname("redis-py-c2") + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 2 + + clients_by_name = {client.get("name"): client for client in clients} + + client_2_id = clients_by_name["redis-py-c2"].get("id") + resp = await r.client_kill_filter(_id=client_2_id) + assert resp == 1 + + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 1 + assert clients[0].get("name") == "redis-py-c1" + + @skip_if_server_version_lt("2.8.12") + async def test_client_kill_filter_by_addr(self, r: aioredis.Redis, r2): + await r.client_setname("redis-py-c1") + await r2.client_setname("redis-py-c2") + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 2 + + clients_by_name = {client.get("name"): client for client in clients} + + client_2_addr = clients_by_name["redis-py-c2"].get("addr") + resp = await r.client_kill_filter(addr=client_2_addr) + assert resp == 1 + + clients = [ + client + for client in await r.client_list() + if client.get("name") in ["redis-py-c1", "redis-py-c2"] + ] + assert len(clients) == 1 + assert clients[0].get("name") == "redis-py-c1" + + @skip_if_server_version_lt("2.6.9") + async def test_client_list_after_client_setname(self, r: aioredis.Redis): + await r.client_setname("redis_py_test") + clients = await r.client_list() + # we don't know which client ours will be + assert "redis_py_test" in [c["name"] for c in clients] + + @skip_if_server_version_lt("2.9.50") + async def test_client_pause(self, r: aioredis.Redis): + assert await r.client_pause(1) + assert await r.client_pause(timeout=1) + with pytest.raises(exceptions.RedisError): + await r.client_pause(timeout="not an integer") + + async def test_config_get(self, r: aioredis.Redis): + data = await r.config_get() + assert "maxmemory" in data + assert data["maxmemory"].isdigit() + + async def test_config_resetstat(self, r: aioredis.Redis): + await r.ping() + prior_commands_processed = int((await r.info())["total_commands_processed"]) + assert prior_commands_processed >= 1 + await r.config_resetstat() + reset_commands_processed = int((await r.info())["total_commands_processed"]) + assert reset_commands_processed < prior_commands_processed + + async def test_config_set(self, r: aioredis.Redis): + data = await r.config_get() + rdbname = data["dbfilename"] + try: + assert await r.config_set("dbfilename", "redis_py_test.rdb") + assert (await r.config_get())["dbfilename"] == "redis_py_test.rdb" + finally: + assert await r.config_set("dbfilename", rdbname) + + async def test_dbsize(self, r: aioredis.Redis): + await r.set("a", "foo") + await r.set("b", "bar") + assert await r.dbsize() == 2 + + async def test_echo(self, r: aioredis.Redis): + assert await r.echo("foo bar") == b"foo bar" + + async def test_info(self, r: aioredis.Redis): + await r.set("a", "foo") + await r.set("b", "bar") + info = await r.info() + assert isinstance(info, dict) + assert info["db9"]["keys"] == 2 + + async def test_lastsave(self, r: aioredis.Redis): + assert isinstance(await r.lastsave(), datetime.datetime) + + async def test_object(self, r: aioredis.Redis): + await r.set("a", "foo") + assert isinstance(await r.object("refcount", "a"), int) + assert isinstance(await r.object("idletime", "a"), int) + assert await r.object("encoding", "a") in (b"raw", b"embstr") + assert await r.object("idletime", "invalid-key") is None + + async def test_ping(self, r: aioredis.Redis): + assert await r.ping() + + async def test_slowlog_get(self, r: aioredis.Redis, slowlog): + assert await r.slowlog_reset() + unicode_string = chr(3456) + "abcd" + chr(3421) + await r.get(unicode_string) + slowlog = await r.slowlog_get() + assert isinstance(slowlog, list) + commands = [log["command"] for log in slowlog] + + get_command = b" ".join((b"GET", unicode_string.encode("utf-8"))) + assert get_command in commands + assert b"SLOWLOG RESET" in commands + # the order should be ['GET ', 'SLOWLOG RESET'], + # but if other clients are executing commands at the same time, there + # could be commands, before, between, or after, so just check that + # the two we care about are in the appropriate order. + assert commands.index(get_command) < commands.index(b"SLOWLOG RESET") + + # make sure other attributes are typed correctly + assert isinstance(slowlog[0]["start_time"], int) + assert isinstance(slowlog[0]["duration"], int) + + async def test_slowlog_get_limit(self, r: aioredis.Redis, slowlog): + assert await r.slowlog_reset() + await r.get("foo") + slowlog = await r.slowlog_get(1) + assert isinstance(slowlog, list) + # only one command, based on the number we passed to slowlog_get() + assert len(slowlog) == 1 + + async def test_slowlog_length(self, r: aioredis.Redis, slowlog): + await r.get("foo") + assert isinstance(await r.slowlog_len(), int) + + @skip_if_server_version_lt("2.6.0") + async def test_time(self, r: aioredis.Redis): + t = await r.time() + assert len(t) == 2 + assert isinstance(t[0], int) + assert isinstance(t[1], int) + + # BASIC KEY COMMANDS + async def test_append(self, r: aioredis.Redis): + assert await r.append("a", "a1") == 2 + assert await r.get("a") == b"a1" + assert await r.append("a", "a2") == 4 + assert await r.get("a") == b"a1a2" + + @skip_if_server_version_lt("2.6.0") + async def test_bitcount(self, r: aioredis.Redis): + await r.setbit("a", 5, True) + assert await r.bitcount("a") == 1 + await r.setbit("a", 6, True) + assert await r.bitcount("a") == 2 + await r.setbit("a", 5, False) + assert await r.bitcount("a") == 1 + await r.setbit("a", 9, True) + await r.setbit("a", 17, True) + await r.setbit("a", 25, True) + await r.setbit("a", 33, True) + assert await r.bitcount("a") == 5 + assert await r.bitcount("a", 0, -1) == 5 + assert await r.bitcount("a", 2, 3) == 2 + assert await r.bitcount("a", 2, -1) == 3 + assert await r.bitcount("a", -2, -1) == 2 + assert await r.bitcount("a", 1, 1) == 1 + + @skip_if_server_version_lt("2.6.0") + async def test_bitop_not_empty_string(self, r: aioredis.Redis): + await r.set("a", "") + await r.bitop("not", "r", "a") + assert await r.get("r") is None + + @skip_if_server_version_lt("2.6.0") + async def test_bitop_not(self, r: aioredis.Redis): + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + await r.set("a", test_str) + await r.bitop("not", "r", "a") + assert int(binascii.hexlify(await r.get("r")), 16) == correct + + @skip_if_server_version_lt("2.6.0") + async def test_bitop_not_in_place(self, r: aioredis.Redis): + test_str = b"\xAA\x00\xFF\x55" + correct = ~0xAA00FF55 & 0xFFFFFFFF + await r.set("a", test_str) + await r.bitop("not", "a", "a") + assert int(binascii.hexlify(await r.get("a")), 16) == correct + + @skip_if_server_version_lt("2.6.0") + async def test_bitop_single_string(self, r: aioredis.Redis): + test_str = b"\x01\x02\xFF" + await r.set("a", test_str) + await r.bitop("and", "res1", "a") + await r.bitop("or", "res2", "a") + await r.bitop("xor", "res3", "a") + assert await r.get("res1") == test_str + assert await r.get("res2") == test_str + assert await r.get("res3") == test_str + + @skip_if_server_version_lt("2.6.0") + async def test_bitop_string_operands(self, r: aioredis.Redis): + await r.set("a", b"\x01\x02\xFF\xFF") + await r.set("b", b"\x01\x02\xFF") + await r.bitop("and", "res1", "a", "b") + await r.bitop("or", "res2", "a", "b") + await r.bitop("xor", "res3", "a", "b") + assert int(binascii.hexlify(await r.get("res1")), 16) == 0x0102FF00 + assert int(binascii.hexlify(await r.get("res2")), 16) == 0x0102FFFF + assert int(binascii.hexlify(await r.get("res3")), 16) == 0x000000FF + + @skip_if_server_version_lt("2.8.7") + async def test_bitpos(self, r: aioredis.Redis): + key = "key:bitpos" + await r.set(key, b"\xff\xf0\x00") + assert await r.bitpos(key, 0) == 12 + assert await r.bitpos(key, 0, 2, -1) == 16 + assert await r.bitpos(key, 0, -2, -1) == 12 + await r.set(key, b"\x00\xff\xf0") + assert await r.bitpos(key, 1, 0) == 8 + assert await r.bitpos(key, 1, 1) == 8 + await r.set(key, b"\x00\x00\x00") + assert await r.bitpos(key, 1) == -1 + + @skip_if_server_version_lt("2.8.7") + async def test_bitpos_wrong_arguments(self, r: aioredis.Redis): + key = "key:bitpos:wrong:args" + await r.set(key, b"\xff\xf0\x00") + with pytest.raises(exceptions.RedisError): + await r.bitpos(key, 0, end=1) == 12 + with pytest.raises(exceptions.RedisError): + await r.bitpos(key, 7) == 12 + + async def test_decr(self, r: aioredis.Redis): + assert await r.decr("a") == -1 + assert await r.get("a") == b"-1" + assert await r.decr("a") == -2 + assert await r.get("a") == b"-2" + assert await r.decr("a", amount=5) == -7 + assert await r.get("a") == b"-7" + + async def test_decrby(self, r: aioredis.Redis): + assert await r.decrby("a", amount=2) == -2 + assert await r.decrby("a", amount=3) == -5 + assert await r.get("a") == b"-5" + + async def test_delete(self, r: aioredis.Redis): + assert await r.delete("a") == 0 + await r.set("a", "foo") + assert await r.delete("a") == 1 + + async def test_delete_with_multiple_keys(self, r: aioredis.Redis): + await r.set("a", "foo") + await r.set("b", "bar") + assert await r.delete("a", "b") == 2 + assert await r.get("a") is None + assert await r.get("b") is None + + async def test_delitem(self, r: aioredis.Redis): + await r.set("a", "foo") + await r.delete("a") + assert await r.get("a") is None + + @skip_if_server_version_lt("4.0.0") + async def test_unlink(self, r: aioredis.Redis): + assert await r.unlink("a") == 0 + await r.set("a", "foo") + assert await r.unlink("a") == 1 + assert await r.get("a") is None + + @skip_if_server_version_lt("4.0.0") + async def test_unlink_with_multiple_keys(self, r: aioredis.Redis): + await r.set("a", "foo") + await r.set("b", "bar") + assert await r.unlink("a", "b") == 2 + assert await r.get("a") is None + assert await r.get("b") is None + + @skip_if_server_version_lt("2.6.0") + async def test_dump_and_restore(self, r: aioredis.Redis): + await r.set("a", "foo") + dumped = await r.dump("a") + await r.delete("a") + await r.restore("a", 0, dumped) + assert await r.get("a") == b"foo" + + @skip_if_server_version_lt("3.0.0") + async def test_dump_and_restore_and_replace(self, r: aioredis.Redis): + await r.set("a", "bar") + dumped = await r.dump("a") + with pytest.raises(aioredis.ResponseError): + await r.restore("a", 0, dumped) + + await r.restore("a", 0, dumped, replace=True) + assert await r.get("a") == b"bar" + + @skip_if_server_version_lt("5.0.0") + async def test_dump_and_restore_absttl(self, r: aioredis.Redis): + await r.set("a", "foo") + dumped = await r.dump("a") + await r.delete("a") + ttl = int( + (await redis_server_time(r) + datetime.timedelta(minutes=1)).timestamp() + * 1000 + ) + await r.restore("a", ttl, dumped, absttl=True) + assert await r.get("a") == b"foo" + assert 0 < await r.ttl("a") <= 61 + + async def test_exists(self, r: aioredis.Redis): + assert await r.exists("a") == 0 + await r.set("a", "foo") + await r.set("b", "bar") + assert await r.exists("a") == 1 + assert await r.exists("a", "b") == 2 + + async def test_exists_contains(self, r: aioredis.Redis): + assert not await r.exists("a") + await r.set("a", "foo") + assert await r.exists("a") + + async def test_expire(self, r: aioredis.Redis): + assert not await r.expire("a", 10) + await r.set("a", "foo") + assert await r.expire("a", 10) + assert 0 < await r.ttl("a") <= 10 + assert await r.persist("a") + assert await r.ttl("a") == -1 + + async def test_expireat_datetime(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + await r.set("a", "foo") + assert await r.expireat("a", expire_at) + assert 0 < await r.ttl("a") <= 61 + + async def test_expireat_no_key(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + assert not await r.expireat("a", expire_at) + + async def test_expireat_unixtime(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + await r.set("a", "foo") + expire_at_seconds = int(time.mktime(expire_at.timetuple())) + assert await r.expireat("a", expire_at_seconds) + assert 0 < await r.ttl("a") <= 61 + + async def test_get_and_set(self, r: aioredis.Redis): + # get and set can't be tested independently of each other + assert await r.get("a") is None + byte_string = b"value" + integer = 5 + unicode_string = chr(3456) + "abcd" + chr(3421) + assert await r.set("byte_string", byte_string) + assert await r.set("integer", 5) + assert await r.set("unicode_string", unicode_string) + assert await r.get("byte_string") == byte_string + assert await r.get("integer") == str(integer).encode() + assert (await r.get("unicode_string")).decode("utf-8") == unicode_string + + async def test_get_set_bit(self, r: aioredis.Redis): + # no value + assert not await r.getbit("a", 5) + # set bit 5 + assert not await r.setbit("a", 5, True) + assert await r.getbit("a", 5) + # unset bit 4 + assert not await r.setbit("a", 4, False) + assert not await r.getbit("a", 4) + # set bit 4 + assert not await r.setbit("a", 4, True) + assert await r.getbit("a", 4) + # set bit 5 again + assert await r.setbit("a", 5, True) + assert await r.getbit("a", 5) + + async def test_getrange(self, r: aioredis.Redis): + await r.set("a", "foo") + assert await r.getrange("a", 0, 0) == b"f" + assert await r.getrange("a", 0, 2) == b"foo" + assert await r.getrange("a", 3, 4) == b"" + + async def test_getset(self, r: aioredis.Redis): + assert await r.getset("a", "foo") is None + assert await r.getset("a", "bar") == b"foo" + assert await r.get("a") == b"bar" + + async def test_incr(self, r: aioredis.Redis): + assert await r.incr("a") == 1 + assert await r.get("a") == b"1" + assert await r.incr("a") == 2 + assert await r.get("a") == b"2" + assert await r.incr("a", amount=5) == 7 + assert await r.get("a") == b"7" + + async def test_incrby(self, r: aioredis.Redis): + assert await r.incrby("a") == 1 + assert await r.incrby("a", 4) == 5 + assert await r.get("a") == b"5" + + @skip_if_server_version_lt("2.6.0") + async def test_incrbyfloat(self, r: aioredis.Redis): + assert await r.incrbyfloat("a") == 1.0 + assert await r.get("a") == b"1" + assert await r.incrbyfloat("a", 1.1) == 2.1 + assert float(await r.get("a")) == float(2.1) + + async def test_keys(self, r: aioredis.Redis): + assert await r.keys() == [] + keys_with_underscores = {b"test_a", b"test_b"} + keys = keys_with_underscores.union({b"testc"}) + for key in keys: + await r.set(key, 1) + assert set(await r.keys(pattern="test_*")) == keys_with_underscores + assert set(await r.keys(pattern="test*")) == keys + + async def test_mget(self, r: aioredis.Redis): + assert await r.mget([]) == [] + assert await r.mget(["a", "b"]) == [None, None] + await r.set("a", "1") + await r.set("b", "2") + await r.set("c", "3") + assert await r.mget("a", "other", "b", "c") == [b"1", None, b"2", b"3"] + + async def test_mset(self, r: aioredis.Redis): + d = {"a": b"1", "b": b"2", "c": b"3"} + assert await r.mset(d) + for k, v in d.items(): + assert await r.get(k) == v + + async def test_msetnx(self, r: aioredis.Redis): + d = {"a": b"1", "b": b"2", "c": b"3"} + assert await r.msetnx(d) + d2 = {"a": b"x", "d": b"4"} + assert not await r.msetnx(d2) + for k, v in d.items(): + assert await r.get(k) == v + assert await r.get("d") is None + + @skip_if_server_version_lt("2.6.0") + async def test_pexpire(self, r: aioredis.Redis): + assert not await r.pexpire("a", 60000) + await r.set("a", "foo") + assert await r.pexpire("a", 60000) + assert 0 < await r.pttl("a") <= 60000 + assert await r.persist("a") + assert await r.pttl("a") == -1 + + @skip_if_server_version_lt("2.6.0") + async def test_pexpireat_datetime(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + await r.set("a", "foo") + assert await r.pexpireat("a", expire_at) + assert 0 < await r.pttl("a") <= 61000 + + @skip_if_server_version_lt("2.6.0") + async def test_pexpireat_no_key(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + assert not await r.pexpireat("a", expire_at) + + @skip_if_server_version_lt("2.6.0") + async def test_pexpireat_unixtime(self, r: aioredis.Redis): + expire_at = await redis_server_time(r) + datetime.timedelta(minutes=1) + await r.set("a", "foo") + expire_at_seconds = int(time.mktime(expire_at.timetuple())) * 1000 + assert await r.pexpireat("a", expire_at_seconds) + assert 0 < await r.pttl("a") <= 61000 + + @skip_if_server_version_lt("2.6.0") + async def test_psetex(self, r: aioredis.Redis): + assert await r.psetex("a", 1000, "value") + assert await r.get("a") == b"value" + assert 0 < await r.pttl("a") <= 1000 + + @skip_if_server_version_lt("2.6.0") + async def test_psetex_timedelta(self, r: aioredis.Redis): + expire_at = datetime.timedelta(milliseconds=1000) + assert await r.psetex("a", expire_at, "value") + assert await r.get("a") == b"value" + assert 0 < await r.pttl("a") <= 1000 + + @skip_if_server_version_lt("2.6.0") + async def test_pttl(self, r: aioredis.Redis): + assert not await r.pexpire("a", 10000) + await r.set("a", "1") + assert await r.pexpire("a", 10000) + assert 0 < await r.pttl("a") <= 10000 + assert await r.persist("a") + assert await r.pttl("a") == -1 + + @skip_if_server_version_lt("2.8.0") + async def test_pttl_no_key(self, r: aioredis.Redis): + """PTTL on servers 2.8 and after return -2 when the key doesn't exist""" + assert await r.pttl("a") == -2 + + async def test_randomkey(self, r: aioredis.Redis): + assert await r.randomkey() is None + for key in ("a", "b", "c"): + await r.set(key, 1) + assert await r.randomkey() in (b"a", b"b", b"c") + + async def test_rename(self, r: aioredis.Redis): + await r.set("a", "1") + assert await r.rename("a", "b") + assert await r.get("a") is None + assert await r.get("b") == b"1" + + async def test_renamenx(self, r: aioredis.Redis): + await r.set("a", "1") + await r.set("b", "2") + assert not await r.renamenx("a", "b") + assert await r.get("a") == b"1" + assert await r.get("b") == b"2" + + @skip_if_server_version_lt("2.6.0") + async def test_set_nx(self, r: aioredis.Redis): + assert await r.set("a", "1", nx=True) + assert not await r.set("a", "2", nx=True) + assert await r.get("a") == b"1" + + @skip_if_server_version_lt("2.6.0") + async def test_set_xx(self, r: aioredis.Redis): + assert not await r.set("a", "1", xx=True) + assert await r.get("a") is None + await r.set("a", "bar") + assert await r.set("a", "2", xx=True) + assert await r.get("a") == b"2" + + @skip_if_server_version_lt("2.6.0") + async def test_set_px(self, r: aioredis.Redis): + assert await r.set("a", "1", px=10000) + assert await r.get("a") == b"1" + assert 0 < await r.pttl("a") <= 10000 + assert 0 < await r.ttl("a") <= 10 + + @skip_if_server_version_lt("2.6.0") + async def test_set_px_timedelta(self, r: aioredis.Redis): + expire_at = datetime.timedelta(milliseconds=1000) + assert await r.set("a", "1", px=expire_at) + assert 0 < await r.pttl("a") <= 1000 + assert 0 < await r.ttl("a") <= 1 + + @skip_if_server_version_lt("2.6.0") + async def test_set_ex(self, r: aioredis.Redis): + assert await r.set("a", "1", ex=10) + assert 0 < await r.ttl("a") <= 10 + + @skip_if_server_version_lt("2.6.0") + async def test_set_ex_timedelta(self, r: aioredis.Redis): + expire_at = datetime.timedelta(seconds=60) + assert await r.set("a", "1", ex=expire_at) + assert 0 < await r.ttl("a") <= 60 + + @skip_if_server_version_lt("2.6.0") + async def test_set_multipleoptions(self, r: aioredis.Redis): + await r.set("a", "val") + assert await r.set("a", "1", xx=True, px=10000) + assert 0 < await r.ttl("a") <= 10 + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_set_keepttl(self, r: aioredis.Redis): + await r.set("a", "val") + assert await r.set("a", "1", xx=True, px=10000) + assert 0 < await r.ttl("a") <= 10 + await r.set("a", "2", keepttl=True) + assert await r.get("a") == b"2" + assert 0 < await r.ttl("a") <= 10 + + async def test_setex(self, r: aioredis.Redis): + assert await r.setex("a", 60, "1") + assert await r.get("a") == b"1" + assert 0 < await r.ttl("a") <= 60 + + async def test_setnx(self, r: aioredis.Redis): + assert await r.setnx("a", "1") + assert await r.get("a") == b"1" + assert not await r.setnx("a", "2") + assert await r.get("a") == b"1" + + async def test_setrange(self, r: aioredis.Redis): + assert await r.setrange("a", 5, "foo") == 8 + assert await r.get("a") == b"\0\0\0\0\0foo" + await r.set("a", "abcdefghijh") + assert await r.setrange("a", 6, "12345") == 11 + assert await r.get("a") == b"abcdef12345" + + async def test_strlen(self, r: aioredis.Redis): + await r.set("a", "foo") + assert await r.strlen("a") == 3 + + async def test_substr(self, r: aioredis.Redis): + await r.set("a", "0123456789") + assert await r.substr("a", 0) == b"0123456789" + assert await r.substr("a", 2) == b"23456789" + assert await r.substr("a", 3, 5) == b"345" + assert await r.substr("a", 3, -2) == b"345678" + + async def test_ttl(self, r: aioredis.Redis): + await r.set("a", "1") + assert await r.expire("a", 10) + assert 0 < await r.ttl("a") <= 10 + assert await r.persist("a") + assert await r.ttl("a") == -1 + + @skip_if_server_version_lt("2.8.0") + async def test_ttl_nokey(self, r: aioredis.Redis): + """TTL on servers 2.8 and after return -2 when the key doesn't exist""" + assert await r.ttl("a") == -2 + + async def test_type(self, r: aioredis.Redis): + assert await r.type("a") == b"none" + await r.set("a", "1") + assert await r.type("a") == b"string" + await r.delete("a") + await r.lpush("a", "1") + assert await r.type("a") == b"list" + await r.delete("a") + await r.sadd("a", "1") + assert await r.type("a") == b"set" + await r.delete("a") + await r.zadd("a", {"1": 1}) + assert await r.type("a") == b"zset" + + # LIST COMMANDS + async def test_blpop(self, r: aioredis.Redis): + await r.rpush("a", "1", "2") + await r.rpush("b", "3", "4") + assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"3") + assert await r.blpop(["b", "a"], timeout=1) == (b"b", b"4") + assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"1") + assert await r.blpop(["b", "a"], timeout=1) == (b"a", b"2") + assert await r.blpop(["b", "a"], timeout=1) is None + await r.rpush("c", "1") + assert await r.blpop("c", timeout=1) == (b"c", b"1") + + async def test_brpop(self, r: aioredis.Redis): + await r.rpush("a", "1", "2") + await r.rpush("b", "3", "4") + assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"4") + assert await r.brpop(["b", "a"], timeout=1) == (b"b", b"3") + assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"2") + assert await r.brpop(["b", "a"], timeout=1) == (b"a", b"1") + assert await r.brpop(["b", "a"], timeout=1) is None + await r.rpush("c", "1") + assert await r.brpop("c", timeout=1) == (b"c", b"1") + + async def test_brpoplpush(self, r: aioredis.Redis): + await r.rpush("a", "1", "2") + await r.rpush("b", "3", "4") + assert await r.brpoplpush("a", "b") == b"2" + assert await r.brpoplpush("a", "b") == b"1" + assert await r.brpoplpush("a", "b", timeout=1) is None + assert await r.lrange("a", 0, -1) == [] + assert await r.lrange("b", 0, -1) == [b"1", b"2", b"3", b"4"] + + async def test_brpoplpush_empty_string(self, r: aioredis.Redis): + await r.rpush("a", "") + assert await r.brpoplpush("a", "b") == b"" + + async def test_lindex(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.lindex("a", "0") == b"1" + assert await r.lindex("a", "1") == b"2" + assert await r.lindex("a", "2") == b"3" + + async def test_linsert(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.linsert("a", "after", "2", "2.5") == 4 + assert await r.lrange("a", 0, -1) == [b"1", b"2", b"2.5", b"3"] + assert await r.linsert("a", "before", "2", "1.5") == 5 + assert await r.lrange("a", 0, -1) == [b"1", b"1.5", b"2", b"2.5", b"3"] + + async def test_llen(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.llen("a") == 3 + + async def test_lpop(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.lpop("a") == b"1" + assert await r.lpop("a") == b"2" + assert await r.lpop("a") == b"3" + assert await r.lpop("a") is None + + async def test_lpush(self, r: aioredis.Redis): + assert await r.lpush("a", "1") == 1 + assert await r.lpush("a", "2") == 2 + assert await r.lpush("a", "3", "4") == 4 + assert await r.lrange("a", 0, -1) == [b"4", b"3", b"2", b"1"] + + async def test_lpushx(self, r: aioredis.Redis): + assert await r.lpushx("a", "1") == 0 + assert await r.lrange("a", 0, -1) == [] + await r.rpush("a", "1", "2", "3") + assert await r.lpushx("a", "4") == 4 + assert await r.lrange("a", 0, -1) == [b"4", b"1", b"2", b"3"] + + async def test_lrange(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3", "4", "5") + assert await r.lrange("a", 0, 2) == [b"1", b"2", b"3"] + assert await r.lrange("a", 2, 10) == [b"3", b"4", b"5"] + assert await r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4", b"5"] + + async def test_lrem(self, r: aioredis.Redis): + await r.rpush("a", "Z", "b", "Z", "Z", "c", "Z", "Z") + # remove the first 'Z' item + assert await r.lrem("a", 1, "Z") == 1 + assert await r.lrange("a", 0, -1) == [b"b", b"Z", b"Z", b"c", b"Z", b"Z"] + # remove the last 2 'Z' items + assert await r.lrem("a", -2, "Z") == 2 + assert await r.lrange("a", 0, -1) == [b"b", b"Z", b"Z", b"c"] + # remove all 'Z' items + assert await r.lrem("a", 0, "Z") == 2 + assert await r.lrange("a", 0, -1) == [b"b", b"c"] + + async def test_lset(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.lrange("a", 0, -1) == [b"1", b"2", b"3"] + assert await r.lset("a", 1, "4") + assert await r.lrange("a", 0, 2) == [b"1", b"4", b"3"] + + async def test_ltrim(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.ltrim("a", 0, 1) + assert await r.lrange("a", 0, -1) == [b"1", b"2"] + + async def test_rpop(self, r: aioredis.Redis): + await r.rpush("a", "1", "2", "3") + assert await r.rpop("a") == b"3" + assert await r.rpop("a") == b"2" + assert await r.rpop("a") == b"1" + assert await r.rpop("a") is None + + async def test_rpoplpush(self, r: aioredis.Redis): + await r.rpush("a", "a1", "a2", "a3") + await r.rpush("b", "b1", "b2", "b3") + assert await r.rpoplpush("a", "b") == b"a3" + assert await r.lrange("a", 0, -1) == [b"a1", b"a2"] + assert await r.lrange("b", 0, -1) == [b"a3", b"b1", b"b2", b"b3"] + + async def test_rpush(self, r: aioredis.Redis): + assert await r.rpush("a", "1") == 1 + assert await r.rpush("a", "2") == 2 + assert await r.rpush("a", "3", "4") == 4 + assert await r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4"] + + @skip_if_server_version_lt("6.0.6") + async def test_lpos(self, r: aioredis.Redis): + assert await r.rpush("a", "a", "b", "c", "1", "2", "3", "c", "c") == 8 + assert await r.lpos("a", "a") == 0 + assert await r.lpos("a", "c") == 2 + + assert await r.lpos("a", "c", rank=1) == 2 + assert await r.lpos("a", "c", rank=2) == 6 + assert await r.lpos("a", "c", rank=4) is None + assert await r.lpos("a", "c", rank=-1) == 7 + assert await r.lpos("a", "c", rank=-2) == 6 + + assert await r.lpos("a", "c", count=0) == [2, 6, 7] + assert await r.lpos("a", "c", count=1) == [2] + assert await r.lpos("a", "c", count=2) == [2, 6] + assert await r.lpos("a", "c", count=100) == [2, 6, 7] + + assert await r.lpos("a", "c", count=0, rank=2) == [6, 7] + assert await r.lpos("a", "c", count=2, rank=-1) == [7, 6] + + assert await r.lpos("axxx", "c", count=0, rank=2) == [] + assert await r.lpos("axxx", "c") is None + + assert await r.lpos("a", "x", count=2) == [] + assert await r.lpos("a", "x") is None + + assert await r.lpos("a", "a", count=0, maxlen=1) == [0] + assert await r.lpos("a", "c", count=0, maxlen=1) == [] + assert await r.lpos("a", "c", count=0, maxlen=3) == [2] + assert await r.lpos("a", "c", count=0, maxlen=3, rank=-1) == [7, 6] + assert await r.lpos("a", "c", count=0, maxlen=7, rank=2) == [6] + + async def test_rpushx(self, r: aioredis.Redis): + assert await r.rpushx("a", "b") == 0 + assert await r.lrange("a", 0, -1) == [] + await r.rpush("a", "1", "2", "3") + assert await r.rpushx("a", "4") == 4 + assert await r.lrange("a", 0, -1) == [b"1", b"2", b"3", b"4"] + + # SCAN COMMANDS + @skip_if_server_version_lt("2.8.0") + async def test_scan(self, r: aioredis.Redis): + await r.set("a", 1) + await r.set("b", 2) + await r.set("c", 3) + cursor, keys = await r.scan() + assert cursor == 0 + assert set(keys) == {b"a", b"b", b"c"} + _, keys = await r.scan(match="a") + assert set(keys) == {b"a"} + + @skip_if_server_version_lt(REDIS_6_VERSION) + async def test_scan_type(self, r: aioredis.Redis): + await r.sadd("a-set", 1) + await r.hset("a-hash", "foo", 2) + await r.lpush("a-list", "aux", 3) + _, keys = await r.scan(match="a*", _type="SET") + assert set(keys) == {b"a-set"} + + @skip_if_server_version_lt("2.8.0") + async def test_scan_iter(self, r: aioredis.Redis): + await r.set("a", 1) + await r.set("b", 2) + await r.set("c", 3) + keys = [k async for k in r.scan_iter()] + assert set(keys) == {b"a", b"b", b"c"} + keys = [k async for k in r.scan_iter(match="a")] + assert set(keys) == {b"a"} + + @skip_if_server_version_lt("2.8.0") + async def test_sscan(self, r: aioredis.Redis): + await r.sadd("a", 1, 2, 3) + cursor, members = await r.sscan("a") + assert cursor == 0 + assert set(members) == {b"1", b"2", b"3"} + _, members = await r.sscan("a", match=b"1") + assert set(members) == {b"1"} + + @skip_if_server_version_lt("2.8.0") + async def test_sscan_iter(self, r: aioredis.Redis): + await r.sadd("a", 1, 2, 3) + members = [k async for k in r.sscan_iter("a")] + assert set(members) == {b"1", b"2", b"3"} + members = [k async for k in r.sscan_iter("a", match=b"1")] + assert set(members) == {b"1"} + + @skip_if_server_version_lt("2.8.0") + async def test_hscan(self, r: aioredis.Redis): + await r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + cursor, dic = await r.hscan("a") + assert cursor == 0 + assert dic == {b"a": b"1", b"b": b"2", b"c": b"3"} + _, dic = await r.hscan("a", match="a") + assert dic == {b"a": b"1"} + + @skip_if_server_version_lt("2.8.0") + async def test_hscan_iter(self, r: aioredis.Redis): + await r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + dic = {k: v async for k, v in r.hscan_iter("a")} + assert dic == {b"a": b"1", b"b": b"2", b"c": b"3"} + dic = {k: v async for k, v in r.hscan_iter("a", match="a")} + assert dic == {b"a": b"1"} + + @skip_if_server_version_lt("2.8.0") + async def test_zscan(self, r: aioredis.Redis): + await r.zadd("a", {"a": 1, "b": 2, "c": 3}) + cursor, pairs = await r.zscan("a") + assert cursor == 0 + assert set(pairs) == {(b"a", 1), (b"b", 2), (b"c", 3)} + _, pairs = await r.zscan("a", match="a") + assert set(pairs) == {(b"a", 1)} + + @skip_if_server_version_lt("2.8.0") + async def test_zscan_iter(self, r: aioredis.Redis): + await r.zadd("a", {"a": 1, "b": 2, "c": 3}) + pairs = [k async for k in r.zscan_iter("a")] + assert set(pairs) == {(b"a", 1), (b"b", 2), (b"c", 3)} + pairs = [k async for k in r.zscan_iter("a", match="a")] + assert set(pairs) == {(b"a", 1)} + + # SET COMMANDS + async def test_sadd(self, r: aioredis.Redis): + members = {b"1", b"2", b"3"} + await r.sadd("a", *members) + assert await r.smembers("a") == members + + async def test_scard(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.scard("a") == 3 + + async def test_sdiff(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.sdiff("a", "b") == {b"1", b"2", b"3"} + await r.sadd("b", "2", "3") + assert await r.sdiff("a", "b") == {b"1"} + + async def test_sdiffstore(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.sdiffstore("c", "a", "b") == 3 + assert await r.smembers("c") == {b"1", b"2", b"3"} + await r.sadd("b", "2", "3") + assert await r.sdiffstore("c", "a", "b") == 1 + assert await r.smembers("c") == {b"1"} + + async def test_sinter(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.sinter("a", "b") == set() + await r.sadd("b", "2", "3") + assert await r.sinter("a", "b") == {b"2", b"3"} + + async def test_sinterstore(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.sinterstore("c", "a", "b") == 0 + assert await r.smembers("c") == set() + await r.sadd("b", "2", "3") + assert await r.sinterstore("c", "a", "b") == 2 + assert await r.smembers("c") == {b"2", b"3"} + + async def test_sismember(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.sismember("a", "1") + assert await r.sismember("a", "2") + assert await r.sismember("a", "3") + assert not await r.sismember("a", "4") + + async def test_smembers(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3") + assert await r.smembers("a") == {b"1", b"2", b"3"} + + async def test_smove(self, r: aioredis.Redis): + await r.sadd("a", "a1", "a2") + await r.sadd("b", "b1", "b2") + assert await r.smove("a", "b", "a1") + assert await r.smembers("a") == {b"a2"} + assert await r.smembers("b") == {b"b1", b"b2", b"a1"} + + async def test_spop(self, r: aioredis.Redis): + s = [b"1", b"2", b"3"] + await r.sadd("a", *s) + value = await r.spop("a") + assert value in s + assert await r.smembers("a") == set(s) - {value} + + @skip_if_server_version_lt("3.2.0") + async def test_spop_multi_value(self, r: aioredis.Redis): + s = [b"1", b"2", b"3"] + await r.sadd("a", *s) + values = await r.spop("a", 2) + assert len(values) == 2 + + for value in values: + assert value in s + + assert await r.spop("a", 1) == list(set(s) - set(values)) + + async def test_srandmember(self, r: aioredis.Redis): + s = [b"1", b"2", b"3"] + await r.sadd("a", *s) + assert await r.srandmember("a") in s + + @skip_if_server_version_lt("2.6.0") + async def test_srandmember_multi_value(self, r: aioredis.Redis): + s = [b"1", b"2", b"3"] + await r.sadd("a", *s) + randoms = await r.srandmember("a", number=2) + assert len(randoms) == 2 + assert set(randoms).intersection(s) == set(randoms) + + async def test_srem(self, r: aioredis.Redis): + await r.sadd("a", "1", "2", "3", "4") + assert await r.srem("a", "5") == 0 + assert await r.srem("a", "2", "4") == 2 + assert await r.smembers("a") == {b"1", b"3"} + + async def test_sunion(self, r: aioredis.Redis): + await r.sadd("a", "1", "2") + await r.sadd("b", "2", "3") + assert await r.sunion("a", "b") == {b"1", b"2", b"3"} + + async def test_sunionstore(self, r: aioredis.Redis): + await r.sadd("a", "1", "2") + await r.sadd("b", "2", "3") + assert await r.sunionstore("c", "a", "b") == 3 + assert await r.smembers("c") == {b"1", b"2", b"3"} + + # SORTED SET COMMANDS + async def test_zadd(self, r: aioredis.Redis): + mapping = {"a1": 1.0, "a2": 2.0, "a3": 3.0} + await r.zadd("a", mapping) + assert await r.zrange("a", 0, -1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + (b"a3", 3.0), + ] + + # error cases + with pytest.raises(exceptions.DataError): + await r.zadd("a", {}) + + # cannot use both nx and xx options + with pytest.raises(exceptions.DataError): + await r.zadd("a", mapping, nx=True, xx=True) + + # cannot use the incr options with more than one value + with pytest.raises(exceptions.DataError): + await r.zadd("a", mapping, incr=True) + + async def test_zadd_nx(self, r: aioredis.Redis): + assert await r.zadd("a", {"a1": 1}) == 1 + assert await r.zadd("a", {"a1": 99, "a2": 2}, nx=True) == 1 + assert await r.zrange("a", 0, -1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + ] + + async def test_zadd_xx(self, r: aioredis.Redis): + assert await r.zadd("a", {"a1": 1}) == 1 + assert await r.zadd("a", {"a1": 99, "a2": 2}, xx=True) == 0 + assert await r.zrange("a", 0, -1, withscores=True) == [(b"a1", 99.0)] + + async def test_zadd_ch(self, r: aioredis.Redis): + assert await r.zadd("a", {"a1": 1}) == 1 + assert await r.zadd("a", {"a1": 99, "a2": 2}, ch=True) == 2 + assert await r.zrange("a", 0, -1, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 99.0), + ] + + async def test_zadd_incr(self, r: aioredis.Redis): + assert await r.zadd("a", {"a1": 1}) == 1 + assert await r.zadd("a", {"a1": 4.5}, incr=True) == 5.5 + + async def test_zadd_incr_with_xx(self, r: aioredis.Redis): + # this asks zadd to incr 'a1' only if it exists, but it clearly + # doesn't. Redis returns a null value in this case and so should + # redis-py + assert await r.zadd("a", {"a1": 1}, xx=True, incr=True) is None + + async def test_zcard(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zcard("a") == 3 + + async def test_zcount(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zcount("a", "-inf", "+inf") == 3 + assert await r.zcount("a", 1, 2) == 2 + assert await r.zcount("a", "(" + str(1), 2) == 1 + assert await r.zcount("a", 1, "(" + str(2)) == 1 + assert await r.zcount("a", 10, 20) == 0 + + async def test_zincrby(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zincrby("a", 1, "a2") == 3.0 + assert await r.zincrby("a", 5, "a3") == 8.0 + assert await r.zscore("a", "a2") == 3.0 + assert await r.zscore("a", "a3") == 8.0 + + @skip_if_server_version_lt("2.8.9") + async def test_zlexcount(self, r: aioredis.Redis): + await r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert await r.zlexcount("a", "-", "+") == 7 + assert await r.zlexcount("a", "[b", "[f") == 5 + + async def test_zinterstore_sum(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("d", ["a", "b", "c"]) == 2 + assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 8), (b"a1", 9)] + + async def test_zinterstore_max(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MAX") == 2 + assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 5), (b"a1", 6)] + + async def test_zinterstore_min(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 2, "a2": 3, "a3": 5}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("d", ["a", "b", "c"], aggregate="MIN") == 2 + assert await r.zrange("d", 0, -1, withscores=True) == [(b"a1", 1), (b"a3", 3)] + + async def test_zinterstore_with_weight(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zinterstore("d", {"a": 1, "b": 2, "c": 3}) == 2 + assert await r.zrange("d", 0, -1, withscores=True) == [(b"a3", 20), (b"a1", 23)] + + @skip_if_server_version_lt("4.9.0") + async def test_zpopmax(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zpopmax("a") == [(b"a3", 3)] + + # with count + assert await r.zpopmax("a", count=2) == [(b"a2", 2), (b"a1", 1)] + + @skip_if_server_version_lt("4.9.0") + async def test_zpopmin(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zpopmin("a") == [(b"a1", 1)] + + # with count + assert await r.zpopmin("a", count=2) == [(b"a2", 2), (b"a3", 3)] + + @skip_if_server_version_lt("4.9.0") + async def test_bzpopmax(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2}) + await r.zadd("b", {"b1": 10, "b2": 20}) + assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b2", 20) + assert await r.bzpopmax(["b", "a"], timeout=1) == (b"b", b"b1", 10) + assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert await r.bzpopmax(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert await r.bzpopmax(["b", "a"], timeout=1) is None + await r.zadd("c", {"c1": 100}) + assert await r.bzpopmax("c", timeout=1) == (b"c", b"c1", 100) + + @skip_if_server_version_lt("4.9.0") + async def test_bzpopmin(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2}) + await r.zadd("b", {"b1": 10, "b2": 20}) + assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b1", 10) + assert await r.bzpopmin(["b", "a"], timeout=1) == (b"b", b"b2", 20) + assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a1", 1) + assert await r.bzpopmin(["b", "a"], timeout=1) == (b"a", b"a2", 2) + assert await r.bzpopmin(["b", "a"], timeout=1) is None + await r.zadd("c", {"c1": 100}) + assert await r.bzpopmin("c", timeout=1) == (b"c", b"c1", 100) + + async def test_zrange(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrange("a", 0, 1) == [b"a1", b"a2"] + assert await r.zrange("a", 1, 2) == [b"a2", b"a3"] + + # withscores + assert await r.zrange("a", 0, 1, withscores=True) == [ + (b"a1", 1.0), + (b"a2", 2.0), + ] + assert await r.zrange("a", 1, 2, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + ] + + # custom score function + assert await r.zrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a1", 1), + (b"a2", 2), + ] + + @skip_if_server_version_lt("2.8.9") + async def test_zrangebylex(self, r: aioredis.Redis): + await r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert await r.zrangebylex("a", "-", "[c") == [b"a", b"b", b"c"] + assert await r.zrangebylex("a", "-", "(c") == [b"a", b"b"] + assert await r.zrangebylex("a", "[aaa", "(g") == [b"b", b"c", b"d", b"e", b"f"] + assert await r.zrangebylex("a", "[f", "+") == [b"f", b"g"] + assert await r.zrangebylex("a", "-", "+", start=3, num=2) == [b"d", b"e"] + + @skip_if_server_version_lt("2.9.9") + async def test_zrevrangebylex(self, r: aioredis.Redis): + await r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert await r.zrevrangebylex("a", "[c", "-") == [b"c", b"b", b"a"] + assert await r.zrevrangebylex("a", "(c", "-") == [b"b", b"a"] + assert await r.zrevrangebylex("a", "(g", "[aaa") == [ + b"f", + b"e", + b"d", + b"c", + b"b", + ] + assert await r.zrevrangebylex("a", "+", "[f") == [b"g", b"f"] + assert await r.zrevrangebylex("a", "+", "-", start=3, num=2) == [b"d", b"c"] + + async def test_zrangebyscore(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrangebyscore("a", 2, 4) == [b"a2", b"a3", b"a4"] + + # slicing with start/num + assert await r.zrangebyscore("a", 2, 4, start=1, num=2) == [b"a3", b"a4"] + + # withscores + assert await r.zrangebyscore("a", 2, 4, withscores=True) == [ + (b"a2", 2.0), + (b"a3", 3.0), + (b"a4", 4.0), + ] + + # custom score function + assert await r.zrangebyscore( + "a", 2, 4, withscores=True, score_cast_func=int + ) == [(b"a2", 2), (b"a3", 3), (b"a4", 4)] + + async def test_zrank(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrank("a", "a1") == 0 + assert await r.zrank("a", "a2") == 1 + assert await r.zrank("a", "a6") is None + + async def test_zrem(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrem("a", "a2") == 1 + assert await r.zrange("a", 0, -1) == [b"a1", b"a3"] + assert await r.zrem("a", "b") == 0 + assert await r.zrange("a", 0, -1) == [b"a1", b"a3"] + + async def test_zrem_multiple_keys(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrem("a", "a1", "a2") == 2 + assert await r.zrange("a", 0, 5) == [b"a3"] + + @skip_if_server_version_lt("2.8.9") + async def test_zremrangebylex(self, r: aioredis.Redis): + await r.zadd("a", {"a": 0, "b": 0, "c": 0, "d": 0, "e": 0, "f": 0, "g": 0}) + assert await r.zremrangebylex("a", "-", "[c") == 3 + assert await r.zrange("a", 0, -1) == [b"d", b"e", b"f", b"g"] + assert await r.zremrangebylex("a", "[f", "+") == 2 + assert await r.zrange("a", 0, -1) == [b"d", b"e"] + assert await r.zremrangebylex("a", "[h", "+") == 0 + assert await r.zrange("a", 0, -1) == [b"d", b"e"] + + async def test_zremrangebyrank(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zremrangebyrank("a", 1, 3) == 3 + assert await r.zrange("a", 0, 5) == [b"a1", b"a5"] + + async def test_zremrangebyscore(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zremrangebyscore("a", 2, 4) == 3 + assert await r.zrange("a", 0, -1) == [b"a1", b"a5"] + assert await r.zremrangebyscore("a", 2, 4) == 0 + assert await r.zrange("a", 0, -1) == [b"a1", b"a5"] + + async def test_zrevrange(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zrevrange("a", 0, 1) == [b"a3", b"a2"] + assert await r.zrevrange("a", 1, 2) == [b"a2", b"a1"] + + # withscores + assert await r.zrevrange("a", 0, 1, withscores=True) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + assert await r.zrevrange("a", 1, 2, withscores=True) == [ + (b"a2", 2.0), + (b"a1", 1.0), + ] + + # custom score function + assert await r.zrevrange("a", 0, 1, withscores=True, score_cast_func=int) == [ + (b"a3", 3.0), + (b"a2", 2.0), + ] + + async def test_zrevrangebyscore(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrevrangebyscore("a", 4, 2) == [b"a4", b"a3", b"a2"] + + # slicing with start/num + assert await r.zrevrangebyscore("a", 4, 2, start=1, num=2) == [b"a3", b"a2"] + + # withscores + assert await r.zrevrangebyscore("a", 4, 2, withscores=True) == [ + (b"a4", 4.0), + (b"a3", 3.0), + (b"a2", 2.0), + ] + + # custom score function + assert await r.zrevrangebyscore( + "a", 4, 2, withscores=True, score_cast_func=int + ) == [(b"a4", 4), (b"a3", 3), (b"a2", 2)] + + async def test_zrevrank(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3, "a4": 4, "a5": 5}) + assert await r.zrevrank("a", "a1") == 4 + assert await r.zrevrank("a", "a2") == 3 + assert await r.zrevrank("a", "a6") is None + + async def test_zscore(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + assert await r.zscore("a", "a1") == 1.0 + assert await r.zscore("a", "a2") == 2.0 + assert await r.zscore("a", "a4") is None + + async def test_zunionstore_sum(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("d", ["a", "b", "c"]) == 4 + assert await r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 3), + (b"a4", 4), + (b"a3", 8), + (b"a1", 9), + ] + + async def test_zunionstore_max(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MAX") == 4 + assert await r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 2), + (b"a4", 4), + (b"a3", 5), + (b"a1", 6), + ] + + async def test_zunionstore_min(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 2, "a3": 3}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 4}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("d", ["a", "b", "c"], aggregate="MIN") == 4 + assert await r.zrange("d", 0, -1, withscores=True) == [ + (b"a1", 1), + (b"a2", 2), + (b"a3", 3), + (b"a4", 4), + ] + + async def test_zunionstore_with_weight(self, r: aioredis.Redis): + await r.zadd("a", {"a1": 1, "a2": 1, "a3": 1}) + await r.zadd("b", {"a1": 2, "a2": 2, "a3": 2}) + await r.zadd("c", {"a1": 6, "a3": 5, "a4": 4}) + assert await r.zunionstore("d", {"a": 1, "b": 2, "c": 3}) == 4 + assert await r.zrange("d", 0, -1, withscores=True) == [ + (b"a2", 5), + (b"a4", 12), + (b"a3", 20), + (b"a1", 23), + ] + + # HYPERLOGLOG TESTS + @skip_if_server_version_lt("2.8.9") + async def test_pfadd(self, r: aioredis.Redis): + members = {b"1", b"2", b"3"} + assert await r.pfadd("a", *members) == 1 + assert await r.pfadd("a", *members) == 0 + assert await r.pfcount("a") == len(members) + + @skip_if_server_version_lt("2.8.9") + async def test_pfcount(self, r: aioredis.Redis): + members = {b"1", b"2", b"3"} + await r.pfadd("a", *members) + assert await r.pfcount("a") == len(members) + members_b = {b"2", b"3", b"4"} + await r.pfadd("b", *members_b) + assert await r.pfcount("b") == len(members_b) + assert await r.pfcount("a", "b") == len(members_b.union(members)) + + @skip_if_server_version_lt("2.8.9") + async def test_pfmerge(self, r: aioredis.Redis): + mema = {b"1", b"2", b"3"} + memb = {b"2", b"3", b"4"} + memc = {b"5", b"6", b"7"} + await r.pfadd("a", *mema) + await r.pfadd("b", *memb) + await r.pfadd("c", *memc) + await r.pfmerge("d", "c", "a") + assert await r.pfcount("d") == 6 + await r.pfmerge("d", "b") + assert await r.pfcount("d") == 7 + + # HASH COMMANDS + async def test_hget_and_hset(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert await r.hget("a", "1") == b"1" + assert await r.hget("a", "2") == b"2" + assert await r.hget("a", "3") == b"3" + + # field was updated, redis returns 0 + assert await r.hset("a", "2", 5) == 0 + assert await r.hget("a", "2") == b"5" + + # field is new, redis returns 1 + assert await r.hset("a", "4", 4) == 1 + assert await r.hget("a", "4") == b"4" + + # key inside of hash that doesn't exist returns null value + assert await r.hget("a", "b") is None + + # keys with bool(key) == False + assert await r.hset("a", 0, 10) == 1 + assert await r.hset("a", "", 10) == 1 + + async def test_hset_with_multi_key_values(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert await r.hget("a", "1") == b"1" + assert await r.hget("a", "2") == b"2" + assert await r.hget("a", "3") == b"3" + + await r.hset("b", "foo", "bar", mapping={"1": 1, "2": 2}) + assert await r.hget("b", "1") == b"1" + assert await r.hget("b", "2") == b"2" + assert await r.hget("b", "foo") == b"bar" + + async def test_hset_without_data(self, r: aioredis.Redis): + with pytest.raises(exceptions.DataError): + await r.hset("x") + + async def test_hdel(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert await r.hdel("a", "2") == 1 + assert await r.hget("a", "2") is None + assert await r.hdel("a", "1", "3") == 2 + assert await r.hlen("a") == 0 + + async def test_hexists(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert await r.hexists("a", "1") + assert not await r.hexists("a", "4") + + async def test_hgetall(self, r: aioredis.Redis): + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + await r.hset("a", mapping=h) + assert await r.hgetall("a") == h + + async def test_hincrby(self, r: aioredis.Redis): + assert await r.hincrby("a", "1") == 1 + assert await r.hincrby("a", "1", amount=2) == 3 + assert await r.hincrby("a", "1", amount=-2) == 1 + + @skip_if_server_version_lt("2.6.0") + async def test_hincrbyfloat(self, r: aioredis.Redis): + assert await r.hincrbyfloat("a", "1") == 1.0 + assert await r.hincrbyfloat("a", "1") == 2.0 + assert await r.hincrbyfloat("a", "1", 1.2) == 3.2 + + async def test_hkeys(self, r: aioredis.Redis): + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + await r.hset("a", mapping=h) + local_keys = list(h.keys()) + remote_keys = await r.hkeys("a") + assert sorted(local_keys) == sorted(remote_keys) + + async def test_hlen(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": 1, "2": 2, "3": 3}) + assert await r.hlen("a") == 3 + + async def test_hmget(self, r: aioredis.Redis): + assert await r.hset("a", mapping={"a": 1, "b": 2, "c": 3}) + assert await r.hmget("a", "a", "b", "c") == [b"1", b"2", b"3"] + + async def test_hmset(self, r: aioredis.Redis): + warning_message = ( + r"^Redis\.hmset\(\) is deprecated\. " r"Use Redis\.hset\(\) instead\.$" + ) + h = {b"a": b"1", b"b": b"2", b"c": b"3"} + with pytest.warns(DeprecationWarning, match=warning_message): + assert await r.hmset("a", h) + assert await r.hgetall("a") == h + + async def test_hsetnx(self, r: aioredis.Redis): + # Initially set the hash field + assert await r.hsetnx("a", "1", 1) + assert await r.hget("a", "1") == b"1" + assert not await r.hsetnx("a", "1", 2) + assert await r.hget("a", "1") == b"1" + + async def test_hvals(self, r: aioredis.Redis): + h = {b"a1": b"1", b"a2": b"2", b"a3": b"3"} + await r.hset("a", mapping=h) + local_vals = list(h.values()) + remote_vals = await r.hvals("a") + assert sorted(local_vals) == sorted(remote_vals) + + @skip_if_server_version_lt("3.2.0") + async def test_hstrlen(self, r: aioredis.Redis): + await r.hset("a", mapping={"1": "22", "2": "333"}) + assert await r.hstrlen("a", "1") == 2 + assert await r.hstrlen("a", "2") == 3 + + # SORT + async def test_sort_basic(self, r: aioredis.Redis): + await r.rpush("a", "3", "2", "1", "4") + assert await r.sort("a") == [b"1", b"2", b"3", b"4"] + + async def test_sort_limited(self, r: aioredis.Redis): + await r.rpush("a", "3", "2", "1", "4") + assert await r.sort("a", start=1, num=2) == [b"2", b"3"] + + async def test_sort_by(self, r: aioredis.Redis): + await r.set("score:1", 8) + await r.set("score:2", 3) + await r.set("score:3", 5) + await r.rpush("a", "3", "2", "1") + assert await r.sort("a", by="score:*") == [b"2", b"3", b"1"] + + async def test_sort_get(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", get="user:*") == [b"u1", b"u2", b"u3"] + + async def test_sort_get_multi(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", get=("user:*", "#")) == [ + b"u1", + b"1", + b"u2", + b"2", + b"u3", + b"3", + ] + + async def test_sort_get_groups_two(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", get=("user:*", "#"), groups=True) == [ + (b"u1", b"1"), + (b"u2", b"2"), + (b"u3", b"3"), + ] + + async def test_sort_groups_string_get(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + with pytest.raises(exceptions.DataError): + await r.sort("a", get="user:*", groups=True) + + async def test_sort_groups_just_one_get(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + with pytest.raises(exceptions.DataError): + await r.sort("a", get=["user:*"], groups=True) + + async def test_sort_groups_no_get(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.rpush("a", "2", "3", "1") + with pytest.raises(exceptions.DataError): + await r.sort("a", groups=True) + + async def test_sort_groups_three_gets(self, r: aioredis.Redis): + await r.set("user:1", "u1") + await r.set("user:2", "u2") + await r.set("user:3", "u3") + await r.set("door:1", "d1") + await r.set("door:2", "d2") + await r.set("door:3", "d3") + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", get=("user:*", "door:*", "#"), groups=True) == [ + (b"u1", b"d1", b"1"), + (b"u2", b"d2", b"2"), + (b"u3", b"d3", b"3"), + ] + + async def test_sort_desc(self, r: aioredis.Redis): + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", desc=True) == [b"3", b"2", b"1"] + + async def test_sort_alpha(self, r: aioredis.Redis): + await r.rpush("a", "e", "c", "b", "d", "a") + assert await r.sort("a", alpha=True) == [b"a", b"b", b"c", b"d", b"e"] + + async def test_sort_store(self, r: aioredis.Redis): + await r.rpush("a", "2", "3", "1") + assert await r.sort("a", store="sorted_values") == 3 + assert await r.lrange("sorted_values", 0, -1) == [b"1", b"2", b"3"] + + async def test_sort_all_options(self, r: aioredis.Redis): + await r.set("user:1:username", "zeus") + await r.set("user:2:username", "titan") + await r.set("user:3:username", "hermes") + await r.set("user:4:username", "hercules") + await r.set("user:5:username", "apollo") + await r.set("user:6:username", "athena") + await r.set("user:7:username", "hades") + await r.set("user:8:username", "dionysus") + + await r.set("user:1:favorite_drink", "yuengling") + await r.set("user:2:favorite_drink", "rum") + await r.set("user:3:favorite_drink", "vodka") + await r.set("user:4:favorite_drink", "milk") + await r.set("user:5:favorite_drink", "pinot noir") + await r.set("user:6:favorite_drink", "water") + await r.set("user:7:favorite_drink", "gin") + await r.set("user:8:favorite_drink", "apple juice") + + await r.rpush("gods", "5", "8", "3", "1", "2", "7", "6", "4") + num = await r.sort( + "gods", + start=2, + num=4, + by="user:*:username", + get="user:*:favorite_drink", + desc=True, + alpha=True, + store="sorted", + ) + assert num == 4 + assert await r.lrange("sorted", 0, 10) == [ + b"vodka", + b"milk", + b"gin", + b"apple juice", + ] + + async def test_sort_issue_924(self, r: aioredis.Redis): + # Tests for issue https://github.com/andymccurdy/redis-py/issues/924 + await r.execute_command("SADD", "issue#924", 1) + await r.execute_command("SORT", "issue#924") + + async def test_cluster_addslots(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("ADDSLOTS", 1) is True + + async def test_cluster_count_failure_reports(self, mock_cluster_resp_int): + assert isinstance( + await mock_cluster_resp_int.cluster("COUNT-FAILURE-REPORTS", "node"), int + ) + + async def test_cluster_countkeysinslot(self, mock_cluster_resp_int): + assert isinstance( + await mock_cluster_resp_int.cluster("COUNTKEYSINSLOT", 2), int + ) + + async def test_cluster_delslots(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("DELSLOTS", 1) is True + + async def test_cluster_failover(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("FAILOVER", 1) is True + + async def test_cluster_forget(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("FORGET", 1) is True + + async def test_cluster_info(self, mock_cluster_resp_info): + assert isinstance(await mock_cluster_resp_info.cluster("info"), dict) + + async def test_cluster_keyslot(self, mock_cluster_resp_int): + assert isinstance(await mock_cluster_resp_int.cluster("keyslot", "asdf"), int) + + async def test_cluster_meet(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("meet", "ip", "port", 1) is True + + async def test_cluster_nodes(self, mock_cluster_resp_nodes): + assert isinstance(await mock_cluster_resp_nodes.cluster("nodes"), dict) + + async def test_cluster_replicate(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("replicate", "nodeid") is True + + async def test_cluster_reset(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("reset", "hard") is True + + async def test_cluster_saveconfig(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.cluster("saveconfig") is True + + async def test_cluster_setslot(self, mock_cluster_resp_ok): + assert ( + await mock_cluster_resp_ok.cluster("setslot", 1, "IMPORTING", "nodeid") + is True + ) + + async def test_cluster_slaves(self, mock_cluster_resp_slaves): + assert isinstance( + await mock_cluster_resp_slaves.cluster("slaves", "nodeid"), dict + ) + + @skip_if_server_version_lt("3.0.0") + async def test_readwrite(self, r: aioredis.Redis): + assert await r.readwrite() + + @skip_if_server_version_lt("3.0.0") + async def test_readonly_invalid_cluster_state(self, r: aioredis.Redis): + with pytest.raises(exceptions.RedisError): + await r.readonly() + + @skip_if_server_version_lt("3.0.0") + async def test_readonly(self, mock_cluster_resp_ok): + assert await mock_cluster_resp_ok.readonly() is True + + # GEO COMMANDS + @skip_if_server_version_lt("3.2.0") + async def test_geoadd(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + assert await r.geoadd("barcelona", *values) == 2 + assert await r.zcard("barcelona") == 2 + + @skip_if_server_version_lt("3.2.0") + async def test_geoadd_invalid_params(self, r: aioredis.Redis): + with pytest.raises(exceptions.RedisError): + await r.geoadd("barcelona", *(1, 2)) + + @skip_if_server_version_lt("3.2.0") + async def test_geodist(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + assert await r.geoadd("barcelona", *values) == 2 + assert await r.geodist("barcelona", "place1", "place2") == 3067.4157 + + @skip_if_server_version_lt("3.2.0") + async def test_geodist_units(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.geodist("barcelona", "place1", "place2", "km") == 3.0674 + + @skip_if_server_version_lt("3.2.0") + async def test_geodist_missing_one_member(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + await r.geoadd("barcelona", *values) + assert await r.geodist("barcelona", "place1", "missing_member", "km") is None + + @skip_if_server_version_lt("3.2.0") + async def test_geodist_invalid_units(self, r: aioredis.Redis): + with pytest.raises(exceptions.RedisError): + assert await r.geodist("x", "y", "z", "inches") + + @skip_if_server_version_lt("3.2.0") + async def test_geohash(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.geohash("barcelona", "place1", "place2", "place3") == [ + "sp3e9yg3kd0", + "sp3e9cbc3t0", + None, + ] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + async def test_geopos(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + # redis uses 52 bits precision, hereby small errors may be introduced. + assert await r.geopos("barcelona", "place1", "place2") == [ + (2.19093829393386841, 41.43379028184083523), + (2.18737632036209106, 41.40634178640635099), + ] + + @skip_if_server_version_lt("4.0.0") + async def test_geopos_no_value(self, r: aioredis.Redis): + assert await r.geopos("barcelona", "place1", "place2") == [None, None] + + @skip_if_server_version_lt("3.2.0") + @skip_if_server_version_gte("4.0.0") + async def test_old_geopos_no_value(self, r: aioredis.Redis): + assert await r.geopos("barcelona", "place1", "place2") == [] + + @skip_if_server_version_lt("3.2.0") + async def test_georadius(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadius("barcelona", 2.191, 41.433, 1000) == [b"place1"] + assert await r.georadius("barcelona", 2.187, 41.406, 1000) == [b"\x80place2"] + + @skip_if_server_version_lt("3.2.0") + async def test_georadius_no_values(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadius("barcelona", 1, 2, 1000) == [] + + @skip_if_server_version_lt("3.2.0") + async def test_georadius_units(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadius("barcelona", 2.191, 41.433, 1, unit="km") == [ + b"place1" + ] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + async def test_georadius_with(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + + # test a bunch of combinations to test the parse response + # function. + assert await r.georadius( + "barcelona", + 2.191, + 41.433, + 1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) == [ + [ + b"place1", + 0.0881, + 3471609698139488, + (2.19093829393386841, 41.43379028184083523), + ] + ] + + assert await r.georadius( + "barcelona", 2.191, 41.433, 1, unit="km", withdist=True, withcoord=True + ) == [[b"place1", 0.0881, (2.19093829393386841, 41.43379028184083523)]] + + assert await r.georadius( + "barcelona", 2.191, 41.433, 1, unit="km", withhash=True, withcoord=True + ) == [ + [b"place1", 3471609698139488, (2.19093829393386841, 41.43379028184083523)] + ] + + # test no values. + assert ( + await r.georadius( + "barcelona", + 2, + 1, + 1, + unit="km", + withdist=True, + withcoord=True, + withhash=True, + ) + == [] + ) + + @skip_if_server_version_lt("3.2.0") + async def test_georadius_count(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadius("barcelona", 2.191, 41.433, 3000, count=1) == [ + b"place1" + ] + + @skip_if_server_version_lt("3.2.0") + async def test_georadius_sort(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadius("barcelona", 2.191, 41.433, 3000, sort="ASC") == [ + b"place1", + b"place2", + ] + assert await r.georadius("barcelona", 2.191, 41.433, 3000, sort="DESC") == [ + b"place2", + b"place1", + ] + + @skip_if_server_version_lt("3.2.0") + async def test_georadius_store(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + await r.georadius("barcelona", 2.191, 41.433, 1000, store="places_barcelona") + assert await r.zrange("places_barcelona", 0, -1) == [b"place1"] + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + async def test_georadius_store_dist(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + "place2", + ) + + await r.geoadd("barcelona", *values) + await r.georadius( + "barcelona", 2.191, 41.433, 1000, store_dist="places_barcelona" + ) + # instead of save the geo score, the distance is saved. + assert await r.zscore("places_barcelona", "place1") == 88.05060698409301 + + @skip_unless_arch_bits(64) + @skip_if_server_version_lt("3.2.0") + async def test_georadiusmember(self, r: aioredis.Redis): + values = (2.1909389952632, 41.433791470673, "place1") + ( + 2.1873744593677, + 41.406342043777, + b"\x80place2", + ) + + await r.geoadd("barcelona", *values) + assert await r.georadiusbymember("barcelona", "place1", 4000) == [ + b"\x80place2", + b"place1", + ] + assert await r.georadiusbymember("barcelona", "place1", 10) == [b"place1"] + + assert await r.georadiusbymember( + "barcelona", "place1", 4000, withdist=True, withcoord=True, withhash=True + ) == [ + [ + b"\x80place2", + 3067.4157, + 3471609625421029, + (2.187376320362091, 41.40634178640635), + ], + [ + b"place1", + 0.0, + 3471609698139488, + (2.1909382939338684, 41.433790281840835), + ], + ] + + @skip_if_server_version_lt("5.0.0") + async def test_xack(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer = "consumer" + # xack on a stream that doesn't exist + assert await r.xack(stream, group, "0-0") == 0 + + m1 = await r.xadd(stream, {"one": "one"}) + m2 = await r.xadd(stream, {"two": "two"}) + m3 = await r.xadd(stream, {"three": "three"}) + + # xack on a group that doesn't exist + assert await r.xack(stream, group, m1) == 0 + + await r.xgroup_create(stream, group, 0) + await r.xreadgroup(group, consumer, streams={stream: ">"}) + # xack returns the number of ack'd elements + assert await r.xack(stream, group, m1) == 1 + assert await r.xack(stream, group, m2, m3) == 2 + + @skip_if_server_version_lt("5.0.0") + async def test_xadd(self, r: aioredis.Redis): + stream = "stream" + message_id = await r.xadd(stream, {"foo": "bar"}) + assert re.match(br"[0-9]+\-[0-9]+", message_id) + + # explicit message id + message_id = b"9999999999999999999-0" + assert message_id == await r.xadd(stream, {"foo": "bar"}, id=message_id) + + # with maxlen, the list evicts the first message + await r.xadd(stream, {"foo": "bar"}, maxlen=2, approximate=False) + assert await r.xlen(stream) == 2 + + @skip_if_server_version_lt("5.0.0") + async def test_xclaim(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + + message_id = await r.xadd(stream, {"john": "wick"}) + message = await get_stream_message(r, stream, message_id) + await r.xgroup_create(stream, group, 0) + + # trying to claim a message that isn't already pending doesn't + # do anything + response = await r.xclaim( + stream, group, consumer2, min_idle_time=0, message_ids=(message_id,) + ) + assert response == [] + + # read the group as consumer1 to initially claim the messages + await r.xreadgroup(group, consumer1, streams={stream: ">"}) + + # claim the message as consumer2 + response = await r.xclaim( + stream, group, consumer2, min_idle_time=0, message_ids=(message_id,) + ) + assert response[0] == message + + # reclaim the message as consumer1, but use the justid argument + # which only returns message ids + assert ( + await r.xclaim( + stream, + group, + consumer1, + min_idle_time=0, + message_ids=(message_id,), + justid=True, + ) + == [message_id] + ) + + @skip_if_server_version_lt("5.0.0") + async def test_xclaim_trimmed(self, r: aioredis.Redis): + # xclaim should not raise an exception if the item is not there + stream = "stream" + group = "group" + + await r.xgroup_create(stream, group, id="$", mkstream=True) + + # add a couple of new items + sid1 = await r.xadd(stream, {"item": 0}) + sid2 = await r.xadd(stream, {"item": 0}) + + # read them from consumer1 + await r.xreadgroup(group, "consumer1", {stream: ">"}) + + # add a 3rd and trim the stream down to 2 items + await r.xadd(stream, {"item": 3}, maxlen=2, approximate=False) + + # xclaim them from consumer2 + # the item that is still in the stream should be returned + item = await r.xclaim(stream, group, "consumer2", 0, [sid1, sid2]) + assert len(item) == 2 + assert item[0] == (None, None) + assert item[1][0] == sid2 + + @skip_if_server_version_lt("5.0.0") + async def test_xdel(self, r: aioredis.Redis): + stream = "stream" + + # deleting from an empty stream doesn't do anything + assert await r.xdel(stream, 1) == 0 + + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + m3 = await r.xadd(stream, {"foo": "bar"}) + + # xdel returns the number of deleted elements + assert await r.xdel(stream, m1) == 1 + assert await r.xdel(stream, m2, m3) == 2 + + @skip_if_server_version_lt("5.0.0") + async def test_xgroup_create(self, r: aioredis.Redis): + # tests xgroup_create and xinfo_groups + stream = "stream" + group = "group" + await r.xadd(stream, {"foo": "bar"}) + + # no group is setup yet, no info to obtain + assert await r.xinfo_groups(stream) == [] + + assert await r.xgroup_create(stream, group, 0) + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": b"0-0", + } + ] + assert await r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xgroup_create_mkstream(self, r: aioredis.Redis): + # tests xgroup_create and xinfo_groups + stream = "stream" + group = "group" + + # an error is raised if a group is created on a stream that + # doesn't already exist + with pytest.raises(exceptions.ResponseError): + await r.xgroup_create(stream, group, 0) + + # however, with mkstream=True, the underlying stream is created + # automatically + assert await r.xgroup_create(stream, group, 0, mkstream=True) + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": b"0-0", + } + ] + assert await r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xgroup_delconsumer(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer = "consumer" + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xgroup_create(stream, group, 0) + + # a consumer that hasn't yet read any messages doesn't do anything + assert await r.xgroup_delconsumer(stream, group, consumer) == 0 + + # read all messages from the group + await r.xreadgroup(group, consumer, streams={stream: ">"}) + + # deleting the consumer should return 2 pending messages + assert await r.xgroup_delconsumer(stream, group, consumer) == 2 + + @skip_if_server_version_lt("5.0.0") + async def test_xgroup_destroy(self, r: aioredis.Redis): + stream = "stream" + group = "group" + await r.xadd(stream, {"foo": "bar"}) + + # destroying a nonexistent group returns False + assert not await r.xgroup_destroy(stream, group) + + await r.xgroup_create(stream, group, 0) + assert await r.xgroup_destroy(stream, group) + + @skip_if_server_version_lt("5.0.0") + async def test_xgroup_setid(self, r: aioredis.Redis): + stream = "stream" + group = "group" + message_id = await r.xadd(stream, {"foo": "bar"}) + + await r.xgroup_create(stream, group, 0) + # advance the last_delivered_id to the message_id + await r.xgroup_setid(stream, group, message_id) + expected = [ + { + "name": group.encode(), + "consumers": 0, + "pending": 0, + "last-delivered-id": message_id, + } + ] + assert await r.xinfo_groups(stream) == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xinfo_consumers(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + + await r.xgroup_create(stream, group, 0) + await r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + await r.xreadgroup(group, consumer2, streams={stream: ">"}) + info = await r.xinfo_consumers(stream, group) + assert len(info) == 2 + expected = [ + {"name": consumer1.encode(), "pending": 1}, + {"name": consumer2.encode(), "pending": 2}, + ] + + # we can't determine the idle time, so just make sure it's an int + assert isinstance(info[0].pop("idle"), int) + assert isinstance(info[1].pop("idle"), int) + assert info == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xinfo_stream(self, r: aioredis.Redis): + stream = "stream" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + info = await r.xinfo_stream(stream) + + assert info["length"] == 2 + assert info["first-entry"] == await get_stream_message(r, stream, m1) + assert info["last-entry"] == await get_stream_message(r, stream, m2) + + @skip_if_server_version_lt("5.0.0") + async def test_xlen(self, r: aioredis.Redis): + stream = "stream" + assert await r.xlen(stream) == 0 + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + assert await r.xlen(stream) == 2 + + @skip_if_server_version_lt("5.0.0") + async def test_xpending(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + await r.xgroup_create(stream, group, 0) + + # xpending on a group that has no consumers yet + expected = {"pending": 0, "min": None, "max": None, "consumers": []} + assert await r.xpending(stream, group) == expected + + # read 1 message from the group with each consumer + await r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + await r.xreadgroup(group, consumer2, streams={stream: ">"}, count=1) + + expected = { + "pending": 2, + "min": m1, + "max": m2, + "consumers": [ + {"name": consumer1.encode(), "pending": 1}, + {"name": consumer2.encode(), "pending": 1}, + ], + } + assert await r.xpending(stream, group) == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xpending_range(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer1 = "consumer1" + consumer2 = "consumer2" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + await r.xgroup_create(stream, group, 0) + + # xpending range on a group that has no consumers yet + assert await r.xpending_range(stream, group, min="-", max="+", count=5) == [] + + # read 1 message from the group with each consumer + await r.xreadgroup(group, consumer1, streams={stream: ">"}, count=1) + await r.xreadgroup(group, consumer2, streams={stream: ">"}, count=1) + + response = await r.xpending_range(stream, group, min="-", max="+", count=5) + assert len(response) == 2 + assert response[0]["message_id"] == m1 + assert response[0]["consumer"] == consumer1.encode() + assert response[1]["message_id"] == m2 + assert response[1]["consumer"] == consumer2.encode() + + @skip_if_server_version_lt("5.0.0") + async def test_xrange(self, r: aioredis.Redis): + stream = "stream" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + m3 = await r.xadd(stream, {"foo": "bar"}) + m4 = await r.xadd(stream, {"foo": "bar"}) + + def get_ids(results): + return [result[0] for result in results] + + results = await r.xrange(stream, min=m1) + assert get_ids(results) == [m1, m2, m3, m4] + + results = await r.xrange(stream, min=m2, max=m3) + assert get_ids(results) == [m2, m3] + + results = await r.xrange(stream, max=m3) + assert get_ids(results) == [m1, m2, m3] + + results = await r.xrange(stream, max=m2, count=1) + assert get_ids(results) == [m1] + + @skip_if_server_version_lt("5.0.0") + async def test_xread(self, r: aioredis.Redis): + stream = "stream" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"bing": "baz"}) + + expected = [ + [ + stream.encode(), + [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), + ], + ] + ] + # xread starting at 0 returns both messages + assert await r.xread(streams={stream: 0}) == expected + + expected = [ + [ + stream.encode(), + [ + await get_stream_message(r, stream, m1), + ], + ] + ] + # xread starting at 0 and count=1 returns only the first message + assert await r.xread(streams={stream: 0}, count=1) == expected + + expected = [ + [ + stream.encode(), + [ + await get_stream_message(r, stream, m2), + ], + ] + ] + # xread starting at m1 returns only the second message + assert await r.xread(streams={stream: m1}) == expected + + # xread starting at the last message returns an empty list + assert await r.xread(streams={stream: m2}) == [] + + @skip_if_server_version_lt("5.0.0") + async def test_xreadgroup(self, r: aioredis.Redis): + stream = "stream" + group = "group" + consumer = "consumer" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"bing": "baz"}) + await r.xgroup_create(stream, group, 0) + + expected = [ + [ + stream.encode(), + [ + await get_stream_message(r, stream, m1), + await get_stream_message(r, stream, m2), + ], + ] + ] + # xread starting at 0 returns both messages + assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + + await r.xgroup_destroy(stream, group) + await r.xgroup_create(stream, group, 0) + + expected = [ + [ + stream.encode(), + [ + await get_stream_message(r, stream, m1), + ], + ] + ] + # xread with count=1 returns only the first message + assert ( + await r.xreadgroup(group, consumer, streams={stream: ">"}, count=1) + == expected + ) + + await r.xgroup_destroy(stream, group) + + # create the group using $ as the last id meaning subsequent reads + # will only find messages added after this + await r.xgroup_create(stream, group, "$") + + expected = [] + # xread starting after the last message returns an empty message list + assert await r.xreadgroup(group, consumer, streams={stream: ">"}) == expected + + # xreadgroup with noack does not have any items in the PEL + await r.xgroup_destroy(stream, group) + await r.xgroup_create(stream, group, "0") + assert ( + len( + ( + await r.xreadgroup( + group, consumer, streams={stream: ">"}, noack=True + ) + )[0][1] + ) + == 2 + ) + # now there should be nothing pending + assert ( + len((await r.xreadgroup(group, consumer, streams={stream: "0"}))[0][1]) == 0 + ) + + await r.xgroup_destroy(stream, group) + await r.xgroup_create(stream, group, "0") + # delete all the messages in the stream + expected = [ + [ + stream.encode(), + [ + (m1, {}), + (m2, {}), + ], + ] + ] + await r.xreadgroup(group, consumer, streams={stream: ">"}) + await r.xtrim(stream, 0) + assert await r.xreadgroup(group, consumer, streams={stream: "0"}) == expected + + @skip_if_server_version_lt("5.0.0") + async def test_xrevrange(self, r: aioredis.Redis): + stream = "stream" + m1 = await r.xadd(stream, {"foo": "bar"}) + m2 = await r.xadd(stream, {"foo": "bar"}) + m3 = await r.xadd(stream, {"foo": "bar"}) + m4 = await r.xadd(stream, {"foo": "bar"}) + + def get_ids(results): + return [result[0] for result in results] + + results = await r.xrevrange(stream, max=m4) + assert get_ids(results) == [m4, m3, m2, m1] + + results = await r.xrevrange(stream, max=m3, min=m2) + assert get_ids(results) == [m3, m2] + + results = await r.xrevrange(stream, min=m3) + assert get_ids(results) == [m4, m3] + + results = await r.xrevrange(stream, min=m2, count=1) + assert get_ids(results) == [m4] + + @skip_if_server_version_lt("5.0.0") + async def test_xtrim(self, r: aioredis.Redis): + stream = "stream" + + # trimming an empty key doesn't do anything + assert await r.xtrim(stream, 1000) == 0 + + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + await r.xadd(stream, {"foo": "bar"}) + + # trimming an amount large than the number of messages + # doesn't do anything + assert await r.xtrim(stream, 5, approximate=False) == 0 + + # 1 message is trimmed + assert await r.xtrim(stream, 3, approximate=False) == 1 + + async def test_bitfield_operations(self, r: aioredis.Redis): + # comments show affected bits + await r.execute_command("SELECT", 10) + bf = r.bitfield("a") + resp = await ( + bf.set("u8", 8, 255) # 00000000 11111111 + .get("u8", 0) # 00000000 + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 111 0 + .execute() + ) + assert resp == [0, 0, 15, 15, 14] + + # .set() returns the previous value... + resp = await ( + bf.set("u8", 4, 1) # 0000 0001 + .get("u16", 0) # 00000000 00011111 + .set("u16", 0, 0) # 00000000 00000000 + .execute() + ) + assert resp == [15, 31, 31] + + # incrby adds to the value + resp = await ( + bf.incrby("u8", 8, 254) # 00000000 11111110 + .incrby("u8", 8, 1) # 00000000 11111111 + .get("u16", 0) # 00000000 11111111 + .execute() + ) + assert resp == [254, 255, 255] + + # Verify overflow protection works as a method: + await r.delete("a") + resp = await ( + bf.set("u8", 8, 254) # 00000000 11111110 + .overflow("fail") + .incrby("u8", 8, 2) # incrby 2 would overflow, None returned + .incrby("u8", 8, 1) # 00000000 11111111 + .incrby("u8", 8, 1) # incrby 1 would overflow, None returned + .get("u16", 0) # 00000000 11111111 + .execute() + ) + assert resp == [0, None, 255, None, 255] + + # Verify overflow protection works as arg to incrby: + await r.delete("a") + resp = await ( + bf.set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1) # 00000000 00000000 wrap default + .set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1, "FAIL") # 00000000 11111111 fail + .incrby("u8", 8, 1) # 00000000 11111111 still fail + .get("u16", 0) # 00000000 11111111 + .execute() + ) + assert resp == [0, 0, 0, None, None, 255] + + # test default default_overflow + await r.delete("a") + bf = r.bitfield("a", default_overflow="FAIL") + resp = await ( + bf.set("u8", 8, 255) # 00000000 11111111 + .incrby("u8", 8, 1) # 00000000 11111111 fail default + .get("u16", 0) # 00000000 11111111 + .execute() + ) + assert resp == [0, None, 255] + + @skip_if_server_version_lt("4.0.0") + async def test_memory_stats(self, r: aioredis.Redis): + # put a key into the current db to make sure that "db." + # has data + await r.set("foo", "bar") + stats = await r.memory_stats() + assert isinstance(stats, dict) + for key, value in stats.items(): + if key.startswith("db."): + assert isinstance(value, dict) + + @skip_if_server_version_lt("4.0.0") + async def test_memory_usage(self, r: aioredis.Redis): + await r.set("foo", "bar") + assert isinstance(await r.memory_usage("foo"), int) + + @skip_if_server_version_lt("4.0.0") + async def test_module_list(self, r: aioredis.Redis): + assert isinstance(await r.module_list(), list) + assert not await r.module_list() + + +class TestBinarySave: + async def test_binary_get_set(self, r: aioredis.Redis): + assert await r.set(" foo bar ", "123") + assert await r.get(" foo bar ") == b"123" + + assert await r.set(" foo\r\nbar\r\n ", "456") + assert await r.get(" foo\r\nbar\r\n ") == b"456" + + assert await r.set(" \r\n\t\x07\x13 ", "789") + assert await r.get(" \r\n\t\x07\x13 ") == b"789" + + assert sorted(await r.keys("*")) == [ + b" \r\n\t\x07\x13 ", + b" foo\r\nbar\r\n ", + b" foo bar ", + ] + + assert await r.delete(" foo bar ") + assert await r.delete(" foo\r\nbar\r\n ") + assert await r.delete(" \r\n\t\x07\x13 ") + + async def test_binary_lists(self, r: aioredis.Redis): + mapping = { + b"foo bar": [b"1", b"2", b"3"], + b"foo\r\nbar\r\n": [b"4", b"5", b"6"], + b"foo\tbar\x07": [b"7", b"8", b"9"], + } + # fill in lists + for key, value in mapping.items(): + await r.rpush(key, *value) + + # check that KEYS returns all the keys as they are + assert sorted(await r.keys("*")) == sorted(mapping.keys()) + + # check that it is possible to get list content by key name + for key, value in mapping.items(): + assert await r.lrange(key, 0, -1) == value + + async def test_22_info(self, r: aioredis.Redis): + """ + Older Redis versions contained 'allocation_stats' in INFO that + was the cause of a number of bugs when parsing. + """ + info = ( + "allocation_stats:6=1,7=1,8=7141,9=180,10=92,11=116,12=5330," + "13=123,14=3091,15=11048,16=225842,17=1784,18=814,19=12020," + "20=2530,21=645,22=15113,23=8695,24=142860,25=318,26=3303," + "27=20561,28=54042,29=37390,30=1884,31=18071,32=31367,33=160," + "34=169,35=201,36=10155,37=1045,38=15078,39=22985,40=12523," + "41=15588,42=265,43=1287,44=142,45=382,46=945,47=426,48=171," + "49=56,50=516,51=43,52=41,53=46,54=54,55=75,56=647,57=332," + "58=32,59=39,60=48,61=35,62=62,63=32,64=221,65=26,66=30," + "67=36,68=41,69=44,70=26,71=144,72=169,73=24,74=37,75=25," + "76=42,77=21,78=126,79=374,80=27,81=40,82=43,83=47,84=46," + "85=114,86=34,87=37,88=7240,89=34,90=38,91=18,92=99,93=20," + "94=18,95=17,96=15,97=22,98=18,99=69,100=17,101=22,102=15," + "103=29,104=39,105=30,106=70,107=22,108=21,109=26,110=52," + "111=45,112=33,113=67,114=41,115=44,116=48,117=53,118=54," + "119=51,120=75,121=44,122=57,123=44,124=66,125=56,126=52," + "127=81,128=108,129=70,130=50,131=51,132=53,133=45,134=62," + "135=12,136=13,137=7,138=15,139=21,140=11,141=20,142=6,143=7," + "144=11,145=6,146=16,147=19,148=1112,149=1,151=83,154=1," + "155=1,156=1,157=1,160=1,161=1,162=2,166=1,169=1,170=1,171=2," + "172=1,174=1,176=2,177=9,178=34,179=73,180=30,181=1,185=3," + "187=1,188=1,189=1,192=1,196=1,198=1,200=1,201=1,204=1,205=1," + "207=1,208=1,209=1,214=2,215=31,216=78,217=28,218=5,219=2," + "220=1,222=1,225=1,227=1,234=1,242=1,250=1,252=1,253=1," + ">=256=203" + ) + parsed = parse_info(info) + assert "allocation_stats" in parsed + assert "6" in parsed["allocation_stats"] + assert ">=256" in parsed["allocation_stats"] + + async def test_large_responses(self, r: aioredis.Redis): + """The PythonParser has some special cases for return values > 1MB""" + # load up 5MB of data into a key + data = "".join([ascii_letters] * (5000000 // len(ascii_letters))) + await r.set("a", data) + assert await r.get("a") == data.encode() + + async def test_floating_point_encoding(self, r: aioredis.Redis): + """ + High precision floating point values sent to the server should keep + precision. + """ + timestamp = 1349673917.939762 + await r.zadd("a", {"a1": timestamp}) + assert await r.zscore("a", "a1") == timestamp diff --git a/tests/test_connection.py b/tests/test_connection.py new file mode 100644 index 000000000..bc7160df7 --- /dev/null +++ b/tests/test_connection.py @@ -0,0 +1,23 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING +from unittest import mock + +import pytest + +from aioredis.exceptions import InvalidResponse +from aioredis.utils import HIREDIS_AVAILABLE + +if TYPE_CHECKING: + from aioredis.connection import PythonParser + + +@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only") +@pytest.mark.asyncio +async def test_invalid_response(r): + raw = b"x" + parser: PythonParser = r.connection._parser + with mock.patch.object(parser._buffer, "readline", return_value=raw): + with pytest.raises(InvalidResponse) as cm: + await parser.read_response() + assert str(cm.value) == "Protocol Error: %r" % raw diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py new file mode 100644 index 000000000..f2cf35037 --- /dev/null +++ b/tests/test_connection_pool.py @@ -0,0 +1,771 @@ +import asyncio +import os +import re +import time +from unittest import mock + +import pytest + +import aioredis +from aioredis.connection import Connection, to_bool + +from .conftest import REDIS_6_VERSION, _get_client, skip_if_server_version_lt +from .test_pubsub import wait_for_message + +pytestmark = pytest.mark.asyncio + + +class DummyConnection(Connection): + description_format = "DummyConnection<>" + + def __init__(self, **kwargs): + self.kwargs = kwargs + self.pid = os.getpid() + + async def connect(self): + pass + + async def can_read(self, timeout: float = 0): + return False + + +class TestConnectionPool: + def get_pool( + self, + connection_kwargs=None, + max_connections=None, + connection_class=aioredis.Connection, + ): + connection_kwargs = connection_kwargs or {} + pool = aioredis.ConnectionPool( + connection_class=connection_class, + max_connections=max_connections, + **connection_kwargs + ) + return pool + + async def test_connection_creation(self): + connection_kwargs = {"foo": "bar", "biz": "baz"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=DummyConnection + ) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_max_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(max_connections=2, connection_kwargs=connection_kwargs) + await pool.get_connection("_") + await pool.get_connection("_") + with pytest.raises(aioredis.ConnectionError): + await pool.get_connection("_") + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + connection_kwargs = { + "host": "localhost", + "port": 6379, + "db": 1, + "client_name": "test-client", + } + pool = self.get_pool( + connection_kwargs=connection_kwargs, connection_class=aioredis.Connection + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + connection_kwargs = {"path": "/abc", "db": 1, "client_name": "test-client"} + pool = self.get_pool( + connection_kwargs=connection_kwargs, + connection_class=aioredis.UnixDomainSocketConnection, + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +class TestBlockingConnectionPool: + def get_pool(self, connection_kwargs=None, max_connections=10, timeout=20): + connection_kwargs = connection_kwargs or {} + pool = aioredis.BlockingConnectionPool( + connection_class=DummyConnection, + max_connections=max_connections, + timeout=timeout, + **connection_kwargs + ) + return pool + + async def test_connection_creation(self, master_host): + connection_kwargs = {"foo": "bar", "biz": "baz", "host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + connection = await pool.get_connection("_") + assert isinstance(connection, DummyConnection) + assert connection.kwargs == connection_kwargs + + async def test_multiple_connections(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + c2 = await pool.get_connection("_") + assert c1 != c2 + + async def test_connection_pool_blocks_until_timeout(self, master_host): + """When out of connections, block for timeout seconds, then raise""" + connection_kwargs = {"host": master_host} + pool = self.get_pool( + max_connections=1, timeout=0.1, connection_kwargs=connection_kwargs + ) + await pool.get_connection("_") + + start = time.time() + with pytest.raises(aioredis.ConnectionError): + await pool.get_connection("_") + # we should have waited at least 0.1 seconds + assert time.time() - start >= 0.1 + + async def test_connection_pool_blocks_until_conn_available(self, master_host): + """ + When out of connections, block until another connection is released + to the pool + """ + connection_kwargs = {"host": master_host} + pool = self.get_pool( + max_connections=1, timeout=2, connection_kwargs=connection_kwargs + ) + c1 = await pool.get_connection("_") + + async def target(): + await asyncio.sleep(0.1) + await pool.release(c1) + + start = time.time() + await asyncio.gather(target(), pool.get_connection("_")) + assert time.time() - start >= 0.1 + + async def test_reuse_previously_released_connection(self, master_host): + connection_kwargs = {"host": master_host} + pool = self.get_pool(connection_kwargs=connection_kwargs) + c1 = await pool.get_connection("_") + await pool.release(c1) + c2 = await pool.get_connection("_") + assert c1 == c2 + + def test_repr_contains_db_info_tcp(self): + pool = aioredis.ConnectionPool( + host="localhost", port=6379, client_name="test-client" + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + def test_repr_contains_db_info_unix(self): + pool = aioredis.ConnectionPool( + connection_class=aioredis.UnixDomainSocketConnection, + path="abc", + client_name="test-client", + ) + expected = ( + "ConnectionPool>" + ) + assert repr(pool) == expected + + +class TestConnectionPoolURLParsing: + def test_hostname(self): + pool = aioredis.ConnectionPool.from_url("redis://my.host") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_quoted_hostname(self): + pool = aioredis.ConnectionPool.from_url("redis://my %2F host %2B%3D+") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "my / host +=+", + } + + def test_port(self): + pool = aioredis.ConnectionPool.from_url("redis://localhost:6380") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "port": 6380, + } + + @skip_if_server_version_lt(REDIS_6_VERSION) + def test_username(self): + pool = aioredis.ConnectionPool.from_url("redis://myuser:@localhost") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + } + + @skip_if_server_version_lt(REDIS_6_VERSION) + def test_quoted_username(self): + pool = aioredis.ConnectionPool.from_url( + "redis://%2Fmyuser%2F%2B name%3D%24+:@localhost" + ) + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = aioredis.ConnectionPool.from_url("redis://:mypassword@localhost") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = aioredis.ConnectionPool.from_url( + "redis://:%2Fmypass%2F%2B word%3D%24+@localhost" + ) + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "password": "/mypass/+ word=$+", + } + + @skip_if_server_version_lt(REDIS_6_VERSION) + def test_username_and_password(self): + pool = aioredis.ConnectionPool.from_url("redis://myuser:mypass@localhost") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "username": "myuser", + "password": "mypass", + } + + def test_db_as_argument(self): + pool = aioredis.ConnectionPool.from_url("redis://localhost", db=1) + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 1, + } + + def test_db_in_path(self): + pool = aioredis.ConnectionPool.from_url("redis://localhost/2", db=1) + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + } + + def test_db_in_querystring(self): + pool = aioredis.ConnectionPool.from_url("redis://localhost/2?db=3", db=1) + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 3, + } + + def test_extra_typed_querystring_options(self): + pool = aioredis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=20&socket_connect_timeout=10" + "&socket_keepalive=&retry_on_timeout=Yes&max_connections=10" + ) + + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == { + "host": "localhost", + "db": 2, + "socket_timeout": 20.0, + "socket_connect_timeout": 10.0, + "retry_on_timeout": True, + } + assert pool.max_connections == 10 + + def test_boolean_parsing(self): + for expected, value in ( + (None, None), + (None, ""), + (False, 0), + (False, "0"), + (False, "f"), + (False, "F"), + (False, "False"), + (False, "n"), + (False, "N"), + (False, "No"), + (True, 1), + (True, "1"), + (True, "y"), + (True, "Y"), + (True, "Yes"), + ): + assert expected is to_bool(value) + + def test_client_name_in_querystring(self): + pool = aioredis.ConnectionPool.from_url( + "redis://location?client_name=test-client" + ) + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_invalid_extra_typed_querystring_options(self): + with pytest.raises(ValueError): + aioredis.ConnectionPool.from_url( + "redis://localhost/2?socket_timeout=_&" "socket_connect_timeout=abc" + ) + + def test_extra_querystring_options(self): + pool = aioredis.ConnectionPool.from_url("redis://localhost?a=1&b=2") + assert pool.connection_class == aioredis.Connection + assert pool.connection_kwargs == {"host": "localhost", "a": "1", "b": "2"} + + def test_calling_from_subclass_returns_correct_instance(self): + pool = aioredis.BlockingConnectionPool.from_url("redis://localhost") + assert isinstance(pool, aioredis.BlockingConnectionPool) + + def test_client_creates_connection_pool(self): + r = aioredis.Redis.from_url("redis://myhost") + assert r.connection_pool.connection_class == aioredis.Connection + assert r.connection_pool.connection_kwargs == { + "host": "myhost", + } + + def test_invalid_scheme_raises_error(self): + with pytest.raises(ValueError) as cm: + aioredis.ConnectionPool.from_url("localhost") + assert str(cm.value) == ( + "Redis URL must specify one of the following schemes " + "(redis://, rediss://, unix://)" + ) + + +class TestConnectionPoolUnixSocketURLParsing: + def test_defaults(self): + pool = aioredis.ConnectionPool.from_url("unix:///socket") + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + } + + @skip_if_server_version_lt(REDIS_6_VERSION) + def test_username(self): + pool = aioredis.ConnectionPool.from_url("unix://myuser:@/socket") + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "myuser", + } + + @skip_if_server_version_lt(REDIS_6_VERSION) + def test_quoted_username(self): + pool = aioredis.ConnectionPool.from_url( + "unix://%2Fmyuser%2F%2B name%3D%24+:@/socket" + ) + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "username": "/myuser/+ name=$+", + } + + def test_password(self): + pool = aioredis.ConnectionPool.from_url("unix://:mypassword@/socket") + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "mypassword", + } + + def test_quoted_password(self): + pool = aioredis.ConnectionPool.from_url( + "unix://:%2Fmypass%2F%2B word%3D%24+@/socket" + ) + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "password": "/mypass/+ word=$+", + } + + def test_quoted_path(self): + pool = aioredis.ConnectionPool.from_url( + "unix://:mypassword@/my%2Fpath%2Fto%2F..%2F+_%2B%3D%24ocket" + ) + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/my/path/to/../+_+=$ocket", + "password": "mypassword", + } + + def test_db_as_argument(self): + pool = aioredis.ConnectionPool.from_url("unix:///socket", db=1) + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 1, + } + + def test_db_in_querystring(self): + pool = aioredis.ConnectionPool.from_url("unix:///socket?db=2", db=1) + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == { + "path": "/socket", + "db": 2, + } + + def test_client_name_in_querystring(self): + pool = aioredis.ConnectionPool.from_url( + "redis://location?client_name=test-client" + ) + assert pool.connection_kwargs["client_name"] == "test-client" + + def test_extra_querystring_options(self): + pool = aioredis.ConnectionPool.from_url("unix:///socket?a=1&b=2") + assert pool.connection_class == aioredis.UnixDomainSocketConnection + assert pool.connection_kwargs == {"path": "/socket", "a": "1", "b": "2"} + + +class TestSSLConnectionURLParsing: + def test_host(self): + pool = aioredis.ConnectionPool.from_url("rediss://my.host") + assert pool.connection_class == aioredis.SSLConnection + assert pool.connection_kwargs == { + "host": "my.host", + } + + def test_cert_reqs_options(self): + import ssl + + class DummyConnectionPool(aioredis.ConnectionPool): + def get_connection(self, *args, **kwargs): + return self.make_connection() + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=none") + assert pool.get_connection("_").cert_reqs == ssl.CERT_NONE + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=optional") + assert pool.get_connection("_").cert_reqs == ssl.CERT_OPTIONAL + + pool = DummyConnectionPool.from_url("rediss://?ssl_cert_reqs=required") + assert pool.get_connection("_").cert_reqs == ssl.CERT_REQUIRED + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=False") + assert pool.get_connection("_").check_hostname is False + + pool = DummyConnectionPool.from_url("rediss://?ssl_check_hostname=True") + assert pool.get_connection("_").check_hostname is True + + +class TestConnection: + async def test_on_connect_error(self): + """ + An error in Connection.on_connect should disconnect from the server + see for details: https://github.com/andymccurdy/redis-py/issues/368 + """ + # this assumes the Redis server being tested against doesn't have + # 9999 databases ;) + bad_connection = aioredis.Redis(db=9999) + # an error should be raised on connect + with pytest.raises(aioredis.RedisError): + await bad_connection.info() + pool = bad_connection.connection_pool + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + async def test_busy_loading_disconnects_socket(self, r): + """ + If Redis raises a LOADING error, the connection should be + disconnected and a BusyLoadingError raised + """ + with pytest.raises(aioredis.BusyLoadingError): + await r.execute_command("DEBUG", "ERROR", "LOADING fake message") + assert not r.connection._reader + + @skip_if_server_version_lt("2.8.8") + async def test_busy_loading_from_pipeline_immediate_command(self, r): + """ + BusyLoadingErrors should raise from Pipelines that execute a + command immediately, like WATCH does. + """ + pipe = r.pipeline() + with pytest.raises(aioredis.BusyLoadingError): + await pipe.immediate_execute_command( + "DEBUG", "ERROR", "LOADING fake message" + ) + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + async def test_busy_loading_from_pipeline(self, r): + """ + BusyLoadingErrors should be raised from a pipeline execution + regardless of the raise_on_error flag. + """ + pipe = r.pipeline() + pipe.execute_command("DEBUG", "ERROR", "LOADING fake message") + with pytest.raises(aioredis.BusyLoadingError): + await pipe.execute() + pool = r.connection_pool + assert not pipe.connection + assert len(pool._available_connections) == 1 + assert not pool._available_connections[0]._reader + + @skip_if_server_version_lt("2.8.8") + async def test_read_only_error(self, r): + """READONLY errors get turned in ReadOnlyError exceptions""" + with pytest.raises(aioredis.ReadOnlyError): + await r.execute_command("DEBUG", "ERROR", "READONLY blah blah") + + def test_connect_from_url_tcp(self): + connection = aioredis.Redis.from_url("redis://localhost") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "Connection", + "host=localhost,port=6379,db=0", + ) + + def test_connect_from_url_unix(self): + connection = aioredis.Redis.from_url("unix:///path/to/socket") + pool = connection.connection_pool + + assert re.match("(.*)<(.*)<(.*)>>", repr(pool)).groups() == ( + "ConnectionPool", + "UnixDomainSocketConnection", + "path=/path/to/socket,db=0", + ) + + async def test_connect_no_auth_supplied_when_required(self, r): + """ + AuthenticationError should be raised when the server requires a + password but one isn't supplied. + """ + with pytest.raises(aioredis.AuthenticationError): + await r.execute_command( + "DEBUG", "ERROR", "ERR Client sent AUTH, but no password is set" + ) + + async def test_connect_invalid_password_supplied(self, r): + """AuthenticationError should be raised when sending the wrong password""" + with pytest.raises(aioredis.AuthenticationError): + await r.execute_command("DEBUG", "ERROR", "ERR invalid password") + + +class TestMultiConnectionClient: + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, request, event_loop, single_connection_client=False + ) + + async def test_multi_connection_command(self, r): + assert not r.connection + assert await r.set("a", "123") + assert (await r.get("a")) == b"123" + + +class TestHealthCheck: + interval = 60 + + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, request, event_loop, health_check_interval=self.interval + ) + + def assert_interval_advanced(self, connection): + diff = connection.next_health_check - time.time() + assert self.interval > diff > (self.interval - 1) + + async def test_health_check_runs(self, r): + r.connection.next_health_check = time.time() - 1 + await r.connection.check_health() + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_invokes_health_check(self, r): + # invoke a command to make sure the connection is entirely setup + await r.get("foo") + r.connection.next_health_check = time.time() + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + m.assert_called_with("PING", check_health=False) + + self.assert_interval_advanced(r.connection) + + async def test_arbitrary_command_advances_next_health_check(self, r): + await r.get("foo") + next_health_check = r.connection.next_health_check + await r.get("foo") + assert next_health_check < r.connection.next_health_check + + async def test_health_check_not_invoked_within_interval(self, r): + await r.get("foo") + with mock.patch.object( + r.connection, "send_command", wraps=r.connection.send_command + ) as m: + await r.get("foo") + ping_call_spec = (("PING",), {"check_health": False}) + assert ping_call_spec not in m.call_args_list + + async def test_health_check_in_pipeline(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_transaction(self, r): + async with r.pipeline(transaction=True) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + responses = await pipe.set("foo", "bar").get("foo").execute() + m.assert_any_call("PING", check_health=False) + assert responses == [True, b"bar"] + + async def test_health_check_in_watched_pipeline(self, r): + await r.set("foo", "bar") + async with r.pipeline(transaction=False) as pipe: + pipe.connection = await pipe.connection_pool.get_connection("_") + pipe.connection.next_health_check = 0 + with mock.patch.object( + pipe.connection, "send_command", wraps=pipe.connection.send_command + ) as m: + await pipe.watch("foo") + # the health check should be called when watching + m.assert_called_with("PING", check_health=False) + self.assert_interval_advanced(pipe.connection) + assert await pipe.get("foo") == b"bar" + + # reset the mock to clear the call list and schedule another + # health check + m.reset_mock() + pipe.connection.next_health_check = 0 + + pipe.multi() + responses = await pipe.set("foo", "not-bar").get("foo").execute() + assert responses == [True, b"not-bar"] + m.assert_any_call("PING", check_health=False) + + async def test_health_check_in_pubsub_before_subscribe(self, r): + """A health check happens before the first [p]subscribe""" + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + assert not p.subscribed + await p.subscribe("foo") + # the connection is not yet in pubsub mode, so the normal + # ping/pong within connection.send_command should check + # the health of the connection + m.assert_any_call("PING", check_health=False) + self.assert_interval_advanced(p.connection) + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + + async def test_health_check_in_pubsub_after_subscribed(self, r): + """ + Pubsub can handle a new subscribe when it's time to check the + connection health + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + p.connection.next_health_check = 0 + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + # because we weren't subscribed when sending the subscribe + # message to 'foo', the connection's standard check_health ran + # prior to subscribing. + m.assert_any_call("PING", check_health=False) + + p.connection.next_health_check = 0 + m.reset_mock() + + await p.subscribe("bar") + # the second subscribe issues exactly only command (the subscribe) + # and the health check is not invoked + m.assert_called_once_with("SUBSCRIBE", "bar", check_health=False) + + # since no message has been read since the health check was + # reset, it should still be 0 + assert p.connection.next_health_check == 0 + + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + assert await wait_for_message(p) is None + # now that the connection is subscribed, the pubsub health + # check should have taken over and include the HEALTH_CHECK_MESSAGE + m.assert_any_call("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) + + async def test_health_check_in_pubsub_poll(self, r): + """ + Polling a pubsub connection that's subscribed will regularly + check the connection's health. + """ + p = r.pubsub() + p.connection = await p.connection_pool.get_connection("_") + with mock.patch.object( + p.connection, "send_command", wraps=p.connection.send_command + ) as m: + await p.subscribe("foo") + subscribe_message = await wait_for_message(p) + assert subscribe_message["type"] == "subscribe" + self.assert_interval_advanced(p.connection) + + # polling the connection before the health check interval + # doesn't result in another health check + m.reset_mock() + next_health_check = p.connection.next_health_check + assert await wait_for_message(p) is None + assert p.connection.next_health_check == next_health_check + m.assert_not_called() + + # reset the health check and poll again + # we should not receive a pong message, but the next_health_check + # should be advanced + p.connection.next_health_check = 0 + assert await wait_for_message(p) is None + m.assert_called_with("PING", p.HEALTH_CHECK_MESSAGE, check_health=False) + self.assert_interval_advanced(p.connection) diff --git a/tests/test_encoding.py b/tests/test_encoding.py new file mode 100644 index 000000000..c7caa45ea --- /dev/null +++ b/tests/test_encoding.py @@ -0,0 +1,130 @@ +import pytest + +import aioredis +from aioredis.connection import Connection + +from .conftest import _get_client + +pytestmark = pytest.mark.asyncio + + +class TestEncoding: + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, + request=request, + event_loop=event_loop, + decode_responses=True, + ) + + @pytest.fixture() + async def r_no_decode(self, request, event_loop): + return await _get_client( + aioredis.Redis, + request=request, + event_loop=event_loop, + decode_responses=False, + ) + + async def test_simple_encoding(self, r_no_decode: aioredis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r_no_decode.set("unicode-string", unicode_string.encode("utf-8")) + cached_val = await r_no_decode.get("unicode-string") + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_simple_encoding_and_decoding(self, r: aioredis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + await r.set("unicode-string", unicode_string) + cached_val = await r.get("unicode-string") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_memoryview_encoding(self, r_no_decode: aioredis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r_no_decode.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r_no_decode.get("unicode-string-memoryview") + # The cached value won't be a memoryview because it's a copy from Redis + assert isinstance(cached_val, bytes) + assert unicode_string == cached_val.decode("utf-8") + + async def test_memoryview_encoding_and_decoding(self, r: aioredis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + unicode_string_view = memoryview(unicode_string.encode("utf-8")) + await r.set("unicode-string-memoryview", unicode_string_view) + cached_val = await r.get("unicode-string-memoryview") + assert isinstance(cached_val, str) + assert unicode_string == cached_val + + async def test_list_encoding(self, r: aioredis.Redis): + unicode_string = chr(3456) + "abcd" + chr(3421) + result = [unicode_string, unicode_string, unicode_string] + await r.rpush("a", *result) + assert await r.lrange("a", 0, -1) == result + + +class TestEncodingErrors: + async def test_ignore(self, request, event_loop): + r = await _get_client( + aioredis.Redis, + request=request, + event_loop=event_loop, + decode_responses=True, + encoding_errors="ignore", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo" + + async def test_replace(self, request, event_loop): + r = await _get_client( + aioredis.Redis, + request=request, + event_loop=event_loop, + decode_responses=True, + encoding_errors="replace", + ) + await r.set("a", b"foo\xff") + assert await r.get("a") == "foo\ufffd" + + +class TestMemoryviewsAreNotPacked: + def test_memoryviews_are_not_packed(self): + c = Connection() + arg = memoryview(b"some_arg") + arg_list = ["SOME_COMMAND", arg] + cmd = c.pack_command(*arg_list) + assert cmd[1] is arg + cmds = c.pack_commands([arg_list, arg_list]) + assert cmds[1] is arg + assert cmds[3] is arg + + +class TestCommandsAreNotEncoded: + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, request=request, event_loop=event_loop, encoding="utf-16" + ) + + async def test_basic_command(self, r: aioredis.Redis): + await r.set("hello", "world") + + +class TestInvalidUserInput: + async def test_boolean_fails(self, r: aioredis.Redis): + with pytest.raises(aioredis.DataError): + await r.set("a", True) + + async def test_none_fails(self, r: aioredis.Redis): + with pytest.raises(aioredis.DataError): + await r.set("a", None) + + async def test_user_type_fails(self, r: aioredis.Redis): + class Foo: + def __str__(self): + return "Foo" + + with pytest.raises(aioredis.DataError): + await r.set("a", Foo()) diff --git a/tests/test_lock.py b/tests/test_lock.py new file mode 100644 index 000000000..e2d45d744 --- /dev/null +++ b/tests/test_lock.py @@ -0,0 +1,237 @@ +import time + +import pytest + +from aioredis.client import Redis +from aioredis.exceptions import LockError, LockNotOwnedError +from aioredis.lock import Lock + +from .conftest import _get_client + +pytestmark = pytest.mark.asyncio + + +class TestLock: + @pytest.fixture() + async def r_decoded(self, request, event_loop): + return await _get_client( + Redis, request=request, event_loop=event_loop, decode_responses=True + ) + + def get_lock(self, redis, *args, **kwargs): + kwargs["lock_class"] = Lock + return redis.lock(*args, **kwargs) + + async def test_lock(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + assert await r.get("foo") == lock.local.token + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + + async def test_lock_token(self, r): + lock = self.get_lock(r, "foo") + await self._test_lock_token(r, lock) + + async def test_lock_token_thread_local_false(self, r): + lock = self.get_lock(r, "foo", thread_local=False) + await self._test_lock_token(r, lock) + + async def _test_lock_token(self, r, lock): + assert await lock.acquire(blocking=False, token="test") + assert await r.get("foo") == b"test" + assert lock.local.token == b"test" + assert await r.ttl("foo") == -1 + await lock.release() + assert await r.get("foo") is None + assert lock.local.token is None + + async def test_locked(self, r): + lock = self.get_lock(r, "foo") + assert await lock.locked() is False + await lock.acquire(blocking=False) + assert await lock.locked() is True + await lock.release() + assert await lock.locked() is False + + async def _test_owned(self, client): + lock = self.get_lock(client, "foo") + assert await lock.owned() is False + await lock.acquire(blocking=False) + assert await lock.owned() is True + await lock.release() + assert await lock.owned() is False + + lock2 = self.get_lock(client, "foo") + assert await lock.owned() is False + assert await lock2.owned() is False + await lock2.acquire(blocking=False) + assert await lock.owned() is False + assert await lock2.owned() is True + await lock2.release() + assert await lock.owned() is False + assert await lock2.owned() is False + + async def test_owned(self, r): + await self._test_owned(r) + + async def test_owned_with_decoded_responses(self, r_decoded): + await self._test_owned(r_decoded) + + async def test_competing_locks(self, r): + lock1 = self.get_lock(r, "foo") + lock2 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + assert not await lock2.acquire(blocking=False) + await lock1.release() + assert await lock2.acquire(blocking=False) + assert not await lock1.acquire(blocking=False) + await lock2.release() + + async def test_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8 < (await r.ttl("foo")) <= 10 + await lock.release() + + async def test_float_timeout(self, r): + lock = self.get_lock(r, "foo", timeout=9.5) + assert await lock.acquire(blocking=False) + assert 8 < (await r.pttl("foo")) <= 9500 + await lock.release() + + async def test_blocking_timeout(self, r): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + bt = 0.2 + sleep = 0.05 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = time.monotonic() + assert not await lock2.acquire() + # The elapsed duration should be less than the total blocking_timeout + assert bt > (time.monotonic() - start) > bt - sleep + await lock1.release() + + async def test_context_manager(self, r): + # blocking_timeout prevents a deadlock if the lock can't be acquired + # for some reason + async with self.get_lock(r, "foo", blocking_timeout=0.2) as lock: + assert await r.get("foo") == lock.local.token + assert await r.get("foo") is None + + async def test_context_manager_raises_when_locked_not_acquired(self, r): + await r.set("foo", "bar") + with pytest.raises(LockError): + async with self.get_lock(r, "foo", blocking_timeout=0.1): + pass + + async def test_high_sleep_small_blocking_timeout(self, r): + lock1 = self.get_lock(r, "foo") + assert await lock1.acquire(blocking=False) + sleep = 60 + bt = 1 + lock2 = self.get_lock(r, "foo", sleep=sleep, blocking_timeout=bt) + start = time.monotonic() + assert not await lock2.acquire() + # the elapsed timed is less than the blocking_timeout as the lock is + # unattainable given the sleep/blocking_timeout configuration + assert bt > (time.monotonic() - start) + await lock1.release() + + async def test_releasing_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo") + with pytest.raises(LockError): + await lock.release() + + async def test_releasing_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo") + await lock.acquire(blocking=False) + # manually change the token + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.release() + # even though we errored, the token is still cleared + assert lock.local.token is None + + async def test_extend_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extend_lock_replace_ttl(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10, replace_ttl=True) + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_extend_lock_float(self, r): + lock = self.get_lock(r, "foo", timeout=10.0) + assert await lock.acquire(blocking=False) + assert 8000 < (await r.pttl("foo")) <= 10000 + assert await lock.extend(10.0) + assert 16000 < (await r.pttl("foo")) <= 20000 + await lock.release() + + async def test_extending_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.extend(10) + + async def test_extending_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.extend(10) + await lock.release() + + async def test_extending_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.extend(10) + + async def test_reacquire_lock(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + assert await r.pexpire("foo", 5000) + assert await r.pttl("foo") <= 5000 + assert await lock.reacquire() + assert 8000 < (await r.pttl("foo")) <= 10000 + await lock.release() + + async def test_reacquiring_unlocked_lock_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + with pytest.raises(LockError): + await lock.reacquire() + + async def test_reacquiring_lock_with_no_timeout_raises_error(self, r): + lock = self.get_lock(r, "foo") + assert await lock.acquire(blocking=False) + with pytest.raises(LockError): + await lock.reacquire() + await lock.release() + + async def test_reacquiring_lock_no_longer_owned_raises_error(self, r): + lock = self.get_lock(r, "foo", timeout=10) + assert await lock.acquire(blocking=False) + await r.set("foo", "a") + with pytest.raises(LockNotOwnedError): + await lock.reacquire() + + +class TestLockClassSelection: + def test_lock_class_argument(self, r): + class MyLock: + def __init__(self, *args, **kwargs): + + pass + + lock = r.lock("foo", lock_class=MyLock) + assert type(lock) == MyLock diff --git a/tests/test_monitor.py b/tests/test_monitor.py new file mode 100644 index 000000000..c266939fc --- /dev/null +++ b/tests/test_monitor.py @@ -0,0 +1,54 @@ +import pytest + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +class TestMonitor: + async def test_wait_command_not_found(self, r): + """Make sure the wait_for_command func works when command is not found""" + async with r.monitor() as m: + response = await wait_for_command(r, m, "nothing") + assert response is None + + async def test_response_values(self, r): + async with r.monitor() as m: + await r.ping() + response = await wait_for_command(r, m, "PING") + assert isinstance(response["time"], float) + assert response["db"] == 9 + assert response["client_type"] in ("tcp", "unix") + assert isinstance(response["client_address"], str) + assert isinstance(response["client_port"], str) + assert response["command"] == "PING" + + async def test_command_with_quoted_key(self, r): + async with r.monitor() as m: + await r.get('foo"bar') + response = await wait_for_command(r, m, 'GET foo"bar') + assert response["command"] == 'GET foo"bar' + + async def test_command_with_binary_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\x92") + assert response["command"] == "GET foo\\x92" + + async def test_command_with_escaped_data(self, r): + async with r.monitor() as m: + byte_string = b"foo\\x92" + await r.get(byte_string) + response = await wait_for_command(r, m, "GET foo\\\\x92") + assert response["command"] == "GET foo\\\\x92" + + async def test_lua_script(self, r): + async with r.monitor() as m: + script = 'return redis.call("GET", "foo")' + assert await r.eval(script, 0) is None + response = await wait_for_command(r, m, "GET foo") + assert response["command"] == "GET foo" + assert response["client_type"] == "lua" + assert response["client_address"] == "lua" + assert response["client_port"] == "" diff --git a/tests/test_multiprocessing.py b/tests/test_multiprocessing.py new file mode 100644 index 000000000..509d06a11 --- /dev/null +++ b/tests/test_multiprocessing.py @@ -0,0 +1,185 @@ +import asyncio +import contextlib +import multiprocessing + +import pytest + +import aioredis +from aioredis.connection import Connection, ConnectionPool +from aioredis.exceptions import ConnectionError + +from .conftest import _get_client + +pytestmark = pytest.mark.asyncio + + +@contextlib.contextmanager +async def exit_callback(callback, *args): + try: + yield + finally: + await callback(*args) + + +@pytest.mark.xfail() +class TestMultiprocessing: + # Test connection sharing between forks. + # See issue #1085 for details. + + # use a multi-connection client as that's the only type that is + # actually fork/process-safe + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, + event_loop=event_loop, + request=request, + single_connection_client=False, + ) + + async def test_close_connection_in_child(self, master_host): + """ + A connection owned by a parent and closed by a child doesn't + destroy the file descriptors so a parent can still use it. + """ + conn = Connection(host=master_host) + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(conn): + async def atarget(conn): + await conn.send_command("ping") + assert conn.read_response() == b"PONG" + await conn.disconnect() + + asyncio.get_event_loop().run_until_complete(atarget(conn)) + + proc = multiprocessing.Process(target=target, args=(conn,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + # The connection was created in the parent but disconnected in the + # child. The child called socket.close() but did not call + # socket.shutdown() because it wasn't the "owning" process. + # Therefore the connection still works in the parent. + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + async def test_close_connection_in_parent(self, master_host): + """ + A connection owned by a parent is unusable by a child if the parent + (the owning process) closes the connection. + """ + conn = Connection(host=master_host) + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(conn, ev): + ev.wait() + # the parent closed the connection. because it also created the + # connection, the connection is shutdown and the child + # cannot use it. + with pytest.raises(ConnectionError): + asyncio.get_event_loop().run_until_complete(conn.send_command("ping")) + + ev = multiprocessing.Event() + proc = multiprocessing.Process(target=target, args=(conn, ev)) + proc.start() + + await conn.disconnect() + ev.set() + + proc.join(3) + assert proc.exitcode == 0 + + @pytest.mark.parametrize("max_connections", [1, 2, None]) + async def test_pool(self, max_connections, master_host): + """ + A child will create its own connections when using a pool created + by a parent. + """ + pool = ConnectionPool.from_url( + f"redis://{master_host}", max_connections=max_connections + ) + + conn = await pool.get_connection("ping") + main_conn_pid = conn.pid + async with exit_callback(pool.release, conn): + await conn.send_command("ping") + assert await conn.read_response() == b"PONG" + + def target(pool): + async def atarget(pool): + async with exit_callback(pool.disconnect): + conn = await pool.get_connection("ping") + assert conn.pid != main_conn_pid + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + asyncio.get_event_loop().run_until_complete(atarget(pool)) + + proc = multiprocessing.Process(target=target, args=(pool,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + # Check that connection is still alive after fork process has exited + # and disconnected the connections in its pool + conn = pool.get_connection("ping") + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + @pytest.mark.parametrize("max_connections", [1, 2, None]) + async def test_close_pool_in_main(self, max_connections, master_host): + """ + A child process that uses the same pool as its parent isn't affected + when the parent disconnects all connections within the pool. + """ + pool = ConnectionPool.from_url( + f"redis://{master_host}", max_connections=max_connections + ) + + conn = await pool.get_connection("ping") + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + def target(pool, disconnect_event): + async def atarget(pool, disconnect_event): + conn = await pool.get_connection("ping") + async with exit_callback(pool.release, conn): + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + disconnect_event.wait() + assert await conn.send_command("ping") is None + assert await conn.read_response() == b"PONG" + + asyncio.get_event_loop().run_until_complete(atarget(pool, disconnect_event)) + + ev = multiprocessing.Event() + + proc = multiprocessing.Process(target=target, args=(pool, ev)) + proc.start() + + await pool.disconnect() + ev.set() + proc.join(3) + assert proc.exitcode == 0 + + async def test_aioredis_client(self, r): + """A aioredis client created in a parent can also be used in a child""" + assert await r.ping() is True + + def target(client): + run = asyncio.get_event_loop().run_until_complete + assert run(client.ping()) is True + del client + + proc = multiprocessing.Process(target=target, args=(r,)) + proc.start() + proc.join(3) + assert proc.exitcode == 0 + + assert await r.ping() is True diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 000000000..69de795c4 --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,362 @@ +import pytest + +import aioredis + +from .conftest import wait_for_command + +pytestmark = pytest.mark.asyncio + + +class TestPipeline: + async def test_pipeline_is_true(self, r): + """Ensure pipeline instances are not false-y""" + async with r.pipeline() as pipe: + assert pipe + + async def test_pipeline(self, r): + async with r.pipeline() as pipe: + ( + pipe.set("a", "a1") + .get("a") + .zadd("z", {"z1": 1}) + .zadd("z", {"z2": 4}) + .zincrby("z", 1, "z1") + .zrange("z", 0, 5, withscores=True) + ) + assert await pipe.execute() == [ + True, + b"a1", + True, + True, + 2.0, + [(b"z1", 2.0), (b"z2", 4)], + ] + + async def test_pipeline_memoryview(self, r): + async with r.pipeline() as pipe: + (pipe.set("a", memoryview(b"a1")).get("a")) + assert await pipe.execute() == [ + True, + b"a1", + ] + + async def test_pipeline_length(self, r): + async with r.pipeline() as pipe: + # Initially empty. + assert len(pipe) == 0 + + # Fill 'er up! + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert len(pipe) == 3 + + # Execute calls reset(), so empty once again. + await pipe.execute() + assert len(pipe) == 0 + + async def test_pipeline_no_transaction(self, r): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", "a1").set("b", "b1").set("c", "c1") + assert await pipe.execute() == [True, True, True] + assert await r.get("a") == b"a1" + assert await r.get("b") == b"b1" + assert await r.get("c") == b"c1" + + async def test_pipeline_no_transaction_watch(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + pipe.multi() + pipe.set("a", int(a) + 1) + assert await pipe.execute() == [True] + + async def test_pipeline_no_transaction_watch_failure(self, r): + await r.set("a", 0) + + async with r.pipeline(transaction=False) as pipe: + await pipe.watch("a") + a = await pipe.get("a") + + await r.set("a", "bad") + + pipe.multi() + pipe.set("a", int(a) + 1) + + with pytest.raises(aioredis.WatchError): + await pipe.execute() + + assert await r.get("a") == b"bad" + + async def test_exec_error_in_response(self, r): + """ + an invalid pipeline command at exec time adds the exception instance + to the list of returned values + """ + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + result = await pipe.execute(raise_on_error=False) + + assert result[0] + assert await r.get("a") == b"1" + assert result[1] + assert await r.get("b") == b"2" + + # we can't lpush to a key that's a string value, so this should + # be a ResponseError exception + assert isinstance(result[2], aioredis.ResponseError) + assert await r.get("c") == b"a" + + # since this isn't a transaction, the other commands after the + # error are still executed + assert result[3] + assert await r.get("d") == b"4" + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_exec_error_raised(self, r): + await r.set("c", "a") + async with r.pipeline() as pipe: + pipe.set("a", 1).set("b", 2).lpush("c", 3).set("d", 4) + with pytest.raises(aioredis.ResponseError) as ex: + await pipe.execute() + assert str(ex.value).startswith( + "Command # 3 (LPUSH c 3) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_transaction_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline() as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + async def test_pipeline_with_empty_error_command(self, r): + """ + Commands with custom EMPTY_ERROR functionality return their default + values in the pipeline no matter the raise_on_error preference + """ + for error_switch in (True, False): + async with r.pipeline(transaction=False) as pipe: + pipe.set("a", 1).mget([]).set("c", 3) + result = await pipe.execute(raise_on_error=error_switch) + + assert result[0] + assert result[1] == [] + assert result[2] + + async def test_parse_error_raised(self, r): + async with r.pipeline() as pipe: + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(aioredis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_parse_error_raised_transaction(self, r): + async with r.pipeline() as pipe: + pipe.multi() + # the zrem is invalid because we don't pass any keys to it + pipe.set("a", 1).zrem("b").set("b", 2) + with pytest.raises(aioredis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 2 (ZREM b) of " "pipeline caused error: " + ) + + # make sure the pipe was restored to a working state + assert await pipe.set("z", "zzz").execute() == [True] + assert await r.get("z") == b"zzz" + + async def test_watch_succeed(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + async def test_watch_failure(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + pipe.get("a") + with pytest.raises(aioredis.WatchError): + await pipe.execute() + + assert not pipe.watching + + async def test_watch_failure_in_empty_transaction(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + pipe.multi() + with pytest.raises(aioredis.WatchError): + await pipe.execute() + + assert not pipe.watching + + async def test_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + await r.set("b", 3) + await pipe.unwatch() + assert not pipe.watching + pipe.get("a") + assert await pipe.execute() == [b"1"] + + async def test_watch_exec_no_unwatch(self, r): + await r.set("a", 1) + await r.set("b", 2) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a", "b") + assert pipe.watching + a_value = await pipe.get("a") + b_value = await pipe.get("b") + assert a_value == b"1" + assert b_value == b"2" + pipe.multi() + pipe.set("c", 3) + assert await pipe.execute() == [True] + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is None, "should not send UNWATCH" + + async def test_watch_reset_unwatch(self, r): + await r.set("a", 1) + + async with r.monitor() as m: + async with r.pipeline() as pipe: + await pipe.watch("a") + assert pipe.watching + await pipe.reset() + assert not pipe.watching + + unwatch_command = await wait_for_command(r, m, "UNWATCH") + assert unwatch_command is not None + assert unwatch_command["command"] == "UNWATCH" + + async def test_transaction_callable(self, r): + await r.set("a", 1) + await r.set("b", 2) + has_run = [] + + async def my_transaction(pipe): + a_value = await pipe.get("a") + assert a_value in (b"1", b"2") + b_value = await pipe.get("b") + assert b_value == b"2" + + # silly run-once code... incr's "a" so WatchError should be raised + # forcing this all to run again. this should incr "a" once to "2" + if not has_run: + await r.incr("a") + has_run.append("it has") + + pipe.multi() + pipe.set("c", int(a_value) + int(b_value)) + + result = await r.transaction(my_transaction, "a", "b") + assert result == [True] + assert await r.get("c") == b"4" + + async def test_transaction_callable_returns_value_from_callable(self, r): + async def callback(pipe): + # No need to do anything here since we only want the return value + return "a" + + res = await r.transaction(callback, "my-key", value_from_callable=True) + assert res == "a" + + async def test_exec_error_in_no_transaction_pipeline(self, r): + await r.set("a", 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen("a") + pipe.expire("a", 100) + + with pytest.raises(aioredis.ResponseError) as ex: + await pipe.execute() + + assert str(ex.value).startswith( + "Command # 1 (LLEN a) of " "pipeline caused error: " + ) + + assert await r.get("a") == b"1" + + async def test_exec_error_in_no_transaction_pipeline_unicode_command(self, r): + key = chr(3456) + "abcd" + chr(3421) + await r.set(key, 1) + async with r.pipeline(transaction=False) as pipe: + pipe.llen(key) + pipe.expire(key, 100) + + with pytest.raises(aioredis.ResponseError) as ex: + await pipe.execute() + + expected = "Command # 1 (LLEN %s) of pipeline caused error: " % key + assert str(ex.value).startswith(expected) + + assert await r.get(key) == b"1" + + async def test_pipeline_with_bitfield(self, r): + async with r.pipeline() as pipe: + pipe.set("a", "1") + bf = pipe.bitfield("b") + pipe2 = ( + bf.set("u8", 8, 255) + .get("u8", 0) + .get("u4", 8) # 1111 + .get("u4", 12) # 1111 + .get("u4", 13) # 1110 + .execute() + ) + pipe.get("a") + response = await pipe.execute() + + assert pipe == pipe2 + assert response == [True, [0, 0, 15, 15, 14], b"1"] diff --git a/tests/test_pubsub.py b/tests/test_pubsub.py new file mode 100644 index 000000000..fcb023f97 --- /dev/null +++ b/tests/test_pubsub.py @@ -0,0 +1,575 @@ +import asyncio +import threading +from unittest import mock + +import async_timeout +import pytest + +import aioredis +from aioredis.exceptions import ConnectionError + +from .conftest import _get_client, skip_if_server_version_lt + +pytestmark = pytest.mark.asyncio(forbid_global_loop=True) + + +async def wait_for_message(pubsub, timeout=0.1, ignore_subscribe_messages=False): + try: + async with async_timeout.timeout(timeout): + while True: + message = await pubsub.get_message( + ignore_subscribe_messages=ignore_subscribe_messages + ) + if message is not None: + return message + await asyncio.sleep(0.01) + except asyncio.TimeoutError: + return None + + +def make_message(type, channel, data, pattern=None): + return { + "type": type, + "pattern": pattern and pattern.encode("utf-8") or None, + "channel": channel and channel.encode("utf-8") or None, + "data": data.encode("utf-8") if isinstance(data, str) else data, + } + + +def make_subscribe_test_data(pubsub, type): + if type == "channel": + return { + "p": pubsub, + "sub_type": "subscribe", + "unsub_type": "unsubscribe", + "sub_func": pubsub.subscribe, + "unsub_func": pubsub.unsubscribe, + "keys": ["foo", "bar", "uni" + chr(4456) + "code"], + } + elif type == "pattern": + return { + "p": pubsub, + "sub_type": "psubscribe", + "unsub_type": "punsubscribe", + "sub_func": pubsub.psubscribe, + "unsub_func": pubsub.punsubscribe, + "keys": ["f*", "b*", "uni" + chr(4456) + "*"], + } + assert False, "invalid subscribe type: %s" % type + + +class TestPubSubSubscribeUnsubscribe: + async def _test_subscribe_unsubscribe( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + for key in keys: + assert await unsub_func(key) is None + + # should be a message for each channel/pattern we just unsubscribed + # from + for i, key in enumerate(keys): + i = len(keys) - 1 - i + assert await wait_for_message(p) == make_message(unsub_type, key, i) + + async def test_channel_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribe_unsubscribe(**kwargs) + + async def test_pattern_subscribe_unsubscribe(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribe_unsubscribe(**kwargs) + + async def _test_resubscribe_on_reconnection( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + for key in keys: + assert await sub_func(key) is None + + # should be a message for each channel/pattern we just subscribed to + for i, key in enumerate(keys): + assert await wait_for_message(p) == make_message(sub_type, key, i + 1) + + # manually disconnect + await p.connection.disconnect() + + # calling get_message again reconnects and resubscribes + # note, we may not re-subscribe to channels in exactly the same order + # so we have to do some extra checks to make sure we got them all + messages = [] + for i in range(len(keys)): + messages.append(await wait_for_message(p)) + + unique_channels = set() + assert len(messages) == len(keys) + for i, message in enumerate(messages): + assert message["type"] == sub_type + assert message["data"] == i + 1 + assert isinstance(message["channel"], bytes) + channel = message["channel"].decode("utf-8") + unique_channels.add(channel) + + assert len(unique_channels) == len(keys) + for channel in unique_channels: + assert channel in keys + + async def test_resubscribe_to_channels_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def test_resubscribe_to_patterns_on_reconnection(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_resubscribe_on_reconnection(**kwargs) + + async def _test_subscribed_property( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + + assert p.subscribed is False + await sub_func(keys[0]) + # we're now subscribed even though we haven't processed the + # reply from the server just yet + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + # we're still subscribed + assert p.subscribed is True + + # unsubscribe from all channels + await unsub_func() + # we're still technically subscribed until we process the + # response messages from the server + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # now we're no longer subscribed as no more messages can be delivered + # to any channels we were listening to + assert p.subscribed is False + + # subscribing again flips the flag back + await sub_func(keys[0]) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, keys[0], 1) + + # unsubscribe again + await unsub_func() + assert p.subscribed is True + # subscribe to another channel before reading the unsubscribe response + await sub_func(keys[1]) + assert p.subscribed is True + # read the unsubscribe for key1 + assert await wait_for_message(p) == make_message(unsub_type, keys[0], 0) + # we're still subscribed to key2, so subscribed should still be True + assert p.subscribed is True + # read the key2 subscribe message + assert await wait_for_message(p) == make_message(sub_type, keys[1], 1) + await unsub_func() + # haven't read the message yet, so we're still subscribed + assert p.subscribed is True + assert await wait_for_message(p) == make_message(unsub_type, keys[1], 0) + # now we're finally unsubscribed + assert p.subscribed is False + + async def test_subscribe_property_with_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_subscribed_property(**kwargs) + + async def test_subscribe_property_with_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_subscribed_property(**kwargs) + + async def test_ignore_all_subscribe_messages(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + assert await wait_for_message(p) is None + assert p.subscribed is False + + async def test_ignore_individual_subscribe_messages(self, r): + p = r.pubsub() + + checks = ( + (p.subscribe, "foo"), + (p.unsubscribe, "foo"), + (p.psubscribe, "f*"), + (p.punsubscribe, "f*"), + ) + + assert p.subscribed is False + for func, channel in checks: + assert await func(channel) is None + assert p.subscribed is True + message = await wait_for_message(p, ignore_subscribe_messages=True) + assert message is None + assert p.subscribed is False + + async def test_sub_unsub_resub_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_resub(**kwargs) + + async def test_sub_unsub_resub_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_resub(**kwargs) + + async def _test_sub_unsub_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func(key) + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + async def test_sub_unsub_all_resub_channels(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "channel") + await self._test_sub_unsub_all_resub(**kwargs) + + async def test_sub_unsub_all_resub_patterns(self, r): + kwargs = make_subscribe_test_data(r.pubsub(), "pattern") + await self._test_sub_unsub_all_resub(**kwargs) + + async def _test_sub_unsub_all_resub( + self, p, sub_type, unsub_type, sub_func, unsub_func, keys + ): + # https://github.com/andymccurdy/redis-py/issues/764 + key = keys[0] + await sub_func(key) + await unsub_func() + await sub_func(key) + assert p.subscribed is True + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert await wait_for_message(p) == make_message(unsub_type, key, 0) + assert await wait_for_message(p) == make_message(sub_type, key, 1) + assert p.subscribed is True + + +class TestPubSubMessages: + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + async def test_published_message_to_channel(self, r): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await r.publish("foo", "test message") == 1 + + message = await wait_for_message(p) + assert isinstance(message, dict) + assert message == make_message("message", "foo", "test message") + + async def test_published_message_to_pattern(self, r): + p = r.pubsub() + await p.subscribe("foo") + await p.psubscribe("f*") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await wait_for_message(p) == make_message("psubscribe", "f*", 2) + # 1 to pattern, 1 to channel + assert await r.publish("foo", "test message") == 2 + + message1 = await wait_for_message(p) + message2 = await wait_for_message(p) + assert isinstance(message1, dict) + assert isinstance(message2, dict) + + expected = [ + make_message("message", "foo", "test message"), + make_message("pmessage", "foo", "test message", pattern="f*"), + ] + + assert message1 in expected + assert message2 in expected + assert message1 != message2 + + async def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(foo=self.message_handler) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", "foo", "test message") + + async def test_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{"f*": self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish("foo", "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", "foo", "test message", pattern="f*" + ) + + async def test_unicode_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + channel = "uni" + chr(4456) + "code" + channels = {channel: self.message_handler} + await p.subscribe(**channels) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message("message", channel, "test message") + + async def test_unicode_pattern_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + pattern = "uni" + chr(4456) + "*" + channel = "uni" + chr(4456) + "code" + await p.psubscribe(**{pattern: self.message_handler}) + assert await wait_for_message(p) is None + assert await r.publish(channel, "test message") == 1 + assert await wait_for_message(p) is None + assert self.message == make_message( + "pmessage", channel, "test message", pattern=pattern + ) + + async def test_get_message_without_subscribe(self, r): + p = r.pubsub() + with pytest.raises(RuntimeError) as info: + await p.get_message() + expect = ( + "connection not set: " "did you forget to call subscribe() or psubscribe()?" + ) + assert expect in info.exconly() + + +class TestPubSubAutoDecoding: + """These tests only validate that we get unicode values back""" + + channel = "uni" + chr(4456) + "code" + pattern = "uni" + chr(4456) + "*" + data = "abc" + chr(4458) + "123" + + def make_message(self, type, channel, data, pattern=None): + return {"type": type, "channel": channel, "pattern": pattern, "data": data} + + def setup_method(self, method): + self.message = None + + def message_handler(self, message): + self.message = message + + @pytest.fixture() + async def r(self, request, event_loop): + return await _get_client( + aioredis.Redis, + request=request, + event_loop=event_loop, + decode_responses=True, + ) + + async def test_channel_subscribe_unsubscribe(self, r): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + + await p.unsubscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "unsubscribe", self.channel, 0 + ) + + async def test_pattern_subscribe_unsubscribe(self, r): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + + await p.punsubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "punsubscribe", self.pattern, 0 + ) + + async def test_channel_publish(self, r): + p = r.pubsub() + await p.subscribe(self.channel) + assert await wait_for_message(p) == self.make_message( + "subscribe", self.channel, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "message", self.channel, self.data + ) + + async def test_pattern_publish(self, r): + p = r.pubsub() + await p.psubscribe(self.pattern) + assert await wait_for_message(p) == self.make_message( + "psubscribe", self.pattern, 1 + ) + await r.publish(self.channel, self.data) + assert await wait_for_message(p) == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + async def test_channel_message_handler(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe(**{self.channel: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, self.data) + + # test that we reconnected to the correct channel + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message("message", self.channel, new_data) + + async def test_pattern_message_handler(self, r: aioredis.Redis): + p = r.pubsub(ignore_subscribe_messages=True) + await p.psubscribe(**{self.pattern: self.message_handler}) + assert await wait_for_message(p) is None + await r.publish(self.channel, self.data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, self.data, pattern=self.pattern + ) + + # test that we reconnected to the correct pattern + self.message = None + await p.connection.disconnect() + assert await wait_for_message(p) is None # should reconnect + new_data = self.data + "new data" + await r.publish(self.channel, new_data) + assert await wait_for_message(p) is None + assert self.message == self.make_message( + "pmessage", self.channel, new_data, pattern=self.pattern + ) + + async def test_context_manager(self, r: aioredis.Redis): + async with r.pubsub() as pubsub: + await pubsub.subscribe("foo") + assert pubsub.connection is not None + + assert pubsub.connection is None + assert pubsub.channels == {} + assert pubsub.patterns == {} + + +class TestPubSubRedisDown: + async def test_channel_subscribe(self, r): + r = aioredis.Redis(host="localhost", port=6390) + p = r.pubsub() + with pytest.raises(ConnectionError): + await p.subscribe("foo") + + +class TestPubSubSubcommands: + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_channels(self, r): + p = r.pubsub() + await p.subscribe("foo", "bar", "baz", "quux") + for i in range(4): + assert (await wait_for_message(p))["type"] == "subscribe" + expected = [b"bar", b"baz", b"foo", b"quux"] + assert all([channel in await r.pubsub_channels() for channel in expected]) + + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numsub(self, r): + p1 = r.pubsub() + await p1.subscribe("foo", "bar", "baz") + for i in range(3): + assert (await wait_for_message(p1))["type"] == "subscribe" + p2 = r.pubsub() + await p2.subscribe("bar", "baz") + for i in range(2): + assert (await wait_for_message(p2))["type"] == "subscribe" + p3 = r.pubsub() + await p3.subscribe("baz") + assert (await wait_for_message(p3))["type"] == "subscribe" + + channels = [(b"foo", 1), (b"bar", 2), (b"baz", 3)] + assert channels == await r.pubsub_numsub("foo", "bar", "baz") + + @skip_if_server_version_lt("2.8.0") + async def test_pubsub_numpat(self, r): + p = r.pubsub() + await p.psubscribe("*oo", "*ar", "b*z") + for i in range(3): + assert (await wait_for_message(p))["type"] == "psubscribe" + assert await r.pubsub_numpat() == 3 + + +class TestPubSubPings: + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping() + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="", pattern=None + ) + + @skip_if_server_version_lt("3.0.0") + async def test_send_pubsub_ping_message(self, r): + p = r.pubsub(ignore_subscribe_messages=True) + await p.subscribe("foo") + await p.ping(message="hello world") + assert await wait_for_message(p) == make_message( + type="pong", channel=None, data="hello world", pattern=None + ) + + +class TestPubSubConnectionKilled: + @skip_if_server_version_lt("3.0.0") + async def test_connection_error_raised_when_connection_dies(self, r): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + for client in await r.client_list(): + if client["cmd"] == "subscribe": + await r.client_kill_filter(_id=client["id"]) + with pytest.raises(ConnectionError): + await wait_for_message(p) + + +class TestPubSubTimeouts: + async def test_get_message_with_timeout_returns_none(self, r): + p = r.pubsub() + await p.subscribe("foo") + assert await wait_for_message(p) == make_message("subscribe", "foo", 1) + assert await p.get_message(timeout=0.01) is None + + +class TestPubSubWorkerThread: + async def test_pubsub_worker_thread_exception_handler(self, r): + event = threading.Event() + + def exception_handler(ex, pubsub, thread): + thread.stop() + event.set() + + p = r.pubsub() + await p.subscribe(**{"foo": lambda m: m}) + with mock.patch.object(p, "get_message", side_effect=Exception("error")): + pubsub_thread = p.run_in_thread(exception_handler=exception_handler) + assert event.wait(timeout=1.0) + pubsub_thread.join() + assert not pubsub_thread.is_alive() + assert event.is_set() diff --git a/tests/test_scripting.py b/tests/test_scripting.py new file mode 100644 index 000000000..845aa22a0 --- /dev/null +++ b/tests/test_scripting.py @@ -0,0 +1,124 @@ +import pytest + +from aioredis import exceptions + +multiply_script = """ +local value = redis.call('GET', KEYS[1]) +value = tonumber(value) +return value * ARGV[1]""" + +msgpack_hello_script = """ +local message = cmsgpack.unpack(ARGV[1]) +local name = message['name'] +return "hello " .. name +""" +msgpack_hello_script_broken = """ +local message = cmsgpack.unpack(ARGV[1]) +local names = message['name'] +return "hello " .. name +""" + + +class TestScripting: + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval(self, r): + await r.flushdb() + await r.set("a", 2) + # 2 * 3 == 6 + assert await r.eval(multiply_script, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # 2 * 3 == 6 + assert await r.evalsha(sha, 1, "a", 3) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_evalsha_script_not_loaded(self, r): + await r.set("a", 2) + sha = await r.script_load(multiply_script) + # remove the script from Redis's cache + await r.script_flush() + with pytest.raises(exceptions.NoScriptError): + await r.evalsha(sha, 1, "a", 3) + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_loading(self, r): + # get the sha, then clear the cache + sha = await r.script_load(multiply_script) + await r.script_flush() + assert await r.script_exists(sha) == [False] + await r.script_load(multiply_script) + assert await r.script_exists(sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object(self, r): + await r.script_flush() + await r.set("a", 2) + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + assert await r.script_exists(multiply.sha) == [False] + # Test second evalsha block (after NoScriptError) + assert await multiply(keys=["a"], args=[3]) == 6 + # At this point, the script should be loaded + assert await r.script_exists(multiply.sha) == [True] + # Test that the precalculated sha matches the one from redis + assert multiply.sha == precalculated_sha + # Test first evalsha block + assert await multiply(keys=["a"], args=[3]) == 6 + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_script_object_in_pipeline(self, r): + await r.script_flush() + multiply = r.register_script(multiply_script) + precalculated_sha = multiply.sha + assert precalculated_sha + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + # The script should have been loaded by pipe.execute() + assert await r.script_exists(multiply.sha) == [True] + # The precalculated sha should have been the correct one + assert multiply.sha == precalculated_sha + + # purge the script from redis's cache and re-run the pipeline + # the multiply script should be reloaded by pipe.execute() + await r.script_flush() + pipe = r.pipeline() + pipe.set("a", 2) + pipe.get("a") + await multiply(keys=["a"], args=[3], client=pipe) + assert await r.script_exists(multiply.sha) == [False] + # [SET worked, GET 'a', result of multiple script] + assert await pipe.execute() == [True, b"2", 6] + assert await r.script_exists(multiply.sha) == [True] + + @pytest.mark.asyncio(forbid_global_loop=True) + async def test_eval_msgpack_pipeline_error_in_lua(self, r): + msgpack_hello = r.register_script(msgpack_hello_script) + assert msgpack_hello.sha + + pipe = r.pipeline() + + # avoiding a dependency to msgpack, this is the output of + # msgpack.dumps({"name": "joe"}) + msgpack_message_1 = b"\x81\xa4name\xa3Joe" + + await msgpack_hello(args=[msgpack_message_1], client=pipe) + + assert await r.script_exists(msgpack_hello.sha) == [False] + assert (await pipe.execute())[0] == b"hello Joe" + assert await r.script_exists(msgpack_hello.sha) == [True] + + msgpack_hello_broken = r.register_script(msgpack_hello_script_broken) + + await msgpack_hello_broken(args=[msgpack_message_1], client=pipe) + with pytest.raises(exceptions.ResponseError) as excinfo: + await pipe.execute() + assert excinfo.type == exceptions.ResponseError diff --git a/tests/test_sentinel.py b/tests/test_sentinel.py new file mode 100644 index 000000000..faf6b6839 --- /dev/null +++ b/tests/test_sentinel.py @@ -0,0 +1,211 @@ +import socket + +import pytest + +import aioredis.sentinel +from aioredis import exceptions +from aioredis.sentinel import ( + MasterNotFoundError, + Sentinel, + SentinelConnectionPool, + SlaveNotFoundError, +) + +pytestmark = pytest.mark.asyncio + + +@pytest.fixture(scope="module") +def master_ip(master_host): + yield socket.gethostbyname(master_host) + + +class SentinelTestClient: + def __init__(self, cluster, id): + self.cluster = cluster + self.id = id + + async def sentinel_masters(self): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + return {self.cluster.service_name: self.cluster.master} + + async def sentinel_slaves(self, master_name): + self.cluster.connection_error_if_down(self) + self.cluster.timeout_if_down(self) + if master_name != self.cluster.service_name: + return [] + return self.cluster.slaves + + +class SentinelTestCluster: + def __init__(self, service_name="mymaster", ip="127.0.0.1", port=6379): + self.clients = {} + self.master = { + "ip": ip, + "port": port, + "is_master": True, + "is_sdown": False, + "is_odown": False, + "num-other-sentinels": 0, + } + self.service_name = service_name + self.slaves = [] + self.nodes_down = set() + self.nodes_timeout = set() + + def connection_error_if_down(self, node): + if node.id in self.nodes_down: + raise exceptions.ConnectionError + + def timeout_if_down(self, node): + if node.id in self.nodes_timeout: + raise exceptions.TimeoutError + + def client(self, host, port, **kwargs): + return SentinelTestClient(self, (host, port)) + + +@pytest.fixture() +def cluster(request, master_ip): + def teardown(): + aioredis.sentinel.Redis = saved_Redis + + cluster = SentinelTestCluster(ip=master_ip) + saved_Redis = aioredis.sentinel.Redis + aioredis.sentinel.Redis = cluster.client + request.addfinalizer(teardown) + return cluster + + +@pytest.fixture() +def sentinel(request, cluster): + return Sentinel([("foo", 26379), ("bar", 26379)]) + + +async def test_discover_master(sentinel, master_ip): + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +async def test_discover_master_error(sentinel): + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("xxx") + + +async def test_discover_master_sentinel_down(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_down.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +async def test_discover_master_sentinel_timeout(cluster, sentinel, master_ip): + # Put first sentinel 'foo' down + cluster.nodes_timeout.add(("foo", 26379)) + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + # 'bar' is now first sentinel + assert sentinel.sentinels[0].id == ("bar", 26379) + + +async def test_master_min_other_sentinels(cluster, master_ip): + sentinel = Sentinel([("foo", 26379)], min_other_sentinels=1) + # min_other_sentinels + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + cluster.master["num-other-sentinels"] = 2 + address = await sentinel.discover_master("mymaster") + assert address == (master_ip, 6379) + + +async def test_master_odown(cluster, sentinel): + cluster.master["is_odown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +async def test_master_sdown(cluster, sentinel): + cluster.master["is_sdown"] = True + with pytest.raises(MasterNotFoundError): + await sentinel.discover_master("mymaster") + + +async def test_discover_slaves(cluster, sentinel): + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves = [ + {"ip": "slave0", "port": 1234, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 1234, "is_odown": False, "is_sdown": False}, + ] + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + # slave0 -> ODOWN + cluster.slaves[0]["is_odown"] = True + assert await sentinel.discover_slaves("mymaster") == [("slave1", 1234)] + + # slave1 -> SDOWN + cluster.slaves[1]["is_sdown"] = True + assert await sentinel.discover_slaves("mymaster") == [] + + cluster.slaves[0]["is_odown"] = False + cluster.slaves[1]["is_sdown"] = False + + # node0 -> DOWN + cluster.nodes_down.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + cluster.nodes_down.clear() + + # node0 -> TIMEOUT + cluster.nodes_timeout.add(("foo", 26379)) + assert await sentinel.discover_slaves("mymaster") == [ + ("slave0", 1234), + ("slave1", 1234), + ] + + +async def test_master_for(cluster, sentinel, master_ip): + master = sentinel.master_for("mymaster", db=9) + assert await master.ping() + assert master.connection_pool.master_address == (master_ip, 6379) + + # Use internal connection check + master = sentinel.master_for("mymaster", db=9, check_connection=True) + assert await master.ping() + + +async def test_slave_for(cluster, sentinel): + cluster.slaves = [ + {"ip": "127.0.0.1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + slave = sentinel.slave_for("mymaster", db=9) + assert await slave.ping() + + +async def test_slave_for_slave_not_found_error(cluster, sentinel): + cluster.master["is_odown"] = True + slave = sentinel.slave_for("mymaster", db=9) + with pytest.raises(SlaveNotFoundError): + await slave.ping() + + +async def test_slave_round_robin(cluster, sentinel, master_ip): + cluster.slaves = [ + {"ip": "slave0", "port": 6379, "is_odown": False, "is_sdown": False}, + {"ip": "slave1", "port": 6379, "is_odown": False, "is_sdown": False}, + ] + pool = SentinelConnectionPool("mymaster", sentinel) + rotator = pool.rotate_slaves() + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + assert await rotator.__anext__() in (("slave0", 6379), ("slave1", 6379)) + # Fallback to master + assert await rotator.__anext__() == (master_ip, 6379) + with pytest.raises(SlaveNotFoundError): + await rotator.__anext__() diff --git a/tests/testutils.py b/tests/testutils.py deleted file mode 100644 index cb2240192..000000000 --- a/tests/testutils.py +++ /dev/null @@ -1,39 +0,0 @@ -import asyncio -import functools -from typing import Awaitable, Callable - -import pytest - -__all__ = ("redis_version", "delay_exc", "select_opener") - -from tests.conftest import RedisServer - - -def redis_version(*version: int, reason: str): - assert 1 < len(version) <= 3, version - assert all(isinstance(v, int) for v in version), version - return pytest.mark.redis_version(version=version, reason=reason) - - -def delay_exc(fn: Callable[..., Awaitable] = None, *, secs: float): - def _delayed_exc(func): - @functools.wraps(func) - async def _delayed_exc_wrapper(*args, **kwargs): - await asyncio.sleep(secs) - return await func(*args, **kwargs) - - return _delayed_exc_wrapper - - return _delayed_exc(fn) if fn else _delayed_exc - - -def select_opener(connect_type: str, server: RedisServer): - if connect_type == "tcp": - from aioredis.connection import open_connection - - return open_connection, server.tcp_address - if connect_type == "unix": - from aioredis.connection import open_unix_connection - - return open_unix_connection, server.unixsocket - raise RuntimeError(f"Connect-type {connect_type!r} is not supported.") diff --git a/tests/transaction_commands_test.py b/tests/transaction_commands_test.py deleted file mode 100644 index 8bafaec96..000000000 --- a/tests/transaction_commands_test.py +++ /dev/null @@ -1,285 +0,0 @@ -import asyncio - -import pytest - -from aioredis import ( - ConnectionClosedError, - MultiExecError, - ReplyError, - WatchVariableError, -) - - -@pytest.mark.asyncio -async def test_multi_exec(redis): - await redis.delete("foo", "bar") - - tr = redis.multi_exec() - f1 = tr.incr("foo") - f2 = tr.incr("bar") - res = await tr.execute() - assert res == [1, 1] - res2 = await asyncio.gather(f1, f2) - assert res == res2 - - tr = redis.multi_exec() - f1 = tr.incr("foo") - f2 = tr.incr("bar") - await tr.execute() - assert (await f1) == 2 - assert (await f2) == 2 - - tr = redis.multi_exec() - f1 = tr.set("foo", 1.0) - f2 = tr.incrbyfloat("foo", 1.2) - res = await tr.execute() - assert res == [True, 2.2] - res2 = await asyncio.gather(f1, f2) - assert res == res2 - - tr = redis.multi_exec() - f1 = tr.incrby("foo", 1.0) - with pytest.raises(MultiExecError, match="increment must be .* int"): - await tr.execute() - with pytest.raises(TypeError): - await f1 - - -@pytest.mark.asyncio -async def test_empty(redis): - tr = redis.multi_exec() - res = await tr.execute() - assert res == [] - - -@pytest.mark.asyncio -async def test_double_execute(redis): - tr = redis.multi_exec() - await tr.execute() - with pytest.raises(AssertionError): - await tr.execute() - with pytest.raises(AssertionError): - await tr.incr("foo") - - -@pytest.mark.asyncio -async def test_connection_closed(redis): - tr = redis.multi_exec() - fut1 = tr.quit() - fut2 = tr.incrby("foo", 1.0) - fut3 = tr.incrby("foo", 1) - with pytest.raises(MultiExecError): - await tr.execute() - - assert fut1.done() is True - assert fut2.done() is True - assert fut3.done() is True - assert fut1.exception() is not None - assert fut2.exception() is not None - assert fut3.exception() is not None - assert not fut1.cancelled() - assert not fut2.cancelled() - assert not fut3.cancelled() - - try: - assert (await fut1) == b"OK" - except Exception as err: - assert isinstance(err, (ConnectionClosedError, ConnectionError)) - assert fut2.cancelled() is False - assert isinstance(fut2.exception(), TypeError) - - # assert fut3.cancelled() is True - assert fut3.done() and not fut3.cancelled() - assert isinstance(fut3.exception(), (ConnectionClosedError, ConnectionError)) - - -@pytest.mark.asyncio -async def test_discard(redis): - await redis.delete("foo") - tr = redis.multi_exec() - fut1 = tr.incrby("foo", 1.0) - fut2 = tr.connection.execute("MULTI") - fut3 = tr.connection.execute("incr", "foo") - - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(TypeError): - await fut1 - with pytest.raises(ReplyError): - await fut2 - # with pytest.raises(ReplyError): - res = await fut3 - assert res == 1 - - -@pytest.mark.asyncio -async def test_exec_error(redis): - tr = redis.multi_exec() - fut = tr.connection.execute("INCRBY", "key", "1.0") - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(ReplyError): - await fut - - await redis.set("foo", "bar") - tr = redis.multi_exec() - fut = tr.incrbyfloat("foo", 1.1) - res = await tr.execute(return_exceptions=True) - assert isinstance(res[0], ReplyError) - with pytest.raises(ReplyError): - await fut - - -@pytest.mark.asyncio -async def test_command_errors(redis): - tr = redis.multi_exec() - fut = tr.incrby("key", 1.0) - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(TypeError): - await fut - - -@pytest.mark.asyncio -async def test_several_command_errors(redis): - tr = redis.multi_exec() - fut1 = tr.incrby("key", 1.0) - fut2 = tr.rename("bar", "bar") - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(TypeError): - await fut1 - with pytest.raises(ValueError): - await fut2 - - -@pytest.mark.asyncio -async def test_error_in_connection(redis): - await redis.set("foo", 1) - tr = redis.multi_exec() - fut1 = tr.mget("foo", None) - fut2 = tr.incr("foo") - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(TypeError): - await fut1 - await fut2 - - -@pytest.mark.asyncio -async def test_watch_unwatch(redis): - res = await redis.watch("key") - assert res is True - res = await redis.watch("key", "key") - assert res is True - - with pytest.raises(TypeError): - await redis.watch(None) - with pytest.raises(TypeError): - await redis.watch("key", None) - with pytest.raises(TypeError): - await redis.watch("key", "key", None) - - res = await redis.unwatch() - assert res is True - - -@pytest.mark.asyncio -async def test_encoding(redis): - res = await redis.set("key", "value") - assert res is True - res = await redis.hmset("hash-key", "foo", "val1", "bar", "val2") - assert res is True - - tr = redis.multi_exec() - fut1 = tr.get("key") - fut2 = tr.get("key", encoding="utf-8") - fut3 = tr.hgetall("hash-key", encoding="utf-8") - await tr.execute() - res = await fut1 - assert res == b"value" - res = await fut2 - assert res == "value" - res = await fut3 - assert res == {"foo": "val1", "bar": "val2"} - - -@pytest.mark.asyncio -async def test_global_encoding(redis, create_redis, server): - redis = await create_redis(server.tcp_address, encoding="utf-8") - res = await redis.set("key", "value") - assert res is True - res = await redis.hmset("hash-key", "foo", "val1", "bar", "val2") - assert res is True - - tr = redis.multi_exec() - fut1 = tr.get("key") - fut2 = tr.get("key", encoding="utf-8") - fut3 = tr.get("key", encoding=None) - fut4 = tr.hgetall("hash-key", encoding="utf-8") - await tr.execute() - res = await fut1 - assert res == "value" - res = await fut2 - assert res == "value" - res = await fut3 - assert res == b"value" - res = await fut4 - assert res == {"foo": "val1", "bar": "val2"} - - -@pytest.mark.asyncio -async def test_transaction__watch_error(redis, create_redis, server): - other = await create_redis(server.tcp_address) - - ok = await redis.set("foo", "bar") - assert ok is True - - ok = await redis.watch("foo") - assert ok is True - - ok = await other.set("foo", "baz") - assert ok is True - - tr = redis.multi_exec() - fut1 = tr.set("foo", "foo") - fut2 = tr.get("bar") - with pytest.raises(MultiExecError): - await tr.execute() - with pytest.raises(WatchVariableError): - await fut1 - with pytest.raises(WatchVariableError): - await fut2 - - -@pytest.mark.asyncio -async def test_multi_exec_and_pool_release(redis): - # Test the case when pool connection is released before - # `exec` result is received. - - slow_script = """ - local a = tonumber(redis.call('time')[1]) - local b = a + 1 - while (a < b) - do - a = tonumber(redis.call('time')[1]) - end - """ - - tr = redis.multi_exec() - fut1 = tr.eval(slow_script) - (ret,) = await tr.execute() - assert ret is None - assert (await fut1) is None - - -@pytest.mark.asyncio -async def test_multi_exec_db_select(redis): - await redis.set("foo", "bar") - - tr = redis.multi_exec() - f1 = tr.get("foo", encoding="utf-8") - f2 = tr.get("foo") - await tr.execute() - assert await f1 == "bar" - assert await f2 == b"bar"