Skip to content

Commit 61f696e

Browse files
authored
Implement the Custom Endpoint Plugin (#727)
1 parent 48ed88a commit 61f696e

32 files changed

+1001
-103
lines changed
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License").
4+
# You may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from typing import Optional, Set
16+
17+
18+
class AllowedAndBlockedHosts:
19+
def __init__(self, allowed_host_ids: Optional[Set[str]], blocked_host_ids: Optional[Set[str]]):
20+
self._allowed_host_ids = None if not allowed_host_ids else allowed_host_ids
21+
self._blocked_host_ids = None if not blocked_host_ids else blocked_host_ids
22+
23+
@property
24+
def allowed_host_ids(self) -> Optional[Set[str]]:
25+
return self._allowed_host_ids
26+
27+
@property
28+
def blocked_host_ids(self) -> Optional[Set[str]]:
29+
return self._blocked_host_ids

aws_advanced_python_wrapper/aurora_connection_tracker_plugin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,15 +201,15 @@ def _connect(self, host_info: HostInfo, connect_func: Callable):
201201

202202
def execute(self, target: object, method_name: str, execute_func: Callable, *args: Any, **kwargs: Any) -> Any:
203203
if self._current_writer is None or self._need_update_current_writer:
204-
self._current_writer = self._get_writer(self._plugin_service.hosts)
204+
self._current_writer = self._get_writer(self._plugin_service.all_hosts)
205205
self._need_update_current_writer = False
206206

207207
try:
208208
return execute_func()
209209

210210
except Exception as e:
211211
# Check that e is a FailoverError and that the writer has changed
212-
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.hosts) != self._current_writer:
212+
if isinstance(e, FailoverError) and self._get_writer(self._plugin_service.all_hosts) != self._current_writer:
213213
self._tracker.invalidate_all_connections(host_info=self._current_writer)
214214
self._tracker.log_opened_connections()
215215
self._need_update_current_writer = True

aws_advanced_python_wrapper/aurora_initial_connection_strategy_plugin.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def _delay(self, delay_ms: int):
198198
sleep(delay_ms / 1000)
199199

200200
def _get_writer(self) -> Optional[HostInfo]:
201-
for host in self._plugin_service.hosts:
201+
for host in self._plugin_service.all_hosts:
202202
if host.role == HostRole.WRITER:
203203
return host
204204

@@ -225,10 +225,10 @@ def init_host_provider(self, props: Properties, host_list_provider_service: Host
225225
init_host_provider_func(props)
226226

227227
def _has_no_readers(self) -> bool:
228-
if len(self._plugin_service.hosts) == 0:
228+
if len(self._plugin_service.all_hosts) == 0:
229229
return False
230230

231-
for host in self._plugin_service.hosts:
231+
for host in self._plugin_service.all_hosts:
232232
if host.role == HostRole.READER:
233233
return False
234234

aws_advanced_python_wrapper/aws_secrets_manager_plugin.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from aws_advanced_python_wrapper.utils.messages import Messages
3636
from aws_advanced_python_wrapper.utils.properties import (Properties,
3737
WrapperProperties)
38+
from aws_advanced_python_wrapper.utils.region_utils import RegionUtils
3839
from aws_advanced_python_wrapper.utils.telemetry.telemetry import \
3940
TelemetryTraceLevel
4041

@@ -63,6 +64,7 @@ def __init__(self, plugin_service: PluginService, props: Properties, session: Op
6364
Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter",
6465
WrapperProperties.SECRETS_MANAGER_SECRET_ID.name))
6566

67+
self._region_utils = RegionUtils()
6668
region: str = self._get_rds_region(secret_id, props)
6769

6870
secrets_endpoint = WrapperProperties.SECRETS_MANAGER_ENDPOINT.get(props)
@@ -194,23 +196,22 @@ def _apply_secret_to_properties(self, properties: Properties):
194196
WrapperProperties.PASSWORD.set(properties, self._secret.password)
195197

196198
def _get_rds_region(self, secret_id: str, props: Properties) -> str:
197-
region: Optional[str] = props.get(WrapperProperties.SECRETS_MANAGER_REGION.name)
198-
if not region:
199-
match = search(self._SECRETS_ARN_PATTERN, secret_id)
200-
if match:
201-
region = match.group("region")
202-
else:
203-
raise AwsWrapperError(
204-
Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter",
205-
WrapperProperties.SECRETS_MANAGER_REGION.name))
206-
207199
session = self._session if self._session else boto3.Session()
208-
if region not in session.get_available_regions("rds"):
209-
exception_message = "AwsSdk.UnsupportedRegion"
210-
logger.debug(exception_message, region)
211-
raise AwsWrapperError(Messages.get_formatted(exception_message, region))
200+
region = self._region_utils.get_region(props, WrapperProperties.SECRETS_MANAGER_REGION.name, session=session)
201+
202+
if region:
203+
return region
212204

213-
return region
205+
match = search(self._SECRETS_ARN_PATTERN, secret_id)
206+
if match:
207+
region = match.group("region")
208+
209+
if region:
210+
return self._region_utils.verify_region(region)
211+
else:
212+
raise AwsWrapperError(
213+
Messages.get_formatted("AwsSecretsManagerPlugin.MissingRequiredConfigParameter",
214+
WrapperProperties.SECRETS_MANAGER_REGION.name))
214215

215216

216217
class AwsSecretsManagerPluginFactory(PluginFactory):

0 commit comments

Comments
 (0)