Skip to content
Merged
175 changes: 107 additions & 68 deletions google/cloud/sql/connector/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,43 +23,118 @@
from google.cloud.sql.connector.utils import generate_keys

from threading import Thread
from typing import Any, Dict, Optional
from typing import Any, Dict

# This thread is used to background processing
_thread: Optional[Thread] = None
_loop: Optional[asyncio.AbstractEventLoop] = None
_keys: Optional[concurrent.futures.Future] = None
logger = logging.getLogger(name=__name__)

_instances: Dict[str, InstanceConnectionManager] = {}
_default_connector = None

logger = logging.getLogger(name=__name__)

class Connector:
"""A class to configure and create connections to Cloud SQL instances.

def _get_loop() -> asyncio.AbstractEventLoop:
global _loop
if _loop is None:
_loop = asyncio.new_event_loop()
_thread = Thread(target=_loop.run_forever, daemon=True)
_thread.start()
return _loop
:type ip_types: IPTypes
:param ip_types
The IP type (public or private) used to connect. IP types
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.

:type enable_iam_auth: bool
:param enable_iam_auth
Enables IAM based authentication (Postgres only).

def _get_keys(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future:
global _keys
if _keys is None:
_keys = asyncio.run_coroutine_threadsafe(generate_keys(), loop)
return _keys
:type timeout: int
:param timeout:
The time limit for a connection before raising a TimeoutError.

"""

def connect(
instance_connection_string: str,
driver: str,
ip_types: IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
**kwargs: Any
) -> Any:
"""Prepares and returns a database connection object and starts a
background thread to refresh the certificates and metadata.
def __init__(
self,
ip_types: IPTypes = IPTypes.PUBLIC,
enable_iam_auth: bool = False,
timeout: int = 30,
) -> None:
self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop()
self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True)
self._thread.start()
self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe(
generate_keys(), self._loop
)
self._instances: Dict[str, InstanceConnectionManager] = {}

# set default params for connections
self._timeout = timeout
self._enable_iam_auth = enable_iam_auth
self._ip_types = ip_types

def connect(
self, instance_connection_string: str, driver: str, **kwargs: Any
) -> Any:
"""Prepares and returns a database connection object and starts a
background thread to refresh the certificates and metadata.

:type instance_connection_string: str
:param instance_connection_string:
A string containing the GCP project name, region name, and instance
name separated by colons.

Example: example-proj:example-region-us6:example-instance

:type driver: str
:param: driver:
A string representing the driver to connect with. Supported drivers are
pymysql, pg8000, and pytds.

:param kwargs:
Pass in any driver-specific arguments needed to connect to the Cloud
SQL instance.

:rtype: Connection
:returns:
A DB-API connection to the specified Cloud SQL instance.
"""

# Initiate event loop and run in background thread.
#
# Create an InstanceConnectionManager object from the connection string.
# The InstanceConnectionManager should verify arguments.
#
# Use the InstanceConnectionManager to establish an SSL Connection.
#
# Return a DBAPI connection

if instance_connection_string in self._instances:
icm = self._instances[instance_connection_string]
else:
enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth)
icm = InstanceConnectionManager(
instance_connection_string,
driver,
self._keys,
self._loop,
enable_iam_auth,
)
self._instances[instance_connection_string] = icm

ip_types = kwargs.pop("ip_types", self._ip_types)
if "timeout" in kwargs:
return icm.connect(driver, ip_types, **kwargs)
elif "connect_timeout" in kwargs:
timeout = kwargs["connect_timeout"]
else:
timeout = self._timeout
try:
return icm.connect(driver, ip_types, timeout, **kwargs)
except Exception as e:
# with any other exception, we attempt a force refresh, then throw the error
icm.force_refresh()
raise (e)


def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any:
"""Uses a Connector object with default settings and returns a database
connection object with a background thread to refresh the certificates and metadata.
For more advanced configurations, callers should instantiate Connector on their own.

:type instance_connection_string: str
:param instance_connection_string:
Expand All @@ -73,14 +148,6 @@ def connect(
A string representing the driver to connect with. Supported drivers are
pymysql, pg8000, and pytds.

:type ip_types: IPTypes
The IP type (public or private) used to connect. IP types
can be either IPTypes.PUBLIC or IPTypes.PRIVATE.

:param enable_iam_auth
Enables IAM based authentication (Postgres only).
:type enable_iam_auth: bool

:param kwargs:
Pass in any driver-specific arguments needed to connect to the Cloud
SQL instance.
Expand All @@ -89,35 +156,7 @@ def connect(
:returns:
A DB-API connection to the specified Cloud SQL instance.
"""

# Initiate event loop and run in background thread.
#
# Create an InstanceConnectionManager object from the connection string.
# The InstanceConnectionManager should verify arguments.
#
# Use the InstanceConnectionManager to establish an SSL Connection.
#
# Return a DBAPI connection

loop = _get_loop()
if instance_connection_string in _instances:
icm = _instances[instance_connection_string]
else:
keys = _get_keys(loop)
icm = InstanceConnectionManager(
instance_connection_string, driver, keys, loop, enable_iam_auth
)
_instances[instance_connection_string] = icm

if "timeout" in kwargs:
return icm.connect(driver, ip_types, **kwargs)
elif "connect_timeout" in kwargs:
timeout = kwargs["connect_timeout"]
else:
timeout = 30 # 30s
try:
return icm.connect(driver, ip_types, timeout, **kwargs)
except Exception as e:
# with any other exception, we attempt a force refresh, then throw the error
icm.force_refresh()
raise (e)
global _default_connector
if _default_connector is None:
_default_connector = Connector()
return _default_connector.connect(instance_connection_string, driver, **kwargs)
4 changes: 3 additions & 1 deletion tests/unit/test_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,9 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None:

mock_instances = {}
mock_instances[connect_string] = icm
with patch.dict(connector._instances, mock_instances):
mock_connector = connector.Connector()
connector._default_connector = mock_connector
with patch.dict(mock_connector._instances, mock_instances):
pytest.raises(
TimeoutError,
connector.connect,
Expand Down