Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add support for auto IAM authentication to Connector #191

Merged
merged 30 commits into from
Jan 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
d9557e6
chore: first attempt
jackwotherspoon Dec 21, 2023
43fd281
chore: add socket read and write
jackwotherspoon Dec 27, 2023
979cc3d
chore: Merge branch 'main' into metadata-exchange
jackwotherspoon Dec 27, 2023
96669eb
chore: type hint ssl
jackwotherspoon Dec 27, 2023
65d276d
chore: add protofbuf type
jackwotherspoon Dec 27, 2023
032f9e6
chore: add google/api dep
jackwotherspoon Dec 28, 2023
0700411
chore: add header files
jackwotherspoon Dec 28, 2023
fc847fa
chore: fix read of response
jackwotherspoon Dec 28, 2023
996fef6
chore: fix message size
jackwotherspoon Dec 28, 2023
658ae5a
chore: set useMetadataExchange to True
jackwotherspoon Jan 1, 2024
1460ba0
chore: lint and headers
jackwotherspoon Jan 1, 2024
2b44528
chore: lint
jackwotherspoon Jan 1, 2024
9e03e77
chore: add IAM authn test
jackwotherspoon Jan 1, 2024
ffbed69
chore: update doc strings
jackwotherspoon Jan 1, 2024
d064f94
chore: fix iam authn flag
jackwotherspoon Jan 1, 2024
46d6938
chore: add docstring for metdata_exchange
jackwotherspoon Jan 2, 2024
f0e205e
chore: first pass at local proxy server for testing
jackwotherspoon Jan 10, 2024
856d0ab
chore: first pass at local proxy server
jackwotherspoon Jan 12, 2024
df031a8
chore: remove prints
jackwotherspoon Jan 12, 2024
930d894
chore: add fixture to be used by test
jackwotherspoon Jan 12, 2024
ebd7e7a
chore: get proxy server working
jackwotherspoon Jan 12, 2024
68776f3
chore: merge main
jackwotherspoon Jan 12, 2024
d49316b
chore: add pyi file for resources_pb2.py
jackwotherspoon Jan 12, 2024
8de1691
chore: don't use metadata exchange for asyncpg
jackwotherspoon Jan 12, 2024
edd9715
chore: address comments
jackwotherspoon Jan 12, 2024
38760fa
chore: send bytes all at once
jackwotherspoon Jan 12, 2024
ac1b8fd
chore: set socket back to blocking after metadata exchange
jackwotherspoon Jan 12, 2024
50ad0bc
chore: review nits
jackwotherspoon Jan 15, 2024
60c145e
chore: add tests
jackwotherspoon Jan 15, 2024
9e444b1
chore: re-add user agent to client
jackwotherspoon Jan 15, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions google/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
52 changes: 52 additions & 0 deletions google/api/field_behavior_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 Google LLC
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# type: ignore
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: google/api/field_behavior.proto
# isort: skip_file
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder

# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3"
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "google.api.field_behavior_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(
field_behavior
)

DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI"
_globals["_FIELDBEHAVIOR"]._serialized_start = 82
_globals["_FIELDBEHAVIOR"]._serialized_end = 264
# @@protoc_insertion_point(module_scope)
1 change: 1 addition & 0 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def connect(
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)

# use existing connection info if possible
Expand Down
12 changes: 10 additions & 2 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
driver: Optional[str] = None,
) -> None:
"""
Establish the client to be used for AlloyDB Admin API requests.
Expand All @@ -55,10 +56,12 @@ def __init__(
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB Admin APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
headers = {
"x-goog-api-client": USER_AGENT,
"User-Agent": USER_AGENT,
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
Expand All @@ -67,6 +70,10 @@ def __init__(
self._client = client if client else aiohttp.ClientSession(headers=headers)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for pg8000 driver
self._use_metadata = True if driver == "pg8000" else False
self._user_agent = user_agent

async def _get_metadata(
self,
Expand Down Expand Up @@ -149,6 +156,7 @@ async def _get_client_certificate(
data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": self._use_metadata,
}

resp = await self._client.post(
Expand Down
128 changes: 126 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import asyncio
from functools import partial
import socket
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
Expand All @@ -27,10 +29,18 @@
from google.cloud.alloydb.connector.instance import Instance
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.utils import generate_keys
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

if TYPE_CHECKING:
import ssl

from google.auth.credentials import Credentials

# the port the AlloyDB server-side proxy receives connections on
SERVER_PROXY_PORT = 5433
# the maximum amount of time to wait before aborting a metadata exchange
IO_TIMEOUT = 30


class Connector:
"""A class to configure and create connections to Cloud SQL instances.
Expand All @@ -45,13 +55,15 @@ class Connector:
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
"""

def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
) -> None:
# create event loop and start it in background thread
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
Expand All @@ -61,6 +73,7 @@ def __init__(
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
Expand Down Expand Up @@ -121,8 +134,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = AlloyDBClient(
self._alloydb_api_endpoint, self._quota_project, self._credentials
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
# use existing connection info if possible
if instance_uri in self._instances:
instance = self._instances[instance_uri]
Expand Down Expand Up @@ -150,13 +167,120 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->

# synchronous drivers are blocking and run using executor
try:
connect_partial = partial(connector, ip_address, context, **kwargs)
metadata_partial = partial(
jackwotherspoon marked this conversation as resolved.
Show resolved Hide resolved
self.metadata_exchange, ip_address, context, enable_iam_auth, driver
)
sock = await self._loop.run_in_executor(None, metadata_partial)
connect_partial = partial(connector, sock, **kwargs)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# we attempt a force refresh, then throw the error
await instance.force_refresh()
raise

def metadata_exchange(
self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str
) -> ssl.SSLSocket:
"""
Sends metadata about the connection prior to the database
protocol taking over.

The exchange consists of four steps:

1. Prepare a MetadataExchangeRequest including the IAM Principal's
OAuth2 token, the user agent, and the requested authentication type.

2. Write the size of the message as a big endian uint32 (4 bytes) to
the server followed by the serialized message. The length does not
include the initial four bytes.

3. Read a big endian uint32 (4 bytes) from the server. This is the
MetadataExchangeResponse message length and does not include the
initial four bytes.

4. Parse the response using the message length in step 3. If the
response is not OK, return the response's error. If there is no error,
the metadata exchange has succeeded and the connection is complete.

Args:
ip_address (str): IP address of AlloyDB instance to connect to.
ctx (ssl.SSLContext): Context used to create a TLS connection
with AlloyDB instance ssl certificates.
enable_iam_auth (bool): Flag to enable IAM database authentication.
driver (str): A string representing the database driver to connect with.
Supported drivers are pg8000.

Returns:
sock (ssl.SSLSocket): mTLS/SSL socket connected to AlloyDB Proxy server.
"""
# Create socket and wrap with SSL/TLS context
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# set auth type for metadata exchange
auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE
if enable_iam_auth:
auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM

# form metadata exchange request
req = connectorspb.MetadataExchangeRequest(
user_agent=f"{self._client._user_agent}", # type: ignore
auth_type=auth_type,
oauth2_token=self._credentials.token,
)

# set I/O timeout
sock.settimeout(IO_TIMEOUT)

# pack big-endian unsigned integer (4 bytes)
packed_len = struct.pack(">I", req.ByteSize())

# send metadata message length and request message
sock.sendall(packed_len + req.SerializeToString())

# form metadata exchange response
resp = connectorspb.MetadataExchangeResponse()

# read metadata message length (4 bytes)
message_len_buffer_size = struct.Struct(">I").size
message_len_buffer = b""
while message_len_buffer_size > 0:
chunk = sock.recv(message_len_buffer_size)
if not chunk:
raise RuntimeError(
"Connection closed while getting metadata exchange length!"
)
message_len_buffer += chunk
message_len_buffer_size -= len(chunk)

(message_len,) = struct.unpack(">I", message_len_buffer)

# read metadata exchange message
buffer = b""
while message_len > 0:
chunk = sock.recv(message_len)
if not chunk:
raise RuntimeError(
"Connection closed while performing metadata exchange!"
)
buffer += chunk
message_len -= len(chunk)

# parse metadata exchange response from buffer
resp.ParseFromString(buffer)

# reset socket back to blocking mode
sock.setblocking(True)

# validate metadata exchange response
if resp.response_code != connectorspb.MetadataExchangeResponse.OK:
raise ValueError(
f"Metadata Exchange request has failed with error: {resp.error}"
)

return sock

def __enter__(self) -> "Connector":
"""Enter context manager by returning Connector object"""
return self
Expand Down
26 changes: 6 additions & 20 deletions google/cloud/alloydb/connector/pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import ssl
from typing import Any, TYPE_CHECKING

SERVER_PROXY_PORT = 5433

if TYPE_CHECKING:
import ssl

import pg8000


def connect(
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
) -> "pg8000.dbapi.Connection":
def connect(sock: "ssl.SSLSocket", **kwargs: Any) -> "pg8000.dbapi.Connection":
"""Create a pg8000 DBAPI connection object.

Args:
ip_address (str): IP address of AlloyDB instance to connect to.
ctx (ssl.SSLContext): Context used to create a TLS connection
with AlloyDB instance ssl certificates.
sock (ssl.SSLSocket): SSL/TLS secure socket stream connected to the
AlloyDB proxy server.

Returns:
pg8000.dbapi.Connection: A pg8000 Connection object for
the AlloyDB instance.
"""
# Connecting through pg8000 is done by passing in an SSL Context and setting the
# "request_ssl" attr to false. This works because when "request_ssl" is false,
# the driver skips the database level SSL/TLS exchange, but still uses the
# ssl_context (if it is not None) to create the connection.
try:
import pg8000
except ImportError:
raise ImportError(
'Unable to import module "pg8000." Please install and try again.'
)
# Create socket and wrap with context.
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)

user = kwargs.pop("user")
db = kwargs.pop("db")
passwd = kwargs.pop("password")
passwd = kwargs.pop("password", None)
return pg8000.dbapi.connect(
user,
database=db,
Expand Down
13 changes: 13 additions & 0 deletions google/cloud/alloydb_connectors_v1/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading
Loading