Skip to content

Commit

Permalink
feat: add support for auto IAM authentication to Connector (#191)
Browse files Browse the repository at this point in the history
  • Loading branch information
jackwotherspoon committed Jan 15, 2024
1 parent 35d417d commit c6c16e8
Show file tree
Hide file tree
Showing 17 changed files with 621 additions and 40 deletions.
13 changes: 13 additions & 0 deletions google/api/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
52 changes: 52 additions & 0 deletions google/api/field_behavior_pb2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# type: ignore
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: google/api/field_behavior.proto
# isort: skip_file
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder

# @@protoc_insertion_point(imports)

_sym_db = _symbol_database.Default()


from google.protobuf import descriptor_pb2 as google_dot_protobuf_dot_descriptor__pb2


DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(
b"\n\x1fgoogle/api/field_behavior.proto\x12\ngoogle.api\x1a google/protobuf/descriptor.proto*\xb6\x01\n\rFieldBehavior\x12\x1e\n\x1a\x46IELD_BEHAVIOR_UNSPECIFIED\x10\x00\x12\x0c\n\x08OPTIONAL\x10\x01\x12\x0c\n\x08REQUIRED\x10\x02\x12\x0f\n\x0bOUTPUT_ONLY\x10\x03\x12\x0e\n\nINPUT_ONLY\x10\x04\x12\r\n\tIMMUTABLE\x10\x05\x12\x12\n\x0eUNORDERED_LIST\x10\x06\x12\x15\n\x11NON_EMPTY_DEFAULT\x10\x07\x12\x0e\n\nIDENTIFIER\x10\x08:Q\n\x0e\x66ield_behavior\x12\x1d.google.protobuf.FieldOptions\x18\x9c\x08 \x03(\x0e\x32\x19.google.api.FieldBehaviorBp\n\x0e\x63om.google.apiB\x12\x46ieldBehaviorProtoP\x01ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\xa2\x02\x04GAPIb\x06proto3"
)

_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(
DESCRIPTOR, "google.api.field_behavior_pb2", _globals
)
if _descriptor._USE_C_DESCRIPTORS == False:
google_dot_protobuf_dot_descriptor__pb2.FieldOptions.RegisterExtension(
field_behavior
)

DESCRIPTOR._options = None
DESCRIPTOR._serialized_options = b"\n\016com.google.apiB\022FieldBehaviorProtoP\001ZAgoogle.golang.org/genproto/googleapis/api/annotations;annotations\242\002\004GAPI"
_globals["_FIELDBEHAVIOR"]._serialized_start = 82
_globals["_FIELDBEHAVIOR"]._serialized_end = 264
# @@protoc_insertion_point(module_scope)
1 change: 1 addition & 0 deletions google/cloud/alloydb/connector/async_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ async def connect(
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)

# use existing connection info if possible
Expand Down
12 changes: 10 additions & 2 deletions google/cloud/alloydb/connector/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ def __init__(
quota_project: Optional[str],
credentials: Credentials,
client: Optional[aiohttp.ClientSession] = None,
driver: Optional[str] = None,
) -> None:
"""
Establish the client to be used for AlloyDB Admin API requests.
Expand All @@ -55,10 +56,12 @@ def __init__(
client (aiohttp.ClientSession): Async client used to make requests to
AlloyDB Admin APIs.
Optional, defaults to None and creates new client.
driver (str): Database driver to be used by the client.
"""
user_agent = f"{USER_AGENT}+{driver}" if driver else USER_AGENT
headers = {
"x-goog-api-client": USER_AGENT,
"User-Agent": USER_AGENT,
"x-goog-api-client": user_agent,
"User-Agent": user_agent,
"Content-Type": "application/json",
}
if quota_project:
Expand All @@ -67,6 +70,10 @@ def __init__(
self._client = client if client else aiohttp.ClientSession(headers=headers)
self._credentials = credentials
self._alloydb_api_endpoint = alloydb_api_endpoint
# asyncpg does not currently support using metadata exchange
# only use metadata exchange for pg8000 driver
self._use_metadata = True if driver == "pg8000" else False
self._user_agent = user_agent

async def _get_metadata(
self,
Expand Down Expand Up @@ -149,6 +156,7 @@ async def _get_client_certificate(
data = {
"publicKey": pub_key,
"certDuration": "3600s",
"useMetadataExchange": self._use_metadata,
}

resp = await self._client.post(
Expand Down
128 changes: 126 additions & 2 deletions google/cloud/alloydb/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@

import asyncio
from functools import partial
import socket
import struct
from threading import Thread
from types import TracebackType
from typing import Any, Dict, Optional, Type, TYPE_CHECKING
Expand All @@ -27,10 +29,18 @@
from google.cloud.alloydb.connector.instance import Instance
import google.cloud.alloydb.connector.pg8000 as pg8000
from google.cloud.alloydb.connector.utils import generate_keys
import google.cloud.alloydb_connectors_v1.proto.resources_pb2 as connectorspb

if TYPE_CHECKING:
import ssl

from google.auth.credentials import Credentials

# the port the AlloyDB server-side proxy receives connections on
SERVER_PROXY_PORT = 5433
# the maximum amount of time to wait before aborting a metadata exchange
IO_TIMEOUT = 30


class Connector:
"""A class to configure and create connections to Cloud SQL instances.
Expand All @@ -45,13 +55,15 @@ class Connector:
Defaults to None, picking up project from environment.
alloydb_api_endpoint (str): Base URL to use when calling
the AlloyDB API endpoint. Defaults to "https://alloydb.googleapis.com".
enable_iam_auth (bool): Enables automatic IAM database authentication.
"""

def __init__(
self,
credentials: Optional[Credentials] = None,
quota_project: Optional[str] = None,
alloydb_api_endpoint: str = "https://alloydb.googleapis.com",
enable_iam_auth: bool = False,
) -> None:
# create event loop and start it in background thread
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
Expand All @@ -61,6 +73,7 @@ def __init__(
# initialize default params
self._quota_project = quota_project
self._alloydb_api_endpoint = alloydb_api_endpoint
self._enable_iam_auth = enable_iam_auth
# initialize credentials
scopes = ["https://www.googleapis.com/auth/cloud-platform"]
if credentials:
Expand Down Expand Up @@ -121,8 +134,12 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->
if self._client is None:
# lazy init client as it has to be initialized in async context
self._client = AlloyDBClient(
self._alloydb_api_endpoint, self._quota_project, self._credentials
self._alloydb_api_endpoint,
self._quota_project,
self._credentials,
driver=driver,
)
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
# use existing connection info if possible
if instance_uri in self._instances:
instance = self._instances[instance_uri]
Expand Down Expand Up @@ -150,13 +167,120 @@ async def connect_async(self, instance_uri: str, driver: str, **kwargs: Any) ->

# synchronous drivers are blocking and run using executor
try:
connect_partial = partial(connector, ip_address, context, **kwargs)
metadata_partial = partial(
self.metadata_exchange, ip_address, context, enable_iam_auth, driver
)
sock = await self._loop.run_in_executor(None, metadata_partial)
connect_partial = partial(connector, sock, **kwargs)
return await self._loop.run_in_executor(None, connect_partial)
except Exception:
# we attempt a force refresh, then throw the error
await instance.force_refresh()
raise

def metadata_exchange(
self, ip_address: str, ctx: ssl.SSLContext, enable_iam_auth: bool, driver: str
) -> ssl.SSLSocket:
"""
Sends metadata about the connection prior to the database
protocol taking over.
The exchange consists of four steps:
1. Prepare a MetadataExchangeRequest including the IAM Principal's
OAuth2 token, the user agent, and the requested authentication type.
2. Write the size of the message as a big endian uint32 (4 bytes) to
the server followed by the serialized message. The length does not
include the initial four bytes.
3. Read a big endian uint32 (4 bytes) from the server. This is the
MetadataExchangeResponse message length and does not include the
initial four bytes.
4. Parse the response using the message length in step 3. If the
response is not OK, return the response's error. If there is no error,
the metadata exchange has succeeded and the connection is complete.
Args:
ip_address (str): IP address of AlloyDB instance to connect to.
ctx (ssl.SSLContext): Context used to create a TLS connection
with AlloyDB instance ssl certificates.
enable_iam_auth (bool): Flag to enable IAM database authentication.
driver (str): A string representing the database driver to connect with.
Supported drivers are pg8000.
Returns:
sock (ssl.SSLSocket): mTLS/SSL socket connected to AlloyDB Proxy server.
"""
# Create socket and wrap with SSL/TLS context
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)
# set auth type for metadata exchange
auth_type = connectorspb.MetadataExchangeRequest.DB_NATIVE
if enable_iam_auth:
auth_type = connectorspb.MetadataExchangeRequest.AUTO_IAM

# form metadata exchange request
req = connectorspb.MetadataExchangeRequest(
user_agent=f"{self._client._user_agent}", # type: ignore
auth_type=auth_type,
oauth2_token=self._credentials.token,
)

# set I/O timeout
sock.settimeout(IO_TIMEOUT)

# pack big-endian unsigned integer (4 bytes)
packed_len = struct.pack(">I", req.ByteSize())

# send metadata message length and request message
sock.sendall(packed_len + req.SerializeToString())

# form metadata exchange response
resp = connectorspb.MetadataExchangeResponse()

# read metadata message length (4 bytes)
message_len_buffer_size = struct.Struct(">I").size
message_len_buffer = b""
while message_len_buffer_size > 0:
chunk = sock.recv(message_len_buffer_size)
if not chunk:
raise RuntimeError(
"Connection closed while getting metadata exchange length!"
)
message_len_buffer += chunk
message_len_buffer_size -= len(chunk)

(message_len,) = struct.unpack(">I", message_len_buffer)

# read metadata exchange message
buffer = b""
while message_len > 0:
chunk = sock.recv(message_len)
if not chunk:
raise RuntimeError(
"Connection closed while performing metadata exchange!"
)
buffer += chunk
message_len -= len(chunk)

# parse metadata exchange response from buffer
resp.ParseFromString(buffer)

# reset socket back to blocking mode
sock.setblocking(True)

# validate metadata exchange response
if resp.response_code != connectorspb.MetadataExchangeResponse.OK:
raise ValueError(
f"Metadata Exchange request has failed with error: {resp.error}"
)

return sock

def __enter__(self) -> "Connector":
"""Enter context manager by returning Connector object"""
return self
Expand Down
26 changes: 6 additions & 20 deletions google/cloud/alloydb/connector/pg8000.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,49 +11,35 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import socket
import ssl
from typing import Any, TYPE_CHECKING

SERVER_PROXY_PORT = 5433

if TYPE_CHECKING:
import ssl

import pg8000


def connect(
ip_address: str, ctx: ssl.SSLContext, **kwargs: Any
) -> "pg8000.dbapi.Connection":
def connect(sock: "ssl.SSLSocket", **kwargs: Any) -> "pg8000.dbapi.Connection":
"""Create a pg8000 DBAPI connection object.
Args:
ip_address (str): IP address of AlloyDB instance to connect to.
ctx (ssl.SSLContext): Context used to create a TLS connection
with AlloyDB instance ssl certificates.
sock (ssl.SSLSocket): SSL/TLS secure socket stream connected to the
AlloyDB proxy server.
Returns:
pg8000.dbapi.Connection: A pg8000 Connection object for
the AlloyDB instance.
"""
# Connecting through pg8000 is done by passing in an SSL Context and setting the
# "request_ssl" attr to false. This works because when "request_ssl" is false,
# the driver skips the database level SSL/TLS exchange, but still uses the
# ssl_context (if it is not None) to create the connection.
try:
import pg8000
except ImportError:
raise ImportError(
'Unable to import module "pg8000." Please install and try again.'
)
# Create socket and wrap with context.
sock = ctx.wrap_socket(
socket.create_connection((ip_address, SERVER_PROXY_PORT)),
server_hostname=ip_address,
)

user = kwargs.pop("user")
db = kwargs.pop("db")
passwd = kwargs.pop("password")
passwd = kwargs.pop("password", None)
return pg8000.dbapi.connect(
user,
database=db,
Expand Down
13 changes: 13 additions & 0 deletions google/cloud/alloydb_connectors_v1/proto/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright 2024 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
Loading

0 comments on commit c6c16e8

Please sign in to comment.