diff --git a/README.md b/README.md index 6eb50a2..3d88523 100644 --- a/README.md +++ b/README.md @@ -88,9 +88,11 @@ This package provides several functions for authorizing and encrypting connections. These functions are used with your database driver to connect to your AlloyDB instance. -AlloyDB supports network connectivity through private, internal IP addresses only. -This package must be run in an environment that is connected to the -[VPC Network][vpc] that hosts your AlloyDB private IP address. +AlloyDB supports network connectivity through public IP addresses and private, +internal IP addresses. By default this package will attempt to connect over a +private IP connection. When doing so, this package must be run in an +environment that is connected to the [VPC Network][vpc] that hosts your +AlloyDB private IP address. Please see [Configuring AlloyDB Connectivity][alloydb-connectivity] for more details. @@ -366,6 +368,27 @@ connector.connect( [configure-iam-authn]: https://cloud.google.com/alloydb/docs/manage-iam-authn#enable [add-iam-user]: https://cloud.google.com/alloydb/docs/manage-iam-authn#create-user +### Specifying IP Address Type + +The AlloyDB Python Connector by default will attempt to establish connections +to your instance's private IP. To change this, such as connecting to AlloyDB +over a public IP address, set the `ip_type` keyword argument when initializing +a `Connector()` or when calling `connector.connect()`. + +Possible values for `ip_type` are `IPTypes.PRIVATE` (default value), and +`IPTypes.PUBLIC`. +Example: + +```python +from google.cloud.alloydb.connector import Connector, IPTypes + +conn = connector.connect( + "projects//locations//clusters//instances/", + "pg8000", + ip_type=IPTypes.PUBLIC, # use public IP +) +``` + ## Support policy ### Major version lifecycle diff --git a/google/cloud/alloydb/connector/__init__.py b/google/cloud/alloydb/connector/__init__.py index 7ee1498..16af0b6 100644 --- a/google/cloud/alloydb/connector/__init__.py +++ b/google/cloud/alloydb/connector/__init__.py @@ -13,6 +13,7 @@ # limitations under the License. from google.cloud.alloydb.connector.async_connector import AsyncConnector from google.cloud.alloydb.connector.connector import Connector +from google.cloud.alloydb.connector.instance import IPTypes from google.cloud.alloydb.connector.version import __version__ -__all__ = ["__version__", "Connector", "AsyncConnector"] +__all__ = ["__version__", "Connector", "AsyncConnector", "IPTypes"] diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 3276fd7..2fb5f50 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -25,6 +25,7 @@ import google.cloud.alloydb.connector.asyncpg as asyncpg from google.cloud.alloydb.connector.client import AlloyDBClient from google.cloud.alloydb.connector.instance import Instance +from google.cloud.alloydb.connector.instance import IPTypes from google.cloud.alloydb.connector.utils import generate_keys if TYPE_CHECKING: @@ -46,6 +47,8 @@ class AsyncConnector: alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. + ip_type (IPTypes): Default IP type for all AlloyDB connections. + Defaults to IPTypes.PRIVATE for private IP connections. """ def __init__( @@ -54,6 +57,7 @@ def __init__( quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", enable_iam_auth: bool = False, + ip_type: IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, ) -> None: self._instances: Dict[str, Instance] = {} @@ -61,6 +65,7 @@ def __init__( self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint self._enable_iam_auth = enable_iam_auth + self._ip_type = ip_type self._user_agent = user_agent # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] @@ -139,7 +144,8 @@ async def connect( kwargs.pop("port", None) # get connection info for AlloyDB instance - ip_address, context = await instance.connection_info() + ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type) + ip_address, context = await instance.connection_info(ip_type) # callable to be used for auto IAM authn def get_authentication_token() -> str: diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 8b646cd..0da9ebb 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -15,7 +15,7 @@ from __future__ import annotations import logging -from typing import List, Optional, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING import aiohttp from google.auth.transport.requests import Request @@ -96,7 +96,7 @@ async def _get_metadata( region: str, cluster: str, name: str, - ) -> str: + ) -> Dict[str, Optional[str]]: """ Fetch the metadata for a given AlloyDB instance. @@ -112,7 +112,7 @@ async def _get_metadata( name (str): The name of the AlloyDB instance. Returns: - str: IP address of the AlloyDB instance. + dict: IP addresses of the AlloyDB instance. """ logger.debug(f"['{project}/{region}/{cluster}/{name}']: Requesting metadata") @@ -129,7 +129,10 @@ async def _get_metadata( resp = await self._client.get(url, headers=headers, raise_for_status=True) resp_dict = await resp.json() - return resp_dict["ipAddress"] + return { + "PRIVATE": resp_dict.get("ipAddress"), + "PUBLIC": resp_dict.get("publicIpAddress"), + } async def _get_client_certificate( self, diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 1abfc6d..696e5c5 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -27,6 +27,7 @@ from google.cloud.alloydb.connector.client import AlloyDBClient from google.cloud.alloydb.connector.instance import Instance +from google.cloud.alloydb.connector.instance import IPTypes import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.utils import generate_keys import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb @@ -56,6 +57,8 @@ class Connector: alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". enable_iam_auth (bool): Enables automatic IAM database authentication. + ip_type (IPTypes): Default IP type for all AlloyDB connections. + Defaults to IPTypes.PRIVATE for private IP connections. """ def __init__( @@ -64,6 +67,7 @@ def __init__( quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", enable_iam_auth: bool = False, + ip_type: IPTypes = IPTypes.PRIVATE, user_agent: Optional[str] = None, ) -> None: # create event loop and start it in background thread @@ -75,6 +79,7 @@ def __init__( self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint self._enable_iam_auth = enable_iam_auth + self._ip_type = ip_type self._user_agent = user_agent # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] @@ -166,7 +171,8 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> kwargs.pop("port", None) # get connection info for AlloyDB instance - ip_address, context = await instance.connection_info() + ip_type: IPTypes = kwargs.pop("ip_type", self._ip_type) + ip_address, context = await instance.connection_info(ip_type) # synchronous drivers are blocking and run using executor try: diff --git a/google/cloud/alloydb/connector/exceptions.py b/google/cloud/alloydb/connector/exceptions.py index 660a2cf..777f17a 100644 --- a/google/cloud/alloydb/connector/exceptions.py +++ b/google/cloud/alloydb/connector/exceptions.py @@ -15,3 +15,7 @@ class RefreshError(Exception): pass + + +class IPTypeNotFoundError(Exception): + pass diff --git a/google/cloud/alloydb/connector/instance.py b/google/cloud/alloydb/connector/instance.py index 464fce4..0bd7770 100644 --- a/google/cloud/alloydb/connector/instance.py +++ b/google/cloud/alloydb/connector/instance.py @@ -15,10 +15,12 @@ from __future__ import annotations import asyncio +from enum import Enum import logging import re from typing import Tuple, TYPE_CHECKING +from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.exceptions import RefreshError from google.cloud.alloydb.connector.rate_limiter import AsyncRateLimiter from google.cloud.alloydb.connector.refresh import _is_valid @@ -39,6 +41,15 @@ ) +class IPTypes(Enum): + """ + Enum for specifying IP type to connect to AlloyDB with. + """ + + PUBLIC: str = "PUBLIC" + PRIVATE: str = "PRIVATE" + + def _parse_instance_uri(instance_uri: str) -> Tuple[str, str, str, str]: # should take form "projects//locations//clusters//instances/" if INSTANCE_URI_REGEX.fullmatch(instance_uri) is None: @@ -214,16 +225,24 @@ async def force_refresh(self) -> None: if not await _is_valid(self._current): self._current = self._next - async def connection_info(self) -> Tuple[str, ssl.SSLContext]: + async def connection_info(self, ip_type: IPTypes) -> Tuple[str, ssl.SSLContext]: """ Return connection info for current refresh result. + Args: + ip_type (IpTypes): Type of AlloyDB instance IP to connect over. Returns: Tuple[str, ssl.SSLContext]: AlloyDB instance IP address and configured TLS connection. """ refresh: RefreshResult = await self._current - return refresh.instance_ip, refresh.context + ip_address = refresh.ip_addrs.get(ip_type.value) + if ip_address is None: + raise IPTypeNotFoundError( + "AlloyDB instance does not have an IP addresses matching " + f"type: '{ip_type.value}'" + ) + return ip_address, refresh.context async def close(self) -> None: """ diff --git a/google/cloud/alloydb/connector/refresh.py b/google/cloud/alloydb/connector/refresh.py index 77734b3..fe0ef48 100644 --- a/google/cloud/alloydb/connector/refresh.py +++ b/google/cloud/alloydb/connector/refresh.py @@ -20,7 +20,7 @@ import logging import ssl from tempfile import TemporaryDirectory -from typing import List, Tuple, TYPE_CHECKING +from typing import Dict, List, Optional, Tuple, TYPE_CHECKING from cryptography import x509 @@ -71,16 +71,19 @@ class RefreshResult: Builds the TLS context required to connect to AlloyDB database. Args: - instance_ip (str): The IP address of the AlloyDB instance. + ip_addrs (Dict[str, str]): The IP addresses of the AlloyDB instance. key (rsa.RSAPrivateKey): Private key for the client connection. certs (Tuple[str, List(str)]): Client cert and CA certs for establishing the chain of trust used in building the TLS context. """ def __init__( - self, instance_ip: str, key: rsa.RSAPrivateKey, certs: Tuple[str, List[str]] + self, + ip_addrs: Dict[str, Optional[str]], + key: rsa.RSAPrivateKey, + certs: Tuple[str, List[str]], ) -> None: - self.instance_ip = instance_ip + self.ip_addrs = ip_addrs # create TLS context self.context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT) # update ssl.PROTOCOL_TLS_CLIENT default diff --git a/tests/system/test_asyncpg_iam_authn.py b/tests/system/test_asyncpg_iam_authn.py index 9ec8c79..3733348 100644 --- a/tests/system/test_asyncpg_iam_authn.py +++ b/tests/system/test_asyncpg_iam_authn.py @@ -16,6 +16,7 @@ import os from typing import Tuple +# [START alloydb_sqlalchemy_connect_async_connector_iam_authn] import asyncpg import sqlalchemy import sqlalchemy.ext.asyncio @@ -78,6 +79,9 @@ async def getconn() -> asyncpg.Connection: return engine, connector +# [END alloydb_sqlalchemy_connect_async_connector_iam_authn] + + async def test_asyncpg_iam_authn_time() -> None: """Basic test to get time from database.""" inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] diff --git a/tests/system/test_asyncpg_public_ip.py b/tests/system/test_asyncpg_public_ip.py new file mode 100644 index 0000000..f4fe93c --- /dev/null +++ b/tests/system/test_asyncpg_public_ip.py @@ -0,0 +1,103 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Tuple + +# [START alloydb_sqlalchemy_connect_async_connector_public_ip] +import asyncpg +import pytest +import sqlalchemy +import sqlalchemy.ext.asyncio + +from google.cloud.alloydb.connector import AsyncConnector +from google.cloud.alloydb.connector import IPTypes + + +async def create_sqlalchemy_engine( + inst_uri: str, + user: str, + password: str, + db: str, +) -> Tuple[sqlalchemy.ext.asyncio.engine.AsyncEngine, AsyncConnector]: + """Creates a connection pool for an AlloyDB instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = await create_sqlalchemy_engine( + inst_uri, + user, + password, + db, + ) + async with engine.connect() as conn: + time = await conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + curr_time = time[0] + # do something with query result + await connector.close() + + Args: + instance_uri (str): + The instance URI specifies the instance relative to the project, + region, and cluster. For example: + "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" + user (str): + The database user name, e.g., postgres + password (str): + The database user's password, e.g., secret-password + db_name (str): + The name of the database, e.g., mydb + """ + connector = AsyncConnector() + + async def getconn() -> asyncpg.Connection: + conn: asyncpg.Connection = await connector.connect( + inst_uri, + "asyncpg", + user=user, + password=password, + db=db, + ip_type=IPTypes.PUBLIC, + ) + return conn + + # create SQLAlchemy connection pool + engine = sqlalchemy.ext.asyncio.create_async_engine( + "postgresql+asyncpg://", + async_creator=getconn, + execution_options={"isolation_level": "AUTOCOMMIT"}, + ) + return engine, connector + + +# [END alloydb_sqlalchemy_connect_async_connector_public_ip] + + +@pytest.mark.asyncio +async def test_connection_with_asyncpg() -> None: + """Basic test to get time from database.""" + inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] + user = os.environ["ALLOYDB_USER"] + password = os.environ["ALLOYDB_PASS"] + db = os.environ["ALLOYDB_DB"] + + pool, connector = await create_sqlalchemy_engine(inst_uri, user, password, db) + + async with pool.connect() as conn: + res = (await conn.execute(sqlalchemy.text("SELECT 1"))).fetchone() + assert res[0] == 1 + + await connector.close() diff --git a/tests/system/test_pg8000_connection.py b/tests/system/test_pg8000_connection.py index f6cd90c..0e6a6cd 100644 --- a/tests/system/test_pg8000_connection.py +++ b/tests/system/test_pg8000_connection.py @@ -14,6 +14,7 @@ from datetime import datetime import os +from typing import Tuple # [START alloydb_sqlalchemy_connect_connector] import pg8000 @@ -27,7 +28,7 @@ def create_sqlalchemy_engine( user: str, password: str, db: str, -) -> (sqlalchemy.engine.Engine, Connector): +) -> Tuple[sqlalchemy.engine.Engine, Connector]: """Creates a connection pool for an AlloyDB instance and returns the pool and the connector. Callers are responsible for closing the pool and the connector. diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py index 0ed356e..ab600cd 100644 --- a/tests/system/test_pg8000_iam_authn.py +++ b/tests/system/test_pg8000_iam_authn.py @@ -16,6 +16,7 @@ import os from typing import Tuple +# [START alloydb_sqlalchemy_connect_connector_iam_authn] import pg8000 import sqlalchemy @@ -73,10 +74,12 @@ def getconn() -> pg8000.dbapi.Connection: "postgresql+pg8000://", creator=getconn, ) - engine.dialect.description_encoding = None return engine, connector +# [END alloydb_sqlalchemy_connect_connector_iam_authn] + + def test_pg8000_iam_authn_time() -> None: """Basic test to get time from database.""" inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] diff --git a/tests/system/test_pg8000_public_ip.py b/tests/system/test_pg8000_public_ip.py new file mode 100644 index 0000000..a9782e1 --- /dev/null +++ b/tests/system/test_pg8000_public_ip.py @@ -0,0 +1,101 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import os +from typing import Tuple + +# [START alloydb_sqlalchemy_connect_connector_public_ip] +import pg8000 +import sqlalchemy + +from google.cloud.alloydb.connector import Connector +from google.cloud.alloydb.connector import IPTypes + + +def create_sqlalchemy_engine( + inst_uri: str, + user: str, + password: str, + db: str, +) -> Tuple[sqlalchemy.engine.Engine, Connector]: + """Creates a connection pool for an AlloyDB instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = create_sqlalchemy_engine( + inst_uri, + user, + password, + db, + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + connector.close() + + Args: + instance_uri (str): + The instance URI specifies the instance relative to the project, + region, and cluster. For example: + "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" + user (str): + The database user name, e.g., postgres + password (str): + The database user's password, e.g., secret-password + db_name (str): + The name of the database, e.g., mydb + """ + connector = Connector() + + def getconn() -> pg8000.dbapi.Connection: + conn: pg8000.dbapi.Connection = connector.connect( + inst_uri, + "pg8000", + user=user, + password=password, + db=db, + ip_type=IPTypes.PUBLIC, + ) + return conn + + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, + ) + return engine, connector + + +# [END alloydb_sqlalchemy_connect_connector_public_ip] + + +def test_pg8000_time() -> None: + """Basic test to get time from database.""" + inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] + user = os.environ["ALLOYDB_USER"] + password = os.environ["ALLOYDB_PASS"] + db = os.environ["ALLOYDB_DB"] + + engine, connector = create_sqlalchemy_engine(inst_uri, user, password, db) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 8c6f097..d001a59 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -73,7 +73,7 @@ def start_proxy_server(instance: FakeInstance) -> None: conn.close() -@pytest.fixture(autouse=True, scope="session") +@pytest.fixture(scope="session") def proxy_server(fake_instance: FakeInstance) -> Generator: """Run local proxy server capable of performing metadata exchange""" thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 1acd851..51a0bd2 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -18,7 +18,7 @@ from datetime import timezone import ssl import struct -from typing import Any, Callable, List, Optional, Tuple +from typing import Any, Callable, Dict, List, Optional, Tuple from cryptography import x509 from cryptography.hazmat.primitives import hashes @@ -109,7 +109,10 @@ def __init__( region: str = "test-region", cluster: str = "test-cluster", name: str = "test-instance", - ip_address: str = "127.0.0.1", + ip_addrs: Dict = { + "PRIVATE": "127.0.0.1", + "PUBLIC": "0.0.0.0", + }, server_name: str = "00000000-0000-0000-0000-000000000000.server.alloydb", cert_before: datetime = datetime.now(timezone.utc), cert_expiry: datetime = datetime.now(timezone.utc) + timedelta(hours=1), @@ -118,7 +121,7 @@ def __init__( self.region = region self.cluster = cluster self.name = name - self.ip_address = ip_address + self.ip_addrs = ip_addrs self.server_name = server_name self.cert_before = cert_before self.cert_expiry = cert_expiry @@ -163,7 +166,7 @@ def __init__( self._user_agent = f"test-user-agent+{driver}" async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: - return self.instance.ip_address + return self.instance.ip_addrs async def _get_client_certificate( self, @@ -267,7 +270,7 @@ def __init__(self) -> None: self._close_called = False self._force_refresh_called = False - def connection_info(self) -> Tuple[str, Any]: + def connection_info(self, ip_type: Any) -> Tuple[str, Any]: f = asyncio.Future() f.set_result(("10.0.0.1", None)) return f diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 172f804..a47da1b 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -26,7 +26,16 @@ async def connectionInfo(request: Any) -> web.Response: response = { - "ipAddress": "127.0.0.1", + "ipAddress": "10.0.0.1", + "instanceUid": "123456789", + } + return web.Response(content_type="application/json", body=json.dumps(response)) + + +async def connectionInfoPublicIP(request: Any) -> web.Response: + response = { + "ipAddress": "10.0.0.1", + "publicIpAddress": "127.0.0.1", "instanceUid": "123456789", } return web.Response(content_type="application/json", body=json.dumps(response)) @@ -49,6 +58,8 @@ async def client(aiohttp_client: Any) -> Any: app = web.Application() metadata_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance/connectionInfo" app.router.add_get(metadata_uri, connectionInfo) + metadata_public_ip_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster/instances/public-instance/connectionInfo" + app.router.add_get(metadata_public_ip_uri, connectionInfoPublicIP) client_cert_uri = "/v1beta/projects/test-project/locations/test-region/clusters/test-cluster:generateClientCertificate" app.router.add_post(client_cert_uri, generateClientCertificate) return await aiohttp_client(app) @@ -60,13 +71,36 @@ async def test__get_metadata(client: Any, credentials: FakeCredentials) -> None: Test _get_metadata returns successfully. """ test_client = AlloyDBClient("", "", credentials, client) - ip_address = await test_client._get_metadata( + ip_addrs = await test_client._get_metadata( "test-project", "test-region", "test-cluster", "test-instance", ) - assert ip_address == "127.0.0.1" + assert ip_addrs == { + "PRIVATE": "10.0.0.1", + "PUBLIC": None, + } + + +@pytest.mark.asyncio +async def test__get_metadata_with_public_ip( + client: Any, credentials: FakeCredentials +) -> None: + """ + Test _get_metadata returns successfully with Public IP. + """ + test_client = AlloyDBClient("", "", credentials, client) + ip_addrs = await test_client._get_metadata( + "test-project", + "test-region", + "test-cluster", + "public-instance", + ) + assert ip_addrs == { + "PRIVATE": "10.0.0.1", + "PUBLIC": "127.0.0.1", + } @pytest.mark.asyncio diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 2705050..bf8ebe4 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -62,6 +62,7 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False +@pytest.mark.usefixtures("proxy_server") def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: """ Test that connector.connect returns connection object. diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 5ea3b9d..686675e 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -21,9 +21,11 @@ from mocks import FakeAlloyDBClient import pytest +from google.cloud.alloydb.connector.exceptions import IPTypeNotFoundError from google.cloud.alloydb.connector.exceptions import RefreshError from google.cloud.alloydb.connector.instance import _parse_instance_uri from google.cloud.alloydb.connector.instance import Instance +from google.cloud.alloydb.connector.instance import IPTypes from google.cloud.alloydb.connector.refresh import _is_valid from google.cloud.alloydb.connector.refresh import RefreshResult from google.cloud.alloydb.connector.utils import generate_keys @@ -132,12 +134,63 @@ async def test_perform_refresh() -> None: keys, ) refresh = await instance._perform_refresh() - assert refresh.instance_ip == "127.0.0.1" + assert refresh.ip_addrs == { + "PRIVATE": "127.0.0.1", + "PUBLIC": "0.0.0.0", + } assert refresh.expiration == client.instance.cert_expiry.replace(microsecond=0) # close instance await instance.close() +@pytest.mark.parametrize( + "ip_type, expected", + [ + ( + IPTypes.PRIVATE, + "127.0.0.1", + ), + ( + IPTypes.PUBLIC, + "0.0.0.0", + ), + ], +) +@pytest.mark.asyncio +async def test_connection_info(ip_type: IPTypes, expected: str) -> None: + """Test that connection_info returns proper ip address.""" + keys = asyncio.create_task(generate_keys()) + client = FakeAlloyDBClient() + instance = Instance( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + client, + keys, + ) + ip_address, _ = await instance.connection_info(ip_type=ip_type) + assert ip_address == expected + # close instance + await instance.close() + + +@pytest.mark.asyncio +async def test_connection_info_IPTypeNotFoundError() -> None: + """Test that connection_info throws IPTypeNotFoundError""" + keys = asyncio.create_task(generate_keys()) + client = FakeAlloyDBClient() + # set ip_addrs to have no public IP + client.instance.ip_addrs = {"PRIVATE": "10.0.0.1"} + instance = Instance( + "projects/test-project/locations/test-region/clusters/test-cluster/instances/test-instance", + client, + keys, + ) + # check RefreshError is thrown + with pytest.raises(IPTypeNotFoundError): + await instance.connection_info(ip_type=IPTypes.PUBLIC) + # close instance + await instance.close() + + @pytest.mark.asyncio async def test_schedule_refresh_replaces_result() -> None: """ diff --git a/tests/unit/test_refresh.py b/tests/unit/test_refresh.py index 9c4ceee..2015218 100644 --- a/tests/unit/test_refresh.py +++ b/tests/unit/test_refresh.py @@ -79,7 +79,7 @@ def test_RefreshResult_init_(fake_instance: FakeInstance) -> None: "UTF-8" ) certs = (ca_cert, [client_cert, intermediate_cert, root_cert]) - refresh = RefreshResult(fake_instance.ip_address, key, certs) + refresh = RefreshResult(fake_instance.ip_addrs, key, certs) # verify TLS requirements assert refresh.context.minimum_version == ssl.TLSVersion.TLSv1_3 assert refresh.context.request_ssl is False