diff --git a/google/api/__init__.py b/google/api/__init__.py new file mode 100644 index 0000000..6d5e14b --- /dev/null +++ b/google/api/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/api/field_behavior_pb2.py b/google/api/field_behavior_pb2.py new file mode 100755 index 0000000..93cdf45 --- /dev/null +++ b/google/api/field_behavior_pb2.py @@ -0,0 +1,52 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# type: ignore +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/api/field_behavior.proto +# isort: skip_file +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3" +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.api.field_behavior_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension( + field_behavior + ) + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI" + _globals["_FIELDBEHAVIOR"]._serialized_start = 82 + _globals["_FIELDBEHAVIOR"]._serialized_end = 264 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 1c36801..38a292b 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -96,6 +96,7 @@ async def connect( self._alloydb_api_endpoint, self._quota_project, self._credentials, + driver=driver, ) # use existing connection info if possible diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index b4cf664..f6a8d3f 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -38,6 +38,7 @@ def __init__( quota_project: Optional[str], credentials: Credentials, client: Optional[aiohttp.ClientSession] = None, + driver: Optional[str] = None, ) -> None: """ Establish the client to be used for AlloyDB Admin API requests. @@ -55,10 +56,12 @@ def __init__( client (aiohttp.ClientSession): Async client used to make requests to AlloyDB Admin APIs. Optional, defaults to None and creates new client. + driver (str): Database driver to be used by the client. """ + user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT headers = { - "x-goog-api-client": USER_AGENT, - "User-Agent": USER_AGENT, + "x-goog-api-client": user_agent, + "User-Agent": user_agent, "Content-Type": "application/json", } if quota_project: @@ -67,6 +70,10 @@ def __init__( self._client = client if client else aiohttp.ClientSession(headers=headers) self._credentials = credentials self._alloydb_api_endpoint = alloydb_api_endpoint + # asyncpg does not currently support using metadata exchange + # only use metadata exchange for pg8000 driver + self._use_metadata = True if driver == "pg8000" else False + self._user_agent = user_agent async def _get_metadata( self, @@ -149,6 +156,7 @@ async def _get_client_certificate( data = { "publicKey": pub_key, "certDuration": "3600s", + "useMetadataExchange": self._use_metadata, } resp = await self._client.post( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index d88750a..1a6d6da 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -16,6 +16,8 @@ import asyncio from functools import partial +import socket +import struct from threading import Thread from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING @@ -27,10 +29,18 @@ from google.cloud.alloydb.connector.instance import Instance import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.utils import generate_keys +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb if TYPE_CHECKING: + import ssl + from google.auth.credentials import Credentials +# the port the AlloyDB server-side proxy receives connections on +SERVER_PROXY_PORT = 5433 +# the maximum amount of time to wait before aborting a metadata exchange +IO_TIMEOUT = 30 + class Connector: """A class to configure and create connections to Cloud SQL instances. @@ -45,6 +55,7 @@ class Connector: Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + enable_iam_auth (bool): Enables automatic IAM database authentication. """ def __init__( @@ -52,6 +63,7 @@ def __init__( credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + enable_iam_auth: bool = False, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -61,6 +73,7 @@ def __init__( # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint + self._enable_iam_auth = enable_iam_auth # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] if credentials: @@ -121,8 +134,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> if self._client is None: # lazy init client as it has to be initialized in async context self._client = AlloyDBClient( - self._alloydb_api_endpoint, self._quota_project, self._credentials + self._alloydb_api_endpoint, + self._quota_project, + self._credentials, + driver=driver, ) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible if instance_uri in self._instances: instance = self._instances[instance_uri] @@ -150,13 +167,120 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # synchronous drivers are blocking and run using executor try: - connect_partial = partial(connector, ip_address, context, **kwargs) + metadata_partial = partial( + self.metadata_exchange, ip_address, context, enable_iam_auth, driver + ) + sock = await self._loop.run_in_executor(None, metadata_partial) + connect_partial = partial(connector, sock, **kwargs) return await self._loop.run_in_executor(None, connect_partial) except Exception: # we attempt a force refresh, then throw the error await instance.force_refresh() raise + def metadata_exchange( + self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str + ) -> ssl.SSLSocket: + """ + Sends metadata about the connection prior to the database + protocol taking over. + + The exchange consists of four steps: + + 1. Prepare a MetadataExchangeRequest including the IAM Principal's + OAuth2 token, the user agent, and the requested authentication type. + + 2. Write the size of the message as a big endian uint32 (4 bytes) to + the server followed by the serialized message. The length does not + include the initial four bytes. + + 3. Read a big endian uint32 (4 bytes) from the server. This is the + MetadataExchangeResponse message length and does not include the + initial four bytes. + + 4. Parse the response using the message length in step 3. If the + response is not OK, return the response's error. If there is no error, + the metadata exchange has succeeded and the connection is complete. + + Args: + ip_address (str): IP address of AlloyDB instance to connect to. + ctx (ssl.SSLContext): Context used to create a TLS connection + with AlloyDB instance ssl certificates. + enable_iam_auth (bool): Flag to enable IAM database authentication. + driver (str): A string representing the database driver to connect with. + Supported drivers are pg8000. + + Returns: + sock (ssl.SSLSocket): mTLS/SSL socket connected to AlloyDB Proxy server. + """ + # Create socket and wrap with SSL/TLS context + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + # set auth type for metadata exchange + auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE + if enable_iam_auth: + auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM + + # form metadata exchange request + req = connectorspb.MetadataExchangeRequest( + user_agent=f"{self._client._user_agent}", # type: ignore + auth_type=auth_type, + oauth2_token=self._credentials.token, + ) + + # set I/O timeout + sock.settimeout(IO_TIMEOUT) + + # pack big-endian unsigned integer (4 bytes) + packed_len = struct.pack(">I", req.ByteSize()) + + # send metadata message length and request message + sock.sendall(packed_len + req.SerializeToString()) + + # form metadata exchange response + resp = connectorspb.MetadataExchangeResponse() + + # read metadata message length (4 bytes) + message_len_buffer_size = struct.Struct(">I").size + message_len_buffer = b"" + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) + message_len_buffer += chunk + message_len_buffer_size -= len(chunk) + + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read metadata exchange message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError( + "Connection closed while performing metadata exchange!" + ) + buffer += chunk + message_len -= len(chunk) + + # parse metadata exchange response from buffer + resp.ParseFromString(buffer) + + # reset socket back to blocking mode + sock.setblocking(True) + + # validate metadata exchange response + if resp.response_code != connectorspb.MetadataExchangeResponse.OK: + raise ValueError( + f"Metadata Exchange request has failed with error: {resp.error}" + ) + + return sock + def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" return self diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index c624363..f17766c 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -11,49 +11,35 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import socket -import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 5433 - if TYPE_CHECKING: + import ssl + import pg8000 -def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any -) -> "pg8000.dbapi.Connection": +def connect(sock: "ssl.SSLSocket", **kwargs: Any) -> "pg8000.dbapi.Connection": """Create a pg8000 DBAPI connection object. Args: - ip_address (str): IP address of AlloyDB instance to connect to. - ctx (ssl.SSLContext): Context used to create a TLS connection - with AlloyDB instance ssl certificates. + sock (ssl.SSLSocket): SSL/TLS secure socket stream connected to the + AlloyDB proxy server. Returns: pg8000.dbapi.Connection: A pg8000 Connection object for the AlloyDB instance. """ - # Connecting through pg8000 is done by passing in an SSL Context and setting the - # "request_ssl" attr to false. This works because when "request_ssl" is false, - # the driver skips the database level SSL/TLS exchange, but still uses the - # ssl_context (if it is not None) to create the connection. try: import pg8000 except ImportError: raise ImportError( 'Unable to import module "pg8000." Please install and try again.' ) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) user = kwargs.pop("user") db = kwargs.pop("db") - passwd = kwargs.pop("password") + passwd = kwargs.pop("password", None) return pg8000.dbapi.connect( user, database=db, diff --git a/google/cloud/alloydb_connectors_v1/proto/__init__.py b/google/cloud/alloydb_connectors_v1/proto/__init__.py new file mode 100644 index 0000000..6d5e14b --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py new file mode 100755 index 0000000..e1a111b --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py @@ -0,0 +1,61 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/alloydb_connectors_v1/proto/resources.proto +# isort: skip_file +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder + +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n8google/cloud/alloydb_connectors_v1/proto/resources.proto\x12"google.cloud.alloydb.connectors.v1\x1a\x1fgoogle/api/field_behavior.proto"\xe6\x01\n\x17MetadataExchangeRequest\x12\x18\n\nuser_agent\x18\x01 \x01(\tB\x04\xe2\x41\x01\x01\x12W\n\tauth_type\x18\x02 \x01(\x0e\x32\x44.google.cloud.alloydb.connectors.v1.MetadataExchangeRequest.AuthType\x12\x14\n\x0coauth2_token\x18\x03 \x01(\t"B\n\x08\x41uthType\x12\x19\n\x15\x41UTH_TYPE_UNSPECIFIED\x10\x00\x12\r\n\tDB_NATIVE\x10\x01\x12\x0c\n\x08\x41UTO_IAM\x10\x02"\xd3\x01\n\x18MetadataExchangeResponse\x12`\n\rresponse_code\x18\x01 \x01(\x0e\x32I.google.cloud.alloydb.connectors.v1.MetadataExchangeResponse.ResponseCode\x12\x13\n\x05\x65rror\x18\x02 \x01(\tB\x04\xe2\x41\x01\x01"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x42\xf5\x01\n&com.google.cloud.alloydb.connectors.v1B\x0eResourcesProtoP\x01ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\xaa\x02"Google.Cloud.AlloyDb.Connectors.V1\xca\x02"Google\\Cloud\\AlloyDb\\Connectors\\V1\xea\x02&Google::Cloud::AlloyDb::Connectors::V1b\x06proto3' +) + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.cloud.alloydb_connectors_v1.proto.resources_pb2", _globals +) +if _descriptor._USE_C_DESCRIPTORS == False: + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n&com.google.cloud.alloydb.connectors.v1B\016ResourcesProtoP\001ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\252\002"Google.Cloud.AlloyDb.Connectors.V1\312\002"Google\\Cloud\\AlloyDb\\Connectors\\V1\352\002&Google::Cloud::AlloyDb::Connectors::V1' + _METADATAEXCHANGEREQUEST.fields_by_name["user_agent"]._options = None + _METADATAEXCHANGEREQUEST.fields_by_name[ + "user_agent" + ]._serialized_options = b"\342A\001\001" + _METADATAEXCHANGERESPONSE.fields_by_name["error"]._options = None + _METADATAEXCHANGERESPONSE.fields_by_name[ + "error" + ]._serialized_options = b"\342A\001\001" + _globals["_METADATAEXCHANGEREQUEST"]._serialized_start = 130 + _globals["_METADATAEXCHANGEREQUEST"]._serialized_end = 360 + _globals["_METADATAEXCHANGEREQUEST_AUTHTYPE"]._serialized_start = 294 + _globals["_METADATAEXCHANGEREQUEST_AUTHTYPE"]._serialized_end = 360 + _globals["_METADATAEXCHANGERESPONSE"]._serialized_start = 363 + _globals["_METADATAEXCHANGERESPONSE"]._serialized_end = 574 + _globals["_METADATAEXCHANGERESPONSE_RESPONSECODE"]._serialized_start = 510 + _globals["_METADATAEXCHANGERESPONSE_RESPONSECODE"]._serialized_end = 574 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi new file mode 100644 index 0000000..3f9db56 --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi @@ -0,0 +1,66 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import ClassVar as _ClassVar +from typing import Optional as _Optional +from typing import Union as _Union + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper + +from google.api import field_behavior_pb2 as _field_behavior_pb2 + +DESCRIPTOR: _descriptor.FileDescriptor + +class MetadataExchangeRequest(_message.Message): + __slots__ = ["auth_type", "oauth2_token", "user_agent"] + + class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] # type: ignore + AUTH_TYPE_FIELD_NUMBER: _ClassVar[int] + AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType + AUTO_IAM: MetadataExchangeRequest.AuthType + DB_NATIVE: MetadataExchangeRequest.AuthType + OAUTH2_TOKEN_FIELD_NUMBER: _ClassVar[int] + USER_AGENT_FIELD_NUMBER: _ClassVar[int] + auth_type: MetadataExchangeRequest.AuthType + oauth2_token: str + user_agent: str + def __init__( + self, + user_agent: _Optional[str] = ..., + auth_type: _Optional[_Union[MetadataExchangeRequest.AuthType, str]] = ..., + oauth2_token: _Optional[str] = ..., + ) -> None: ... + +class MetadataExchangeResponse(_message.Message): + __slots__ = ["error", "response_code"] + + class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] # type: ignore + ERROR: MetadataExchangeResponse.ResponseCode + ERROR_FIELD_NUMBER: _ClassVar[int] + OK: MetadataExchangeResponse.ResponseCode + RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int] + RESPONSE_CODE_UNSPECIFIED: MetadataExchangeResponse.ResponseCode + error: str + response_code: MetadataExchangeResponse.ResponseCode + def __init__( + self, + response_code: _Optional[ + _Union[MetadataExchangeResponse.ResponseCode, str] + ] = ..., + error: _Optional[str] = ..., + ) -> None: ... diff --git a/requirements.txt b/requirements.txt index 60e2307..665fe0f 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ aiohttp==3.9.1 cryptography==41.0.7 google-auth==2.26.2 requests==2.31.0 +protobuf==4.25.1 diff --git a/setup.py b/setup.py index 6e80a30..d6dcc82 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "cryptography>=38.0.3", "requests", "google-auth", + "protobuf", ] package_root = os.path.abspath(os.path.dirname(__file__)) diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py new file mode 100644 index 0000000..62400fe --- /dev/null +++ b/tests/system/test_pg8000_iam_authn.py @@ -0,0 +1,94 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import os + +# [START alloydb_sqlalchemy_connect_connector] +import pg8000 +import sqlalchemy + +from google.cloud.alloydb.connector import Connector + + +def create_sqlalchemy_engine( + inst_uri: str, + user: str, + db: str, +) -> (sqlalchemy.engine.Engine, Connector): + """Creates a connection pool for an AlloyDB instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = create_sqlalchemy_engine( + inst_uri, + user, + db, + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + connector.close() + + Args: + instance_uri (str): + The instance URI specifies the instance relative to the project, + region, and cluster. For example: + "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" + user (str): + The database user name, e.g., postgres + db_name (str): + The name of the database, e.g., mydb + """ + connector = Connector() + + def getconn() -> pg8000.dbapi.Connection: + conn: pg8000.dbapi.Connection = connector.connect( + inst_uri, + "pg8000", + user=user, + db=db, + enable_iam_auth=True, + ) + return conn + + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, + ) + engine.dialect.description_encoding = None + return engine, connector + + +# [END alloydb_sqlalchemy_connect_connector] + + +def test_pg8000_iam_authn_time() -> None: + """Basic test to get time from database.""" + inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] + user = os.environ["ALLOYDB_IAM_USER"] + db = os.environ["ALLOYDB_DB"] + + engine, connector = create_sqlalchemy_engine(inst_uri, user, db) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d0c3c50..8c6f097 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -12,16 +12,71 @@ # See the License for the specific language governing permissions and # limitations under the License. +import socket +import ssl +from tempfile import TemporaryDirectory +from threading import Thread +from typing import Generator + +from mocks import FakeAlloyDBClient from mocks import FakeCredentials from mocks import FakeInstance +from mocks import metadata_exchange import pytest +from google.cloud.alloydb.connector.utils import _write_to_file + @pytest.fixture def credentials() -> FakeCredentials: return FakeCredentials() -@pytest.fixture +@pytest.fixture(scope="session") def fake_instance() -> FakeInstance: return FakeInstance() + + +@pytest.fixture +def fake_client(fake_instance: FakeInstance) -> FakeAlloyDBClient: + return FakeAlloyDBClient(fake_instance) + + +def start_proxy_server(instance: FakeInstance) -> None: + """Run local proxy server capable of performing metadata exchange""" + ip_address = "127.0.0.1" + port = 5433 + # create socket + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 + root, _, server = instance.get_pem_certs() + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + with TemporaryDirectory() as tmpdir: + _, cert_chain_filename, key_filename = _write_to_file( + tmpdir, server, [server, root], instance.server_key + ) + context.load_cert_chain(cert_chain_filename, key_filename) + # bind socket to AlloyDB proxy server port on localhost + sock.bind((ip_address, port)) + # listen for incoming connections + sock.listen(5) + + while True: + with context.wrap_socket(sock, server_side=True) as ssock: + conn, _ = ssock.accept() + metadata_exchange(conn) + conn.sendall(instance.name.encode("utf-8")) + conn.close() + + +@pytest.fixture(autouse=True, scope="session") +def proxy_server(fake_instance: FakeInstance) -> Generator: + """Run local proxy server capable of performing metadata exchange""" + thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True) + thread.start() + yield thread + thread.join() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index b7ff722..1acd851 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -16,6 +16,8 @@ from datetime import datetime from datetime import timedelta from datetime import timezone +import ssl +import struct from typing import Any, Callable, List, Optional, Tuple from cryptography import x509 @@ -24,6 +26,8 @@ from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb + class FakeCredentials: def __init__(self) -> None: @@ -119,10 +123,6 @@ def __init__( self.cert_before = cert_before self.cert_expiry = cert_expiry - def generate_certs(self) -> None: - """ - Build certs required for chain of trust with testing server. - """ # build root cert self.root_cert, self.root_key = generate_cert("root.alloydb") # create self signed root cert @@ -155,12 +155,15 @@ def get_pem_certs(self) -> Tuple[str, str, str]: class FakeAlloyDBClient: """Fake class for testing AlloyDBClient""" - def __init__(self) -> None: - self.instance = FakeInstance() + def __init__( + self, instance: Optional[FakeInstance] = None, driver: str = "pg8000" + ) -> None: + self.instance = FakeInstance() if instance is None else instance self.closed = False + self._user_agent = f"test-user-agent+{driver}" - async def _get_metadata(*args: Any, **kwargs: Any) -> str: - return "127.0.0.1" + async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: + return self.instance.ip_address async def _get_client_certificate( self, @@ -169,8 +172,7 @@ async def _get_client_certificate( cluster: str, pub_key: str, ) -> Tuple[str, List[str]]: - self.instance.generate_certs() - root_cert, intermediate_cert, ca_cert = self.instance.get_pem_certs() + root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs() # encode public key to bytes pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( pub_key.encode("UTF-8"), @@ -190,12 +192,74 @@ async def _get_client_certificate( client_cert = client_cert.public_bytes( encoding=serialization.Encoding.PEM ).decode("UTF-8") - return (ca_cert, [client_cert, intermediate_cert, root_cert]) + return (server_cert, [client_cert, intermediate_cert, root_cert]) async def close(self) -> None: self.closed = True +def metadata_exchange(sock: ssl.SSLSocket) -> None: + """ + Mimics server side metadata exchange behavior in four steps: + + 1. Read a big endian uint32 (4 bytes) from the client. This is the number of + bytes the message consumes. The length does not include the initial four + bytes. + + 2. Read the message from the client using the message length and serialize + it into a MetadataExchangeResponse message. + + The real server implementation will then validate the client has connection + permissions using the provided OAuth2 token based on the auth type. Here in + the test implementation, the server does nothing. + + 3. Prepare a response and write the size of the response as a big endian + uint32 (4 bytes) + + 4. Parse the response to bytes and write those to the client as well. + + Subsequent interactions with the test server use the database protocol. + """ + # read metadata message length (4 bytes) + message_len_buffer_size = struct.Struct("I").size + message_len_buffer = b"" + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) + message_len_buffer += chunk + message_len_buffer_size -= len(chunk) + + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read metadata exchange message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError("Connection closed while performing metadata exchange!") + buffer += chunk + message_len -= len(chunk) + + # form metadata exchange request to be received from client + message = connectorspb.MetadataExchangeRequest() + # parse metadata exchange request from buffer + message.ParseFromString(buffer) + + # form metadata exchange response to send to client + resp = connectorspb.MetadataExchangeResponse( + response_code=connectorspb.MetadataExchangeResponse.OK + ) + + # pack big-endian unsigned integer (4 bytes) + resp_len = struct.pack(">I", resp.ByteSize()) + + # send metadata response length and response message + sock.sendall(resp_len + resp.SerializeToString()) + + class FakeConnectionInfo: """Fake connection info class that doesn't perform a refresh""" diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fa686ed..fada5e4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Any +from typing import Any, Optional from aiohttp import web from mocks import FakeCredentials @@ -102,3 +102,46 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None: assert client._client.headers["x-goog-user-project"] == "my-quota-project" # close client await client.close() + + +@pytest.mark.parametrize( + "driver", + [None, "pg8000", "asyncpg"], +) +@pytest.mark.asyncio +async def test_AlloyDBClient_user_agent( + driver: Optional[str], credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of AlloyDBClient + properly sets user agent when passed a database driver. + """ + client = AlloyDBClient( + "www.test-endpoint.com", "my-quota-project", credentials, driver=driver + ) + if driver is None: + assert client._user_agent == f"alloydb-python-connector/{version}" + else: + assert client._user_agent == f"alloydb-python-connector/{version}+{driver}" + # close client + await client.close() + + +@pytest.mark.parametrize( + "driver, expected", + [(None, False), ("pg8000", True), ("asyncpg", False)], +) +@pytest.mark.asyncio +async def test_AlloyDBClient_use_metadata( + driver: Optional[str], expected: bool, credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of AlloyDBClient + properly sets use_metadata. + """ + client = AlloyDBClient( + "www.test-endpoint.com", "my-quota-project", credentials, driver=driver + ) + assert client._use_metadata == expected + # close client + await client.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 293917e..2705050 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -62,11 +62,11 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False -def test_connect(credentials: FakeCredentials) -> None: +def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: """ Test that connector.connect returns connection object. """ - client = FakeAlloyDBClient() + client = fake_client with Connector(credentials) as connector: connector._client = client # patch db connection creation diff --git a/tests/unit/test_refresh.py b/tests/unit/test_refresh.py index 55761ce..9c4ceee 100644 --- a/tests/unit/test_refresh.py +++ b/tests/unit/test_refresh.py @@ -62,7 +62,6 @@ def test_RefreshResult_init_(fake_instance: FakeInstance) -> None: can correctly initialize TLS context. """ key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - fake_instance.generate_certs() root_cert, intermediate_cert, ca_cert = fake_instance.get_pem_certs() # build client cert client_cert = (