From 7044d4152c382d55e83694b4529b415eea5f3bc2 Mon Sep 17 00:00:00 2001 From: Elvis Pranskevichus Date: Tue, 15 Aug 2023 15:18:28 -0700 Subject: [PATCH] Handle environments with HOME set to a not-a-directory If `HOME` points to a regular file (or `/dev/null`), make sure we don't crash unnecessarily, and if we do need to crash, so so informatively. Fixes: #1014 --- asyncpg/connect_utils.py | 64 ++++++++++++++++++++++++------------- asyncpg/exceptions/_base.py | 7 +++- tests/test_connect.py | 33 +++++++++++++++++-- 3 files changed, 79 insertions(+), 25 deletions(-) diff --git a/asyncpg/connect_utils.py b/asyncpg/connect_utils.py index b5beb4e8..b91da671 100644 --- a/asyncpg/connect_utils.py +++ b/asyncpg/connect_utils.py @@ -165,7 +165,7 @@ def _validate_port_spec(hosts, port): # If there is a list of ports, its length must # match that of the host list. if len(port) != len(hosts): - raise exceptions.InterfaceError( + raise exceptions.ClientConfigurationError( 'could not match {} port numbers to {} hosts'.format( len(port), len(hosts))) else: @@ -211,7 +211,7 @@ def _parse_hostlist(hostlist, port, *, unquote=False): addr = m.group(1) hostspec_port = m.group(2) else: - raise ValueError( + raise exceptions.ClientConfigurationError( 'invalid IPv6 address in the connection URI: {!r}'.format( hostspec ) @@ -240,13 +240,13 @@ def _parse_hostlist(hostlist, port, *, unquote=False): def _parse_tls_version(tls_version): if tls_version.startswith('SSL'): - raise ValueError( + raise exceptions.ClientConfigurationError( f"Unsupported TLS version: {tls_version}" ) try: return ssl_module.TLSVersion[tls_version.replace('.', '_')] except KeyError: - raise ValueError( + raise exceptions.ClientConfigurationError( f"No such TLS version: {tls_version}" ) @@ -274,7 +274,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, parsed = urllib.parse.urlparse(dsn) if parsed.scheme not in {'postgresql', 'postgres'}: - raise ValueError( + raise exceptions.ClientConfigurationError( 'invalid DSN: scheme is expected to be either ' '"postgresql" or "postgres", got {!r}'.format(parsed.scheme)) @@ -437,11 +437,11 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, database = user if user is None: - raise exceptions.InterfaceError( + raise exceptions.ClientConfigurationError( 'could not determine user name to connect with') if database is None: - raise exceptions.InterfaceError( + raise exceptions.ClientConfigurationError( 'could not determine database name to connect to') if password is None: @@ -477,7 +477,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, have_tcp_addrs = True if not addrs: - raise ValueError( + raise exceptions.InternalClientError( 'could not determine the database address to connect to') if ssl is None: @@ -491,7 +491,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, sslmode = SSLMode.parse(ssl) except AttributeError: modes = ', '.join(m.name.replace('_', '-') for m in SSLMode) - raise exceptions.InterfaceError( + raise exceptions.ClientConfigurationError( '`sslmode` parameter must be one of: {}'.format(modes)) # docs at https://www.postgresql.org/docs/10/static/libpq-connect.html @@ -511,19 +511,36 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, else: try: sslrootcert = _dot_postgresql_path('root.crt') - assert sslrootcert is not None - ssl.load_verify_locations(cafile=sslrootcert) - except (AssertionError, FileNotFoundError): + if sslrootcert is not None: + ssl.load_verify_locations(cafile=sslrootcert) + else: + raise exceptions.ClientConfigurationError( + 'cannot determine location of user ' + 'PostgreSQL configuration directory' + ) + except ( + exceptions.ClientConfigurationError, + FileNotFoundError, + NotADirectoryError, + ): if sslmode > SSLMode.require: if sslrootcert is None: - raise RuntimeError( - 'Cannot determine home directory' + sslrootcert = '~/.postgresql/root.crt' + detail = ( + 'Could not determine location of user ' + 'home directory (HOME is either unset, ' + 'inaccessible, or does not point to a ' + 'valid directory)' ) - raise ValueError( + else: + detail = None + raise exceptions.ClientConfigurationError( f'root certificate file "{sslrootcert}" does ' - f'not exist\nEither provide the file or ' - f'change sslmode to disable server ' - f'certificate verification.' + f'not exist or cannot be accessed', + hint='Provide the certificate file directly ' + f'or make sure "{sslrootcert}" ' + 'exists and is readable.', + detail=detail, ) elif sslmode == SSLMode.require: ssl.verify_mode = ssl_module.CERT_NONE @@ -542,7 +559,10 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, if sslcrl is not None: try: ssl.load_verify_locations(cafile=sslcrl) - except FileNotFoundError: + except ( + FileNotFoundError, + NotADirectoryError, + ): pass else: ssl.verify_flags |= \ @@ -571,7 +591,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, keyfile=sslkey, password=lambda: sslpassword ) - except FileNotFoundError: + except (FileNotFoundError, NotADirectoryError): pass # OpenSSL 1.1.1 keylog file, copied from create_default_context() @@ -606,7 +626,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, not isinstance(server_settings, dict) or not all(isinstance(k, str) for k in server_settings) or not all(isinstance(v, str) for v in server_settings.values())): - raise ValueError( + raise exceptions.ClientConfigurationError( 'server_settings is expected to be None or ' 'a Dict[str, str]') @@ -617,7 +637,7 @@ def _parse_connect_dsn_and_args(*, dsn, host, port, user, try: target_session_attrs = SessionAttribute(target_session_attrs) except ValueError: - raise exceptions.InterfaceError( + raise exceptions.ClientConfigurationError( "target_session_attrs is expected to be one of " "{!r}" ", got {!r}".format( diff --git a/asyncpg/exceptions/_base.py b/asyncpg/exceptions/_base.py index de981d25..e2da6bd8 100644 --- a/asyncpg/exceptions/_base.py +++ b/asyncpg/exceptions/_base.py @@ -13,7 +13,8 @@ __all__ = ('PostgresError', 'FatalPostgresError', 'UnknownPostgresError', 'InterfaceError', 'InterfaceWarning', 'PostgresLogMessage', 'InternalClientError', 'OutdatedSchemaCacheError', 'ProtocolError', - 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched') + 'UnsupportedClientFeatureError', 'TargetServerAttributeNotMatched', + 'ClientConfigurationError') def _is_asyncpg_class(cls): @@ -220,6 +221,10 @@ def with_msg(self, msg): ) +class ClientConfigurationError(InterfaceError, ValueError): + """An error caused by improper client configuration.""" + + class DataError(InterfaceError, ValueError): """An error caused by invalid query input.""" diff --git a/tests/test_connect.py b/tests/test_connect.py index e3cfb372..e487ee61 100644 --- a/tests/test_connect.py +++ b/tests/test_connect.py @@ -79,6 +79,15 @@ def mock_no_home_dir(): yield +@contextlib.contextmanager +def mock_dev_null_home_dir(): + with unittest.mock.patch( + 'pathlib.Path.home', + unittest.mock.Mock(return_value=pathlib.Path('/dev/null')), + ): + yield + + class TestSettings(tb.ConnectedTestCase): async def test_get_settings_01(self): @@ -1318,9 +1327,18 @@ async def test_connection_no_home_dir(self): await con.fetchval('SELECT 42') await con.close() + with mock_dev_null_home_dir(): + con = await self.connect( + dsn='postgresql://foo/', + user='postgres', + database='postgres', + host='localhost') + await con.fetchval('SELECT 42') + await con.close() + with self.assertRaisesRegex( - RuntimeError, - 'Cannot determine home directory' + exceptions.ClientConfigurationError, + r'root certificate file "~/\.postgresql/root\.crt" does not exist' ): with mock_no_home_dir(): await self.connect( @@ -1328,6 +1346,17 @@ async def test_connection_no_home_dir(self): user='ssl_user', ssl='verify-full') + with self.assertRaisesRegex( + exceptions.ClientConfigurationError, + r'root certificate file "/dev/null/\.postgresql/root\.crt" ' + r'does not exist' + ): + with mock_dev_null_home_dir(): + await self.connect( + host='localhost', + user='ssl_user', + ssl='verify-full') + class BaseTestSSLConnection(tb.ConnectedTestCase): @classmethod