From d9557e6cf1b01a99c1f891017bb22f49527650fb Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 21 Dec 2023 16:26:37 +0000 Subject: [PATCH 01/28] chore: first attempt --- google/cloud/alloydb/connector/client.py | 1 + google/cloud/alloydb/connector/connector.py | 30 ++++++++++++ google/cloud/alloydb/connector/pg8000.py | 16 ++----- .../alloydb_connectors_v1/proto/__init__.py | 0 .../proto/resources_pb2.py | 38 +++++++++++++++ .../proto/resources_pb2.pyi | 48 +++++++++++++++++++ 6 files changed, 121 insertions(+), 12 deletions(-) create mode 100644 google/cloud/alloydb_connectors_v1/proto/__init__.py create mode 100755 google/cloud/alloydb_connectors_v1/proto/resources_pb2.py create mode 100644 google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index a39447c..bb7e221 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -67,6 +67,7 @@ def __init__( self._client = client if client else aiohttp.ClientSession(headers=headers) self._credentials = credentials self._alloydb_api_endpoint = alloydb_api_endpoint + self._user_agent = USER_AGENT async def _get_metadata( self, diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index b77deee..fe009c5 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -16,6 +16,7 @@ import asyncio from functools import partial +import socket from threading import Thread from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING @@ -26,10 +27,14 @@ from google.cloud.alloydb.connector.instance import Instance import google.cloud.alloydb.connector.pg8000 as pg8000 from google.cloud.alloydb.connector.utils import generate_keys +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb if TYPE_CHECKING: + import ssl from google.auth.credentials import Credentials +SERVER_PROXY_PORT = 5433 + class Connector: """A class to configure and create connections to Cloud SQL instances. @@ -44,6 +49,7 @@ class Connector: Defaults to None, picking up project from environment. alloydb_api_endpoint (str): Base URL to use when calling the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com". + enable_iam_auth (bool): Enables automatic IAM database authentication. """ def __init__( @@ -51,6 +57,7 @@ def __init__( credentials: Optional[Credentials] = None, quota_project: Optional[str] = None, alloydb_api_endpoint: str = "https://alloydb.googleapis.com", + enable_iam_auth: bool = False, ) -> None: # create event loop and start it in background thread self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() @@ -60,6 +67,7 @@ def __init__( # initialize default params self._quota_project = quota_project self._alloydb_api_endpoint = alloydb_api_endpoint + self._enable_iam_auth = enable_iam_auth # initialize credentials scopes = ["https://www.googleapis.com/auth/cloud-platform"] if credentials: @@ -122,6 +130,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> self._client = AlloyDBClient( self._alloydb_api_endpoint, self._quota_project, self._credentials ) + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible if instance_uri in self._instances: instance = self._instances[instance_uri] @@ -156,6 +165,27 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> await instance.force_refresh() raise + def metadata_exchange( + self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str + ): + # Create socket and wrap with SSL/TLS context + sock = ctx.wrap_socket( + socket.create_connection((ip_address, SERVER_PROXY_PORT)), + server_hostname=ip_address, + ) + # set auth type for metadata exchange + if enable_iam_auth: + auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM + else: + auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE + + # form metadata exchange request + req = connectorspb.MetadataExchangeRequest( + user_agent=self._client._user_agent + f"+{driver}", + auth_type=auth_type, + oauth2_token=self._credentials.token, + ) + def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" return self diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index c624363..88fc1ee 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -11,25 +11,22 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import socket -import ssl from typing import Any, TYPE_CHECKING -SERVER_PROXY_PORT = 5433 - if TYPE_CHECKING: import pg8000 + import ssl def connect( - ip_address: str, ctx: ssl.SSLContext, **kwargs: Any + ip_address: str, sock: ssl.SSLSocket, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Create a pg8000 DBAPI connection object. Args: ip_address (str): IP address of AlloyDB instance to connect to. - ctx (ssl.SSLContext): Context used to create a TLS connection - with AlloyDB instance ssl certificates. + sock (ssl.SSLSocket): SSL/TLS secure socket stream connected to the + AlloyDB proxy server. Returns: pg8000.dbapi.Connection: A pg8000 Connection object for @@ -45,11 +42,6 @@ def connect( raise ImportError( 'Unable to import module "pg8000." Please install and try again.' ) - # Create socket and wrap with context. - sock = ctx.wrap_socket( - socket.create_connection((ip_address, SERVER_PROXY_PORT)), - server_hostname=ip_address, - ) user = kwargs.pop("user") db = kwargs.pop("db") diff --git a/google/cloud/alloydb_connectors_v1/proto/__init__.py b/google/cloud/alloydb_connectors_v1/proto/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py new file mode 100755 index 0000000..87b813c --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py @@ -0,0 +1,38 @@ +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/cloud/alloydb_connectors_v1/proto/resources.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n8google/cloud/alloydb_connectors_v1/proto/resources.proto\x12\"google.cloud.alloydb.connectors.v1\x1a\x1fgoogle/api/field_behavior.proto\"\xe6\x01\n\x17MetadataExchangeRequest\x12\x18\n\nuser_agent\x18\x01 \x01(\tB\x04\xe2\x41\x01\x01\x12W\n\tauth_type\x18\x02 \x01(\x0e\x32\x44.google.cloud.alloydb.connectors.v1.MetadataExchangeRequest.AuthType\x12\x14\n\x0coauth2_token\x18\x03 \x01(\t\"B\n\x08\x41uthType\x12\x19\n\x15\x41UTH_TYPE_UNSPECIFIED\x10\x00\x12\r\n\tDB_NATIVE\x10\x01\x12\x0c\n\x08\x41UTO_IAM\x10\x02\"\xd3\x01\n\x18MetadataExchangeResponse\x12`\n\rresponse_code\x18\x01 \x01(\x0e\x32I.google.cloud.alloydb.connectors.v1.MetadataExchangeResponse.ResponseCode\x12\x13\n\x05\x65rror\x18\x02 \x01(\tB\x04\xe2\x41\x01\x01\"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x42\xf5\x01\n&com.google.cloud.alloydb.connectors.v1B\x0eResourcesProtoP\x01ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\xaa\x02\"Google.Cloud.AlloyDb.Connectors.V1\xca\x02\"Google\\Cloud\\AlloyDb\\Connectors\\V1\xea\x02&Google::Cloud::AlloyDb::Connectors::V1b\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.cloud.alloydb_connectors_v1.proto.resources_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n&com.google.cloud.alloydb.connectors.v1B\016ResourcesProtoP\001ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\252\002\"Google.Cloud.AlloyDb.Connectors.V1\312\002\"Google\\Cloud\\AlloyDb\\Connectors\\V1\352\002&Google::Cloud::AlloyDb::Connectors::V1' + _METADATAEXCHANGEREQUEST.fields_by_name['user_agent']._options = None + _METADATAEXCHANGEREQUEST.fields_by_name['user_agent']._serialized_options = b'\342A\001\001' + _METADATAEXCHANGERESPONSE.fields_by_name['error']._options = None + _METADATAEXCHANGERESPONSE.fields_by_name['error']._serialized_options = b'\342A\001\001' + _globals['_METADATAEXCHANGEREQUEST']._serialized_start=130 + _globals['_METADATAEXCHANGEREQUEST']._serialized_end=360 + _globals['_METADATAEXCHANGEREQUEST_AUTHTYPE']._serialized_start=294 + _globals['_METADATAEXCHANGEREQUEST_AUTHTYPE']._serialized_end=360 + _globals['_METADATAEXCHANGERESPONSE']._serialized_start=363 + _globals['_METADATAEXCHANGERESPONSE']._serialized_end=574 + _globals['_METADATAEXCHANGERESPONSE_RESPONSECODE']._serialized_start=510 + _globals['_METADATAEXCHANGERESPONSE_RESPONSECODE']._serialized_end=574 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi new file mode 100644 index 0000000..650e339 --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi @@ -0,0 +1,48 @@ +from google.api import field_behavior_pb2 as _field_behavior_pb2 +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union + +DESCRIPTOR: _descriptor.FileDescriptor + +class MetadataExchangeRequest(_message.Message): + __slots__ = ["auth_type", "oauth2_token", "user_agent"] + + class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + AUTH_TYPE_FIELD_NUMBER: _ClassVar[int] + AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType + AUTO_IAM: MetadataExchangeRequest.AuthType + DB_NATIVE: MetadataExchangeRequest.AuthType + OAUTH2_TOKEN_FIELD_NUMBER: _ClassVar[int] + USER_AGENT_FIELD_NUMBER: _ClassVar[int] + auth_type: MetadataExchangeRequest.AuthType + oauth2_token: str + user_agent: str + def __init__( + self, + user_agent: _Optional[str] = ..., + auth_type: _Optional[_Union[MetadataExchangeRequest.AuthType, str]] = ..., + oauth2_token: _Optional[str] = ..., + ) -> None: ... + +class MetadataExchangeResponse(_message.Message): + __slots__ = ["error", "response_code"] + + class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + ERROR: MetadataExchangeResponse.ResponseCode + ERROR_FIELD_NUMBER: _ClassVar[int] + OK: MetadataExchangeResponse.ResponseCode + RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int] + RESPONSE_CODE_UNSPECIFIED: MetadataExchangeResponse.ResponseCode + error: str + response_code: MetadataExchangeResponse.ResponseCode + def __init__( + self, + response_code: _Optional[ + _Union[MetadataExchangeResponse.ResponseCode, str] + ] = ..., + error: _Optional[str] = ..., + ) -> None: ... From 43fd2813a0c6f700661a0cbfe2a97a5613588652 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 27 Dec 2023 23:26:57 +0000 Subject: [PATCH 02/28] chore: add socket read and write --- google/cloud/alloydb/connector/connector.py | 38 ++++++++++++++++++++- google/cloud/alloydb/connector/pg8000.py | 7 +--- 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index fe009c5..6eaa769 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -17,6 +17,7 @@ import asyncio from functools import partial import socket +import struct from threading import Thread from types import TracebackType from typing import Any, Dict, Optional, Type, TYPE_CHECKING @@ -158,7 +159,9 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # synchronous drivers are blocking and run using executor try: - connect_partial = partial(connector, ip_address, context, **kwargs) + metadata_partial = partial(self.metadata_exchange, ip_address, context, enable_iam_auth, driver) + sock = await self._loop.run_in_executor(None, metadata_partial) + connect_partial = partial(connector, sock, **kwargs) return await self._loop.run_in_executor(None, connect_partial) except Exception: # we attempt a force refresh, then throw the error @@ -186,6 +189,39 @@ def metadata_exchange( oauth2_token=self._credentials.token, ) + # set I/O timeout + sock.settimeout(30) + + # pack big-endian unsigned integer + packed_len = struct.pack('>I', req.ByteSize()) + + # send message length + sock.sendall(packed_len) + # send message + sock.sendall(req.SerializeToString()) + + # form metadata exchange response + resp = connectorspb.MetadataExchangeResponse() + + # read response + message_len_buffer_size = struct.Struct("I").size + buffer = b'' + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError('connection closed before chunk was read') + buffer += chunk + message_len_buffer_size -= len(chunk) + (message_len,) = struct.unpack('>I', buffer) + + resp.ParseFromString(sock.recv(message_len)) + + if resp.response_code != connectorspb.MetadataExchangeResponse.OK: + raise ValueError("Metadata Exchange request has failed") + + return sock + + def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" return self diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index 88fc1ee..44ab939 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -19,12 +19,11 @@ def connect( - ip_address: str, sock: ssl.SSLSocket, **kwargs: Any + sock: ssl.SSLSocket, **kwargs: Any ) -> "pg8000.dbapi.Connection": """Create a pg8000 DBAPI connection object. Args: - ip_address (str): IP address of AlloyDB instance to connect to. sock (ssl.SSLSocket): SSL/TLS secure socket stream connected to the AlloyDB proxy server. @@ -32,10 +31,6 @@ def connect( pg8000.dbapi.Connection: A pg8000 Connection object for the AlloyDB instance. """ - # Connecting through pg8000 is done by passing in an SSL Context and setting the - # "request_ssl" attr to false. This works because when "request_ssl" is false, - # the driver skips the database level SSL/TLS exchange, but still uses the - # ssl_context (if it is not None) to create the connection. try: import pg8000 except ImportError: From 96669ebf7c764bf5f2ca2968fdb6e0f3a6723791 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 27 Dec 2023 23:31:26 +0000 Subject: [PATCH 03/28] chore: type hint ssl --- google/cloud/alloydb/connector/pg8000.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index 44ab939..7edf9b4 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -19,7 +19,7 @@ def connect( - sock: ssl.SSLSocket, **kwargs: Any + sock: "ssl.SSLSocket", **kwargs: Any ) -> "pg8000.dbapi.Connection": """Create a pg8000 DBAPI connection object. From 65d276db8c58aa36967c6e732e8794d45cd8157a Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 27 Dec 2023 23:35:55 +0000 Subject: [PATCH 04/28] chore: add protofbuf type --- requirements.txt | 1 + setup.py | 1 + 2 files changed, 2 insertions(+) diff --git a/requirements.txt b/requirements.txt index b4ae376..c0e1ae8 100644 --- a/requirements.txt +++ b/requirements.txt @@ -2,3 +2,4 @@ aiohttp==3.9.1 cryptography==41.0.7 google-auth==2.25.2 requests==2.31.0 +protobuf==4.25.1 diff --git a/setup.py b/setup.py index 67bf80e..30c350f 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ "cryptography>=38.0.3", "requests", "google-auth", + "protobuf", ] package_root = os.path.abspath(os.path.dirname(__file__)) From 032f9e6411f38935a5eddb711323ad0000727599 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Dec 2023 16:27:54 +0000 Subject: [PATCH 05/28] chore: add google/api dep --- google/api/__init__.py | 0 google/api/field_behavior_pb2.py | 43 +++++++++++++++++ .../proto/resources_pb2.py | 14 ++++++ .../proto/resources_pb2.pyi | 48 ------------------- 4 files changed, 57 insertions(+), 48 deletions(-) create mode 100644 google/api/__init__.py create mode 100755 google/api/field_behavior_pb2.py delete mode 100644 google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi diff --git a/google/api/__init__.py b/google/api/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/google/api/field_behavior_pb2.py b/google/api/field_behavior_pb2.py new file mode 100755 index 0000000..9451503 --- /dev/null +++ b/google/api/field_behavior_pb2.py @@ -0,0 +1,43 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# -*- coding: utf-8 -*- +# Generated by the protocol buffer compiler. DO NOT EDIT! +# source: google/api/field_behavior.proto +"""Generated protocol buffer code.""" +from google.protobuf import descriptor as _descriptor +from google.protobuf import descriptor_pool as _descriptor_pool +from google.protobuf import symbol_database as _symbol_database +from google.protobuf.internal import builder as _builder +# @@protoc_insertion_point(imports) + +_sym_db = _symbol_database.Default() + + +from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 + + +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3') + +_globals = globals() +_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) +_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.api.field_behavior_pb2', _globals) +if _descriptor._USE_C_DESCRIPTORS == False: + google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(field_behavior) + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI' + _globals['_FIELDBEHAVIOR']._serialized_start=82 + _globals['_FIELDBEHAVIOR']._serialized_end=264 +# @@protoc_insertion_point(module_scope) diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py index 87b813c..36e5278 100755 --- a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py @@ -1,3 +1,17 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: google/cloud/alloydb_connectors_v1/proto/resources.proto diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi deleted file mode 100644 index 650e339..0000000 --- a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi +++ /dev/null @@ -1,48 +0,0 @@ -from google.api import field_behavior_pb2 as _field_behavior_pb2 -from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper -from google.protobuf import descriptor as _descriptor -from google.protobuf import message as _message -from typing import ClassVar as _ClassVar, Optional as _Optional, Union as _Union - -DESCRIPTOR: _descriptor.FileDescriptor - -class MetadataExchangeRequest(_message.Message): - __slots__ = ["auth_type", "oauth2_token", "user_agent"] - - class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - AUTH_TYPE_FIELD_NUMBER: _ClassVar[int] - AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType - AUTO_IAM: MetadataExchangeRequest.AuthType - DB_NATIVE: MetadataExchangeRequest.AuthType - OAUTH2_TOKEN_FIELD_NUMBER: _ClassVar[int] - USER_AGENT_FIELD_NUMBER: _ClassVar[int] - auth_type: MetadataExchangeRequest.AuthType - oauth2_token: str - user_agent: str - def __init__( - self, - user_agent: _Optional[str] = ..., - auth_type: _Optional[_Union[MetadataExchangeRequest.AuthType, str]] = ..., - oauth2_token: _Optional[str] = ..., - ) -> None: ... - -class MetadataExchangeResponse(_message.Message): - __slots__ = ["error", "response_code"] - - class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] - ERROR: MetadataExchangeResponse.ResponseCode - ERROR_FIELD_NUMBER: _ClassVar[int] - OK: MetadataExchangeResponse.ResponseCode - RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int] - RESPONSE_CODE_UNSPECIFIED: MetadataExchangeResponse.ResponseCode - error: str - response_code: MetadataExchangeResponse.ResponseCode - def __init__( - self, - response_code: _Optional[ - _Union[MetadataExchangeResponse.ResponseCode, str] - ] = ..., - error: _Optional[str] = ..., - ) -> None: ... From 0700411a8b68f38e400ae11752eb4e3b6da1fcc4 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Dec 2023 16:30:26 +0000 Subject: [PATCH 06/28] chore: add header files --- google/api/__init__.py | 13 +++++++++++++ .../cloud/alloydb_connectors_v1/proto/__init__.py | 13 +++++++++++++ 2 files changed, 26 insertions(+) diff --git a/google/api/__init__.py b/google/api/__init__.py index e69de29..1dc90d1 100644 --- a/google/api/__init__.py +++ b/google/api/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/google/cloud/alloydb_connectors_v1/proto/__init__.py b/google/cloud/alloydb_connectors_v1/proto/__init__.py index e69de29..1dc90d1 100644 --- a/google/cloud/alloydb_connectors_v1/proto/__init__.py +++ b/google/cloud/alloydb_connectors_v1/proto/__init__.py @@ -0,0 +1,13 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. From fc847fa749ae73493fd4d81f16f77250ef28a098 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Dec 2023 17:58:44 +0000 Subject: [PATCH 07/28] chore: fix read of response --- google/cloud/alloydb/connector/connector.py | 31 ++++++++++++++------- 1 file changed, 21 insertions(+), 10 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 6eaa769..0b4e8da 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -159,7 +159,9 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> # synchronous drivers are blocking and run using executor try: - metadata_partial = partial(self.metadata_exchange, ip_address, context, enable_iam_auth, driver) + metadata_partial = partial( + self.metadata_exchange, ip_address, context, enable_iam_auth, driver + ) sock = await self._loop.run_in_executor(None, metadata_partial) connect_partial = partial(connector, sock, **kwargs) return await self._loop.run_in_executor(None, connect_partial) @@ -193,7 +195,7 @@ def metadata_exchange( sock.settimeout(30) # pack big-endian unsigned integer - packed_len = struct.pack('>I', req.ByteSize()) + packed_len = struct.pack(">I", req.ByteSize()) # send message length sock.sendall(packed_len) @@ -203,24 +205,33 @@ def metadata_exchange( # form metadata exchange response resp = connectorspb.MetadataExchangeResponse() - # read response + # read message length message_len_buffer_size = struct.Struct("I").size - buffer = b'' + message_len_buffer = b"" while message_len_buffer_size > 0: chunk = sock.recv(message_len_buffer_size) if not chunk: - raise RuntimeError('connection closed before chunk was read') - buffer += chunk + raise RuntimeError("connection closed before chunk was read") + message_len_buffer += chunk message_len_buffer_size -= len(chunk) - (message_len,) = struct.unpack('>I', buffer) - resp.ParseFromString(sock.recv(message_len)) + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError("connection closed before chunk was read") + buffer += chunk + size -= len(chunk) + # parse mdx resp from buffer + resp.ParseFromString(buffer) if resp.response_code != connectorspb.MetadataExchangeResponse.OK: raise ValueError("Metadata Exchange request has failed") - - return sock + return sock def __enter__(self) -> "Connector": """Enter context manager by returning Connector object""" From 996fef6c06a3860e01117a4ad2d00857d723f502 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Thu, 28 Dec 2023 18:02:54 +0000 Subject: [PATCH 08/28] chore: fix message size --- google/cloud/alloydb/connector/connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 0b4e8da..8be8595 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -224,7 +224,7 @@ def metadata_exchange( if not chunk: raise RuntimeError("connection closed before chunk was read") buffer += chunk - size -= len(chunk) + message_len -= len(chunk) # parse mdx resp from buffer resp.ParseFromString(buffer) From 658ae5aa7ea27cd5b872550bc5562aa00e39573b Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 21:31:11 +0000 Subject: [PATCH 09/28] chore: set useMetadataExchange to True --- google/cloud/alloydb/connector/client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index bb7e221..2da7da9 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -150,6 +150,7 @@ async def _get_client_certificate( data = { "publicKey": pub_key, "certDuration": "3600s", + "useMetadataExchange": True, } resp = await self._client.post( From 1460ba049560827a311359b9a59bb3164b99b642 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 21:57:32 +0000 Subject: [PATCH 10/28] chore: lint and headers --- google/api/__init__.py | 2 +- google/api/field_behavior_pb2.py | 25 +++++++---- google/cloud/alloydb/connector/connector.py | 24 ++++++---- google/cloud/alloydb/connector/pg8000.py | 4 +- .../alloydb_connectors_v1/proto/__init__.py | 2 +- .../proto/resources_pb2.py | 44 +++++++++++-------- 6 files changed, 60 insertions(+), 41 deletions(-) diff --git a/google/api/__init__.py b/google/api/__init__.py index 1dc90d1..6d5e14b 100644 --- a/google/api/__init__.py +++ b/google/api/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/api/field_behavior_pb2.py b/google/api/field_behavior_pb2.py index 9451503..ba7817c 100755 --- a/google/api/field_behavior_pb2.py +++ b/google/api/field_behavior_pb2.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -28,16 +29,22 @@ from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3" +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.api.field_behavior_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.api.field_behavior_pb2", _globals +) if _descriptor._USE_C_DESCRIPTORS == False: - google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(field_behavior) - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI' - _globals['_FIELDBEHAVIOR']._serialized_start=82 - _globals['_FIELDBEHAVIOR']._serialized_end=264 + google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension( + field_behavior + ) + + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI" + _globals["_FIELDBEHAVIOR"]._serialized_start = 82 + _globals["_FIELDBEHAVIOR"]._serialized_end = 264 # @@protoc_insertion_point(module_scope) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 8be8595..b67ef8c 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -194,42 +194,48 @@ def metadata_exchange( # set I/O timeout sock.settimeout(30) - # pack big-endian unsigned integer + # pack big-endian unsigned integer (4 bytes) packed_len = struct.pack(">I", req.ByteSize()) - # send message length + # send metadata message length sock.sendall(packed_len) - # send message + # send metadata request message sock.sendall(req.SerializeToString()) # form metadata exchange response resp = connectorspb.MetadataExchangeResponse() - # read message length + # read metadata message length (4 bytes) message_len_buffer_size = struct.Struct("I").size message_len_buffer = b"" while message_len_buffer_size > 0: chunk = sock.recv(message_len_buffer_size) if not chunk: - raise RuntimeError("connection closed before chunk was read") + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) message_len_buffer += chunk message_len_buffer_size -= len(chunk) (message_len,) = struct.unpack(">I", message_len_buffer) - # read message + # read metadata exchange message buffer = b"" while message_len > 0: chunk = sock.recv(message_len) if not chunk: - raise RuntimeError("connection closed before chunk was read") + raise RuntimeError( + "Connection closed while performing metadata exchange!" + ) buffer += chunk message_len -= len(chunk) - # parse mdx resp from buffer + + # parse metadata exchange response from buffer resp.ParseFromString(buffer) + # validate metadata exchange response if resp.response_code != connectorspb.MetadataExchangeResponse.OK: - raise ValueError("Metadata Exchange request has failed") + raise ValueError("Metadata Exchange request has failed!") return sock diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index 7edf9b4..d140263 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -18,9 +18,7 @@ import ssl -def connect( - sock: "ssl.SSLSocket", **kwargs: Any -) -> "pg8000.dbapi.Connection": +def connect(sock: "ssl.SSLSocket", **kwargs: Any) -> "pg8000.dbapi.Connection": """Create a pg8000 DBAPI connection object. Args: diff --git a/google/cloud/alloydb_connectors_v1/proto/__init__.py b/google/cloud/alloydb_connectors_v1/proto/__init__.py index 1dc90d1..6d5e14b 100644 --- a/google/cloud/alloydb_connectors_v1/proto/__init__.py +++ b/google/cloud/alloydb_connectors_v1/proto/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py index 36e5278..33295d8 100755 --- a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ from google.protobuf import descriptor_pool as _descriptor_pool from google.protobuf import symbol_database as _symbol_database from google.protobuf.internal import builder as _builder + # @@protoc_insertion_point(imports) _sym_db = _symbol_database.Default() @@ -28,25 +29,32 @@ from google.api import field_behavior_pb2 as google_dot_api_dot_field__behavior__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n8google/cloud/alloydb_connectors_v1/proto/resources.proto\x12\"google.cloud.alloydb.connectors.v1\x1a\x1fgoogle/api/field_behavior.proto\"\xe6\x01\n\x17MetadataExchangeRequest\x12\x18\n\nuser_agent\x18\x01 \x01(\tB\x04\xe2\x41\x01\x01\x12W\n\tauth_type\x18\x02 \x01(\x0e\x32\x44.google.cloud.alloydb.connectors.v1.MetadataExchangeRequest.AuthType\x12\x14\n\x0coauth2_token\x18\x03 \x01(\t\"B\n\x08\x41uthType\x12\x19\n\x15\x41UTH_TYPE_UNSPECIFIED\x10\x00\x12\r\n\tDB_NATIVE\x10\x01\x12\x0c\n\x08\x41UTO_IAM\x10\x02\"\xd3\x01\n\x18MetadataExchangeResponse\x12`\n\rresponse_code\x18\x01 \x01(\x0e\x32I.google.cloud.alloydb.connectors.v1.MetadataExchangeResponse.ResponseCode\x12\x13\n\x05\x65rror\x18\x02 \x01(\tB\x04\xe2\x41\x01\x01\"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x42\xf5\x01\n&com.google.cloud.alloydb.connectors.v1B\x0eResourcesProtoP\x01ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\xaa\x02\"Google.Cloud.AlloyDb.Connectors.V1\xca\x02\"Google\\Cloud\\AlloyDb\\Connectors\\V1\xea\x02&Google::Cloud::AlloyDb::Connectors::V1b\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile( + b'\n8google/cloud/alloydb_connectors_v1/proto/resources.proto\x12"google.cloud.alloydb.connectors.v1\x1a\x1fgoogle/api/field_behavior.proto"\xe6\x01\n\x17MetadataExchangeRequest\x12\x18\n\nuser_agent\x18\x01 \x01(\tB\x04\xe2\x41\x01\x01\x12W\n\tauth_type\x18\x02 \x01(\x0e\x32\x44.google.cloud.alloydb.connectors.v1.MetadataExchangeRequest.AuthType\x12\x14\n\x0coauth2_token\x18\x03 \x01(\t"B\n\x08\x41uthType\x12\x19\n\x15\x41UTH_TYPE_UNSPECIFIED\x10\x00\x12\r\n\tDB_NATIVE\x10\x01\x12\x0c\n\x08\x41UTO_IAM\x10\x02"\xd3\x01\n\x18MetadataExchangeResponse\x12`\n\rresponse_code\x18\x01 \x01(\x0e\x32I.google.cloud.alloydb.connectors.v1.MetadataExchangeResponse.ResponseCode\x12\x13\n\x05\x65rror\x18\x02 \x01(\tB\x04\xe2\x41\x01\x01"@\n\x0cResponseCode\x12\x1d\n\x19RESPONSE_CODE_UNSPECIFIED\x10\x00\x12\x06\n\x02OK\x10\x01\x12\t\n\x05\x45RROR\x10\x02\x42\xf5\x01\n&com.google.cloud.alloydb.connectors.v1B\x0eResourcesProtoP\x01ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\xaa\x02"Google.Cloud.AlloyDb.Connectors.V1\xca\x02"Google\\Cloud\\AlloyDb\\Connectors\\V1\xea\x02&Google::Cloud::AlloyDb::Connectors::V1b\x06proto3' +) _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) -_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'google.cloud.alloydb_connectors_v1.proto.resources_pb2', _globals) +_builder.BuildTopDescriptorsAndMessages( + DESCRIPTOR, "google.cloud.alloydb_connectors_v1.proto.resources_pb2", _globals +) if _descriptor._USE_C_DESCRIPTORS == False: - - DESCRIPTOR._options = None - DESCRIPTOR._serialized_options = b'\n&com.google.cloud.alloydb.connectors.v1B\016ResourcesProtoP\001ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\252\002\"Google.Cloud.AlloyDb.Connectors.V1\312\002\"Google\\Cloud\\AlloyDb\\Connectors\\V1\352\002&Google::Cloud::AlloyDb::Connectors::V1' - _METADATAEXCHANGEREQUEST.fields_by_name['user_agent']._options = None - _METADATAEXCHANGEREQUEST.fields_by_name['user_agent']._serialized_options = b'\342A\001\001' - _METADATAEXCHANGERESPONSE.fields_by_name['error']._options = None - _METADATAEXCHANGERESPONSE.fields_by_name['error']._serialized_options = b'\342A\001\001' - _globals['_METADATAEXCHANGEREQUEST']._serialized_start=130 - _globals['_METADATAEXCHANGEREQUEST']._serialized_end=360 - _globals['_METADATAEXCHANGEREQUEST_AUTHTYPE']._serialized_start=294 - _globals['_METADATAEXCHANGEREQUEST_AUTHTYPE']._serialized_end=360 - _globals['_METADATAEXCHANGERESPONSE']._serialized_start=363 - _globals['_METADATAEXCHANGERESPONSE']._serialized_end=574 - _globals['_METADATAEXCHANGERESPONSE_RESPONSECODE']._serialized_start=510 - _globals['_METADATAEXCHANGERESPONSE_RESPONSECODE']._serialized_end=574 + DESCRIPTOR._options = None + DESCRIPTOR._serialized_options = b'\n&com.google.cloud.alloydb.connectors.v1B\016ResourcesProtoP\001ZFcloud.google.com/go/alloydb/connectors/apiv1/connectorspb;connectorspb\252\002"Google.Cloud.AlloyDb.Connectors.V1\312\002"Google\\Cloud\\AlloyDb\\Connectors\\V1\352\002&Google::Cloud::AlloyDb::Connectors::V1' + _METADATAEXCHANGEREQUEST.fields_by_name["user_agent"]._options = None + _METADATAEXCHANGEREQUEST.fields_by_name[ + "user_agent" + ]._serialized_options = b"\342A\001\001" + _METADATAEXCHANGERESPONSE.fields_by_name["error"]._options = None + _METADATAEXCHANGERESPONSE.fields_by_name[ + "error" + ]._serialized_options = b"\342A\001\001" + _globals["_METADATAEXCHANGEREQUEST"]._serialized_start = 130 + _globals["_METADATAEXCHANGEREQUEST"]._serialized_end = 360 + _globals["_METADATAEXCHANGEREQUEST_AUTHTYPE"]._serialized_start = 294 + _globals["_METADATAEXCHANGEREQUEST_AUTHTYPE"]._serialized_end = 360 + _globals["_METADATAEXCHANGERESPONSE"]._serialized_start = 363 + _globals["_METADATAEXCHANGERESPONSE"]._serialized_end = 574 + _globals["_METADATAEXCHANGERESPONSE_RESPONSECODE"]._serialized_start = 510 + _globals["_METADATAEXCHANGERESPONSE_RESPONSECODE"]._serialized_end = 574 # @@protoc_insertion_point(module_scope) From 2b44528caf078b7b0f83efc9e6ddad67656ac1a1 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 22:21:39 +0000 Subject: [PATCH 11/28] chore: lint --- google/api/field_behavior_pb2.py | 1 + google/cloud/alloydb/connector/connector.py | 4 ++-- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/google/api/field_behavior_pb2.py b/google/api/field_behavior_pb2.py index ba7817c..15538ab 100755 --- a/google/api/field_behavior_pb2.py +++ b/google/api/field_behavior_pb2.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +# type: ignore # -*- coding: utf-8 -*- # Generated by the protocol buffer compiler. DO NOT EDIT! # source: google/api/field_behavior.proto diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index b67ef8c..de131da 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -172,7 +172,7 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> def metadata_exchange( self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str - ): + ) -> ssl.SSLSocket: # Create socket and wrap with SSL/TLS context sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), @@ -186,7 +186,7 @@ def metadata_exchange( # form metadata exchange request req = connectorspb.MetadataExchangeRequest( - user_agent=self._client._user_agent + f"+{driver}", + user_agent=self._client._user_agent + f"+{driver}", # type: ignore auth_type=auth_type, oauth2_token=self._credentials.token, ) From 9e03e773c341d8abf1b5536c824ec0ed5cb5f2dd Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 22:30:11 +0000 Subject: [PATCH 12/28] chore: add IAM authn test --- google/cloud/alloydb/connector/pg8000.py | 2 +- tests/system/test_pg8000_iam_authn.py | 96 ++++++++++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) create mode 100644 tests/system/test_pg8000_iam_authn.py diff --git a/google/cloud/alloydb/connector/pg8000.py b/google/cloud/alloydb/connector/pg8000.py index d140263..745f342 100644 --- a/google/cloud/alloydb/connector/pg8000.py +++ b/google/cloud/alloydb/connector/pg8000.py @@ -38,7 +38,7 @@ def connect(sock: "ssl.SSLSocket", **kwargs: Any) -> "pg8000.dbapi.Connection": user = kwargs.pop("user") db = kwargs.pop("db") - passwd = kwargs.pop("password") + passwd = kwargs.pop("password", None) return pg8000.dbapi.connect( user, database=db, diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py new file mode 100644 index 0000000..8f4a7b4 --- /dev/null +++ b/tests/system/test_pg8000_iam_authn.py @@ -0,0 +1,96 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from datetime import datetime +import os + +# [START alloydb_sqlalchemy_connect_connector] +import pg8000 +import sqlalchemy + +from google.cloud.alloydb.connector import Connector + + +def create_sqlalchemy_engine( + inst_uri: str, + user: str, + db: str, +) -> (sqlalchemy.engine.Engine, Connector): + """Creates a connection pool for an AlloyDB instance and returns the pool + and the connector. Callers are responsible for closing the pool and the + connector. + + A sample invocation looks like: + + engine, connector = create_sqlalchemy_engine( + inst_uri, + user, + db, + ) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + # do something with query result + connector.close() + + Args: + instance_uri (str): + The instance URI specifies the instance relative to the project, + region, and cluster. For example: + "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" + user (str): + The database user name, e.g., postgres + password (str): + The database user's password, e.g., secret-password + db_name (str): + The name of the database, e.g., mydb + """ + connector = Connector() + + def getconn() -> pg8000.dbapi.Connection: + conn: pg8000.dbapi.Connection = connector.connect( + inst_uri, + "pg8000", + user=user, + db=db, + enable_iam_authn=True, + ) + return conn + + # create SQLAlchemy connection pool + engine = sqlalchemy.create_engine( + "postgresql+pg8000://", + creator=getconn, + ) + engine.dialect.description_encoding = None + return engine, connector + + +# [END alloydb_sqlalchemy_connect_connector] + + +def test_pg8000_time() -> None: + """Basic test to get time from database.""" + inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] + user = os.environ["ALLOYDB_IAM_USER"] + db = os.environ["ALLOYDB_DB"] + + engine, connector = create_sqlalchemy_engine(inst_uri, user, db) + with engine.connect() as conn: + time = conn.execute(sqlalchemy.text("SELECT NOW()")).fetchone() + conn.commit() + curr_time = time[0] + assert type(curr_time) is datetime + connector.close() From ffbed6989774a9ed9527481289b14c77a4406ae8 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 22:32:59 +0000 Subject: [PATCH 13/28] chore: update doc strings --- tests/system/test_pg8000_iam_authn.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py index 8f4a7b4..0725ee2 100644 --- a/tests/system/test_pg8000_iam_authn.py +++ b/tests/system/test_pg8000_iam_authn.py @@ -1,4 +1,4 @@ -# Copyright 2023 Google LLC +# Copyright 2024 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -52,8 +52,6 @@ def create_sqlalchemy_engine( "projects/my-project/locations/us-central1/clusters/my-cluster/instances/my-instance" user (str): The database user name, e.g., postgres - password (str): - The database user's password, e.g., secret-password db_name (str): The name of the database, e.g., mydb """ @@ -81,7 +79,7 @@ def getconn() -> pg8000.dbapi.Connection: # [END alloydb_sqlalchemy_connect_connector] -def test_pg8000_time() -> None: +def test_pg8000_iam_authn_time() -> None: """Basic test to get time from database.""" inst_uri = os.environ["ALLOYDB_INSTANCE_URI"] user = os.environ["ALLOYDB_IAM_USER"] From d064f941f79cbddd67e916c4ff71ab070a403573 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 1 Jan 2024 22:36:28 +0000 Subject: [PATCH 14/28] chore: fix iam authn flag --- tests/system/test_pg8000_iam_authn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/system/test_pg8000_iam_authn.py b/tests/system/test_pg8000_iam_authn.py index 0725ee2..62400fe 100644 --- a/tests/system/test_pg8000_iam_authn.py +++ b/tests/system/test_pg8000_iam_authn.py @@ -63,7 +63,7 @@ def getconn() -> pg8000.dbapi.Connection: "pg8000", user=user, db=db, - enable_iam_authn=True, + enable_iam_auth=True, ) return conn From 46d693806ebd67d5fc7c5920c25f70e64a85913f Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Tue, 2 Jan 2024 01:22:55 +0000 Subject: [PATCH 15/28] chore: add docstring for metdata_exchange --- google/cloud/alloydb/connector/connector.py | 35 +++++++++++++++++++-- 1 file changed, 33 insertions(+), 2 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index de131da..ea66c70 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -173,16 +173,47 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> def metadata_exchange( self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str ) -> ssl.SSLSocket: + """ + Sends metadata about the connection prior to the database + protocol taking over. + + The exchange consists of four steps: + + 1. Prepare a MetadataExchangeRequest including the IAM Principal's + OAuth2 token, the user agent, and the requested authentication type. + + 2. Write the size of the message as a big endian uint32 (4 bytes) to + the server followed by the serialized message. The length does not + include the initial four bytes. + + 3. Read a big endian uint32 (4 bytes) from the server. This is the + MetadataExchangeResponse message length and does not include the + initial four bytes. + + 4. Parse the response using the message length in step 3. If the + response is not OK, return the response's error. If there is no error, + the metadata exchange has succeeded and the connection is complete. + + Args: + ip_address (str): IP address of AlloyDB instance to connect to. + ctx (ssl.SSLContext): Context used to create a TLS connection + with AlloyDB instance ssl certificates. + enable_iam_auth (bool): Flag to enable IAM database authentication. + driver (str): A string representing the database driver to connect with. + Supported drivers are pg8000. + + Returns: + sock (ssl.SSLSocket): mTLS/SSL socket connected to AlloyDB Proxy server. + """ # Create socket and wrap with SSL/TLS context sock = ctx.wrap_socket( socket.create_connection((ip_address, SERVER_PROXY_PORT)), server_hostname=ip_address, ) # set auth type for metadata exchange + auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE if enable_iam_auth: auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM - else: - auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE # form metadata exchange request req = connectorspb.MetadataExchangeRequest( From f0e205e68e1de25bc125065986fb46fda59d0150 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Wed, 10 Jan 2024 21:04:55 +0000 Subject: [PATCH 16/28] chore: first pass at local proxy server for testing --- google/cloud/alloydb/connector/connector.py | 2 +- tests/unit/conftest.py | 45 +++++++++++++- tests/unit/mocks.py | 68 +++++++++++++++++++++ 3 files changed, 112 insertions(+), 3 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index ea66c70..446ae20 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -217,7 +217,7 @@ def metadata_exchange( # form metadata exchange request req = connectorspb.MetadataExchangeRequest( - user_agent=self._client._user_agent + f"+{driver}", # type: ignore + user_agent=f"{self._client._user_agent}+{driver}", # type: ignore auth_type=auth_type, oauth2_token=self._credentials.token, ) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index c47035c..4e171ee 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -15,8 +15,14 @@ from mocks import ( FakeCredentials, FakeInstance, + metadata_exchange, ) import pytest +import socket +import ssl +from tempfile import TemporaryDirectory + +from google.cloud.alloydb.connector.utils import _write_to_file @pytest.fixture @@ -24,6 +30,41 @@ def credentials() -> FakeCredentials: return FakeCredentials() -@pytest.fixture +@pytest.fixture(scope="session") def fake_instance() -> FakeInstance: - return FakeInstance() + instance = FakeInstance() + instance.generate_certs() + return instance + + +@pytest.fixture(scope="session") +def proxy_server(fake_instance: FakeInstance) -> None: + """Run local proxy server capable of performing metadata exchange""" + ip_address = "127.0.0.1" + port = 5433 + # create socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + + root, intermediate, server = fake_instance.get_pem_certs() + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + with TemporaryDirectory() as tmpdir: + ca_filename, _, key_filename = _write_to_file( + tmpdir, server, [root, intermediate], fake_instance.server_key + ) + ssock: ssl.SSLSocket = context.wrap_socket( + sock, server_side=True, certfile=ca_filename, keyfile=key_filename + ) + # bind socket to AlloyDB proxy server port on localhost + ssock.bind((ip_address, port)) + # listen for incoming connections + ssock.listen(5) + + while True: + conn, _ = ssock.accept() + metadata_exchange(conn) + conn.sendall(fake_instance.name.encode("utf-8")) + conn.close() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index d100305..ec500b1 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -14,12 +14,16 @@ from datetime import datetime, timedelta, timezone from typing import Any, Callable, List, Optional, Tuple +import ssl +import struct from cryptography import x509 from cryptography.hazmat.primitives import hashes, serialization from cryptography.hazmat.primitives.asymmetric import rsa from cryptography.x509.oid import NameOID +import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb + class FakeCredentials: def __init__(self) -> None: @@ -189,3 +193,67 @@ async def _get_client_certificate( async def close(self) -> None: pass + + +def metadata_exchange(sock: ssl.SSLSocket) -> None: + """ + Mimics server side metadata exchange behavior in four steps: + + 1. Read a big endian uint32 (4 bytes) from the client. This is the number of + bytes the message consumes. The length does not include the initial four + bytes. + + 2. Read the message from the client using the message length and serialize + it into a MetadataExchangeResponse message. + + The real server implementation will then validate the client has connection + permissions using the provided OAuth2 token based on the auth type. Here in + the test implementation, the server does nothing. + + 3. Prepare a response and write the size of the response as a big endian + uint32 (4 bytes) + + 4. Parse the response to bytes and write those to the client as well. + + Subsequent interactions with the test server use the database protocol. + """ + # read metadata message length (4 bytes) + message_len_buffer_size = struct.Struct("I").size + message_len_buffer = b"" + while message_len_buffer_size > 0: + chunk = sock.recv(message_len_buffer_size) + if not chunk: + raise RuntimeError( + "Connection closed while getting metadata exchange length!" + ) + message_len_buffer += chunk + message_len_buffer_size -= len(chunk) + + (message_len,) = struct.unpack(">I", message_len_buffer) + + # read metadata exchange message + buffer = b"" + while message_len > 0: + chunk = sock.recv(message_len) + if not chunk: + raise RuntimeError("Connection closed while performing metadata exchange!") + buffer += chunk + message_len -= len(chunk) + + # form metadata exchange request to be received from client + message = connectorspb.MetadataExchangeRequest() + # parse metadata exchange request from buffer + message.ParseFromString(buffer) + + # form metadata exchange response to send to client + resp = connectorspb.MetadataExchangeResponse( + response_code=connectorspb.MetadataExchangeResponse.OK + ) + + # pack big-endian unsigned integer (4 bytes) + resp_len = struct.pack(">I", resp.ByteSize()) + + # send metadata response message length + sock.sendall(resp_len) + # send metadata request response message + sock.sendall(resp.SerializeToString()) From 856d0ab17485f2f7e00f902d6d5e092c88e4a5dc Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 14:29:00 +0000 Subject: [PATCH 17/28] chore: first pass at local proxy server --- tests/unit/conftest.py | 70 +++++++++++++++++++++--------------- tests/unit/mocks.py | 11 ++---- tests/unit/test_connector.py | 2 +- tests/unit/test_refresh.py | 1 - 4 files changed, 46 insertions(+), 38 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4e171ee..f268e5f 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -13,6 +13,7 @@ # limitations under the License. from mocks import ( + FakeAlloyDBClient, FakeCredentials, FakeInstance, metadata_exchange, @@ -21,6 +22,8 @@ import socket import ssl from tempfile import TemporaryDirectory +from threading import Thread +from typing import Generator from google.cloud.alloydb.connector.utils import _write_to_file @@ -32,39 +35,50 @@ def credentials() -> FakeCredentials: @pytest.fixture(scope="session") def fake_instance() -> FakeInstance: - instance = FakeInstance() - instance.generate_certs() - return instance + return FakeInstance() -@pytest.fixture(scope="session") -def proxy_server(fake_instance: FakeInstance) -> None: +@pytest.fixture +def fake_client(fake_instance: FakeInstance) -> FakeAlloyDBClient: + return FakeAlloyDBClient(fake_instance) + + +def start_proxy_server(instance: FakeInstance) -> None: """Run local proxy server capable of performing metadata exchange""" ip_address = "127.0.0.1" port = 5433 # create socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - # create SSL/TLS context - context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: + # create SSL/TLS context + context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + root, intermediate, server = instance.get_pem_certs() + # tmpdir and its contents are automatically deleted after the CA cert + # and cert chain are loaded into the SSLcontext. The values + # need to be written to files in order to be loaded by the SSLContext + with TemporaryDirectory() as tmpdir: + ca_filename, cert_chain_filename, key_filename = _write_to_file( + tmpdir, server, [server, root], instance.server_key + ) + context.load_cert_chain(cert_chain_filename, key_filename) + # bind socket to AlloyDB proxy server port on localhost + sock.bind((ip_address, port)) + # listen for incoming connections + sock.listen(5) + + while True: + print("WAITING!!!!!") + with context.wrap_socket(sock, server_side=True) as ssock: + conn, _ = ssock.accept() + print("GOT CONNECTION!!!!!!!!") + metadata_exchange(conn) + conn.sendall(instance.name.encode("utf-8")) + conn.close() - root, intermediate, server = fake_instance.get_pem_certs() - # tmpdir and its contents are automatically deleted after the CA cert - # and cert chain are loaded into the SSLcontext. The values - # need to be written to files in order to be loaded by the SSLContext - with TemporaryDirectory() as tmpdir: - ca_filename, _, key_filename = _write_to_file( - tmpdir, server, [root, intermediate], fake_instance.server_key - ) - ssock: ssl.SSLSocket = context.wrap_socket( - sock, server_side=True, certfile=ca_filename, keyfile=key_filename - ) - # bind socket to AlloyDB proxy server port on localhost - ssock.bind((ip_address, port)) - # listen for incoming connections - ssock.listen(5) - while True: - conn, _ = ssock.accept() - metadata_exchange(conn) - conn.sendall(fake_instance.name.encode("utf-8")) - conn.close() +@pytest.fixture(autouse=True, scope="session") +def proxy_server(fake_instance: FakeInstance) -> Generator: + """Run local proxy server capable of performing metadata exchange""" + thread = Thread(target=start_proxy_server, args=(fake_instance,), daemon=True) + thread.start() + yield thread + thread.join() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ec500b1..ce0cbfd 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -119,10 +119,6 @@ def __init__( self.cert_before = cert_before self.cert_expiry = cert_expiry - def generate_certs(self) -> None: - """ - Build certs required for chain of trust with testing server. - """ # build root cert self.root_cert, self.root_key = generate_cert("root.alloydb") # create self signed root cert @@ -155,8 +151,8 @@ def get_pem_certs(self) -> Tuple[str, str, str]: class FakeAlloyDBClient: """Fake class for testing AlloyDBClient""" - def __init__(self) -> None: - self.instance = FakeInstance() + def __init__(self, instance: Optional[FakeInstance] = None) -> None: + self.instance = FakeInstance() if instance is None else instance async def _get_metadata(*args: Any, **kwargs: Any) -> str: return "127.0.0.1" @@ -168,7 +164,6 @@ async def _get_client_certificate( cluster: str, pub_key: str, ) -> Tuple[str, List[str]]: - self.instance.generate_certs() root_cert, intermediate_cert, ca_cert = self.instance.get_pem_certs() # encode public key to bytes pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( @@ -189,7 +184,7 @@ async def _get_client_certificate( client_cert = client_cert.public_bytes( encoding=serialization.Encoding.PEM ).decode("UTF-8") - return (ca_cert, [client_cert, intermediate_cert, root_cert]) + return (root_cert, [client_cert, intermediate_cert, root_cert]) async def close(self) -> None: pass diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 7d77670..1b30f4d 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -61,7 +61,7 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False -def test_connect(credentials: FakeCredentials) -> None: +def test_connect(credentials: FakeCredentials, proxy_server) -> None: """ Test that connector.connect returns connection object. """ diff --git a/tests/unit/test_refresh.py b/tests/unit/test_refresh.py index 90d6891..bfd75f9 100644 --- a/tests/unit/test_refresh.py +++ b/tests/unit/test_refresh.py @@ -61,7 +61,6 @@ def test_RefreshResult_init_(fake_instance: FakeInstance) -> None: can correctly initialize TLS context. """ key = rsa.generate_private_key(public_exponent=65537, key_size=2048) - fake_instance.generate_certs() root_cert, intermediate_cert, ca_cert = fake_instance.get_pem_certs() # build client cert client_cert = ( From df031a8ae49d1e7466ed8a302d76efddaa021cb9 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 14:49:41 +0000 Subject: [PATCH 18/28] chore: remove prints --- tests/unit/conftest.py | 2 -- tests/unit/test_connector.py | 2 +- 2 files changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index f268e5f..7713b14 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -66,10 +66,8 @@ def start_proxy_server(instance: FakeInstance) -> None: sock.listen(5) while True: - print("WAITING!!!!!") with context.wrap_socket(sock, server_side=True) as ssock: conn, _ = ssock.accept() - print("GOT CONNECTION!!!!!!!!") metadata_exchange(conn) conn.sendall(instance.name.encode("utf-8")) conn.close() diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 1b30f4d..7d77670 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -61,7 +61,7 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False -def test_connect(credentials: FakeCredentials, proxy_server) -> None: +def test_connect(credentials: FakeCredentials) -> None: """ Test that connector.connect returns connection object. """ From 930d8943807c586b4993dc17787621855deb3905 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 16:23:56 +0000 Subject: [PATCH 19/28] chore: add fixture to be used by test --- tests/unit/test_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 7d77670..55350aa 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,7 +16,7 @@ from threading import Thread from mock import patch -from mocks import FakeAlloyDBClient, FakeCredentials +from mocks import FakeAlloyDBClient, FakeCredentials, FakeInstance import pytest from google.cloud.alloydb.connector import Connector @@ -61,11 +61,11 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False -def test_connect(credentials: FakeCredentials) -> None: +def test_connect(credentials: FakeCredentials, fake_instance: FakeInstance) -> None: """ Test that connector.connect returns connection object. """ - client = FakeAlloyDBClient() + client = FakeAlloyDBClient(instance=fake_instance) with Connector(credentials) as connector: connector._client = client # patch db connection creation From ebd7e7a8e672e154946cd9a02df62e1956681a68 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 19:19:04 +0000 Subject: [PATCH 20/28] chore: get proxy server working --- tests/unit/mocks.py | 7 ++++--- tests/unit/test_connector.py | 6 +++--- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index ce0cbfd..fc67b20 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -153,9 +153,10 @@ class FakeAlloyDBClient: def __init__(self, instance: Optional[FakeInstance] = None) -> None: self.instance = FakeInstance() if instance is None else instance + self._user_agent = "test-user-agent" - async def _get_metadata(*args: Any, **kwargs: Any) -> str: - return "127.0.0.1" + async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: + return self.instance.ip_address async def _get_client_certificate( self, @@ -184,7 +185,7 @@ async def _get_client_certificate( client_cert = client_cert.public_bytes( encoding=serialization.Encoding.PEM ).decode("UTF-8") - return (root_cert, [client_cert, intermediate_cert, root_cert]) + return (ca_cert, [client_cert, intermediate_cert, root_cert]) async def close(self) -> None: pass diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 55350aa..8f53776 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -16,7 +16,7 @@ from threading import Thread from mock import patch -from mocks import FakeAlloyDBClient, FakeCredentials, FakeInstance +from mocks import FakeAlloyDBClient, FakeCredentials import pytest from google.cloud.alloydb.connector import Connector @@ -61,11 +61,11 @@ def test_Connector_close(credentials: FakeCredentials) -> None: assert thread.is_alive() is False -def test_connect(credentials: FakeCredentials, fake_instance: FakeInstance) -> None: +def test_connect(credentials: FakeCredentials, fake_client: FakeAlloyDBClient) -> None: """ Test that connector.connect returns connection object. """ - client = FakeAlloyDBClient(instance=fake_instance) + client = fake_client with Connector(credentials) as connector: connector._client = client # patch db connection creation From d49316bd8e7cd69e569c0543549c3c1538abef79 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 19:43:09 +0000 Subject: [PATCH 21/28] chore: add pyi file for resources_pb2.py --- .../proto/resources_pb2.pyi | 52 +++++++++++++++++++ tests/unit/conftest.py | 1 + 2 files changed, 53 insertions(+) create mode 100644 google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi new file mode 100644 index 0000000..d2b15a6 --- /dev/null +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi @@ -0,0 +1,52 @@ +from typing import ClassVar as _ClassVar +from typing import Optional as _Optional +from typing import Union as _Union + +from google.protobuf import descriptor as _descriptor +from google.protobuf import message as _message +from google.protobuf.internal import enum_type_wrapper as _enum_type_wrapper + +from google.api import field_behavior_pb2 as _field_behavior_pb2 + +DESCRIPTOR: _descriptor.FileDescriptor + +class MetadataExchangeRequest(_message.Message): + __slots__ = ["auth_type", "oauth2_token", "user_agent"] + + class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + AUTH_TYPE_FIELD_NUMBER: _ClassVar[int] + AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType + AUTO_IAM: MetadataExchangeRequest.AuthType + DB_NATIVE: MetadataExchangeRequest.AuthType + OAUTH2_TOKEN_FIELD_NUMBER: _ClassVar[int] + USER_AGENT_FIELD_NUMBER: _ClassVar[int] + auth_type: MetadataExchangeRequest.AuthType + oauth2_token: str + user_agent: str + def __init__( + self, + user_agent: _Optional[str] = ..., + auth_type: _Optional[_Union[MetadataExchangeRequest.AuthType, str]] = ..., + oauth2_token: _Optional[str] = ..., + ) -> None: ... + +class MetadataExchangeResponse(_message.Message): + __slots__ = ["error", "response_code"] + + class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): + __slots__ = [] + ERROR: MetadataExchangeResponse.ResponseCode + ERROR_FIELD_NUMBER: _ClassVar[int] + OK: MetadataExchangeResponse.ResponseCode + RESPONSE_CODE_FIELD_NUMBER: _ClassVar[int] + RESPONSE_CODE_UNSPECIFIED: MetadataExchangeResponse.ResponseCode + error: str + response_code: MetadataExchangeResponse.ResponseCode + def __init__( + self, + response_code: _Optional[ + _Union[MetadataExchangeResponse.ResponseCode, str] + ] = ..., + error: _Optional[str] = ..., + ) -> None: ... diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 11faf95..2b1cf8d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -50,6 +50,7 @@ def start_proxy_server(instance: FakeInstance) -> None: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: # create SSL/TLS context context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) + context.minimum_version = ssl.TLSVersion.TLSv1_3 root, intermediate, server = instance.get_pem_certs() # tmpdir and its contents are automatically deleted after the CA cert # and cert chain are loaded into the SSLcontext. The values From 8de1691ef723c967021ae1436871f3f050dec9be Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 20:28:40 +0000 Subject: [PATCH 22/28] chore: don't use metadata exchange for asyncpg --- .../cloud/alloydb/connector/async_connector.py | 1 + google/cloud/alloydb/connector/client.py | 14 ++++++++++---- google/cloud/alloydb/connector/connector.py | 5 ++++- .../proto/resources_pb2.pyi | 18 ++++++++++++++++-- tests/unit/mocks.py | 6 ++++-- 5 files changed, 35 insertions(+), 9 deletions(-) diff --git a/google/cloud/alloydb/connector/async_connector.py b/google/cloud/alloydb/connector/async_connector.py index 1c36801..38a292b 100644 --- a/google/cloud/alloydb/connector/async_connector.py +++ b/google/cloud/alloydb/connector/async_connector.py @@ -96,6 +96,7 @@ async def connect( self._alloydb_api_endpoint, self._quota_project, self._credentials, + driver=driver, ) # use existing connection info if possible diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index e324227..98a4dcb 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -38,6 +38,7 @@ def __init__( quota_project: Optional[str], credentials: Credentials, client: Optional[aiohttp.ClientSession] = None, + driver: Optional[str] = None, ) -> None: """ Establish the client to be used for AlloyDB Admin API requests. @@ -55,10 +56,12 @@ def __init__( client (aiohttp.ClientSession): Async client used to make requests to AlloyDB Admin APIs. Optional, defaults to None and creates new client. + driver (str): Database driver to be used by the client. """ + user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT headers = { - "x-goog-api-client": USER_AGENT, - "User-Agent": USER_AGENT, + "x-goog-api-client": user_agent, + "User-Agent": user_agent, "Content-Type": "application/json", } if quota_project: @@ -67,7 +70,7 @@ def __init__( self._client = client if client else aiohttp.ClientSession(headers=headers) self._credentials = credentials self._alloydb_api_endpoint = alloydb_api_endpoint - self._user_agent = USER_AGENT + self._user_agent = user_agent async def _get_metadata( self, @@ -147,10 +150,13 @@ async def _get_client_certificate( url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate" + # asyncpg does not currently support using metadata exchange + # only use metadata exchangefor pg8000 driver + use_metadata = self._user_agent.endswith("pg8000") data = { "publicKey": pub_key, "certDuration": "3600s", - "useMetadataExchange": True, + "useMetadataExchange": use_metadata, } resp = await self._client.post( diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index f1d5e55..e79037e 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -131,7 +131,10 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) -> if self._client is None: # lazy init client as it has to be initialized in async context self._client = AlloyDBClient( - self._alloydb_api_endpoint, self._quota_project, self._credentials + self._alloydb_api_endpoint, + self._quota_project, + self._credentials, + driver=driver, ) enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) # use existing connection info if possible diff --git a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi index d2b15a6..3f9db56 100644 --- a/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi +++ b/google/cloud/alloydb_connectors_v1/proto/resources_pb2.pyi @@ -1,3 +1,17 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from typing import ClassVar as _ClassVar from typing import Optional as _Optional from typing import Union as _Union @@ -14,7 +28,7 @@ class MetadataExchangeRequest(_message.Message): __slots__ = ["auth_type", "oauth2_token", "user_agent"] class AuthType(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = [] # type: ignore AUTH_TYPE_FIELD_NUMBER: _ClassVar[int] AUTH_TYPE_UNSPECIFIED: MetadataExchangeRequest.AuthType AUTO_IAM: MetadataExchangeRequest.AuthType @@ -35,7 +49,7 @@ class MetadataExchangeResponse(_message.Message): __slots__ = ["error", "response_code"] class ResponseCode(int, metaclass=_enum_type_wrapper.EnumTypeWrapper): - __slots__ = [] + __slots__ = [] # type: ignore ERROR: MetadataExchangeResponse.ResponseCode ERROR_FIELD_NUMBER: _ClassVar[int] OK: MetadataExchangeResponse.ResponseCode diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index e7fecd3..f5dd502 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -155,10 +155,12 @@ def get_pem_certs(self) -> Tuple[str, str, str]: class FakeAlloyDBClient: """Fake class for testing AlloyDBClient""" - def __init__(self, instance: Optional[FakeInstance] = None) -> None: + def __init__( + self, instance: Optional[FakeInstance] = None, driver: str = "pg8000" + ) -> None: self.instance = FakeInstance() if instance is None else instance self.closed = False - self._user_agent = "test-user-agent" + self._user_agent = f"test-user-agent+{driver}" async def _get_metadata(self, *args: Any, **kwargs: Any) -> str: return self.instance.ip_address From edd97152ca78f820fdce703ae01f3c357641daa6 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 20:50:39 +0000 Subject: [PATCH 23/28] chore: address comments --- google/cloud/alloydb/connector/connector.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index e79037e..2b2581f 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -242,7 +242,7 @@ def metadata_exchange( resp = connectorspb.MetadataExchangeResponse() # read metadata message length (4 bytes) - message_len_buffer_size = struct.Struct("I").size + message_len_buffer_size = struct.Struct(">I").size message_len_buffer = b"" while message_len_buffer_size > 0: chunk = sock.recv(message_len_buffer_size) @@ -271,7 +271,9 @@ def metadata_exchange( # validate metadata exchange response if resp.response_code != connectorspb.MetadataExchangeResponse.OK: - raise ValueError("Metadata Exchange request has failed!") + raise ValueError( + f"Metadata Exchange request has failed with error: {resp.error}" + ) return sock From 38760fac3be92078a107f9f6583a0c1a27731574 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 20:55:53 +0000 Subject: [PATCH 24/28] chore: send bytes all at once --- google/cloud/alloydb/connector/connector.py | 6 ++---- tests/unit/mocks.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 2b2581f..6fb7010 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -233,10 +233,8 @@ def metadata_exchange( # pack big-endian unsigned integer (4 bytes) packed_len = struct.pack(">I", req.ByteSize()) - # send metadata message length - sock.sendall(packed_len) - # send metadata request message - sock.sendall(req.SerializeToString()) + # send metadata message length and request message + sock.sendall(packed_len + req.SerializeToString()) # form metadata exchange response resp = connectorspb.MetadataExchangeResponse() diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index f5dd502..7ad11ac 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -256,10 +256,8 @@ def metadata_exchange(sock: ssl.SSLSocket) -> None: # pack big-endian unsigned integer (4 bytes) resp_len = struct.pack(">I", resp.ByteSize()) - # send metadata response message length - sock.sendall(resp_len) - # send metadata request response message - sock.sendall(resp.SerializeToString()) + # send metadata response length and response message + sock.sendall(resp_len + resp.SerializeToString()) class FakeConnectionInfo: From ac1b8fdbd4ad48e5e5944fa82048738e9bd40142 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Fri, 12 Jan 2024 21:25:51 +0000 Subject: [PATCH 25/28] chore: set socket back to blocking after metadata exchange --- google/cloud/alloydb/connector/connector.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 6fb7010..1c33bce 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -267,6 +267,9 @@ def metadata_exchange( # parse metadata exchange response from buffer resp.ParseFromString(buffer) + # reset socket back to blocking mode + sock.setblocking(True) + # validate metadata exchange response if resp.response_code != connectorspb.MetadataExchangeResponse.OK: raise ValueError( From 50ad0bc0c4e0cbba2fb9b2a0e2693760cb369151 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 15 Jan 2024 15:03:30 +0000 Subject: [PATCH 26/28] chore: review nits --- google/cloud/alloydb/connector/client.py | 2 +- google/cloud/alloydb/connector/connector.py | 5 ++++- tests/unit/conftest.py | 4 ++-- tests/unit/mocks.py | 4 ++-- 4 files changed, 9 insertions(+), 6 deletions(-) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 98a4dcb..d63552c 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -151,7 +151,7 @@ async def _get_client_certificate( url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate" # asyncpg does not currently support using metadata exchange - # only use metadata exchangefor pg8000 driver + # only use metadata exchange for pg8000 driver use_metadata = self._user_agent.endswith("pg8000") data = { "publicKey": pub_key, diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 1c33bce..504a2e6 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -36,7 +36,10 @@ from google.auth.credentials import Credentials +# the port the AlloyDB server-side proxy receives connections on SERVER_PROXY_PORT = 5433 +# the maximum amount of time to wait before aborting a metadata exchange +IO_TIMEOUT = 30 class Connector: @@ -228,7 +231,7 @@ def metadata_exchange( ) # set I/O timeout - sock.settimeout(30) + sock.settimeout(IO_TIMEOUT) # pack big-endian unsigned integer (4 bytes) packed_len = struct.pack(">I", req.ByteSize()) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 2b1cf8d..8c6f097 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -51,12 +51,12 @@ def start_proxy_server(instance: FakeInstance) -> None: # create SSL/TLS context context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER) context.minimum_version = ssl.TLSVersion.TLSv1_3 - root, intermediate, server = instance.get_pem_certs() + root, _, server = instance.get_pem_certs() # tmpdir and its contents are automatically deleted after the CA cert # and cert chain are loaded into the SSLcontext. The values # need to be written to files in order to be loaded by the SSLContext with TemporaryDirectory() as tmpdir: - ca_filename, cert_chain_filename, key_filename = _write_to_file( + _, cert_chain_filename, key_filename = _write_to_file( tmpdir, server, [server, root], instance.server_key ) context.load_cert_chain(cert_chain_filename, key_filename) diff --git a/tests/unit/mocks.py b/tests/unit/mocks.py index 7ad11ac..1acd851 100644 --- a/tests/unit/mocks.py +++ b/tests/unit/mocks.py @@ -172,7 +172,7 @@ async def _get_client_certificate( cluster: str, pub_key: str, ) -> Tuple[str, List[str]]: - root_cert, intermediate_cert, ca_cert = self.instance.get_pem_certs() + root_cert, intermediate_cert, server_cert = self.instance.get_pem_certs() # encode public key to bytes pub_key_bytes: rsa.RSAPublicKey = serialization.load_pem_public_key( pub_key.encode("UTF-8"), @@ -192,7 +192,7 @@ async def _get_client_certificate( client_cert = client_cert.public_bytes( encoding=serialization.Encoding.PEM ).decode("UTF-8") - return (ca_cert, [client_cert, intermediate_cert, root_cert]) + return (server_cert, [client_cert, intermediate_cert, root_cert]) async def close(self) -> None: self.closed = True From 60c145edeff45909f8e9ff9a706a09245fe6812e Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 15 Jan 2024 16:35:44 +0000 Subject: [PATCH 27/28] chore: add tests --- google/cloud/alloydb/connector/client.py | 9 ++--- tests/unit/test_client.py | 51 +++++++++++++++++++++++- 2 files changed, 54 insertions(+), 6 deletions(-) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index d63552c..48b92d8 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -70,7 +70,9 @@ def __init__( self._client = client if client else aiohttp.ClientSession(headers=headers) self._credentials = credentials self._alloydb_api_endpoint = alloydb_api_endpoint - self._user_agent = user_agent + # asyncpg does not currently support using metadata exchange + # only use metadata exchange for pg8000 driver + self._use_metadata = True if driver == "pg8000" else False async def _get_metadata( self, @@ -150,13 +152,10 @@ async def _get_client_certificate( url = f"{self._alloydb_api_endpoint}/{API_VERSION}/projects/{project}/locations/{region}/clusters/{cluster}:generateClientCertificate" - # asyncpg does not currently support using metadata exchange - # only use metadata exchange for pg8000 driver - use_metadata = self._user_agent.endswith("pg8000") data = { "publicKey": pub_key, "certDuration": "3600s", - "useMetadataExchange": use_metadata, + "useMetadataExchange": self._use_metadata, } resp = await self._client.post( diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index fa686ed..2cc68b1 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -13,7 +13,7 @@ # limitations under the License. import json -from typing import Any +from typing import Any, Optional from aiohttp import web from mocks import FakeCredentials @@ -102,3 +102,52 @@ async def test_AlloyDBClient_init_(credentials: FakeCredentials) -> None: assert client._client.headers["x-goog-user-project"] == "my-quota-project" # close client await client.close() + + +@pytest.mark.parametrize( + "driver", + [None, "pg8000", "asyncpg"], +) +@pytest.mark.asyncio +async def test_AlloyDBClient_user_agent( + driver: Optional[str], credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of AlloyDBClient + properly sets user agent when passed a database driver. + """ + client = AlloyDBClient( + "www.test-endpoint.com", "my-quota-project", credentials, driver=driver + ) + if driver is None: + assert ( + client._client.headers["User-Agent"] + == f"alloydb-python-connector/{version}" + ) + else: + assert ( + client._client.headers["User-Agent"] + == f"alloydb-python-connector/{version}+{driver}" + ) + # close client + await client.close() + + +@pytest.mark.parametrize( + "driver, expected", + [(None, False), ("pg8000", True), ("asyncpg", False)], +) +@pytest.mark.asyncio +async def test_AlloyDBClient_use_metadata( + driver: Optional[str], expected: bool, credentials: FakeCredentials +) -> None: + """ + Test to check whether the __init__ method of AlloyDBClient + properly sets use_metadata. + """ + client = AlloyDBClient( + "www.test-endpoint.com", "my-quota-project", credentials, driver=driver + ) + assert client._use_metadata == expected + # close client + await client.close() From 9e444b105fcbcdb9bc303e18358cf0667bbd7185 Mon Sep 17 00:00:00 2001 From: jackwotherspoon Date: Mon, 15 Jan 2024 16:42:37 +0000 Subject: [PATCH 28/28] chore: re-add user agent to client --- google/cloud/alloydb/connector/client.py | 1 + google/cloud/alloydb/connector/connector.py | 2 +- tests/unit/test_client.py | 10 ++-------- 3 files changed, 4 insertions(+), 9 deletions(-) diff --git a/google/cloud/alloydb/connector/client.py b/google/cloud/alloydb/connector/client.py index 48b92d8..f6a8d3f 100644 --- a/google/cloud/alloydb/connector/client.py +++ b/google/cloud/alloydb/connector/client.py @@ -73,6 +73,7 @@ def __init__( # asyncpg does not currently support using metadata exchange # only use metadata exchange for pg8000 driver self._use_metadata = True if driver == "pg8000" else False + self._user_agent = user_agent async def _get_metadata( self, diff --git a/google/cloud/alloydb/connector/connector.py b/google/cloud/alloydb/connector/connector.py index 504a2e6..1a6d6da 100644 --- a/google/cloud/alloydb/connector/connector.py +++ b/google/cloud/alloydb/connector/connector.py @@ -225,7 +225,7 @@ def metadata_exchange( # form metadata exchange request req = connectorspb.MetadataExchangeRequest( - user_agent=f"{self._client._user_agent}+{driver}", # type: ignore + user_agent=f"{self._client._user_agent}", # type: ignore auth_type=auth_type, oauth2_token=self._credentials.token, ) diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index 2cc68b1..fada5e4 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -120,15 +120,9 @@ async def test_AlloyDBClient_user_agent( "www.test-endpoint.com", "my-quota-project", credentials, driver=driver ) if driver is None: - assert ( - client._client.headers["User-Agent"] - == f"alloydb-python-connector/{version}" - ) + assert client._user_agent == f"alloydb-python-connector/{version}" else: - assert ( - client._client.headers["User-Agent"] - == f"alloydb-python-connector/{version}+{driver}" - ) + assert client._user_agent == f"alloydb-python-connector/{version}+{driver}" # close client await client.close()