diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index 414231fd..34dca360 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -56,6 +56,7 @@ def parse(cls, sslmode): 'direct_tls', 'server_settings', 'target_session_attrs', + 'krbsrvname', ]) @@ -261,7 +262,7 @@ def _dot_postgresql_path(filename) -> typing.Optional[pathlib.Path]: def _parse_connect_dsn_and_args(*, dsn, host, port, user, password, passfile, database, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, krbsrvname): # `auth_hosts` is the version of host information for the purposes # of reading the pgpass file. auth_hosts = None @@ -383,6 +384,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if target_session_attrs is None: target_session_attrs = dsn_target_session_attrs + if 'krbsrvname' in query: + val = query.pop('krbsrvname') + if krbsrvname is None: + krbsrvname = val + if query: if server_settings is None: server_settings = query @@ -654,7 +660,8 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, user=user, password=password, database=database, ssl=ssl, sslmode=sslmode, direct_tls=direct_tls, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname) return addrs, params @@ -665,7 +672,7 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, max_cached_statement_lifetime, max_cacheable_statement_size, ssl, direct_tls, server_settings, - target_session_attrs): + target_session_attrs, krbsrvname): local_vars = locals() for var_name in {'max_cacheable_statement_size', 'max_cached_statement_lifetime', @@ -694,7 +701,8 @@ def _parse_connect_arguments(*, dsn, host, port, user, password, passfile, password=password, passfile=passfile, ssl=ssl, direct_tls=direct_tls, database=database, server_settings=server_settings, - target_session_attrs=target_session_attrs) + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname) config = _ClientConfiguration( command_timeout=command_timeout, diff --git a/asyncpg/connection.py b/asyncpg/connection.py index 0367e365..c8272897 100644 --- a/asyncpg/connection.py +++ b/asyncpg/connection.py @@ -2007,7 +2007,8 @@ async def connect(dsn=None, *, connection_class=Connection, record_class=protocol.Record, server_settings=None, - target_session_attrs=None): + target_session_attrs=None, + krbsrvname=None): r"""A coroutine to establish a connection to a PostgreSQL server. The connection parameters may be specified either as a connection @@ -2235,6 +2236,9 @@ async def connect(dsn=None, *, or the value of the ``PGTARGETSESSIONATTRS`` environment variable, or ``"any"`` if neither is specified. + :param str krbsrvname: + Kerberos service name to use when authenticating with GSSAPI. + :return: A :class:`~asyncpg.connection.Connection` instance. Example: @@ -2344,7 +2348,8 @@ async def connect(dsn=None, *, statement_cache_size=statement_cache_size, max_cached_statement_lifetime=max_cached_statement_lifetime, max_cacheable_statement_size=max_cacheable_statement_size, - target_session_attrs=target_session_attrs + target_session_attrs=target_session_attrs, + krbsrvname=krbsrvname, ) diff --git a/asyncpg/protocol/coreproto.pxd b/asyncpg/protocol/coreproto.pxd index 7ce4f574..612d8cae 100644 --- a/asyncpg/protocol/coreproto.pxd +++ b/asyncpg/protocol/coreproto.pxd @@ -51,16 +51,6 @@ cdef enum AuthenticationMessage: AUTH_SASL_FINAL = 12 -AUTH_METHOD_NAME = { - AUTH_REQUIRED_KERBEROS: 'kerberosv5', - AUTH_REQUIRED_PASSWORD: 'password', - AUTH_REQUIRED_PASSWORDMD5: 'md5', - AUTH_REQUIRED_GSS: 'gss', - AUTH_REQUIRED_SASL: 'scram-sha-256', - AUTH_REQUIRED_SSPI: 'sspi', -} - - cdef enum ResultType: RESULT_OK = 1 RESULT_FAILED = 2 @@ -96,10 +86,13 @@ cdef class CoreProtocol: object transport + object address # Instance of _ConnectionParameters object con_params # Instance of SCRAMAuthentication SCRAMAuthentication scram + # Instance of gssapi.SecurityContext + object gss_ctx readonly int32_t backend_pid readonly int32_t backend_secret @@ -145,6 +138,8 @@ cdef class CoreProtocol: cdef _auth_password_message_md5(self, bytes salt) cdef _auth_password_message_sasl_initial(self, list sasl_auth_methods) cdef _auth_password_message_sasl_continue(self, bytes server_response) + cdef _auth_gss_init(self) + cdef _auth_gss_step(self, bytes server_response) cdef _write(self, buf) cdef _writelines(self, list buffers) diff --git a/asyncpg/protocol/coreproto.pyx b/asyncpg/protocol/coreproto.pyx index 64afe934..26fee179 100644 --- a/asyncpg/protocol/coreproto.pyx +++ b/asyncpg/protocol/coreproto.pyx @@ -6,14 +6,26 @@ import hashlib +import socket include "scram.pyx" +cdef dict AUTH_METHOD_NAME = { + AUTH_REQUIRED_KERBEROS: 'kerberosv5', + AUTH_REQUIRED_PASSWORD: 'password', + AUTH_REQUIRED_PASSWORDMD5: 'md5', + AUTH_REQUIRED_GSS: 'gss', + AUTH_REQUIRED_SASL: 'scram-sha-256', + AUTH_REQUIRED_SSPI: 'sspi', +} + + cdef class CoreProtocol: - def __init__(self, con_params): + def __init__(self, addr, con_params): + self.address = addr # type of `con_params` is `_ConnectionParameters` self.buffer = ReadBuffer() self.user = con_params.user @@ -26,6 +38,8 @@ cdef class CoreProtocol: self.encoding = 'utf-8' # type of `scram` is `SCRAMAuthentcation` self.scram = None + # type of `gss_ctx` is `gssapi.SecurityContext` + self.gss_ctx = None self._reset_result() @@ -619,9 +633,17 @@ cdef class CoreProtocol: 'could not verify server signature for ' 'SCRAM authentciation: scram-sha-256', ) + self.scram = None + + elif status == AUTH_REQUIRED_GSS: + self._auth_gss_init() + self.auth_msg = self._auth_gss_step(None) + + elif status == AUTH_REQUIRED_GSS_CONTINUE: + server_response = self.buffer.consume_message() + self.auth_msg = self._auth_gss_step(server_response) elif status in (AUTH_REQUIRED_KERBEROS, AUTH_REQUIRED_SCMCRED, - AUTH_REQUIRED_GSS, AUTH_REQUIRED_GSS_CONTINUE, AUTH_REQUIRED_SSPI): self.result_type = RESULT_FAILED self.result = apg_exc.InterfaceError( @@ -634,7 +656,8 @@ cdef class CoreProtocol: 'unsupported authentication method requested by the ' 'server: {}'.format(status)) - if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL]: + if status not in [AUTH_SASL_CONTINUE, AUTH_SASL_FINAL, + AUTH_REQUIRED_GSS_CONTINUE]: self.buffer.discard_message() cdef _auth_password_message_cleartext(self): @@ -691,6 +714,39 @@ cdef class CoreProtocol: return msg + cdef _auth_gss_init(self): + try: + import gssapi + except ModuleNotFoundError: + raise RuntimeError( + 'gssapi module not found; please install asyncpg[gss] to use ' + 'asyncpg with Kerberos or GSSAPI authentication' + ) from None + + service_name = self.con_params.krbsrvname or 'postgres' + # find the canonical name of the server host + if isinstance(self.address, str): + host = socket.gethostname() + else: + host = self.address[0] + host_cname = socket.gethostbyname_ex(host)[0].rstrip('.') or host + gss_name = gssapi.Name(f'{service_name}/{host_cname}') + self.gss_ctx = gssapi.SecurityContext(name=gss_name, usage='initiate') + + cdef _auth_gss_step(self, bytes server_response): + cdef: + WriteBuffer msg + + token = self.gss_ctx.step(server_response) + if not token: + self.gss_ctx = None + return None + msg = WriteBuffer.new_message(b'p') + msg.write_bytes(token) + msg.end_message() + + return msg + cdef _parse_msg_ready_for_query(self): cdef char status = self.buffer.read_byte() diff --git a/asyncpg/protocol/protocol.pxd b/asyncpg/protocol/protocol.pxd index a9ac8d5f..cd221fbb 100644 --- a/asyncpg/protocol/protocol.pxd +++ b/asyncpg/protocol/protocol.pxd @@ -31,7 +31,6 @@ cdef class BaseProtocol(CoreProtocol): cdef: object loop - object address ConnectionSettings settings object cancel_sent_waiter object cancel_waiter diff --git a/asyncpg/protocol/protocol.pyx b/asyncpg/protocol/protocol.pyx index b43b0e9c..1459d908 100644 --- a/asyncpg/protocol/protocol.pyx +++ b/asyncpg/protocol/protocol.pyx @@ -75,7 +75,7 @@ NO_TIMEOUT = object() cdef class BaseProtocol(CoreProtocol): def __init__(self, addr, connected_fut, con_params, record_class: type, loop): # type of `con_params` is `_ConnectionParameters` - CoreProtocol.__init__(self, con_params) + CoreProtocol.__init__(self, addr, con_params) self.loop = loop self.transport = None @@ -83,8 +83,7 @@ cdef class BaseProtocol(CoreProtocol): self.cancel_waiter = None self.cancel_sent_waiter = None - self.address = addr - self.settings = ConnectionSettings((self.address, con_params.database)) + self.settings = ConnectionSettings((addr, con_params.database)) self.record_class = record_class self.statement = None diff --git a/pyproject.toml b/pyproject.toml index ed2340a7..f22821e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,9 @@ dependencies = [ github = "https://github.com/MagicStack/asyncpg" [project.optional-dependencies] +gss = [ + 'gssapi', +] test = [ 'flake8~=6.1', 'uvloop>=0.15.3; platform_system != "Windows" and python_version < "3.12.0"',