Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 55 additions & 0 deletions aws_advanced_python_wrapper/credentials_provider_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License").
# You may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Dict, Optional, Protocol

import boto3

if TYPE_CHECKING:
from aws_advanced_python_wrapper.utils.properties import Properties

from abc import abstractmethod

from aws_advanced_python_wrapper.utils.properties import WrapperProperties


class CredentialsProviderFactory(Protocol):
@abstractmethod
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
...


class SamlCredentialsProviderFactory(CredentialsProviderFactory):

def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
saml_assertion: str = self.get_saml_assertion(props)
session = boto3.Session()

sts_client = session.client(
'sts',
region_name=region
)

response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
SAMLAssertion=saml_assertion,
)

return response.get('Credentials')

def get_saml_assertion(self, props: Properties):
...
164 changes: 36 additions & 128 deletions aws_advanced_python_wrapper/federated_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,15 @@

from __future__ import annotations

from abc import abstractmethod
from html import unescape
from re import DOTALL, findall, search
from typing import TYPE_CHECKING, List, Protocol
from urllib.parse import urlencode, urlparse
from typing import TYPE_CHECKING, List
from urllib.parse import urlencode

from aws_advanced_python_wrapper.utils.iamutils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.credentials_provider_factory import (
CredentialsProviderFactory, SamlCredentialsProviderFactory)
from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo
from aws_advanced_python_wrapper.utils.saml_utils import SamlUtils

if TYPE_CHECKING:
from boto3 import Session
Expand All @@ -32,7 +34,6 @@
from datetime import datetime, timedelta
from typing import Callable, Dict, Optional, Set

import boto3
import requests

from aws_advanced_python_wrapper.errors import AwsWrapperError
Expand All @@ -58,6 +59,10 @@ def __init__(self, plugin_service: PluginService, credentials_provider_factory:
self._credentials_provider_factory = credentials_provider_factory
self._session = session

telemetry_factory = self._plugin_service.get_telemetry_factory()
self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count")
self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache))

@property
def subscribed_methods(self) -> Set[str]:
return self._SUBSCRIBED_METHODS
Expand All @@ -73,14 +78,15 @@ def connect(
return self._connect(host_info, props, connect_func)

def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callable) -> Connection:
self._check_idp_credentials_with_fallback(props)
SamlUtils.check_idp_credentials_with_fallback(props)

host = IamAuthUtils.get_iam_host(props, host_info)
port = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
region: str = self._get_rds_region(host, props)
region: str = IamAuthUtils.get_rds_region(self._rds_utils, host, props, self._session)

cache_key: str = self._get_cache_key(
WrapperProperties.DB_USER.get(props),
user = WrapperProperties.DB_USER.get(props)
cache_key: str = IamAuthUtils.get_cache_key(
user,
host,
port,
region
Expand All @@ -89,17 +95,17 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl
token_info = FederatedAuthPlugin._token_cache.get(cache_key)

if token_info is not None and not token_info.is_expired():
logger.debug("IamAuthPlugin.UseCachedIamToken", token_info.token)
logger.debug("FederatedAuthPlugin.UseCachedToken", token_info.token)
self._plugin_service.driver_dialect.set_password(props, token_info.token)
else:
self._update_authentication_token(host_info, props, region, cache_key)
self._update_authentication_token(host_info, props, user, region, cache_key)

WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))
WrapperProperties.USER.set(props, WrapperProperties.DB_USER.get(props))

try:
return connect_func()
except Exception:
self._update_authentication_token(host_info, props, region, cache_key)
self._update_authentication_token(host_info, props, user, region, cache_key)

try:
return connect_func()
Expand All @@ -121,77 +127,25 @@ def force_connect(
def _update_authentication_token(self,
host_info: HostInfo,
props: Properties,
user: Optional[str],
region: str,
cache_key: str) -> None:
token_expiration_sec: int = WrapperProperties.IAM_TOKEN_EXPIRATION.get_int(props)
token_expiry: datetime = datetime.now() + timedelta(seconds=token_expiration_sec)
port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port)
credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props)

token: str = self._generate_authentication_token(props, host_info.host, port, region, credentials)
logger.debug("IamAuthPlugin.GeneratedNewIamToken", token)
self._fetch_token_counter.inc()
token: str = IamAuthUtils.generate_authentication_token(
self._plugin_service,
user,
host_info.host,
port,
region,
credentials,
self._session)
WrapperProperties.PASSWORD.set(props, token)
FederatedAuthPlugin._token_cache[token] = TokenInfo(token, token_expiry)

def _get_rds_region(self, hostname: Optional[str], props: Properties) -> str:
rds_region = WrapperProperties.IAM_REGION.get(props)
if rds_region is None or rds_region == "":
rds_region = self._rds_utils.get_rds_region(hostname)

if not rds_region:
error_message = "RdsUtils.UnsupportedHostname"
logger.debug(error_message, hostname)
raise AwsWrapperError(Messages.get_formatted(error_message, hostname))

session = self._session if self._session else boto3.Session()
if rds_region not in session.get_available_regions("rds"):
error_message = "AwsSdk.UnsupportedRegion"
logger.debug(error_message, rds_region)
raise AwsWrapperError(Messages.get_formatted(error_message, rds_region))

return rds_region

def _generate_authentication_token(self,
props: Properties,
host_name: Optional[str],
port: Optional[int],
region: Optional[str],
credentials: Optional[Dict[str, str]]) -> str:
session = self._session if self._session else boto3.Session()

if credentials is not None:
client = session.client(
'rds',
region_name=region,
aws_access_key_id=credentials.get('AccessKeyId'),
aws_secret_access_key=credentials.get('SecretAccessKey'),
aws_session_token=credentials.get('SessionToken')
)
else:
client = session.client(
'rds',
region_name=region
)

user = WrapperProperties.USER.get(props)
token = client.generate_db_auth_token(
DBHostname=host_name,
Port=port,
DBUsername=user
)

client.close()

return token

def _get_cache_key(self, user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str:
return f"{region}:{hostname}:{port}:{user}"

def _check_idp_credentials_with_fallback(self, props: Properties) -> None:
if WrapperProperties.IDP_USERNAME.get(props) is None:
WrapperProperties.IDP_USERNAME.set(props, WrapperProperties.USER.name)
if WrapperProperties.IDP_PASSWORD.get(props) is None:
WrapperProperties.IDP_PASSWORD.set(props, WrapperProperties.PASSWORD.name)
FederatedAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry)


class FederatedAuthPluginFactory(PluginFactory):
Expand All @@ -205,35 +159,6 @@ def get_credentials_provider_factory(self, plugin_service: PluginService, props:
raise AwsWrapperError(Messages.get_formatted("FederatedAuthPluginFactory.UnsupportedIdp", idp_name))


class CredentialsProviderFactory(Protocol):
@abstractmethod
def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
...


class SamlCredentialsProviderFactory(CredentialsProviderFactory):

def get_aws_credentials(self, region: str, props: Properties) -> Optional[Dict[str, str]]:
saml_assertion: str = self.get_saml_assertion(props)
session = boto3.Session()

sts_client = session.client(
'sts',
region_name=region
)

response: Dict[str, Dict[str, str]] = sts_client.assume_role_with_saml(
RoleArn=WrapperProperties.IAM_ROLE_ARN.get(props),
PrincipalArn=WrapperProperties.IAM_IDP_ARN.get(props),
SAMLAssertion=saml_assertion
)

return response.get('Credentials')

def get_saml_assertion(self, props: Properties):
...


class AdfsCredentialsProviderFactory(SamlCredentialsProviderFactory):
_INPUT_TAG_PATTERN = r"<input(.+?)/>"
_FORM_ACTION_PATTERN = r"<form.*?action=\"([^\"]+)\""
Expand Down Expand Up @@ -274,32 +199,22 @@ def get_saml_assertion(self, props: Properties):

def _get_sign_in_page_body(self, url: str, props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPageUrl", url)
self._validate_url(url)
SamlUtils.validate_url(url)
r = requests.get(url,
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))

# Check HTTP Status Code is 2xx Success
if r.status_code / 100 != 2:
error_message = "AdfsCredentialsProviderFactory.SignOnPageRequestFailed"
logger.debug(error_message, r.status_code, r.reason, r.text)
raise AwsWrapperError(Messages.get_formatted(error_message, r.status_code, r.reason, r.text))

SamlUtils.validate_response(r)
return r.text

def _post_form_action_body(self, uri: str, parameters: Dict[str, str], props: Properties) -> str:
logger.debug("AdfsCredentialsProviderFactory.SignOnPagePostActionUrl", uri)
self._validate_url(uri)
SamlUtils.validate_url(uri)
r = requests.post(uri, data=urlencode(parameters),
verify=WrapperProperties.SSL_SECURE.get_bool(props),
timeout=WrapperProperties.HTTP_REQUEST_TIMEOUT.get_int(props))
# Check HTTP Status Code is 2xx Success
if r.status_code / 100 != 2:
error_message = "AdfsCredentialsProviderFactory.SignOnPagePostActionRequestFailed"
logger.debug(error_message, r.status_code, r.reason, r.text)
raise AwsWrapperError(
Messages.get_formatted(error_message, r.status_code, r.reason, r.text))

SamlUtils.validate_response(r)
return r.text

def _get_sign_in_page_url(self, props) -> str:
Expand All @@ -308,7 +223,7 @@ def _get_sign_in_page_url(self, props) -> str:
relaying_party_id = WrapperProperties.RELAYING_PARTY_ID.get(props)
url = f"https://{idp_endpoint}:{idp_port}/adfs/ls/IdpInitiatedSignOn.aspx?loginToRp={relaying_party_id}"
if idp_endpoint is None or relaying_party_id is None:
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(Messages.get_formatted(error_message, url))

Expand All @@ -319,7 +234,7 @@ def _get_form_action_url(self, props: Properties, action: str) -> str:
idp_port = WrapperProperties.IDP_PORT.get(props)
url = f"https://{idp_endpoint}:{idp_port}{action}"
if idp_endpoint is None:
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
error_message = "SamlUtils.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(
Messages.get_formatted(error_message, url))
Expand Down Expand Up @@ -373,10 +288,3 @@ def _get_form_action_from_html_body(self, body: str) -> str:
return unescape(match.group(1))

return ""

def _validate_url(self, url: str) -> None:
result = urlparse(url)
if not result.scheme or not search(self._HTTPS_URL_PATTERN, url):
error_message = "AdfsCredentialsProviderFactory.InvalidHttpsUrl"
logger.debug(error_message, url)
raise AwsWrapperError(Messages.get_formatted(error_message, url))
Loading