From 5b4afe990ec74729bd5e6f455fe13e900ada9ce5 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 10 Sep 2024 20:41:14 +0000 Subject: [PATCH 1/2] fix: only set PSC ip type if PSC is enabled --- google/cloud/sql/connector/client.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/google/cloud/sql/connector/client.py b/google/cloud/sql/connector/client.py index e5af54b2b..1c805814e 100644 --- a/google/cloud/sql/connector/client.py +++ b/google/cloud/sql/connector/client.py @@ -141,10 +141,13 @@ async def _get_metadata( if "ipAddresses" in ret_dict else {} ) - # Remove trailing period from PSC DNS name. - psc_dns = ret_dict.get("dnsName") - if psc_dns: - ip_addresses["PSC"] = psc_dns.rstrip(".") + # resolve dnsName into IP address for PSC + # Note that we have to check for PSC enablement also because CAS + # instances also set the dnsName field. + # Remove trailing period from DNS name. Required for SSL in Python + dns_name = ret_dict.get("dnsName", "").rstrip(".") + if dns_name and ret_dict.get("pscEnabled"): + ip_addresses["PSC"] = dns_name return { "ip_addresses": ip_addresses, From 2cd09a368e160ed3ec5609c335a575dbe8d5f4b3 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 12 Sep 2024 14:59:50 +0000 Subject: [PATCH 2/2] chore: add unit test --- tests/unit/mocks.py | 2 ++ tests/unit/test_client.py | 24 ++++++++++++++++++++++-- 2 files changed, 24 insertions(+), 2 deletions(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 03ad5a6e0..0f25f1c14 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -234,6 +234,7 @@ def __init__( self.name = name self.db_version = db_version self.ip_addrs = ip_addrs + self.psc_enabled = False self.cert_before = cert_before self.cert_expiration = cert_expiration # create self signed CA cert @@ -255,6 +256,7 @@ async def connect_settings(self, request: Any) -> web.Response: "expirationTime": str(self.cert_expiration), }, "dnsName": "abcde.12345.us-central1.sql.goog", + "pscEnabled": self.psc_enabled, "ipAddresses": ip_addrs, "region": self.region, "databaseVersion": self.db_version, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index da1c97ea9..046e8e51b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -24,9 +24,9 @@ @pytest.mark.asyncio -async def test_get_metadata(fake_client: CloudSQLClient) -> None: +async def test_get_metadata_no_psc(fake_client: CloudSQLClient) -> None: """ - Test _get_metadata returns successfully. + Test _get_metadata returns successfully and does not include PSC IP type. """ resp = await fake_client._get_metadata( "test-project", @@ -34,6 +34,26 @@ async def test_get_metadata(fake_client: CloudSQLClient) -> None: "test-instance", ) assert resp["database_version"] == "POSTGRES_15" + assert resp["ip_addresses"] == { + "PRIMARY": "127.0.0.1", + "PRIVATE": "10.0.0.1", + } + assert isinstance(resp["server_ca_cert"], str) + + +@pytest.mark.asyncio +async def test_get_metadata_with_psc(fake_client: CloudSQLClient) -> None: + """ + Test _get_metadata returns successfully with PSC IP type. + """ + # set PSC to enabled on test instance + fake_client.instance.psc_enabled = True + resp = await fake_client._get_metadata( + "test-project", + "test-region", + "test-instance", + ) + assert resp["database_version"] == "POSTGRES_15" assert resp["ip_addresses"] == { "PRIMARY": "127.0.0.1", "PRIVATE": "10.0.0.1",