diff --git a/google/cloud/sql/connector/connector.py b/google/cloud/sql/connector/connector.py index 677372369..06168b292 100755 --- a/google/cloud/sql/connector/connector.py +++ b/google/cloud/sql/connector/connector.py @@ -23,43 +23,118 @@ from google.cloud.sql.connector.utils import generate_keys from threading import Thread -from typing import Any, Dict, Optional +from typing import Any, Dict -# This thread is used to background processing -_thread: Optional[Thread] = None -_loop: Optional[asyncio.AbstractEventLoop] = None -_keys: Optional[concurrent.futures.Future] = None +logger = logging.getLogger(name=__name__) -_instances: Dict[str, InstanceConnectionManager] = {} +_default_connector = None -logger = logging.getLogger(name=__name__) +class Connector: + """A class to configure and create connections to Cloud SQL instances. -def _get_loop() -> asyncio.AbstractEventLoop: - global _loop - if _loop is None: - _loop = asyncio.new_event_loop() - _thread = Thread(target=_loop.run_forever, daemon=True) - _thread.start() - return _loop + :type ip_types: IPTypes + :param ip_types + The IP type (public or private) used to connect. IP types + can be either IPTypes.PUBLIC or IPTypes.PRIVATE. + :type enable_iam_auth: bool + :param enable_iam_auth + Enables IAM based authentication (Postgres only). -def _get_keys(loop: asyncio.AbstractEventLoop) -> concurrent.futures.Future: - global _keys - if _keys is None: - _keys = asyncio.run_coroutine_threadsafe(generate_keys(), loop) - return _keys + :type timeout: int + :param timeout: + The time limit for a connection before raising a TimeoutError. + """ -def connect( - instance_connection_string: str, - driver: str, - ip_types: IPTypes = IPTypes.PUBLIC, - enable_iam_auth: bool = False, - **kwargs: Any -) -> Any: - """Prepares and returns a database connection object and starts a - background thread to refresh the certificates and metadata. + def __init__( + self, + ip_types: IPTypes = IPTypes.PUBLIC, + enable_iam_auth: bool = False, + timeout: int = 30, + ) -> None: + self._loop: asyncio.AbstractEventLoop = asyncio.new_event_loop() + self._thread: Thread = Thread(target=self._loop.run_forever, daemon=True) + self._thread.start() + self._keys: concurrent.futures.Future = asyncio.run_coroutine_threadsafe( + generate_keys(), self._loop + ) + self._instances: Dict[str, InstanceConnectionManager] = {} + + # set default params for connections + self._timeout = timeout + self._enable_iam_auth = enable_iam_auth + self._ip_types = ip_types + + def connect( + self, instance_connection_string: str, driver: str, **kwargs: Any + ) -> Any: + """Prepares and returns a database connection object and starts a + background thread to refresh the certificates and metadata. + + :type instance_connection_string: str + :param instance_connection_string: + A string containing the GCP project name, region name, and instance + name separated by colons. + + Example: example-proj:example-region-us6:example-instance + + :type driver: str + :param: driver: + A string representing the driver to connect with. Supported drivers are + pymysql, pg8000, and pytds. + + :param kwargs: + Pass in any driver-specific arguments needed to connect to the Cloud + SQL instance. + + :rtype: Connection + :returns: + A DB-API connection to the specified Cloud SQL instance. + """ + + # Initiate event loop and run in background thread. + # + # Create an InstanceConnectionManager object from the connection string. + # The InstanceConnectionManager should verify arguments. + # + # Use the InstanceConnectionManager to establish an SSL Connection. + # + # Return a DBAPI connection + + if instance_connection_string in self._instances: + icm = self._instances[instance_connection_string] + else: + enable_iam_auth = kwargs.pop("enable_iam_auth", self._enable_iam_auth) + icm = InstanceConnectionManager( + instance_connection_string, + driver, + self._keys, + self._loop, + enable_iam_auth, + ) + self._instances[instance_connection_string] = icm + + ip_types = kwargs.pop("ip_types", self._ip_types) + if "timeout" in kwargs: + return icm.connect(driver, ip_types, **kwargs) + elif "connect_timeout" in kwargs: + timeout = kwargs["connect_timeout"] + else: + timeout = self._timeout + try: + return icm.connect(driver, ip_types, timeout, **kwargs) + except Exception as e: + # with any other exception, we attempt a force refresh, then throw the error + icm.force_refresh() + raise (e) + + +def connect(instance_connection_string: str, driver: str, **kwargs: Any) -> Any: + """Uses a Connector object with default settings and returns a database + connection object with a background thread to refresh the certificates and metadata. + For more advanced configurations, callers should instantiate Connector on their own. :type instance_connection_string: str :param instance_connection_string: @@ -73,14 +148,6 @@ def connect( A string representing the driver to connect with. Supported drivers are pymysql, pg8000, and pytds. - :type ip_types: IPTypes - The IP type (public or private) used to connect. IP types - can be either IPTypes.PUBLIC or IPTypes.PRIVATE. - - :param enable_iam_auth - Enables IAM based authentication (Postgres only). - :type enable_iam_auth: bool - :param kwargs: Pass in any driver-specific arguments needed to connect to the Cloud SQL instance. @@ -89,35 +156,7 @@ def connect( :returns: A DB-API connection to the specified Cloud SQL instance. """ - - # Initiate event loop and run in background thread. - # - # Create an InstanceConnectionManager object from the connection string. - # The InstanceConnectionManager should verify arguments. - # - # Use the InstanceConnectionManager to establish an SSL Connection. - # - # Return a DBAPI connection - - loop = _get_loop() - if instance_connection_string in _instances: - icm = _instances[instance_connection_string] - else: - keys = _get_keys(loop) - icm = InstanceConnectionManager( - instance_connection_string, driver, keys, loop, enable_iam_auth - ) - _instances[instance_connection_string] = icm - - if "timeout" in kwargs: - return icm.connect(driver, ip_types, **kwargs) - elif "connect_timeout" in kwargs: - timeout = kwargs["connect_timeout"] - else: - timeout = 30 # 30s - try: - return icm.connect(driver, ip_types, timeout, **kwargs) - except Exception as e: - # with any other exception, we attempt a force refresh, then throw the error - icm.force_refresh() - raise (e) + global _default_connector + if _default_connector is None: + _default_connector = Connector() + return _default_connector.connect(instance_connection_string, driver, **kwargs) diff --git a/tests/unit/test_connector.py b/tests/unit/test_connector.py index 3ebbf8068..e85a1ae11 100644 --- a/tests/unit/test_connector.py +++ b/tests/unit/test_connector.py @@ -43,7 +43,9 @@ async def timeout_stub(*args: Any, **kwargs: Any) -> None: mock_instances = {} mock_instances[connect_string] = icm - with patch.dict(connector._instances, mock_instances): + mock_connector = connector.Connector() + connector._default_connector = mock_connector + with patch.dict(mock_connector._instances, mock_instances): pytest.raises( TimeoutError, connector.connect,