From cd1954075bc724a8d7c6387bce1620b82c900220 Mon Sep 17 00:00:00 2001 From: kukushking <> Date: Thu, 6 May 2021 14:34:07 +0100 Subject: [PATCH 1/2] Pass SSL properties from Glue Connection to PyMySQL --- awswrangler/_databases.py | 18 ++++++++++++++++++ awswrangler/mysql.py | 1 + cloudformation/databases.yaml | 26 ++++++++++++++++++++++++++ tests/test_mysql.py | 4 ++++ 4 files changed, 49 insertions(+) diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 6a223d63e..8ae459744 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -1,6 +1,7 @@ """Databases Utilities.""" import logging +import ssl from typing import Any, Dict, Generator, Iterator, List, NamedTuple, Optional, Tuple, Union, cast import boto3 @@ -22,6 +23,7 @@ class ConnectionAttributes(NamedTuple): host: str port: int database: str + ssl_context: Optional[ssl.SSLContext] def _get_dbname(cluster_id: str, boto3_session: Optional[boto3.Session] = None) -> str: @@ -41,6 +43,20 @@ def _get_connection_attributes_from_catalog( else: database_sep = "/" port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split(database_sep) + ssl_context: Optional[ssl.SSLContext] = None + if details.get("JDBC_ENFORCE_SSL") == "true": + ssl_cert_path: Optional[str] = details.get("CUSTOM_JDBC_CERT") + ssl_cadata: Optional[str] = None + if ssl_cert_path: + bucket_name, key_path = _utils.parse_path(ssl_cert_path) + client_s3: boto3.client = _utils.client(service_name="s3", session=boto3_session) + try: + ssl_cadata = client_s3.get_object(Bucket=bucket_name, Key=key_path)["Body"].read().decode("utf-8") + except client_s3.exception.NoSuchKey: + raise exceptions.NoFilesFound( # pylint: disable=raise-missing-from + f"No CA certificate found at {ssl_cert_path}." + ) + ssl_context = ssl.create_default_context(cadata=ssl_cadata) return ConnectionAttributes( kind=details["JDBC_CONNECTION_URL"].split(":")[1].lower(), user=details["USERNAME"], @@ -48,6 +64,7 @@ def _get_connection_attributes_from_catalog( host=details["JDBC_CONNECTION_URL"].split(":")[2].replace("/", ""), port=int(port), database=dbname if dbname is not None else database, + ssl_context=ssl_context, ) @@ -71,6 +88,7 @@ def _get_connection_attributes_from_secrets_manager( host=secret_value["host"], port=secret_value["port"], database=_dbname, + ssl_context=None, ) diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index f6e11f494..763e012c2 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -139,6 +139,7 @@ def connect( read_timeout=read_timeout, write_timeout=write_timeout, connect_timeout=connect_timeout, + ssl=attrs.ssl_context, ) diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index d84412645..ecd2810f7 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -384,6 +384,32 @@ Resources: PASSWORD: Ref: DatabasesPassword Name: aws-data-wrangler-mysql + MysqlGlueConnectionSSL: + Type: AWS::Glue::Connection + Properties: + CatalogId: + Ref: AWS::AccountId + ConnectionInput: + Description: Connect to Aurora (MySQL). + ConnectionType: JDBC + PhysicalConnectionRequirements: + AvailabilityZone: + Fn::Select: + - 0 + - Fn::GetAZs: '' + SecurityGroupIdList: + - Ref: DatabaseSecurityGroup + SubnetId: + Fn::ImportValue: aws-data-wrangler-base-PrivateSubnet + ConnectionProperties: + JDBC_CONNECTION_URL: + Fn::Sub: jdbc:mysql://${AuroraInstanceMysql.Endpoint.Address}:${AuroraInstanceMysql.Endpoint.Port}/test + JDBC_ENFORCE_SSL: true + CUSTOM_JDBC_CERT: s3://rds-downloads/rds-combined-ca-bundle.pem + USERNAME: test + PASSWORD: + Ref: DatabasesPassword + Name: aws-data-wrangler-mysql-ssl SqlServerGlueConnection: Type: AWS::Glue::Connection Properties: diff --git a/tests/test_mysql.py b/tests/test_mysql.py index ea04ff45b..8dab62b50 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -24,6 +24,10 @@ def test_connection(): wr.mysql.connect("aws-data-wrangler-mysql", connect_timeout=10).close() +def test_connection_ssl(): + wr.mysql.connect("aws-data-wrangler-mysql-ssl", connect_timeout=10).close() + + def test_read_sql_query_simple(databases_parameters): con = pymysql.connect( host=databases_parameters["mysql"]["host"], From 0dcccbc6c9cf4244f372e717e5b9c2e31571c074 Mon Sep 17 00:00:00 2001 From: kukushking <3997468+kukushking@users.noreply.github.com> Date: Mon, 10 May 2021 13:29:22 +0100 Subject: [PATCH 2/2] Tests + some houskeeping --- awswrangler/mysql.py | 9 +++++++-- cloudformation/databases.yaml | 2 +- tests/test_mysql.py | 17 +++++++++++++---- 3 files changed, 21 insertions(+), 7 deletions(-) diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 763e012c2..941af8044 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -78,10 +78,15 @@ def connect( write_timeout: Optional[int] = None, connect_timeout: int = 10, ) -> pymysql.connections.Connection: - """Return a pymysql connection from a Glue Catalog Connection. + """Return a pymysql connection from a Glue Catalog Connection or Secrets Manager. https://pymysql.readthedocs.io + Note + ---- + It is only possible to configure SSL using Glue Catalog Connection. More at: + https://docs.aws.amazon.com/glue/latest/dg/connection-defining.html + Parameters ---------- connection : str @@ -136,10 +141,10 @@ def connect( password=attrs.password, port=attrs.port, host=attrs.host, + ssl=attrs.ssl_context, read_timeout=read_timeout, write_timeout=write_timeout, connect_timeout=connect_timeout, - ssl=attrs.ssl_context, ) diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index ecd2810f7..ca698e5df 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -390,7 +390,7 @@ Resources: CatalogId: Ref: AWS::AccountId ConnectionInput: - Description: Connect to Aurora (MySQL). + Description: Connect to Aurora (MySQL) SSL enabled. ConnectionType: JDBC PhysicalConnectionRequirements: AvailabilityZone: diff --git a/tests/test_mysql.py b/tests/test_mysql.py index 8dab62b50..07f21a25a 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -20,12 +20,16 @@ def mysql_con(): con.close() -def test_connection(): - wr.mysql.connect("aws-data-wrangler-mysql", connect_timeout=10).close() +@pytest.fixture(scope="function") +def mysql_con_ssl(): + con = wr.mysql.connect("aws-data-wrangler-mysql-ssl") + yield con + con.close() -def test_connection_ssl(): - wr.mysql.connect("aws-data-wrangler-mysql-ssl", connect_timeout=10).close() +@pytest.mark.parametrize("connection", ["aws-data-wrangler-mysql", "aws-data-wrangler-mysql-ssl"]) +def test_connection(connection): + wr.mysql.connect(connection, connect_timeout=10).close() def test_read_sql_query_simple(databases_parameters): @@ -46,6 +50,11 @@ def test_to_sql_simple(mysql_table, mysql_con): wr.mysql.to_sql(df, mysql_con, mysql_table, "test", "overwrite", True) +def test_to_sql_simple_ssl(mysql_table, mysql_con_ssl): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) + wr.mysql.to_sql(df, mysql_con_ssl, mysql_table, "test", "overwrite", True) + + def test_sql_types(mysql_table, mysql_con): table = mysql_table df = get_df()