diff --git a/aws_advanced_python_wrapper/allowed_and_blocked_hosts.py b/aws_advanced_python_wrapper/allowed_and_blocked_hosts.py new file mode 100644 index 00000000..4668e627 --- /dev/null +++ b/aws_advanced_python_wrapper/allowed_and_blocked_hosts.py @@ -0,0 +1,29 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional, Set + + +class AllowedAndBlockedHosts: + def __init__(self, allowed_host_ids: Optional[Set[str]], blocked_host_ids: Optional[Set[str]]): + self._allowed_host_ids = None if not allowed_host_ids else allowed_host_ids + self._blocked_host_ids = None if not blocked_host_ids else blocked_host_ids + + @property + def allowed_host_ids(self) -> Optional[Set[str]]: + return self._allowed_host_ids + + @property + def blocked_host_ids(self) -> Optional[Set[str]]: + return self._blocked_host_ids diff --git a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py index dfb97744..699f0dcf 100644 --- a/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py +++ b/aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py @@ -201,7 +201,7 @@ def _connect(self, host_info: HostInfo, connect_func: Callable): def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: if self._current_writer is None or self._need_update_current_writer: - self._current_writer = self._get_writer(self._plugin_service.hosts) + self._current_writer = self._get_writer(self._plugin_service.all_hosts) self._need_update_current_writer = False try: @@ -209,7 +209,7 @@ def execute(self, target: object, method_name: str, execute_func: Callable, *arg except Exception as e: # Check that e is a FailoverError and that the writer has changed - if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer: + if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer: self._tracker.invalidate_all_connections(host_info=self._current_writer) self._tracker.log_opened_connections() self._need_update_current_writer = True diff --git a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py index 5c5a7557..0d1aad50 100644 --- a/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py +++ b/aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py @@ -198,7 +198,7 @@ def _delay(self, delay_ms: int): sleep(delay_ms / 1000) def _get_writer(self) -> Optional[HostInfo]: - for host in self._plugin_service.hosts: + for host in self._plugin_service.all_hosts: if host.role == HostRole.WRITER: return host @@ -225,10 +225,10 @@ def init_host_provider(self, props: Properties, host_list_provider_service: Host init_host_provider_func(props) def _has_no_readers(self) -> bool: - if len(self._plugin_service.hosts) == 0: + if len(self._plugin_service.all_hosts) == 0: return False - for host in self._plugin_service.hosts: + for host in self._plugin_service.all_hosts: if host.role == HostRole.READER: return False diff --git a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py index 0ee98add..2f9fef80 100644 --- a/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py +++ b/aws_advanced_python_wrapper/aws_secrets_manager_plugin.py @@ -35,6 +35,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ TelemetryTraceLevel @@ -63,6 +64,7 @@ def __init__(self, plugin_service: PluginService, props: Properties, session: Op Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter", WrapperProperties.SECRETS_MANAGER_SECRET_ID.name)) + self._region_utils = RegionUtils() region: str = self._get_rds_region(secret_id, props) secrets_endpoint = WrapperProperties.SECRETS_MANAGER_ENDPOINT.get(props) @@ -194,23 +196,22 @@ def _apply_secret_to_properties(self, properties: Properties): WrapperProperties.PASSWORD.set(properties, self._secret.password) def _get_rds_region(self, secret_id: str, props: Properties) -> str: - region: Optional[str] = props.get(WrapperProperties.SECRETS_MANAGER_REGION.name) - if not region: - match = search(self._SECRETS_ARN_PATTERN, secret_id) - if match: - region = match.group("region") - else: - raise AwsWrapperError( - Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter", - WrapperProperties.SECRETS_MANAGER_REGION.name)) - session = self._session if self._session else boto3.Session() - if region not in session.get_available_regions("rds"): - exception_message = "AwsSdk.UnsupportedRegion" - logger.debug(exception_message, region) - raise AwsWrapperError(Messages.get_formatted(exception_message, region)) + region = self._region_utils.get_region(props, WrapperProperties.SECRETS_MANAGER_REGION.name, session=session) + + if region: + return region - return region + match = search(self._SECRETS_ARN_PATTERN, secret_id) + if match: + region = match.group("region") + + if region: + return self._region_utils.verify_region(region) + else: + raise AwsWrapperError( + Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter", + WrapperProperties.SECRETS_MANAGER_REGION.name)) class AwsSecretsManagerPluginFactory(PluginFactory): diff --git a/aws_advanced_python_wrapper/custom_endpoint_plugin.py b/aws_advanced_python_wrapper/custom_endpoint_plugin.py new file mode 100644 index 00000000..03b672a6 --- /dev/null +++ b/aws_advanced_python_wrapper/custom_endpoint_plugin.py @@ -0,0 +1,344 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from threading import Event, Thread +from time import perf_counter_ns, sleep +from typing import (TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, + Optional, Set, Union, cast) + +from aws_advanced_python_wrapper.allowed_and_blocked_hosts import \ + AllowedAndBlockedHosts +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.cache_map import CacheMap +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.driver_dialect import DriverDialect + from aws_advanced_python_wrapper.hostinfo import HostInfo + from aws_advanced_python_wrapper.pep249 import Connection + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import Properties + +from enum import Enum + +from boto3 import Session + +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.properties import WrapperProperties +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils +from aws_advanced_python_wrapper.utils.sliding_expiration_cache import \ + SlidingExpirationCacheWithCleanupThread +from aws_advanced_python_wrapper.utils.telemetry.telemetry import ( + TelemetryCounter, TelemetryFactory) + +logger = Logger(__name__) + + +class CustomEndpointRoleType(Enum): + """ + Enum representing the possible roles of instances specified by a custom endpoint. Note that, currently, it is not + possible to create a WRITER custom endpoint. + """ + ANY = "ANY" + READER = "READER" + + @classmethod + def from_string(cls, value): + return CustomEndpointRoleType(value) + + +class CustomEndpointInfo: + def __init__(self, + endpoint_id: str, + cluster_id: str, + endpoint: str, + role_type: CustomEndpointRoleType, + static_members: Optional[Set[str]], + excluded_members: Optional[Set[str]]): + self.endpoint_id = endpoint_id + self.cluster_id = cluster_id + self.endpoint = endpoint + self.role_type = role_type + self.static_members = None if not static_members else static_members + self.excluded_members = None if not excluded_members else excluded_members + + @classmethod + def from_db_cluster_endpoint(cls, endpoint_response_info: Dict[str, Union[str, List[str]]]): + return CustomEndpointInfo( + str(endpoint_response_info.get("DBClusterEndpointIdentifier")), + str(endpoint_response_info.get("DBClusterIdentifier")), + str(endpoint_response_info.get("Endpoint")), + CustomEndpointRoleType.from_string(str(endpoint_response_info.get("CustomEndpointType"))), + set(cast('List[str]', endpoint_response_info.get("StaticMembers"))), + set(cast('List[str]', endpoint_response_info.get("ExcludedMembers"))) + ) + + def __eq__(self, other: object): + if self is object: + return True + if not isinstance(other, CustomEndpointInfo): + return False + + return self.endpoint_id == other.endpoint_id \ + and self.cluster_id == other.cluster_id \ + and self.endpoint == other.endpoint \ + and self.role_type == other.role_type \ + and self.static_members == other.static_members \ + and self.excluded_members == other.excluded_members + + def __hash__(self): + return hash((self.endpoint_id, self.cluster_id, self.endpoint, self.role_type)) + + def __str__(self): + return (f"CustomEndpointInfo[endpoint={self.endpoint}, cluster_id={self.cluster_id}, " + f"role_type={self.role_type}, endpoint_id={self.endpoint_id}, static_members={self.static_members}, " + f"excluded_members={self.excluded_members}]") + + +class CustomEndpointMonitor: + """ + A custom endpoint monitor. This class uses a background thread to monitor a given custom endpoint for custom + endpoint information and future changes to the custom endpoint. + """ + _CUSTOM_ENDPOINT_INFO_EXPIRATION_NS: ClassVar[int] = 5 * 60_000_000_000 # 5 minutes + # Keys are custom endpoint URLs, values are information objects for the associated custom endpoint. + _custom_endpoint_info_cache: ClassVar[CacheMap[str, CustomEndpointInfo]] = CacheMap() + + def __init__(self, + plugin_service: PluginService, + custom_endpoint_host_info: HostInfo, + endpoint_id: str, + region: str, + refresh_rate_ns: int, + session: Optional[Session] = None): + self._plugin_service = plugin_service + self._custom_endpoint_host_info = custom_endpoint_host_info + self._endpoint_id = endpoint_id + self._region = region + self._refresh_rate_ns = refresh_rate_ns + self._session = session if session else Session() + self._client = self._session.client('rds', region_name=region) + + self._stop_event = Event() + telemetry_factory = self._plugin_service.get_telemetry_factory() + self._info_changed_counter = telemetry_factory.create_counter("customEndpoint.infoChanged.counter") + + self._thread = Thread(daemon=True, name="CustomEndpointMonitorThread", target=self._run) + self._thread.start() + + def _run(self): + logger.debug("CustomEndpointMonitor.StartingMonitor", self._custom_endpoint_host_info.host) + + try: + while not self._stop_event.is_set(): + try: + start_ns = perf_counter_ns() + + response = self._client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self._endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + endpoints = response["DBClusterEndpoints"] + if len(endpoints) != 1: + endpoint_hostnames = [endpoint["Endpoint"] for endpoint in endpoints] + logger.warning( + "CustomEndpointMonitor.UnexpectedNumberOfEndpoints", + self._endpoint_id, + self._region, + len(endpoints), + endpoint_hostnames) + + sleep(self._refresh_rate_ns / 1_000_000_000) + continue + + endpoint_info = CustomEndpointInfo.from_db_cluster_endpoint(endpoints[0]) + cached_info = \ + CustomEndpointMonitor._custom_endpoint_info_cache.get(self._custom_endpoint_host_info.host) + if cached_info is not None and cached_info == endpoint_info: + elapsed_time = perf_counter_ns() - start_ns + sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) + sleep(sleep_duration / 1_000_000_000) + continue + + logger.debug( + "CustomEndpointMonitor.DetectedChangeInCustomEndpointInfo", + self._custom_endpoint_host_info.host, endpoint_info) + + # The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. + hosts = AllowedAndBlockedHosts(endpoint_info.static_members, endpoint_info.excluded_members) + self._plugin_service.allowed_and_blocked_hosts = hosts + CustomEndpointMonitor._custom_endpoint_info_cache.put( + self._custom_endpoint_host_info.host, + endpoint_info, + CustomEndpointMonitor._CUSTOM_ENDPOINT_INFO_EXPIRATION_NS) + self._info_changed_counter.inc() + + elapsed_time = perf_counter_ns() - start_ns + sleep_duration = max(0, self._refresh_rate_ns - elapsed_time) + sleep(sleep_duration / 1_000_000_000) + continue + except InterruptedError as e: + raise e + except Exception as e: + # If the exception is not an InterruptedError, log it and continue monitoring. + logger.error("CustomEndpointMonitor.Exception", self._custom_endpoint_host_info.host, e) + except InterruptedError: + logger.info("CustomEndpointMonitor.Interrupted", self._custom_endpoint_host_info.host) + finally: + CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host) + self._stop_event.set() + self._client.close() + logger.debug("CustomEndpointMonitor.StoppedMonitor", self._custom_endpoint_host_info.host) + + def has_custom_endpoint_info(self): + return CustomEndpointMonitor._custom_endpoint_info_cache.get(self._custom_endpoint_host_info.host) is not None + + def close(self): + logger.debug("CustomEndpointMonitor.StoppingMonitor", self._custom_endpoint_host_info.host) + CustomEndpointMonitor._custom_endpoint_info_cache.remove(self._custom_endpoint_host_info.host) + self._stop_event.set() + + +class CustomEndpointPlugin(Plugin): + """ + A plugin that analyzes custom endpoints for custom endpoint information and custom endpoint changes, such as adding + or removing an instance in the custom endpoint. + """ + _SUBSCRIBED_METHODS: ClassVar[Set[str]] = {"connect"} + _CACHE_CLEANUP_RATE_NS: ClassVar[int] = 6 * 10 ^ 10 # 1 minute + _monitors: ClassVar[SlidingExpirationCacheWithCleanupThread[str, CustomEndpointMonitor]] = \ + SlidingExpirationCacheWithCleanupThread(_CACHE_CLEANUP_RATE_NS, + should_dispose_func=lambda _: True, + item_disposal_func=lambda monitor: monitor.close()) + + def __init__(self, plugin_service: PluginService, props: Properties): + self._plugin_service = plugin_service + self._props = props + + self._should_wait_for_info: bool = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO.get_bool(self._props) + self._wait_for_info_timeout_ms: int = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.get_int(self._props) + self._idle_monitor_expiration_ms: int = \ + WrapperProperties.CUSTOM_ENDPOINT_IDLE_MONITOR_EXPIRATION_MS.get_int(self._props) + + self._rds_utils = RdsUtils() + self._region_utils = RegionUtils() + self._region: Optional[str] = None + self._custom_endpoint_host_info: Optional[HostInfo] = None + self._custom_endpoint_id: Optional[str] = None + telemetry_factory: TelemetryFactory = self._plugin_service.get_telemetry_factory() + self._wait_for_info_counter: TelemetryCounter = telemetry_factory.create_counter("customEndpoint.waitForInfo.counter") + + CustomEndpointPlugin._SUBSCRIBED_METHODS.update(self._plugin_service.network_bound_methods) + + @property + def subscribed_methods(self) -> Set[str]: + return CustomEndpointPlugin._SUBSCRIBED_METHODS + + def connect( + self, + target_driver_func: Callable, + driver_dialect: DriverDialect, + host_info: HostInfo, + props: Properties, + is_initial_connection: bool, + connect_func: Callable) -> Connection: + if not self._rds_utils.is_rds_custom_cluster_dns(host_info.host): + return connect_func() + + self._custom_endpoint_host_info = host_info + logger.debug("CustomEndpointPlugin.ConnectionRequestToCustomEndpoint", host_info.host) + + self._custom_endpoint_id = self._rds_utils.get_cluster_id(host_info.host) + if not self._custom_endpoint_id: + raise AwsWrapperError( + Messages.get_formatted( + "CustomEndpointPlugin.ErrorParsingEndpointIdentifier", self._custom_endpoint_host_info.host)) + + hostname = self._custom_endpoint_host_info.host + self._region = self._region_utils.get_region_from_hostname(hostname) + if not self._region: + error_message = "RdsUtils.UnsupportedHostname" + logger.debug(error_message, hostname) + raise AwsWrapperError(Messages.get_formatted(error_message, hostname)) + + monitor = self._create_monitor_if_absent(props) + if self._should_wait_for_info: + self._wait_for_info(monitor) + + return connect_func() + + def _create_monitor_if_absent(self, props: Properties) -> CustomEndpointMonitor: + host_info = cast('HostInfo', self._custom_endpoint_host_info) + endpoint_id = cast('str', self._custom_endpoint_id) + region = cast('str', self._region) + monitor = CustomEndpointPlugin._monitors.compute_if_absent( + host_info.host, + lambda key: CustomEndpointMonitor( + self._plugin_service, + host_info, + endpoint_id, + region, + WrapperProperties.CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS.get_int(props) * 1_000_000), + self._idle_monitor_expiration_ms * 1_000_000) + + return cast('CustomEndpointMonitor', monitor) + + def _wait_for_info(self, monitor: CustomEndpointMonitor): + has_info = monitor.has_custom_endpoint_info() + if has_info: + return + + self._wait_for_info_counter.inc() + host_info = cast('HostInfo', self._custom_endpoint_host_info) + hostname = host_info.host + logger.debug("CustomEndpointPlugin.WaitingForCustomEndpointInfo", hostname, self._wait_for_info_timeout_ms) + wait_for_info_timeout_ns = perf_counter_ns() + self._wait_for_info_timeout_ms * 1_000_000 + + try: + while not has_info and perf_counter_ns() < wait_for_info_timeout_ns: + sleep(0.1) + has_info = monitor.has_custom_endpoint_info() + except InterruptedError: + raise AwsWrapperError(Messages.get_formatted("CustomEndpointPlugin.InterruptedThread", hostname)) + + if not has_info: + raise AwsWrapperError( + Messages.get_formatted( + "CustomEndpointPlugin.TimedOutWaitingForCustomEndpointInfo", + self._wait_for_info_timeout_ms, hostname)) + + def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: + if self._custom_endpoint_host_info is None: + return execute_func() + + monitor = self._create_monitor_if_absent(self._props) + if self._should_wait_for_info: + self._wait_for_info(monitor) + + return execute_func() + + +class CustomEndpointPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return CustomEndpointPlugin(plugin_service, props) diff --git a/aws_advanced_python_wrapper/failover_plugin.py b/aws_advanced_python_wrapper/failover_plugin.py index 4076397a..497a50bc 100644 --- a/aws_advanced_python_wrapper/failover_plugin.py +++ b/aws_advanced_python_wrapper/failover_plugin.py @@ -25,6 +25,7 @@ from typing import Any, Callable, Dict, Optional, Set +from aws_advanced_python_wrapper import LogUtils from aws_advanced_python_wrapper.errors import ( AwsWrapperError, FailoverFailedError, FailoverSuccessError, TransactionResolutionUnknownError) @@ -328,17 +329,23 @@ def _failover_writer(self): try: logger.info("FailoverPlugin.StartWriterFailover") - result: WriterFailoverResult = self._writer_failover_handler.failover(self._plugin_service.hosts) - + result: WriterFailoverResult = self._writer_failover_handler.failover(self._plugin_service.all_hosts) if result is not None and result.exception is not None: raise result.exception elif result is None or not result.is_connected: raise FailoverFailedError(Messages.get("FailoverPlugin.UnableToConnectToWriter")) writer_host = self._get_writer(result.topology) + allowed_hosts = self._plugin_service.hosts + allowed_hostnames = [host.host for host in allowed_hosts] + if writer_host.host not in allowed_hostnames: + raise FailoverFailedError( + Messages.get_formatted( + "FailoverPlugin.NewWriterNotAllowed", + "" if writer_host is None else writer_host.host, + LogUtils.log_topology(allowed_hosts))) self._plugin_service.set_current_connection(result.new_connection, writer_host) - logger.info("FailoverPlugin.EstablishedConnection", self._plugin_service.current_host_info) self._plugin_service.refresh_host_list() @@ -438,11 +445,11 @@ def _should_attempt_reader_connection(self) -> bool: def _is_failover_enabled(self) -> bool: return self._enable_failover_setting and \ self._rds_url_type != RdsUrlType.RDS_PROXY and \ - self._plugin_service.hosts is not None and \ - len(self._plugin_service.hosts) > 0 + self._plugin_service.all_hosts is not None and \ + len(self._plugin_service.all_hosts) > 0 def _get_current_writer(self) -> Optional[HostInfo]: - topology = self._plugin_service.hosts + topology = self._plugin_service.all_hosts if topology is None: return None diff --git a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py index 6df20e25..74bcf1f5 100644 --- a/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py +++ b/aws_advanced_python_wrapper/fastest_response_strategy_plugin.py @@ -111,7 +111,7 @@ def get_host_info_by_strategy(self, role: HostRole, strategy: str) -> HostInfo: fastest_response_host: Optional[HostInfo] = self._cached_fastest_response_host_by_role.get(role.name) if fastest_response_host is not None: - # Found a fastest host. Let's find it in the the latest topology. + # Found a fastest host. Let's find it in the latest topology. for host in self._plugin_service.hosts: if host == fastest_response_host: # found the fastest host in the topology diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 8a25c461..4ab0fbf0 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -22,6 +22,7 @@ from aws_advanced_python_wrapper.credentials_provider_factory import ( CredentialsProviderFactory, SamlCredentialsProviderFactory) from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils if TYPE_CHECKING: @@ -59,6 +60,7 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory: self._credentials_provider_factory = credentials_provider_factory self._session = session + self._region_utils = RegionUtils() telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) @@ -82,7 +84,11 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl host = IamAuthUtils.get_iam_host(props, host_info) port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) - region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) + region = self._region_utils.get_region(props, WrapperProperties.IAM_REGION.name, host, self._session) + if not region: + error_message = "RdsUtils.UnsupportedHostname" + logger.debug(error_message, host) + raise AwsWrapperError(Messages.get_formatted(error_message, host)) user = WrapperProperties.DB_USER.get(props) cache_key: str = IamAuthUtils.get_cache_key( diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index c4210d8f..1a26c58a 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils if TYPE_CHECKING: from boto3 import Session @@ -51,6 +52,7 @@ def __init__(self, plugin_service: PluginService, session: Optional[Session] = N self._plugin_service = plugin_service self._session = session + self._region_utils = RegionUtils() telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge( @@ -76,8 +78,12 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl raise AwsWrapperError(Messages.get_formatted("IamAuthPlugin.IsNoneOrEmpty", WrapperProperties.USER.name)) host = IamAuthUtils.get_iam_host(props, host_info) - region = WrapperProperties.IAM_REGION.get(props) \ - if WrapperProperties.IAM_REGION.get(props) else IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) + region = self._region_utils.get_region(props, WrapperProperties.IAM_REGION.name, host, self._session) + if not region: + error_message = "RdsUtils.UnsupportedHostname" + logger.debug(error_message, host) + raise AwsWrapperError(Messages.get_formatted(error_message, host)) + port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) token_expiration_sec: int = WrapperProperties.IAM_EXPIRATION.get_int(props) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 7e6b26fd..88b92f13 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -22,6 +22,7 @@ from aws_advanced_python_wrapper.credentials_provider_factory import ( CredentialsProviderFactory, SamlCredentialsProviderFactory) from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.region_utils import RegionUtils from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils if TYPE_CHECKING: @@ -55,6 +56,7 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory: self._credentials_provider_factory = credentials_provider_factory self._session = session + self._region_utils = RegionUtils() telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) @@ -78,7 +80,11 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl host = IamAuthUtils.get_iam_host(props, host_info) port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) - region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session) + region = self._region_utils.get_region(props, WrapperProperties.IAM_REGION.name, host, self._session) + if not region: + error_message = "RdsUtils.UnsupportedHostname" + logger.debug(error_message, host) + raise AwsWrapperError(Messages.get_formatted(error_message, host)) user = WrapperProperties.DB_USER.get(props) cache_key: str = IamAuthUtils.get_cache_key( diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 69a458c7..3c108284 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -18,6 +18,8 @@ from aws_advanced_python_wrapper.aurora_initial_connection_strategy_plugin import \ AuroraInitialConnectionStrategyPluginFactory +from aws_advanced_python_wrapper.custom_endpoint_plugin import \ + CustomEndpointPluginFactory from aws_advanced_python_wrapper.fastest_response_strategy_plugin import \ FastestResponseStrategyPluginFactory from aws_advanced_python_wrapper.federated_plugin import \ @@ -27,6 +29,7 @@ SessionStateService, SessionStateServiceImpl) if TYPE_CHECKING: + from aws_advanced_python_wrapper.allowed_and_blocked_hosts import AllowedAndBlockedHosts from aws_advanced_python_wrapper.driver_dialect import DriverDialect from aws_advanced_python_wrapper.driver_dialect_manager import DriverDialectManager from aws_advanced_python_wrapper.pep249 import Connection @@ -109,11 +112,25 @@ def plugin_manager(self, value): class PluginService(ExceptionHandler, Protocol): + @property + @abstractmethod + def all_hosts(self) -> Tuple[HostInfo, ...]: + ... + @property @abstractmethod def hosts(self) -> Tuple[HostInfo, ...]: ... + @property + @abstractmethod + def allowed_and_blocked_hosts(self) -> Optional[AllowedAndBlockedHosts]: + ... + + @allowed_and_blocked_hosts.setter + def allowed_and_blocked_hosts(self, allowed_and_blocked_hosts: Optional[AllowedAndBlockedHosts]): + ... + @property @abstractmethod def current_connection(self) -> Optional[Connection]: @@ -279,7 +296,8 @@ def __init__( self._original_url = PropertiesUtils.get_url(props) self._host_list_provider: HostListProvider = ConnectionStringHostListProvider(self, props) - self._hosts: Tuple[HostInfo, ...] = () + self._all_hosts: Tuple[HostInfo, ...] = () + self._allowed_and_blocked_hosts: Optional[AllowedAndBlockedHosts] = None self._current_connection: Optional[Connection] = None self._current_host_info: Optional[HostInfo] = None self._initial_connection_host_info: Optional[HostInfo] = None @@ -292,13 +310,35 @@ def __init__( self._database_dialect = self._dialect_provider.get_dialect(driver_dialect.dialect_code, props) self._session_state_service = session_state_service if session_state_service is not None else SessionStateServiceImpl(self, props) + @property + def all_hosts(self) -> Tuple[HostInfo, ...]: + return self._all_hosts + @property def hosts(self) -> Tuple[HostInfo, ...]: - return self._hosts + host_permissions = self.allowed_and_blocked_hosts + if host_permissions is None: + return self._all_hosts + + hosts = self._all_hosts + allowed_ids = host_permissions.allowed_host_ids + blocked_ids = host_permissions.blocked_host_ids + + if allowed_ids is not None: + hosts = tuple(host for host in hosts if host.host_id in allowed_ids) + + if blocked_ids is not None: + hosts = tuple(host for host in hosts if host.host_id not in blocked_ids) + + return hosts + + @property + def allowed_and_blocked_hosts(self) -> Optional[AllowedAndBlockedHosts]: + return self._allowed_and_blocked_hosts - @hosts.setter - def hosts(self, new_hosts: Tuple[HostInfo, ...]): - self._hosts = new_hosts + @allowed_and_blocked_hosts.setter + def allowed_and_blocked_hosts(self, allowed_and_blocked_hosts: Optional[AllowedAndBlockedHosts]): + self._allowed_and_blocked_hosts = allowed_and_blocked_hosts @property def current_connection(self) -> Optional[Connection]: @@ -453,14 +493,14 @@ def get_host_role(self, connection: Optional[Connection] = None) -> HostRole: def refresh_host_list(self, connection: Optional[Connection] = None): connection = self.current_connection if connection is None else connection updated_host_list: Tuple[HostInfo, ...] = self.host_list_provider.refresh(connection) - if updated_host_list != self.hosts: + if updated_host_list != self._all_hosts: self._update_host_availability(updated_host_list) self._update_hosts(updated_host_list) def force_refresh_host_list(self, connection: Optional[Connection] = None): connection = self.current_connection if connection is None else connection updated_host_list: Tuple[HostInfo, ...] = self.host_list_provider.force_refresh(connection) - if updated_host_list != self.hosts: + if updated_host_list != self._all_hosts: self._update_host_availability(updated_host_list) self._update_hosts(updated_host_list) @@ -546,12 +586,12 @@ def _update_host_availability(self, hosts: Tuple[HostInfo, ...]): host.set_availability(availability) def _update_hosts(self, new_hosts: Tuple[HostInfo, ...]): - old_hosts_dict = {x.url: x for x in self.hosts} + old_hosts_dict = {x.url: x for x in self._all_hosts} new_hosts_dict = {x.url: x for x in new_hosts} changes: Dict[str, Set[HostEvent]] = {} - for host in self.hosts: + for host in self._all_hosts: corresponding_new_host = new_hosts_dict.get(host.url) if corresponding_new_host is None: changes[host.url] = {HostEvent.HOST_DELETED} @@ -565,7 +605,7 @@ def _update_hosts(self, new_hosts: Tuple[HostInfo, ...]): changes[key] = {HostEvent.HOST_ADDED} if len(changes) > 0: - self.hosts = tuple(new_hosts) if new_hosts is not None else () + self._all_hosts = tuple(new_hosts) if new_hosts is not None else () self._container.plugin_manager.notify_host_list_changed(changes) def _compare(self, host_a: HostInfo, host_b: HostInfo) -> Set[HostEvent]: @@ -622,6 +662,7 @@ class PluginManager(CanReleaseResources): "read_write_splitting": ReadWriteSplittingPluginFactory, "fastest_response_strategy": FastestResponseStrategyPluginFactory, "stale_dns": StaleDnsPluginFactory, + "custom_endpoint": CustomEndpointPluginFactory, "connect_time": ConnectTimePluginFactory, "execute_time": ExecuteTimePluginFactory, "dev": DeveloperPluginFactory, @@ -636,6 +677,7 @@ class PluginManager(CanReleaseResources): # the highest values. The first plugin of the list will have the lowest weight, and the # last one will have the highest weight. PLUGIN_FACTORY_WEIGHTS: Dict[Type[PluginFactory], int] = { + CustomEndpointPluginFactory: 40, AuroraInitialConnectionStrategyPluginFactory: 50, AuroraConnectionTrackerPluginFactory: 100, StaleDnsPluginFactory: 200, diff --git a/aws_advanced_python_wrapper/read_write_splitting_plugin.py b/aws_advanced_python_wrapper/read_write_splitting_plugin.py index 60e7c3c8..7b00255f 100644 --- a/aws_advanced_python_wrapper/read_write_splitting_plugin.py +++ b/aws_advanced_python_wrapper/read_write_splitting_plugin.py @@ -263,6 +263,11 @@ def _switch_to_reader_connection(self, hosts: Tuple[HostInfo, ...]): self._is_connection_usable(current_conn, driver_dialect)): return + hostnames = [host_info.host for host_info in hosts] + if self._reader_host_info is not None and self._reader_host_info.host not in hostnames: + # The old reader cannot be used anymore because it is no longer in the list of allowed hosts. + self._close_connection_if_idle(self._reader_connection) + self._in_read_write_split = True if not self._is_connection_usable(self._reader_connection, driver_dialect): self._initialize_reader_connection(hosts) diff --git a/aws_advanced_python_wrapper/reader_failover_handler.py b/aws_advanced_python_wrapper/reader_failover_handler.py index 4b9be223..bd10f0f6 100644 --- a/aws_advanced_python_wrapper/reader_failover_handler.py +++ b/aws_advanced_python_wrapper/reader_failover_handler.py @@ -126,7 +126,7 @@ def _internal_failover_task( self._plugin_service.force_refresh_host_list(result.connection) if result.new_host is not None: - topology = self._plugin_service.hosts + topology = self._plugin_service.all_hosts for host in topology: # found new connection host in the latest topology if host.url == result.new_host.url and host.role == HostRole.READER: diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index 5fbca7b5..3c0b341e 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -48,6 +48,20 @@ ConnectionProvider.UnsupportedHostSelectorStrategy=[ConnectionProvider] Unsuppor ConnectionStringHostListProvider.UnsupportedMethod = [ConnectionStringHostListProvider] ConnectionStringHostListProvider does not support {}. +CustomEndpointMonitor.DetectedChangeInCustomEndpointInfo=[CustomEndpointMonitor] Detected change in custom endpoint info for '{}':\n{} +CustomEndpointMonitor.Exception=[CustomEndpointMonitor] Encountered an exception while monitoring custom endpoint '{}': {}. +CustomEndpointMonitor.Interrupted=[CustomEndpointMonitor] Custom endpoint monitor for '{}' was interrupted. +CustomEndpointMonitor.StartingMonitor=[CustomEndpointMonitor] Starting custom endpoint monitor for '{}'. +CustomEndpointMonitor.StoppedMonitor=[CustomEndpointMonitor] Stopped custom endpoint monitor for '{}'. +CustomEndpointMonitor.StoppingMonitor=[CustomEndpointMonitor] Stopping custom endpoint monitor for '{}'. +CustomEndpointMonitor.UnexpectedNumberOfEndpoints=[CustomEndpointMonitor] Unexpected number of custom endpoints with endpoint identifier '{}' in region '{}'. Expected 1, but found {}. Endpoints:\n{}. + +CustomEndpointPlugin.TimedOutWaitingForCustomEndpointInfo=[CustomEndpointPlugin] The custom endpoint plugin timed out after {}ms while waiting for custom endpoint info for host '{}'. +CustomEndpointPlugin.ConnectionRequestToCustomEndpoint=[CustomEndpointPlugin] Detected a connection request to a custom endpoint URL: '{}'. +CustomEndpointPlugin.ErrorParsingEndpointIdentifier=[CustomEndpointPlugin] Unable to parse custom endpoint identifier from URL: '{}'. +CustomEndpointPlugin.InterruptedThread=[CustomEndpointPlugin] The custom endpoint plugin was interrupted while waiting for custom endpoint info for host '{}'. +CustomEndpointPlugin.WaitingForCustomEndpointInfo=[CustomEndpointPlugin] Custom endpoint info for '{}' was not found. Waiting {}ms for the endpoint monitor to fetch info... + DefaultPlugin.EmptyHosts=[DefaultPlugin] The default connection plugin received an empty host list from the plugin service. DefaultPlugin.UnknownHosts=[DefaultPlugin] A HostInfo with the role of HostRole.UNKNOWN was requested via get_host_info_by_strategy. The requested role must be either HostRole.WRITER or HostRole.READER. @@ -88,6 +102,7 @@ FailoverPlugin.DetectedException=[Failover] Detected an exception while executin FailoverPlugin.EstablishedConnection=[Failover] Connected to: {} FailoverPlugin.FailoverDisabled=[Failover] Cluster-aware failover is disabled. FailoverPlugin.InvalidHost=[Failover] Host is no longer available in the topology: {} +FailoverPlugin.NewWriterNotAllowed=[Failover] The failover process identified the new writer but the host is not in the list of allowed hosts. New writer host: '{}'. Allowed hosts {} FailoverPlugin.NoOperationsAfterConnectionClosed=[Failover] No operations allowed after connection closed. FailoverPlugin.ParameterValue=[Failover] {}={} FailoverPlugin.StartReaderFailover=[Failover] Starting reader failover procedure. @@ -195,7 +210,7 @@ PluginServiceImpl.UpdateDialectConnectionNone=[PluginServiceImpl] The plugin ser PropertiesUtils.ErrorParsingConnectionString=[PropertiesUtils] An error occurred while parsing the connection string: '{}'. Please ensure the format of your connection string is valid. PropertiesUtils.InvalidPgSchemeUrl=[PropertiesUtils] PropertiesUtils.parse_pg_scheme_url was called, but the passed in string did not begin with 'postgresql://' or 'postgres://'. Detected connection string: '{}'. -PropertiesUtils.MultipleHostsNotSupported=[PropertiesUtils] Connection strings containing multiple hosts are not supported by the wrapper driver. If you are using an Aurora database, please specify only the initial instance that you would like to connect to. The cluster topology will be automatically discovered. Detected connection string: `{}` +PropertiesUtils.MultipleHostsNotSupported=[PropertiesUtils] Connection strings containing multiple hosts are not supported by the wrapper driver. If you are using an Aurora database, please specify only the initial instance that you would like to connect to. The cluster topology will be automatically discovered. Detected connection string: '{}' PropertiesUtils.NoHostDefined=[PropertiesUtils] PropertiesUtils.get_url was called but no host was defined in the properties. Please ensure you pass in a 'host' parameter when connecting. RdsHostListProvider.ClusterInstanceHostPatternNotSupportedForRDSCustom=[RdsHostListProvider] An RDS Custom url can't be used as the 'cluster_instance_host_pattern' configuration setting. @@ -220,8 +235,10 @@ RdsHostListProvider.UninitializedInitialHostInfo=[RdsHostListProvider] The drive RdsPgDialect.RdsToolsAuroraUtils=[RdsPgDialect] rds_tools: {}, aurora_utils: {} RdsTestUtility.ClusterMemberNotFound=[RdsTestUtility] Cannot find cluster member whose db instance identifier is '{}'. -RdsTestUtility.CreateDBInstanceFailed=[RdsTestUtility] Could not create database instance `{}`. +RdsTestUtility.CreateDBInstanceFailed=[RdsTestUtility] Could not create database instance '{}'. +RdsTestUtility.FailoverClusterFailed=[RdsTestUtility] Failed to request a cluster failover for cluster '{}'. RdsTestUtility.FailoverRequestNotSuccessful=[RdsTestUtility] Failover cluster request was not successful. +RdsTestUtility.FailoverToTargetNotSupported=[RdsTestUtility] Failover to target instance '{}' was requested, but failover to a target is not supported for {} deployments. RdsTestUtility.InstanceDescriptionTimeout=[RdsTestUtility] Instance description timeout for {}. The instance did not reach status '{}' within {} minutes. RdsTestUtility.InvalidDatabaseEngine=[RdsTestUtility] The detected database engine is not valid: {} RdsTestUtility.MethodNotSupportedForDeployment=[RdsTestUtility] Method '{}' is not supported for the current database engine deployment: '{}' @@ -247,7 +264,7 @@ ReadWriteSplittingPlugin.FailoverExceptionWhileExecutingCommand=[ReadWriteSplitt ReadWriteSplittingPlugin.FallbackToWriter=[ReadWriteSplittingPlugin] Failed to switch to a reader; the current writer will be used as a fallback: '{}' ReadWriteSplittingPlugin.NoReadersAvailable=[ReadWriteSplittingPlugin] The plugin was unable to establish a reader connection to any reader instance. ReadWriteSplittingPlugin.NoReadersFound=[ReadWriteSplittingPlugin] A reader instance was requested via set_read_only, but there are no readers in the host list. The current writer will be used as a fallback: '{}' -ReadWriteSplittingPlugin.NoWriterFound=[ReadWriteSplittingPlugin] No writer was found in the current host list. +ReadWriteSplittingPlugin.NoWriterFound=[ReadWriteSplittingPlugin] No writer was found in the current host list. This may occur if the writer is not in the list of allowed hosts. ReadWriteSplittingPlugin.SetReaderConnection=[ReadWriteSplittingPlugin] Reader connection set to '{}' ReadWriteSplittingPlugin.SetReadOnlyFalseInTransaction=[ReadWriteSplittingPlugin] set_read_only(false) was called on a read-only connection inside a transaction. Please complete the transaction before calling set_read_only(false). ReadWriteSplittingPlugin.SetReadOnlyOnClosedConnection=[ReadWriteSplittingPlugin] set_read_only cannot be called on a closed connection. @@ -271,6 +288,7 @@ SqlAlchemyPooledConnectionProvider.UnableToCreateDefaultKey=[SqlAlchemyPooledCon SqlAlchemyDriverDialect.SetValueOnNoneConnection=[SqlAlchemyDriverDialect] Attempted to set the '{}' value on a pooled connection, but no underlying driver connection was found. This can happen if the pooled connection has previously been closed. StaleDnsHelper.ClusterEndpointDns=[StaleDnsPlugin] Cluster endpoint {} resolves to {}. +StaleDnsHelper.CurrentWriterNotAllowed=[StaleDnsPlugin] The current writer is not in the list of allowed hosts. Current host: '{}'. Allowed hosts: {} StaleDnsHelper.Reset=[StaleDnsPlugin] Reset stored writer host. StaleDnsHelper.StaleDnsDetected=[StaleDnsPlugin] Stale DNS data detected. Opening a connection to '{}'. StaleDnsHelper.WriterHostSpec=[StaleDnsPlugin] Writer host: {} diff --git a/aws_advanced_python_wrapper/stale_dns_plugin.py b/aws_advanced_python_wrapper/stale_dns_plugin.py index 785aeb90..6a3e815e 100644 --- a/aws_advanced_python_wrapper/stale_dns_plugin.py +++ b/aws_advanced_python_wrapper/stale_dns_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.plugin_service import PluginService from aws_advanced_python_wrapper.utils.properties import Properties +from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.hostinfo import HostRole from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory from aws_advanced_python_wrapper.utils.log import Logger @@ -82,7 +83,7 @@ def get_verified_connection(self, is_initial_connection: bool, host_list_provide else: self._plugin_service.refresh_host_list(conn) - logger.debug("LogUtils.Topology", LogUtils.log_topology(self._plugin_service.hosts)) + logger.debug("LogUtils.Topology", LogUtils.log_topology(self._plugin_service.all_hosts)) if self._writer_host_info is None: writer_candidate: Optional[HostInfo] = self._get_writer() @@ -110,6 +111,15 @@ def get_verified_connection(self, is_initial_connection: bool, host_list_provide if self._writer_host_address != cluster_inet_address: logger.debug("StaleDnsHelper.StaleDnsDetected", self._writer_host_info) + allowed_hosts = self._plugin_service.hosts + allowed_hostnames = [host.host for host in allowed_hosts] + if self._writer_host_info.host not in allowed_hostnames: + raise AwsWrapperError( + Messages.get_formatted( + "StaleDnsHelper.CurrentWriterNotAllowed", + "" if self._writer_host_info is None else self._writer_host_info.host, + LogUtils.log_topology(allowed_hosts))) + writer_conn: Connection = self._plugin_service.connect(self._writer_host_info, props) if is_initial_connection: host_list_provider_service.initial_connection_host_info = self._writer_host_info @@ -134,7 +144,7 @@ def notify_host_list_changed(self, changes: Dict[str, Set[HostEvent]]) -> None: self._writer_host_address = None def _get_writer(self) -> Optional[HostInfo]: - for host in self._plugin_service.hosts: + for host in self._plugin_service.all_hosts: if host.role == HostRole.WRITER: return host return None diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index ba5a6924..ecb5868f 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -70,25 +70,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" - @staticmethod - def get_rds_region(rds_utils: RdsUtils, hostname: Optional[str], props: Properties, client_session: Optional[Session] = None) -> str: - rds_region = WrapperProperties.IAM_REGION.get(props) - if rds_region is None or rds_region == "": - rds_region = rds_utils.get_rds_region(hostname) - - if not rds_region: - error_message = "RdsUtils.UnsupportedHostname" - logger.debug(error_message, hostname) - raise AwsWrapperError(Messages.get_formatted(error_message, hostname)) - - session = client_session if client_session else boto3.Session() - if rds_region not in session.get_available_regions("rds"): - error_message = "AwsSdk.UnsupportedRegion" - logger.debug(error_message, rds_region) - raise AwsWrapperError(Messages.get_formatted(error_message, rds_region)) - - return rds_region - @staticmethod def generate_authentication_token( plugin_service: PluginService, diff --git a/aws_advanced_python_wrapper/utils/properties.py b/aws_advanced_python_wrapper/utils/properties.py index 93825d56..93712659 100644 --- a/aws_advanced_python_wrapper/utils/properties.py +++ b/aws_advanced_python_wrapper/utils/properties.py @@ -195,6 +195,28 @@ class WrapperProperties: "Reader connection attempt timeout in seconds during a reader failover process.", 30) + # CustomEndpointPlugin + CUSTOM_ENDPOINT_INFO_REFRESH_RATE_MS = WrapperProperty( + "custom_endpoint_info_refresh_rate_ms", + "Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds.", + 30_000) + CUSTOM_ENDPOINT_IDLE_MONITOR_EXPIRATION_MS = WrapperProperty( + "custom_endpoint_idle_monitor_expiration_ms", + "Controls how long a monitor should run without use before expiring and being removed, in milliseconds.", + 900_000) # 15 minutes + WAIT_FOR_CUSTOM_ENDPOINT_INFO = WrapperProperty( + "wait_for_custom_endpoint_info", + """Controls whether to wait for custom endpoint info to become available before connecting or executing a + method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used + recently. Note that disabling this may result in occasional connections to instances outside of the custom + endpoint.""", + True) + WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = WrapperProperty( + "wait_for_custom_endpoint_info_timeout_ms", + """Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made + available by the custom endpoint monitor, in milliseconds.""", + 5_000) + # Host Availability Strategy DEFAULT_HOST_AVAILABILITY_STRATEGY = WrapperProperty( "default_host_availability_strategy", diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 60d340ec..d1f0f812 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -195,6 +195,15 @@ def get_rds_cluster_host_url(self, host: str): return None + def get_cluster_id(self, host: str) -> Optional[str]: + if host is None or not host.strip(): + return None + + if self._get_dns_group(host) is not None: + return self._get_group(host, self.INSTANCE_GROUP) + + return None + def get_instance_id(self, host: str) -> Optional[str]: if self._get_dns_group(host) is None: return self._get_group(host, self.INSTANCE_GROUP) diff --git a/aws_advanced_python_wrapper/utils/region_utils.py b/aws_advanced_python_wrapper/utils/region_utils.py new file mode 100644 index 00000000..7a0675be --- /dev/null +++ b/aws_advanced_python_wrapper/utils/region_utils.py @@ -0,0 +1,58 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Optional + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.utils.properties import Properties + +from boto3 import Session + +from aws_advanced_python_wrapper.errors import AwsWrapperError +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.messages import Messages +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils + +logger = Logger(__name__) + + +class RegionUtils: + def __init__(self): + self._rds_utils = RdsUtils() + + def get_region(self, + props: Properties, + prop_key: str, + hostname: Optional[str] = None, + session: Optional[Session] = None) -> Optional[str]: + region = props.get(prop_key) + if region: + return self.verify_region(region, session) + + return self.get_region_from_hostname(hostname, session) + + def get_region_from_hostname(self, hostname: Optional[str], session: Optional[Session] = None) -> Optional[str]: + region = self._rds_utils.get_rds_region(hostname) + return self.verify_region(region, session) if region else None + + def verify_region(self, region: str, session: Optional[Session] = None) -> str: + session = session if session is not None else Session() + if region not in session.get_available_regions("rds"): + error_message = "AwsSdk.UnsupportedRegion" + logger.debug(error_message, region) + raise AwsWrapperError(Messages.get_formatted(error_message, region)) + + return region diff --git a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py index f109ff8d..4e5e97f0 100644 --- a/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py +++ b/aws_advanced_python_wrapper/utils/sliding_expiration_cache.py @@ -16,8 +16,7 @@ from concurrent.futures import Executor, ThreadPoolExecutor from time import perf_counter_ns, sleep -from typing import (Callable, ClassVar, Generic, ItemsView, KeysView, Optional, - TypeVar) +from typing import Callable, Generic, ItemsView, KeysView, Optional, TypeVar from aws_advanced_python_wrapper.utils.atomic import AtomicInt from aws_advanced_python_wrapper.utils.concurrent import ConcurrentDict @@ -74,9 +73,22 @@ def _remove_and_dispose(self, key: K): self._item_disposal_func(cache_item.item) def _remove_if_expired(self, key: K): - cache_item = self._cdict.get(key) - if cache_item is None or self._should_cleanup_item(cache_item): - self._remove_and_dispose(key) + item = None + + def _remove_if_expired_internal(_, cache_item): + if self._should_cleanup_item(cache_item): + nonlocal item + item = cache_item.item + return None + + return cache_item + + self._cdict.compute_if_present(key, _remove_if_expired_internal) + + if item is None or self._item_disposal_func is None: + return + + self._item_disposal_func(item) def _should_cleanup_item(self, cache_item: CacheItem) -> bool: if self._should_dispose_func is not None: @@ -101,19 +113,17 @@ def _cleanup(self): class SlidingExpirationCacheWithCleanupThread(SlidingExpirationCache, Generic[K, V]): - - _executor: ClassVar[Executor] = ThreadPoolExecutor(thread_name_prefix="SlidingExpirationCacheWithCleanupThreadExecutor") - def __init__( self, cleanup_interval_ns: int = 10 * 60_000_000_000, # 10 minutes should_dispose_func: Optional[Callable] = None, item_disposal_func: Optional[Callable] = None): super().__init__(cleanup_interval_ns, should_dispose_func, item_disposal_func) + self._executor: Executor = ThreadPoolExecutor(thread_name_prefix="SlidingExpirationCacheWithCleanupThreadExecutor") self.init_cleanup_thread() def init_cleanup_thread(self) -> None: - SlidingExpirationCacheWithCleanupThread._executor.submit(self._cleanup_thread_internal) + self._executor.submit(self._cleanup_thread_internal) def _cleanup_thread_internal(self): logger.debug("SlidingExpirationCache.CleaningUp") @@ -127,7 +137,7 @@ def _cleanup_thread_internal(self): except Exception: pass # ignore - SlidingExpirationCacheWithCleanupThread._executor.shutdown() + self._executor.shutdown() def _cleanup(self): pass # do nothing, cleanup thread does the job diff --git a/aws_advanced_python_wrapper/writer_failover_handler.py b/aws_advanced_python_wrapper/writer_failover_handler.py index 600c24d2..db1fe24d 100644 --- a/aws_advanced_python_wrapper/writer_failover_handler.py +++ b/aws_advanced_python_wrapper/writer_failover_handler.py @@ -174,7 +174,7 @@ def reconnect_to_writer(self, initial_writer_host: HostInfo): conn = self._plugin_service.force_connect(initial_writer_host, self._initial_connection_properties, self._timeout_event) self._plugin_service.force_refresh_host_list(conn) - latest_topology = self._plugin_service.hosts + latest_topology = self._plugin_service.all_hosts except Exception as ex: if not self._plugin_service.is_network_exception(ex): @@ -267,7 +267,7 @@ def refresh_topology_and_connect_to_new_writer(self, initial_writer_host: HostIn while not self._timeout_event.is_set(): try: self._plugin_service.force_refresh_host_list(self._current_reader_connection) - current_topology: Tuple[HostInfo, ...] = self._plugin_service.hosts + current_topology: Tuple[HostInfo, ...] = self._plugin_service.all_hosts if len(current_topology) > 0: if len(current_topology) == 1: diff --git a/docs/development-guide/LoadablePlugins.md b/docs/development-guide/LoadablePlugins.md index a4dd3550..21181d70 100644 --- a/docs/development-guide/LoadablePlugins.md +++ b/docs/development-guide/LoadablePlugins.md @@ -171,12 +171,12 @@ class GoodPlugin(Plugin): def subscribed_methods(self) -> Set[str]: return {"*"} - + def execute(self, target: type, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any: if len(self._plugin_service.hosts) == 0: # Re-fetch host info if it is empty. self._plugin_service.force_refresh_host_list() - + return execute_func() def connect( diff --git a/docs/using-the-python-driver/using-plugins/UsingTheCustomEndpointPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheCustomEndpointPlugin.md new file mode 100644 index 00000000..2b656e73 --- /dev/null +++ b/docs/using-the-python-driver/using-plugins/UsingTheCustomEndpointPlugin.md @@ -0,0 +1,25 @@ +# Custom Endpoint Plugin + +The Custom Endpoint Plugin adds support for RDS custom endpoints. When the Custom Endpoint Plugin is in use, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover and read-write splitting. + +## Prerequisites +- This plugin requires the AWS SDK for Python, [Boto3](https://pypi.org/project/boto3/). Boto3 is a runtime dependency and must be resolved. It can be installed via pip like so: `pip install boto3`. + +## How to use the Custom Endpoint Plugin with the AWS Advanced Python Driver + +### Enabling the Custom Endpoint Plugin + +1. If needed, create a custom endpoint using the AWS RDS Console: + - If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html). +2. Add the plugin code `custom_endpoint` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter value, or to the current [driver profile](../UsingThePythonDriver.md#connection-plugin-manager-parameters). +3. If you are using the failover plugin, set the failover parameter `failover_mode` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `failover_mode` to `strict_reader`, or if it is of type `ANY`, you can set `failover_mode` to `reader_or_writer`. +4. Specify parameters that are required or specific to your case. + +### Custom Endpoint Plugin Parameters + +| Parameter | Value | Required | Description | Default Value | Example Value | +|----------------------------------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------|---------------| +| `custom_endpoint_info_refresh_rate_ms` | Integer | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `30000` | `20000` | +| `custom_endpoint_idle_monitor_expiration_ms` | Integer | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` | +| `wait_for_custom_endpoint_info` | Boolean | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` | +| `wait_for_custom_endpoint_info_timeout_ms` | Integer | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `5000` | `7000` | diff --git a/tests/integration/container/conftest.py b/tests/integration/container/conftest.py index 3a8358ae..eefabcaf 100644 --- a/tests/integration/container/conftest.py +++ b/tests/integration/container/conftest.py @@ -21,6 +21,8 @@ from aws_advanced_python_wrapper.connection_provider import \ ConnectionProviderManager +from aws_advanced_python_wrapper.custom_endpoint_plugin import ( + CustomEndpointMonitor, CustomEndpointPlugin) from aws_advanced_python_wrapper.database_dialect import DatabaseDialectManager from aws_advanced_python_wrapper.driver_dialect_manager import \ DriverDialectManager @@ -131,6 +133,8 @@ def pytest_runtest_setup(item): RdsHostListProvider._cluster_ids_to_update.clear() PluginServiceImpl._host_availability_expiring_cache.clear() DatabaseDialectManager._known_endpoint_dialects.clear() + CustomEndpointPlugin._monitors.clear() + CustomEndpointMonitor._custom_endpoint_info_cache.clear() ConnectionProviderManager.reset_provider() DatabaseDialectManager.reset_custom_dialect() diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py new file mode 100644 index 00000000..37fd3a99 --- /dev/null +++ b/tests/integration/container/test_custom_endpoint.py @@ -0,0 +1,267 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar, Dict, Set + +if TYPE_CHECKING: + from tests.integration.container.utils.test_driver import TestDriver + +from time import perf_counter_ns, sleep +from uuid import uuid4 + +import pytest +from boto3 import client +from botocore.exceptions import ClientError + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import (FailoverSuccessError, + ReadWriteSplittingError) +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from tests.integration.container.utils.conditions import ( + disable_on_features, enable_on_deployments, enable_on_num_instances) +from tests.integration.container.utils.database_engine_deployment import \ + DatabaseEngineDeployment +from tests.integration.container.utils.driver_helper import DriverHelper +from tests.integration.container.utils.rds_test_utility import RdsTestUtility +from tests.integration.container.utils.test_environment import TestEnvironment +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + + +@enable_on_num_instances(min_instances=3) +@enable_on_deployments([DatabaseEngineDeployment.AURORA]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +class TestCustomEndpoint: + logger: ClassVar[Logger] = Logger(__name__) + endpoint_id: ClassVar[str] = f"test-endpoint-1-{uuid4()}" + endpoint_info: ClassVar[Dict[str, Any]] = {} + reuse_existing_endpoint: ClassVar[bool] = False + + @pytest.fixture(scope='class') + def rds_utils(self): + region: str = TestEnvironment.get_current().get_info().get_region() + return RdsTestUtility(region) + + @pytest.fixture(scope='class') + def props(self): + p: Properties = Properties( + {"plugins": "custom_endpoint,read_write_splitting,failover", "connect_timeout": 10_000, "autocommit": True}) + + features = TestEnvironment.get_current().get_features() + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in features \ + or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in features: + WrapperProperties.ENABLE_TELEMETRY.set(p, True) + WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.set(p, True) + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in features: + WrapperProperties.TELEMETRY_TRACES_BACKEND.set(p, "XRAY") + if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in features: + WrapperProperties.TELEMETRY_METRICS_BACKEND.set(p, "OTLP") + + return p + + @pytest.fixture(scope='class', autouse=True) + def setup_and_teardown(self): + env_info = TestEnvironment.get_current().get_info() + region = env_info.get_region() + + rds_client = client('rds', region_name=region) + if not self.reuse_existing_endpoint: + instances = env_info.get_database_info().get_instances() + self._create_endpoint(rds_client, instances[0:1]) + + self.wait_until_endpoint_available(rds_client) + + yield + + if not self.reuse_existing_endpoint: + self.delete_endpoint(rds_client) + + rds_client.close() + + def wait_until_endpoint_available(self, rds_client): + end_ns = perf_counter_ns() + 5 * 60 * 1_000_000_000 # 5 minutes + available = False + + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints( + DBClusterEndpointIdentifier=self.endpoint_id, + Filters=[ + { + "Name": "db-cluster-endpoint-type", + "Values": ["custom"] + } + ] + ) + + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + sleep(3) # Endpoint needs more time to get created. + continue + + response_endpoint = response_endpoints[0] + TestCustomEndpoint.endpoint_info = response_endpoint + available = "available" == response_endpoint["Status"] + if available: + break + + sleep(3) + + if not available: + pytest.fail(f"The test setup step timed out while waiting for the test custom endpoint to become available: " + f"'{TestCustomEndpoint.endpoint_id}'.") + + def _create_endpoint(self, rds_client, instances): + instance_ids = [instance.get_instance_id() for instance in instances] + rds_client.create_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + DBClusterIdentifier=TestEnvironment.get_current().get_cluster_name(), + EndpointType="ANY", + StaticMembers=instance_ids + ) + + def delete_endpoint(self, rds_client): + try: + rds_client.delete_db_cluster_endpoint(DBClusterEndpointIdentifier=self.endpoint_id) + except ClientError as e: + # If the custom endpoint already does not exist, we can continue. Otherwise, fail the test. + if e.response['Error']['Code'] != 'DBClusterEndpointNotFoundFault': + pytest.fail(e) + + def wait_until_endpoint_has_members(self, rds_client, expected_members: Set[str]): + start_ns = perf_counter_ns() + end_ns = perf_counter_ns() + 20 * 60 * 1_000_000_000 # 20 minutes + has_correct_state = False + while perf_counter_ns() < end_ns: + response = rds_client.describe_db_cluster_endpoints(DBClusterEndpointIdentifier=self.endpoint_id) + response_endpoints = response["DBClusterEndpoints"] + if len(response_endpoints) != 1: + response_ids = [endpoint["DBClusterEndpointIdentifier"] for endpoint in response_endpoints] + pytest.fail("Unexpected number of endpoints returned while waiting for custom endpoint to have the " + f"specified list of members. Expected 1, got {len(response_endpoints)}. " + f"Endpoint IDs: {response_ids}.") + + endpoint = response_endpoints[0] + response_members = set(endpoint["StaticMembers"]) + has_correct_state = response_members == expected_members and "available" == endpoint["Status"] + if has_correct_state: + break + + sleep(3) + + if not has_correct_state: + pytest.fail(f"Timed out while waiting for the custom endpoint to stabilize: " + f"'{TestCustomEndpoint.endpoint_id}'.") + + duration_sec = (perf_counter_ns() - start_ns) / 1_000_000_000 + self.logger.debug(f"wait_until_endpoint_has_specified_members took {duration_sec} seconds.") + + def test_custom_endpoint_failover(self, test_driver: TestDriver, conn_utils, props, rds_utils): + props["failover_mode"] = "reader_or_writer" + + target_driver_connect = DriverHelper.get_connect_func(test_driver) + kwargs = conn_utils.get_connect_params() + kwargs["host"] = self.endpoint_info["Endpoint"] + conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props) + + endpoint_members = self.endpoint_info["StaticMembers"] + instance_id = rds_utils.query_instance_id(conn) + assert instance_id in endpoint_members + + # Use failover API to break connection. + target_id = None if instance_id == rds_utils.get_cluster_writer_instance_id() else instance_id + rds_utils.failover_cluster_and_wait_until_writer_changed(target_id=target_id) + + rds_utils.assert_first_query_throws(conn, FailoverSuccessError) + + instance_id = rds_utils.query_instance_id(conn) + assert instance_id in endpoint_members + + conn.close() + + def test_custom_endpoint_read_write_splitting__with_custom_endpoint_changes( + self, test_driver: TestDriver, conn_utils, props, rds_utils): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + kwargs = conn_utils.get_connect_params() + kwargs["host"] = self.endpoint_info["Endpoint"] + # This setting is not required for the test, but it allows us to also test re-creation of expired monitors since + # it takes more than 30 seconds to modify the cluster endpoint (usually around 140s). + props["custom_endpoint_idle_monitor_expiration_ms"] = 30_000 + conn = AwsWrapperConnection.connect(target_driver_connect, **kwargs, **props) + + endpoint_members = self.endpoint_info["StaticMembers"] + original_instance_id = rds_utils.query_instance_id(conn) + assert original_instance_id in endpoint_members + + # Attempt to switch to an instance of the opposite role. This should fail since the custom endpoint consists + # only of the current host. + new_read_only_value = original_instance_id == rds_utils.get_cluster_writer_instance_id() + if new_read_only_value: + # We are connected to the writer. Attempting to switch to the reader will not work but will intentionally + # not throw an exception. In this scenario we log a warning and purposefully stick with the writer. + self.logger.debug("Initial connection is to the writer. Attempting to switch to reader...") + conn.read_only = new_read_only_value + new_instance_id = rds_utils.query_instance_id(conn) + assert new_instance_id == original_instance_id + else: + # We are connected to the reader. Attempting to switch to the writer will throw an exception. + self.logger.debug("Initial connection is to a reader. Attempting to switch to writer...") + with pytest.raises(ReadWriteSplittingError): + conn.read_only = new_read_only_value + + instances = TestEnvironment.get_current().get_instances() + writer_id = rds_utils.get_cluster_writer_instance_id() + if original_instance_id == writer_id: + new_member = instances[1].get_instance_id() + else: + new_member = writer_id + + rds_client = client('rds', region_name=TestEnvironment.get_current().get_aurora_region()) + rds_client.modify_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + StaticMembers=[original_instance_id, new_member] + ) + + try: + self.wait_until_endpoint_has_members(rds_client, {original_instance_id, new_member}) + + # We should now be able to switch to new_member. + conn.read_only = new_read_only_value + new_instance_id = rds_utils.query_instance_id(conn) + assert new_instance_id == new_member + + # Switch back to original instance + conn.read_only = not new_read_only_value + finally: + rds_client.modify_db_cluster_endpoint( + DBClusterEndpointIdentifier=self.endpoint_id, + StaticMembers=[original_instance_id]) + self.wait_until_endpoint_has_members(rds_client, {original_instance_id}) + + # We should not be able to switch again because new_member was removed from the custom endpoint. + if new_read_only_value: + # We are connected to the writer. Attempting to switch to the reader will not work but will intentionally + # not throw an exception. In this scenario we log a warning and purposefully stick with the writer. + conn.read_only = new_read_only_value + new_instance_id = rds_utils.query_instance_id(conn) + assert new_instance_id == original_instance_id + else: + # We are connected to the reader. Attempting to switch to the writer will throw an exception. + with pytest.raises(ReadWriteSplittingError): + conn.read_only = new_read_only_value + + conn.close() diff --git a/tests/integration/container/utils/rds_test_utility.py b/tests/integration/container/utils/rds_test_utility.py index 9eabbbc1..bfdd96f6 100644 --- a/tests/integration/container/utils/rds_test_utility.py +++ b/tests/integration/container/utils/rds_test_utility.py @@ -128,7 +128,14 @@ def get_db_cluster(self, cluster_id: str) -> Any: return clusters[0] def failover_cluster_and_wait_until_writer_changed( - self, initial_writer_id: Optional[str] = None, cluster_id: Optional[str] = None) -> None: + self, + initial_writer_id: Optional[str] = None, + cluster_id: Optional[str] = None, + target_id: Optional[str] = None) -> None: + deployment = TestEnvironment.get_current().get_deployment() + if DatabaseEngineDeployment.RDS_MULTI_AZ == deployment and target_id is not None: + raise Exception(Messages.get_formatted("RdsTestUtility.FailoverToTargetNotSupported", target_id, deployment)) + start = perf_counter_ns() if cluster_id is None: cluster_id = TestEnvironment.get_current().get_info().get_cluster_name() @@ -140,14 +147,14 @@ def failover_cluster_and_wait_until_writer_changed( cluster_endpoint = database_info.get_cluster_endpoint() initial_cluster_address = socket.gethostbyname(cluster_endpoint) - self.failover_cluster(cluster_id) + self.failover_cluster(cluster_id, target_id) remaining_attempts = 5 while not self.writer_changed(initial_writer_id, cluster_id, 300): # if writer is not changed, try triggering failover again remaining_attempts -= 1 if remaining_attempts == 0: raise Exception(Messages.get("RdsTestUtility.FailoverRequestNotSuccessful")) - self.failover_cluster(cluster_id) + self.failover_cluster(cluster_id, target_id) # Failover has finished, wait for DNS to be updated so cluster endpoint resolves to the new writer instance. cluster_address = socket.gethostbyname(cluster_endpoint) @@ -158,7 +165,7 @@ def failover_cluster_and_wait_until_writer_changed( self.logger.debug("Testing.FinishedFailover", initial_writer_id, str((perf_counter_ns() - start) / 1_000_000)) - def failover_cluster(self, cluster_id: Optional[str] = None) -> None: + def failover_cluster(self, cluster_id: Optional[str] = None, target_id: Optional[str] = None) -> None: if cluster_id is None: cluster_id = TestEnvironment.get_current().get_info().get_cluster_name() @@ -168,7 +175,11 @@ def failover_cluster(self, cluster_id: Optional[str] = None) -> None: while remaining_attempts > 0: remaining_attempts -= 1 try: - result = self._client.failover_db_cluster(DBClusterIdentifier=cluster_id) + if not target_id: + result = self._client.failover_db_cluster(DBClusterIdentifier=cluster_id) + else: + result = self._client.failover_db_cluster( + DBClusterIdentifier=cluster_id, TargetDBInstanceIdentifier=target_id) http_status_code = result.get("ResponseMetadata").get("HTTPStatusCode") if result.get("DBCluster") is not None and http_status_code == 200: return @@ -176,6 +187,8 @@ def failover_cluster(self, cluster_id: Optional[str] = None) -> None: except Exception: sleep(1) + raise Exception(Messages.get_formatted("RdsTestUtility.FailoverClusterFailed", cluster_id)) + def writer_changed(self, initial_writer_id: str, cluster_id: str, timeout: int) -> bool: wait_until = timeit.default_timer() + timeout diff --git a/tests/integration/container/utils/test_environment.py b/tests/integration/container/utils/test_environment.py index 28184d34..c1c9d749 100644 --- a/tests/integration/container/utils/test_environment.py +++ b/tests/integration/container/utils/test_environment.py @@ -202,6 +202,9 @@ def get_instances(self) -> List[TestInstanceInfo]: def get_writer(self) -> TestInstanceInfo: return self.get_instances()[0] + def get_cluster_name(self) -> str: + return self.get_info().get_cluster_name() + def get_proxy_database_info(self) -> TestProxyDatabaseInfo: return self.get_info().get_proxy_database_info() diff --git a/tests/unit/test_aurora_connection_tracker.py b/tests/unit/test_aurora_connection_tracker.py index c0847f43..f9f30544 100644 --- a/tests/unit/test_aurora_connection_tracker.py +++ b/tests/unit/test_aurora_connection_tracker.py @@ -79,7 +79,7 @@ def props(): def test_track_new_instance_connection( mocker, mock_plugin_service, mock_rds_utils, mock_tracker, mock_cursor, mock_callable): host_info: HostInfo = HostInfo("instance1") - mock_plugin_service.hosts = [host_info] + mock_plugin_service.all_hosts = [host_info] mock_plugin_service.current_host_info = host_info mock_rds_utils.is_rds_instance.return_value = True mock_callable.return_value = mock_conn @@ -107,7 +107,7 @@ def test_invalidate_opened_connections( new_host = HostInfo("new-host") mock_callable.side_effect = expected_exception mock_hosts_prop = mocker.PropertyMock(side_effect=[(original_host,), (new_host,)]) - type(mock_plugin_service).hosts = mock_hosts_prop + type(mock_plugin_service).all_hosts = mock_hosts_prop plugin: AuroraConnectionTrackerPlugin = AuroraConnectionTrackerPlugin( mock_plugin_service, Properties(), mock_rds_utils, mock_tracker) diff --git a/tests/unit/test_failover_plugin.py b/tests/unit/test_failover_plugin.py index 2b09def7..4610b316 100644 --- a/tests/unit/test_failover_plugin.py +++ b/tests/unit/test_failover_plugin.py @@ -198,7 +198,7 @@ def test_update_topology( refresh_mock.assert_not_called() force_refresh_mock.assert_not_called() - type(plugin_service_mock).hosts = PropertyMock(return_value=[HostInfo("host")]) + type(plugin_service_mock).all_hosts = PropertyMock(return_value=[HostInfo("host")]) driver_dialect_mock.is_closed.return_value = False with mock.patch.object(plugin_service_mock, "force_refresh_host_list") as force_refresh_mock: @@ -288,7 +288,7 @@ def test_failover_writer_failed_failover_raises_error(plugin_service_mock, host_ host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] hosts: Tuple[HostInfo, ...] = (host, ) - type(plugin_service_mock).hosts = PropertyMock(return_value=hosts) + type(plugin_service_mock).all_hosts = PropertyMock(return_value=hosts) properties = Properties() WrapperProperties.ENABLE_FAILOVER.set(properties, "True") @@ -309,7 +309,7 @@ def test_failover_writer_failed_failover_with_no_result(plugin_service_mock, hos host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] hosts: Tuple[HostInfo, ...] = (host, ) - type(plugin_service_mock).hosts = PropertyMock(return_value=hosts) + type(plugin_service_mock).all_hosts = PropertyMock(return_value=hosts) writer_result_mock: WriterFailoverResult = MagicMock() get_connection_mock = PropertyMock() @@ -340,7 +340,7 @@ def test_failover_writer_success(plugin_service_mock, host_list_provider_service host: HostInfo = HostInfo("host") host._aliases = ["alias1", "alias2"] hosts: Tuple[HostInfo, ...] = (host, ) - type(plugin_service_mock).hosts = PropertyMock(return_value=hosts) + type(plugin_service_mock).all_hosts = PropertyMock(return_value=hosts) properties = Properties() WrapperProperties.ENABLE_FAILOVER.set(properties, "True") diff --git a/tests/unit/test_rds_utils.py b/tests/unit/test_rds_utils.py index dacd100f..d3d03bca 100644 --- a/tests/unit/test_rds_utils.py +++ b/tests/unit/test_rds_utils.py @@ -284,3 +284,33 @@ def test_get_rds_cluster_host_url(): def test_get_instance_id(host: str, expected_id: str): target = RdsUtils() assert target.get_instance_id(host) == expected_id + + +@pytest.mark.parametrize("expected, test_value", [ + ("database-test-name", us_east_region_cluster), + ("database-test-name", us_east_region_cluster_read_only), + (None, us_east_region_instance), + ("proxy-test-name", us_east_region_proxy), + ("custom-test-name", us_east_region_custom_domain), + ("database-test-name", china_region_cluster), + ("database-test-name", china_region_cluster_read_only), + (None, china_region_instance), + ("proxy-test-name", china_region_proxy), + ("custom-test-name", china_region_custom_domain), + ("database-test-name", china_alt_region_cluster), + ("database-test-name", china_alt_region_cluster_read_only), + (None, china_alt_region_instance), + ("proxy-test-name", china_alt_region_proxy), + ("custom-test-name", china_alt_region_custom_domain), + ("database-test-name", china_alt_region_limitless_db_shard_group), + ("database-test-name", us_isob_east_region_cluster), + ("database-test-name", us_isob_east_region_cluster_read_only), + (None, us_isob_east_region_instance), + ("proxy-test-name", us_isob_east_region_proxy), + ("custom-test-name", us_isob_east_region_custom_domain), + ("database-test-name", us_isob_east_region_limitless_db_shard_group), + ("database-test-name", us_gov_east_region_cluster), +]) +def test_get_cluster_id(expected, test_value): + target = RdsUtils() + assert target.get_cluster_id(test_value) == expected diff --git a/tests/unit/test_stale_dns_helper.py b/tests/unit/test_stale_dns_helper.py index 59fa1a47..b805b361 100644 --- a/tests/unit/test_stale_dns_helper.py +++ b/tests/unit/test_stale_dns_helper.py @@ -134,7 +134,7 @@ def test_get_verified_connection__cluster_inet_address_none(mocker, plugin_servi def test_get_verified_connection__no_writer_hostinfo(mocker, plugin_service_mock, host_list_provider_mock, default_properties, initial_conn_mock, connect_func_mock, reader_host_list, writer_cluster): target = StaleDnsHelper(plugin_service_mock) - plugin_service_mock.hosts = reader_host_list + plugin_service_mock.all_hosts = reader_host_list plugin_service_mock.get_host_role.return_value = HostRole.READER connect_func_mock.return_value = initial_conn_mock socket.gethostbyname = mocker.MagicMock(return_value='2.2.2.2') @@ -153,7 +153,7 @@ def test_get_verified_connection__writer_rds_cluster_dns_true(mocker, plugin_ser connect_func_mock.return_value = initial_conn_mock socket.gethostbyname = mocker.MagicMock(return_value='5.5.5.5') - plugin_service_mock.hosts = cluster_host_list + plugin_service_mock.all_hosts = cluster_host_list target = StaleDnsHelper(plugin_service_mock) return_conn = target.get_verified_connection(True, host_list_provider_mock, writer_cluster, default_properties, connect_func_mock) @@ -166,7 +166,7 @@ def test_get_verified_connection__writer_rds_cluster_dns_true(mocker, plugin_ser def test_get_verified_connection__writer_host_address_none(mocker, plugin_service_mock, host_list_provider_mock, default_properties, initial_conn_mock, connect_func_mock, writer_cluster, instance_host_list): target = StaleDnsHelper(plugin_service_mock) - plugin_service_mock.hosts = instance_host_list + plugin_service_mock.all_hosts = instance_host_list socket.gethostbyname = mocker.MagicMock(side_effect=['5.5.5.5', None]) connect_func_mock.return_value = initial_conn_mock @@ -179,7 +179,7 @@ def test_get_verified_connection__writer_host_address_none(mocker, plugin_servic def test_get_verified_connection__writer_host_info_none(mocker, plugin_service_mock, host_list_provider_mock, default_properties, initial_conn_mock, connect_func_mock, writer_cluster, reader_host_list): target = StaleDnsHelper(plugin_service_mock) - plugin_service_mock.hosts = reader_host_list + plugin_service_mock.all_hosts = reader_host_list socket.gethostbyname = mocker.MagicMock(side_effect=['5.5.5.5', None]) connect_func_mock.return_value = initial_conn_mock @@ -194,7 +194,7 @@ def test_get_verified_connection__writer_host_address_equals_cluster_inet_addres default_properties, initial_conn_mock, connect_func_mock, writer_cluster, instance_host_list): target = StaleDnsHelper(plugin_service_mock) - plugin_service_mock.hosts = instance_host_list + plugin_service_mock.all_hosts = instance_host_list socket.gethostbyname = mocker.MagicMock(side_effect=['5.5.5.5', '5.5.5.5']) connect_func_mock.return_value = initial_conn_mock @@ -209,6 +209,7 @@ def test_get_verified_connection__writer_host_address_not_equals_cluster_inet_ad writer_cluster, cluster_host_list): target = StaleDnsHelper(plugin_service_mock) target._writer_host_info = writer_cluster + plugin_service_mock.all_hosts = cluster_host_list plugin_service_mock.hosts = cluster_host_list socket.gethostbyname = mocker.MagicMock(side_effect=['5.5.5.5', '8.8.8.8']) connect_func_mock.return_value = initial_conn_mock @@ -228,6 +229,7 @@ def test_get_verified_connection__initial_connection_writer_host_address_not_equ writer_cluster, cluster_host_list): target = StaleDnsHelper(plugin_service_mock) target._writer_host_info = writer_cluster + plugin_service_mock.all_hosts = cluster_host_list plugin_service_mock.hosts = cluster_host_list socket.gethostbyname = mocker.MagicMock(side_effect=['5.5.5.5', '8.8.8.8']) connect_func_mock.return_value = initial_conn_mock diff --git a/tests/unit/test_writer_failover_handler.py b/tests/unit/test_writer_failover_handler.py index ec28eb70..71c8c150 100644 --- a/tests/unit/test_writer_failover_handler.py +++ b/tests/unit/test_writer_failover_handler.py @@ -125,7 +125,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: plugin_service_mock.force_connect.side_effect = force_connect_side_effect - plugin_service_mock.hosts = topology + plugin_service_mock.all_hosts = topology reader_failover_mock.get_reader_connection.side_effect = FailoverError("error") target: WriterFailoverHandler = WriterFailoverHandlerImpl( @@ -156,7 +156,7 @@ def test_reconnect_to_writer_slow_task_b( call(writer.as_aliases(), HostAvailability.AVAILABLE)] mock_hosts_property = mocker.PropertyMock(side_effect=chain([topology], cycle([new_topology]))) - type(plugin_service_mock).hosts = mock_hosts_property + type(plugin_service_mock).all_hosts = mock_hosts_property def force_connect_side_effect(host_info, _, __) -> Connection: if host_info == writer: @@ -208,7 +208,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: def get_reader_connection_side_effect(_): return ReaderFailoverResult(reader_a_connection_mock, True, reader_a, None) - plugin_service_mock.hosts = topology + plugin_service_mock.all_hosts = topology reader_failover_mock.get_reader_connection.side_effect = get_reader_connection_side_effect target: WriterFailoverHandler = WriterFailoverHandlerImpl( @@ -254,7 +254,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: def get_reader_connection_side_effect(_): return ReaderFailoverResult(reader_a_connection_mock, True, reader_a, None) - plugin_service_mock.hosts = new_topology + plugin_service_mock.all_hosts = new_topology reader_failover_mock.get_reader_connection.side_effect = get_reader_connection_side_effect target: WriterFailoverHandler = WriterFailoverHandlerImpl( @@ -302,7 +302,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: def get_reader_connection_side_effect(_): return ReaderFailoverResult(reader_a_connection_mock, True, reader_a, None) - plugin_service_mock.hosts = updated_topology + plugin_service_mock.all_hosts = updated_topology reader_failover_mock.get_reader_connection.side_effect = get_reader_connection_side_effect target: WriterFailoverHandler = WriterFailoverHandlerImpl( @@ -358,7 +358,7 @@ def force_connect_side_effect(host_info, _, timeout_event) -> Connection: def get_reader_connection_side_effect(_): return ReaderFailoverResult(reader_a_connection_mock, True, reader_a, None) - plugin_service_mock.hosts = new_topology + plugin_service_mock.all_hosts = new_topology reader_failover_mock.get_reader_connection.side_effect = get_reader_connection_side_effect target: WriterFailoverHandler = WriterFailoverHandlerImpl( @@ -404,7 +404,7 @@ def force_connect_side_effect(host_info, _, __) -> Connection: def get_reader_connection_side_effect(_): return ReaderFailoverResult(reader_a_connection_mock, True, reader_a, None) - plugin_service_mock.hosts = new_topology + plugin_service_mock.all_hosts = new_topology reader_failover_mock.get_reader_connection.side_effect = get_reader_connection_side_effect target: WriterFailoverHandler = WriterFailoverHandlerImpl(