Skip to content

Commit

Permalink
Allow setting client_name during connection construction.
Browse files Browse the repository at this point in the history
Client instances and Connection pools now accept "client_name" as an optional
argument. If supplied, all connections created will be named via
CLIENT SETNAME once the connection to the server is established.
  • Loading branch information
Habbie authored and andymccurdy committed Dec 29, 2019
1 parent a41465e commit dca7bd4
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 23 deletions.
4 changes: 4 additions & 0 deletions CHANGES
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@
pipeline instances relied on __len__ for boolean evaluation which
meant that pipelines with no commands on the stack would be considered
False. #994
* Client instances and Connection pools now support a 'client_name'
argument. If supplied, all connections created will call CLIENT SETNAME
as soon as the connection is opened. Thanks to @Habbie for supplying
the basis of this chanfge. #802
* 3.3.11
* Further fix for the SSLError -> TimeoutError mapping to work
on obscure releases of Python 2.7.
Expand Down
3 changes: 2 additions & 1 deletion redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -684,7 +684,7 @@ def __init__(self, host='localhost', port=6379,
ssl=False, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs='required', ssl_ca_certs=None,
max_connections=None, single_connection_client=False,
health_check_interval=0):
health_check_interval=0, client_name=None):
if not connection_pool:
if charset is not None:
warnings.warn(DeprecationWarning(
Expand All @@ -706,6 +706,7 @@ def __init__(self, host='localhost', port=6379,
'retry_on_timeout': retry_on_timeout,
'max_connections': max_connections,
'health_check_interval': health_check_interval,
'client_name': client_name
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
Expand Down
46 changes: 31 additions & 15 deletions redis/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,7 +485,6 @@ def read_response(self):

class Connection(object):
"Manages TCP communication to and from a Redis server"
description_format = "Connection<host=%(host)s,port=%(port)s,db=%(db)s>"

def __init__(self, host='localhost', port=6379, db=0, username=None,
password=None, socket_timeout=None,
Expand All @@ -494,12 +493,13 @@ def __init__(self, host='localhost', port=6379, db=0, username=None,
retry_on_timeout=False, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
parser_class=DefaultParser, socket_read_size=65536,
health_check_interval=0):
health_check_interval=0, client_name=None):
self.pid = os.getpid()
self.host = host
self.port = int(port)
self.db = db
self.username = username
self.client_name = client_name
self.password = password
self.socket_timeout = socket_timeout
self.socket_connect_timeout = socket_connect_timeout or socket_timeout
Expand All @@ -512,16 +512,22 @@ def __init__(self, host='localhost', port=6379, db=0, username=None,
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
'host': self.host,
'port': self.port,
'db': self.db,
}
self._connect_callbacks = []
self._buffer_cutoff = 6000

def __repr__(self):
return self.description_format % self._description_args
repr_args = ','.join(['%s=%s' % (k, v) for k, v in self.repr_pieces()])
return '%s<%s>' % (self.__class__.__name__, repr_args)

def repr_pieces(self):
pieces = [
('host', self.host),
('port', self.port),
('db', self.db)
]
if self.client_name:
pieces.append(('client_name', self.client_name))
return pieces

def __del__(self):
try:
Expand Down Expand Up @@ -626,6 +632,12 @@ def on_connect(self):
if nativestr(self.read_response()) != 'OK':
raise AuthenticationError('Invalid Username or Password')

# if a client_name is given, set it
if self.client_name:
self.send_command('CLIENT', 'SETNAME', self.client_name)
if nativestr(self.read_response()) != 'OK':
raise ConnectionError('Error setting client name')

# if a database is specified, switch to it
if self.db:
self.send_command('SELECT', self.db)
Expand Down Expand Up @@ -785,7 +797,6 @@ def pack_commands(self, commands):


class SSLConnection(Connection):
description_format = "SSLConnection<host=%(host)s,port=%(port)s,db=%(db)s>"

def __init__(self, ssl_keyfile=None, ssl_certfile=None,
ssl_cert_reqs='required', ssl_ca_certs=None, **kwargs):
Expand Down Expand Up @@ -838,18 +849,18 @@ def _connect(self):


class UnixDomainSocketConnection(Connection):
description_format = "UnixDomainSocketConnection<path=%(path)s,db=%(db)s>"

def __init__(self, path='', db=0, username=None, password=None,
socket_timeout=None, encoding='utf-8',
encoding_errors='strict', decode_responses=False,
retry_on_timeout=False,
parser_class=DefaultParser, socket_read_size=65536,
health_check_interval=0):
health_check_interval=0, client_name=None):
self.pid = os.getpid()
self.path = path
self.db = db
self.username = username
self.client_name = client_name
self.password = password
self.socket_timeout = socket_timeout
self.retry_on_timeout = retry_on_timeout
Expand All @@ -858,13 +869,18 @@ def __init__(self, path='', db=0, username=None, password=None,
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self._sock = None
self._parser = parser_class(socket_read_size=socket_read_size)
self._description_args = {
'path': self.path,
'db': self.db,
}
self._connect_callbacks = []
self._buffer_cutoff = 6000

def repr_pieces(self):
pieces = [
('path', self.path),
('db', self.db),
]
if self.client_name:
pieces.append(('client_name', self.client_name))
return pieces

def _connect(self):
"Create a Unix domain socket connection"
sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
Expand Down
45 changes: 38 additions & 7 deletions tests/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,17 +64,28 @@ def test_reuse_previously_released_connection(self):
assert c1 == c2

def test_repr_contains_db_info_tcp(self):
connection_kwargs = {'host': 'localhost', 'port': 6379, 'db': 1}
connection_kwargs = {
'host': 'localhost',
'port': 6379,
'db': 1,
'client_name': 'test-client'
}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.Connection)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=1>>'
expected = ('ConnectionPool<Connection<'
'host=localhost,port=6379,db=1,client_name=test-client>>')
assert repr(pool) == expected

def test_repr_contains_db_info_unix(self):
connection_kwargs = {'path': '/abc', 'db': 1}
connection_kwargs = {
'path': '/abc',
'db': 1,
'client_name': 'test-client'
}
pool = self.get_pool(connection_kwargs=connection_kwargs,
connection_class=redis.UnixDomainSocketConnection)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=/abc,db=1>>'
expected = ('ConnectionPool<UnixDomainSocketConnection<'
'path=/abc,db=1,client_name=test-client>>')
assert repr(pool) == expected

def test_pool_equality(self):
Expand Down Expand Up @@ -177,17 +188,25 @@ def test_reuse_previously_released_connection(self):
assert c1 == c2

def test_repr_contains_db_info_tcp(self):
pool = redis.ConnectionPool(host='localhost', port=6379, db=0)
expected = 'ConnectionPool<Connection<host=localhost,port=6379,db=0>>'
pool = redis.ConnectionPool(
host='localhost',
port=6379,
db=0,
client_name='test-client'
)
expected = ('ConnectionPool<Connection<'
'host=localhost,port=6379,db=0,client_name=test-client>>')
assert repr(pool) == expected

def test_repr_contains_db_info_unix(self):
pool = redis.ConnectionPool(
connection_class=redis.UnixDomainSocketConnection,
path='abc',
db=0,
client_name='test-client'
)
expected = 'ConnectionPool<UnixDomainSocketConnection<path=abc,db=0>>'
expected = ('ConnectionPool<UnixDomainSocketConnection<'
'path=abc,db=0,client_name=test-client>>')
assert repr(pool) == expected


Expand Down Expand Up @@ -364,6 +383,12 @@ def test_boolean_parsing(self):
):
assert expected is to_bool(value)

def test_client_name_in_querystring(self):
pool = redis.ConnectionPool.from_url(
'redis://location?client_name=test-client'
)
assert pool.connection_kwargs['client_name'] == 'test-client'

def test_invalid_extra_typed_querystring_options(self):
import warnings
with warnings.catch_warnings(record=True) as warning_log:
Expand Down Expand Up @@ -502,6 +527,12 @@ def test_db_in_querystring(self):
'password': None,
}

def test_client_name_in_querystring(self):
pool = redis.ConnectionPool.from_url(
'redis://location?client_name=test-client'
)
assert pool.connection_kwargs['client_name'] == 'test-client'

def test_extra_querystring_options(self):
pool = redis.ConnectionPool.from_url('unix:///socket?a=1&b=2')
assert pool.connection_class == redis.UnixDomainSocketConnection
Expand Down

0 comments on commit dca7bd4

Please sign in to comment.