From 838634c004922909fc7b44c60b442964b431d82e Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 02:57:27 +0100 Subject: [PATCH 01/19] WIP: Add support for SQL Server --- awswrangler/__init__.py | 2 + awswrangler/_data_types.py | 35 +++++++++ awswrangler/_databases.py | 85 +++++++++++++------- awswrangler/sqlserver.py | 146 +++++++++++++++++++++++++++++++++++ requirements.txt | 1 + tests/_utils.py | 10 +-- tests/conftest.py | 53 ++++++++----- tests/test_athena.py | 2 +- tests/test_athena_parquet.py | 4 +- tests/test_mysql.py | 2 +- tests/test_postgresql.py | 2 +- tests/test_redshift.py | 2 +- tests/test_sqlserver.py | 75 ++++++++++++++++++ 13 files changed, 364 insertions(+), 55 deletions(-) create mode 100644 awswrangler/sqlserver.py create mode 100644 tests/test_sqlserver.py diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index b4603f01a..b0563e94b 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -20,6 +20,7 @@ redshift, s3, secretsmanager, + sqlserver, sts, timestream, ) @@ -40,6 +41,7 @@ "mysql", "postgresql", "secretsmanager", + "sqlserver", "config", "timestream", "__description__", diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index e170df47c..46aa15c5a 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -166,6 +166,41 @@ def pyarrow2postgresql( # pylint: disable=too-many-branches,too-many-return-sta raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}") +def pyarrow2sqlserver( # pylint: disable=too-many-branches,too-many-return-statements + dtype: pa.DataType, string_type: str +) -> str: + """Pyarrow to Microsoft SQL Server data types conversion.""" + if pa.types.is_int8(dtype): + return "SMALLINT" + if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): + return "SMALLINT" + if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): + return "INT" + if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): + return "BIGINT" + if pa.types.is_uint64(dtype): + raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.") + if pa.types.is_float32(dtype): + return "FLOAT(24)" + if pa.types.is_float64(dtype): + return "FLOAT" + if pa.types.is_boolean(dtype): + return "BIT" + if pa.types.is_string(dtype): + return string_type + if pa.types.is_timestamp(dtype): + return "DATETIME2" + if pa.types.is_date(dtype): + return "DATE" + if pa.types.is_decimal(dtype): + return f"DECIMAL({dtype.precision},{dtype.scale})" + if pa.types.is_dictionary(dtype): + return pyarrow2sqlserver(dtype=dtype.value_type, string_type=string_type) + if pa.types.is_binary(dtype): + return "VARBINARY" + raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}") + + def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Pyarrow to Amazon Timestream data types conversion.""" if pa.types.is_int8(dtype): diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 5a1bbb1da..cfef30fdf 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -36,7 +36,11 @@ def _get_connection_attributes_from_catalog( details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[ "ConnectionProperties" ] - port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split("/") + if "databaseName=" in details["JDBC_CONNECTION_URL"]: + database_sep = ";databaseName=" + else: + database_sep = "/" + port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split(database_sep) return ConnectionAttributes( kind=details["JDBC_CONNECTION_URL"].split(":")[1].lower(), user=details["USERNAME"], @@ -136,19 +140,48 @@ def _records2df( return df -def _iterate_cursor( - cursor: Any, +def _get_cols_names(cursor_description: Any) -> List[str]: + cols_names = [col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor_description] + _logger.debug("cols_names: %s", cols_names) + + return cols_names + + +def _iterate_results( + con: Any, + cursor_args: List[Any], chunksize: int, - cols_names: List[str], - index: Optional[Union[str, List[str]]], + index_col: Optional[Union[str, List[str]]], safe: bool, dtype: Optional[Dict[str, pa.DataType]], ) -> Iterator[pd.DataFrame]: - while True: - records = cursor.fetchmany(chunksize) - if not records: - break - yield _records2df(records=records, cols_names=cols_names, index=index, safe=safe, dtype=dtype) + with con.cursor() as cursor: + cursor.execute(*cursor_args) + cols_names = _get_cols_names(cursor.description) + while True: + records = cursor.fetchmany(chunksize) + if not records: + break + yield _records2df(records=records, cols_names=cols_names, index=index_col, safe=safe, dtype=dtype) + + +def _fetch_all_results( + con: Any, + cursor_args: List[Any], + index_col: Optional[Union[str, List[str]]] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> pd.DataFrame: + with con.cursor() as cursor: + cursor.execute(*cursor_args) + cols_names = _get_cols_names(cursor.description) + return _records2df( + records=cast(List[Tuple[Any]], cursor.fetchall()), + cols_names=cols_names, + index=index_col, + dtype=dtype, + safe=safe, + ) def read_sql_query( @@ -163,22 +196,22 @@ def read_sql_query( """Read SQL Query (generic).""" args = _convert_params(sql, params) try: - with con.cursor() as cursor: - cursor.execute(*args) - cols_names: List[str] = [ - col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description - ] - _logger.debug("cols_names: %s", cols_names) - if chunksize is None: - return _records2df( - records=cast(List[Tuple[Any]], cursor.fetchall()), - cols_names=cols_names, - index=index_col, - dtype=dtype, - safe=safe, - ) - return _iterate_cursor( - cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe + if chunksize is None: + return _fetch_all_results( + con=con, + cursor_args=args, + index_col=index_col, + dtype=dtype, + safe=safe, + ) + else: + return _iterate_results( + con=con, + cursor_args=args, + chunksize=chunksize, + index_col=index_col, + dtype=dtype, + safe=safe, ) except Exception as ex: con.rollback() diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py new file mode 100644 index 000000000..d3710f2f8 --- /dev/null +++ b/awswrangler/sqlserver.py @@ -0,0 +1,146 @@ +"""Amazon Microsoft SQL Server Module.""" + + +import logging +from typing import Any, Dict, Iterator, List, Optional, Tuple, Union + +import boto3 +import pandas as pd +import pyarrow as pa +import pymssql + +from awswrangler import _data_types +from awswrangler import _databases as _db_utils +from awswrangler import exceptions + +_logger: logging.Logger = logging.getLogger(__name__) + + +def _validate_connection(con: pymssql.Connection) -> None: + if not isinstance(con, pymssql.Connection): + raise exceptions.InvalidConnection( + "Invalid 'conn' argument, please pass a " + "pymssql.Connection object. Use pymssql.connect() to use " + "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog." + ) + + +def _drop_table(cursor: pymssql.Cursor, schema: Optional[str], table: str) -> None: + schema_str = f"{schema}." if schema else "" + sql = f"IF OBJECT_ID(N'{schema_str}{table}', N'U') IS NOT NULL DROP TABLE {schema_str}{table}" + _logger.debug("Drop table query:\n%s", sql) + cursor.execute(sql) + + +def _does_table_exist(cursor: pymssql.Cursor, schema: Optional[str], table: str) -> bool: + schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" + cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") + return len(cursor.fetchall()) > 0 + + +def _create_table( + df: pd.DataFrame, + cursor: pymssql.Cursor, + table: str, + schema: str, + mode: str, + index: bool, + dtype: Optional[Dict[str, str]], + varchar_lengths: Optional[Dict[str, int]], +) -> None: + if mode == "overwrite": + _drop_table(cursor=cursor, schema=schema, table=table) + elif _does_table_exist(cursor=cursor, schema=schema, table=table): + return + sqlserver_types: Dict[str, str] = _data_types.database_types_from_pandas( + df=df, + index=index, + dtype=dtype, + varchar_lengths_default="VARCHAR(MAX)", + varchar_lengths=varchar_lengths, + converter_func=_data_types.pyarrow2sqlserver, + ) + cols_str: str = "".join([f"{k} {v},\n" for k, v in sqlserver_types.items()])[:-2] + sql = f"IF OBJECT_ID(N'{schema}.{table}', N'U') IS NULL BEGIN CREATE TABLE {schema}.{table} (\n{cols_str}); END;" + _logger.debug("Create table query:\n%s", sql) + cursor.execute(sql) + + +def connect( + connection: Optional[str] = None, + secret_id: Optional[str] = None, + catalog_id: Optional[str] = None, + dbname: Optional[str] = None, + boto3_session: Optional[boto3.Session] = None, + timeout: Optional[int] = 0, + login_timeout: Optional[int] = 60, +) -> pymssql.Connection: + attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes( + connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session + ) + if attrs.kind != "sqlserver": + exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)") + return pymssql.connect( + user=attrs.user, + database=attrs.database, + password=attrs.password, + port=attrs.port, + host=attrs.host, + timeout=timeout, + login_timeout=login_timeout, + ) + + +def read_sql_query( + sql: str, + con: pymssql.Connection, + index_col: Optional[Union[str, List[str]]] = None, + params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, + chunksize: Optional[int] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + _validate_connection(con=con) + return _db_utils.read_sql_query( + sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe + ) + + +def to_sql( + df: pd.DataFrame, + con: pymssql.Connection, + table: str, + schema: str, + mode: str = "append", + index: bool = False, + dtype: Optional[Dict[str, str]] = None, + varchar_lengths: Optional[Dict[str, int]] = None, +) -> None: + if df.empty is True: + raise exceptions.EmptyDataFrame() + _validate_connection(con=con) + try: + with con.cursor() as cursor: + _create_table( + df=df, + cursor=cursor, + table=table, + schema=schema, + mode=mode, + index=index, + dtype=dtype, + varchar_lengths=varchar_lengths, + ) + if index: + df.reset_index(level=df.index.names, inplace=True) + placeholders: str = ", ".join(["%s"] * len(df.columns)) + sql: str = f'INSERT INTO "{schema}"."{table}" VALUES ({placeholders})' + _logger.debug("sql: %s", sql) + parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) + parameter_tuples: List[Tuple[Any]] = [tuple(parameter_set) for parameter_set in parameters] + cursor.executemany(sql, parameter_tuples) + con.commit() + except Exception as ex: + con.rollback() + _logger.error(ex) + raise diff --git a/requirements.txt b/requirements.txt index 6417ea8c0..9bd3b99b0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ pyarrow~=2.0.0 redshift-connector~=2.0.0 pymysql>=0.9.0,<0.11.0 pg8000~=1.16.0 +pymssql~=2.1.5 diff --git a/tests/_utils.py b/tests/_utils.py index b51dbcbd4..85df69484 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -25,7 +25,7 @@ def get_df(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -56,7 +56,7 @@ def get_df_list(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -90,7 +90,7 @@ def get_df_cast(): "iint32": [None, None, None], "iint64": [None, None, None], "float": [None, None, None], - "double": [None, None, None], + "ddouble": [None, None, None], "decimal": [None, None, None], "string": [None, None, None], "date": [None, None, dt("2020-01-02")], @@ -201,7 +201,7 @@ def get_df_quicksight(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -425,7 +425,7 @@ def ensure_data_types(df, has_list=False): assert str(df["iint32"].dtype).startswith("Int") assert str(df["iint64"].dtype) == "Int64" assert str(df["float"].dtype).startswith("float") - assert str(df["double"].dtype) == "float64" + assert str(df["ddouble"].dtype) == "float64" assert str(df["decimal"].dtype) == "object" if "string_object" in df.columns: assert str(df["string_object"].dtype) == "string" diff --git a/tests/conftest.py b/tests/conftest.py index 39eaef776..a1ac8f9a7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,24 +128,29 @@ def workgroup3(bucket, kms_key): @pytest.fixture(scope="session") -def databases_parameters(cloudformation_outputs): - parameters = dict(postgresql={}, mysql={}, redshift={}) - parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] - parameters["postgresql"]["port"] = 3306 - parameters["postgresql"]["schema"] = "public" - parameters["postgresql"]["database"] = "postgres" - parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] - parameters["mysql"]["port"] = 3306 - parameters["mysql"]["schema"] = "test" - parameters["mysql"]["database"] = "test" - parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] - parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] - parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] - parameters["redshift"]["schema"] = "public" - parameters["redshift"]["database"] = "test" - parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] - parameters["password"] = cloudformation_outputs["DatabasesPassword"] - parameters["user"] = "test" +def databases_parameters(): + parameters = dict(postgresql={}, mysql={}, redshift={}, sqlserver={}) + # parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] + # parameters["postgresql"]["port"] = 3306 + # parameters["postgresql"]["schema"] = "public" + # parameters["postgresql"]["database"] = "postgres" + # parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] + # parameters["mysql"]["port"] = 3306 + # parameters["mysql"]["schema"] = "test" + # parameters["mysql"]["database"] = "test" + # parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] + # parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] + # parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] + # parameters["redshift"]["schema"] = "public" + # parameters["redshift"]["database"] = "test" + # parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] + # parameters["password"] = cloudformation_outputs["DatabasesPassword"] + # parameters["user"] = "test" + parameters["sqlserver"]["host"] = "my-sql-server.cv9de6ia0cf2.us-east-1.rds.amazonaws.com" + parameters["sqlserver"]["port"] = "1433" + parameters["sqlserver"]["database"] = "TestDb" + parameters["user"] = "admin" + parameters["password"] = "123456Ab" return parameters @@ -236,6 +241,18 @@ def mysql_table(): con.close() +@pytest.fixture(scope="function") +def sqlserver_table(): + name = f"tbl_{get_time_str_with_random_suffix()}" + print(f"Table name: {name}") + yield name + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + with con.cursor() as cursor: + cursor.execute(f"IF OBJECT_ID(N'dbo.{name}', N'U') IS NOT NULL DROP TABLE dbo.{name}") + con.commit() + con.close() + + @pytest.fixture(scope="function") def timestream_database_and_table(): name = f"tbl_{get_time_str_with_random_suffix()}" diff --git a/tests/test_athena.py b/tests/test_athena.py index b2fa849ca..87fd72561 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -153,7 +153,7 @@ def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1 f" `iint32` int," f" `iint64` bigint," f" `float` float," - f" `double` double," + f" `ddouble` double," f" `decimal` decimal(3,2)," f" `string_object` string," f" `string` string," diff --git a/tests/test_athena_parquet.py b/tests/test_athena_parquet.py index 348492b34..cfde1c760 100644 --- a/tests/test_athena_parquet.py +++ b/tests/test_athena_parquet.py @@ -112,7 +112,7 @@ def test_parquet_catalog_casting(path, glue_database): "iint32": "int", "iint64": "bigint", "float": "float", - "double": "double", + "ddouble": "double", "decimal": "decimal(3,2)", "string": "string", "date": "date", @@ -408,7 +408,7 @@ def test_parquet_catalog_casting_to_string(path, glue_table, glue_database): "iint32": "string", "iint64": "string", "float": "string", - "double": "string", + "ddouble": "string", "decimal": "string", "string": "string", "date": "string", diff --git a/tests/test_mysql.py b/tests/test_mysql.py index b7746e354..6cccbcfb6 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -63,7 +63,7 @@ def test_sql_types(mysql_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0d00e0a4e..c697a3685 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -63,7 +63,7 @@ def test_sql_types(postgresql_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_redshift.py b/tests/test_redshift.py index 528ac38ad..160fa42dc 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -67,7 +67,7 @@ def test_sql_types(redshift_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py new file mode 100644 index 000000000..aff147c41 --- /dev/null +++ b/tests/test_sqlserver.py @@ -0,0 +1,75 @@ +import logging + +import pandas as pd +import pyarrow as pa +import pymssql + +import awswrangler as wr + +from ._utils import ensure_data_types, get_df + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +def test_connection(): + wr.sqlserver.connect("aws-data-wrangler-sqlserver", timeout=10).close() + + +def test_read_sql_query_simple(databases_parameters): + con = pymssql.connect( + host=databases_parameters["sqlserver"]["host"], + port=int(databases_parameters["sqlserver"]["port"]), + database=databases_parameters["sqlserver"]["database"], + user=databases_parameters["user"], + password=databases_parameters["password"], + ) + df = wr.sqlserver.read_sql_query("SELECT 1", con=con) + con.close() + assert df.shape == (1, 1) + + +def test_to_sql_simple(sqlserver_table): + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) + wr.sqlserver.to_sql(df, con, sqlserver_table, "dbo", "overwrite", True) + con.close() + + +def test_sql_types(sqlserver_table): + table = sqlserver_table + df = get_df() + df.drop(["binary"], axis=1, inplace=True) + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + wr.sqlserver.to_sql( + df=df, + con=con, + table=table, + schema="dbo", + mode="overwrite", + index=True, + dtype={"iint32": "INTEGER"}, + ) + df = wr.sqlserver.read_sql_query(f"SELECT * FROM dbo.{table}", con) + ensure_data_types(df, has_list=False) + dfs = wr.sqlserver.read_sql_query( + sql=f"SELECT * FROM dbo.{table}", + con=con, + chunksize=1, + dtype={ + "iint8": pa.int8(), + "iint16": pa.int16(), + "iint32": pa.int32(), + "iint64": pa.int64(), + "float": pa.float32(), + "ddouble": pa.float64(), + "decimal": pa.decimal128(3, 2), + "string_object": pa.string(), + "string": pa.string(), + "date": pa.date32(), + "timestamp": pa.timestamp(unit="ns"), + "binary": pa.binary(), + "category": pa.float64(), + }, + ) + for df in dfs: + ensure_data_types(df, has_list=False) From 30c1f032a8ccb8861e05687efc7151441ad060d5 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 11:46:36 +0100 Subject: [PATCH 02/19] WIP: SQL Server feature complete --- awswrangler/sqlserver.py | 33 +++++++++-- tests/test_sqlserver.py | 115 ++++++++++++++++++++++++++++++++++++++- 2 files changed, 141 insertions(+), 7 deletions(-) diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index d3710f2f8..2e283dd78 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -25,9 +25,15 @@ def _validate_connection(con: pymssql.Connection) -> None: ) +def _get_table_identifier(schema: Optional[str], table: str) -> str: + schema_str = f'"{schema}".' if schema else '' + table_identifier = f'{schema_str}"{table}"' + return table_identifier + + def _drop_table(cursor: pymssql.Cursor, schema: Optional[str], table: str) -> None: - schema_str = f"{schema}." if schema else "" - sql = f"IF OBJECT_ID(N'{schema_str}{table}', N'U') IS NOT NULL DROP TABLE {schema_str}{table}" + table_identifier = _get_table_identifier(schema, table) + sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) @@ -61,7 +67,8 @@ def _create_table( converter_func=_data_types.pyarrow2sqlserver, ) cols_str: str = "".join([f"{k} {v},\n" for k, v in sqlserver_types.items()])[:-2] - sql = f"IF OBJECT_ID(N'{schema}.{table}', N'U') IS NULL BEGIN CREATE TABLE {schema}.{table} (\n{cols_str}); END;" + table_identifier = _get_table_identifier(schema, table) + sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NULL BEGIN CREATE TABLE {table_identifier} (\n{cols_str}); END;" _logger.debug("Create table query:\n%s", sql) cursor.execute(sql) @@ -106,6 +113,23 @@ def read_sql_query( ) +def read_sql_table( + table: str, + con: pymssql.Connection, + schema: Optional[str] = None, + index_col: Optional[Union[str, List[str]]] = None, + params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, + chunksize: Optional[int] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + table_identifier = _get_table_identifier(schema, table) + sql: str = f"SELECT * FROM {table_identifier}" + return read_sql_query( + sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe + ) + + def to_sql( df: pd.DataFrame, con: pymssql.Connection, @@ -134,7 +158,8 @@ def to_sql( if index: df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) - sql: str = f'INSERT INTO "{schema}"."{table}" VALUES ({placeholders})' + table_identifier = _get_table_identifier(schema, table) + sql: str = f'INSERT INTO {table_identifier} VALUES ({placeholders})' _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) parameter_tuples: List[Tuple[Any]] = [tuple(parameter_set) for parameter_set in parameters] diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py index aff147c41..226054c40 100644 --- a/tests/test_sqlserver.py +++ b/tests/test_sqlserver.py @@ -1,18 +1,21 @@ import logging +from decimal import Decimal import pandas as pd import pyarrow as pa import pymssql +import pytest import awswrangler as wr from ._utils import ensure_data_types, get_df logging.getLogger("awswrangler").setLevel(logging.DEBUG) +_GLUE_CONNECTION_NAME = "aws-data-wrangler-sqlserver" def test_connection(): - wr.sqlserver.connect("aws-data-wrangler-sqlserver", timeout=10).close() + wr.sqlserver.connect(_GLUE_CONNECTION_NAME, timeout=10).close() def test_read_sql_query_simple(databases_parameters): @@ -29,7 +32,7 @@ def test_read_sql_query_simple(databases_parameters): def test_to_sql_simple(sqlserver_table): - con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + con = wr.sqlserver.connect(_GLUE_CONNECTION_NAME) df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) wr.sqlserver.to_sql(df, con, sqlserver_table, "dbo", "overwrite", True) con.close() @@ -39,7 +42,7 @@ def test_sql_types(sqlserver_table): table = sqlserver_table df = get_df() df.drop(["binary"], axis=1, inplace=True) - con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + con = wr.sqlserver.connect(_GLUE_CONNECTION_NAME) wr.sqlserver.to_sql( df=df, con=con, @@ -73,3 +76,109 @@ def test_sql_types(sqlserver_table): ) for df in dfs: ensure_data_types(df, has_list=False) + + +def test_to_sql_cast(sqlserver_table): + table = sqlserver_table + df = pd.DataFrame( + { + "col": [ + "".join([str(i)[-1] for i in range(1_024)]), + "".join([str(i)[-1] for i in range(1_024)]), + "".join([str(i)[-1] for i in range(1_024)]), + ] + }, + dtype="string", + ) + con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) + wr.sqlserver.to_sql( + df=df, + con=con, + table=table, + schema="dbo", + mode="overwrite", + index=False, + dtype={"col": "VARCHAR(1024)"}, + ) + df2 = wr.sqlserver.read_sql_query(sql=f"SELECT * FROM dbo.{table}", con=con) + assert df.equals(df2) + con.close() + + +def test_null(sqlserver_table): + table = sqlserver_table + con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) + df = pd.DataFrame({"id": [1, 2, 3], "nothing": [None, None, None]}) + wr.sqlserver.to_sql( + df=df, + con=con, + table=table, + schema="dbo", + mode="overwrite", + index=False, + dtype={"nothing": "INTEGER"}, + ) + wr.sqlserver.to_sql( + df=df, + con=con, + table=table, + schema="dbo", + mode="append", + index=False, + ) + df2 = wr.sqlserver.read_sql_table(table=table, schema="dbo", con=con) + df["id"] = df["id"].astype("Int64") + assert pd.concat(objs=[df, df], ignore_index=True).equals(df2) + con.close() + + +def test_decimal_cast(sqlserver_table): + table = sqlserver_table + df = pd.DataFrame( + { + "col0": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + "col1": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + "col2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + } + ) + con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) + wr.sqlserver.to_sql(df, con, table, "dbo") + df2 = wr.sqlserver.read_sql_table( + schema="dbo", table=table, con=con, dtype={"col0": "float32", "col1": "float64", "col2": "Int64"} + ) + assert df2.dtypes.to_list() == ["float32", "float64", "Int64"] + assert 3.88 <= df2.col0.sum() <= 3.89 + assert 3.88 <= df2.col1.sum() <= 3.89 + assert df2.col2.sum() == 2 + con.close() + + +def test_read_retry(): + con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) + try: + wr.sqlserver.read_sql_query("ERROR", con) + except: # noqa + pass + df = wr.sqlserver.read_sql_query("SELECT 1", con) + assert df.shape == (1, 1) + con.close() + + +def test_table_name(): + df = pd.DataFrame({"col0": [1]}) + con = wr.sqlserver.connect(connection="aws-data-wrangler-sqlserver") + wr.sqlserver.to_sql(df, con, "Test Name", "dbo", mode="overwrite") + df = wr.sqlserver.read_sql_table(schema="dbo", con=con, table="Test Name") + assert df.shape == (1, 1) + with con.cursor() as cursor: + cursor.execute('DROP TABLE "Test Name"') + con.commit() + con.close() + + +@pytest.mark.parametrize("dbname", [None, "sqlserver"]) +def test_connect_secret_manager(dbname): + con = wr.sqlserver.connect(secret_id="aws-data-wrangler/sqlserver", dbname=dbname) + df = wr.sqlserver.read_sql_query("SELECT 1", con=con) + con.close() + assert df.shape == (1, 1) From 0ecadbc2a4f82f16c4234eb3b3f83d537e8f4983 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 11:46:53 +0100 Subject: [PATCH 03/19] WIP: Adapt databases cfn template --- cloudformation/databases.yaml | 80 ++++++++++++++++++++++++++++++++--- 1 file changed, 74 insertions(+), 6 deletions(-) diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index 0590c622b..b3a08f285 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -136,7 +136,7 @@ Resources: SubnetIds: - Fn::ImportValue: aws-data-wrangler-base-PublicSubnet1 - Fn::ImportValue: aws-data-wrangler-base-PublicSubnet2 - AuroraRole: + RdsRole: Type: AWS::IAM::Role Properties: Tags: @@ -205,7 +205,7 @@ Resources: - FeatureName: s3Import RoleArn: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraInstancePostgresql: Type: AWS::RDS::DBInstance @@ -234,15 +234,15 @@ Resources: Parameters: aurora_load_from_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn aws_default_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn aurora_select_into_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraClusterMysql: Type: AWS::RDS::DBCluster @@ -268,7 +268,7 @@ Resources: AssociatedRoles: - RoleArn: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraInstanceMysql: Type: AWS::RDS::DBInstance @@ -286,6 +286,25 @@ Resources: DBSubnetGroupName: Ref: RdsSubnetGroup PubliclyAccessible: true + SqlServerInstance: + Type: AWS::RDS::DBInstance + DeletionPolicy: Delete + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + Engine: sqlserver-ex + EngineVersion: '15.00' + DBInstanceIdentifier: sqlserver-instance-wrangler + DBInstanceClass: db.t3.small + DBSubnetGroupName: + Ref: RdsSubnetGroup + PubliclyAccessible: true + AssociatedRoles: + - RoleArn: + Fn::GetAtt: + - RdsRole + - Arn RedshiftGlueConnection: Type: AWS::Glue::Connection Properties: @@ -358,6 +377,30 @@ Resources: PASSWORD: Ref: DatabasesPassword Name: aws-data-wrangler-mysql + SqlServerGlueConnection: + Type: AWS::Glue::Connection + Properties: + CatalogId: + Ref: AWS::AccountId + ConnectionInput: + Description: Connect to SQL Server. + ConnectionType: JDBC + PhysicalConnectionRequirements: + AvailabilityZone: + Fn::Select: + - 0 + - Fn::GetAZs: '' + SecurityGroupIdList: + - Ref: DatabaseSecurityGroup + SubnetId: + Fn::ImportValue: aws-data-wrangler-base-PrivateSubnet + ConnectionProperties: + JDBC_CONNECTION_URL: + Fn::Sub: jdbc:mysql://${SqlServerInstance.Endpoint.Address}:${SqlServerInstance.Endpoint.Port};databaseName=test + USERNAME: test + PASSWORD: + Ref: DatabasesPassword + Name: aws-data-wrangler-sqlserver GlueCatalogSettings: Type: AWS::Glue::DataCatalogEncryptionSettings Properties: @@ -426,6 +469,25 @@ Resources: Tags: - Key: Env Value: aws-data-wrangler + SqlServerSecret: + Type: AWS::SecretsManager::Secret + Properties: + Name: aws-data-wrangler/sqlserver + Description: SQL Server credentials + SecretString: + Fn::Sub: | + { + "username": "test", + "password": "${DatabasesPassword}", + "engine": "sqlserver", + "host": "${SqlServerInstance.Endpoint.Address}", + "port": ${SqlServerInstance.Endpoint.Port}, + "dbClusterIdentifier": "${SqlServerInstance}", + "dbname": "test" + } + Tags: + - Key: Env + Value: aws-data-wrangler Outputs: DatabasesPassword: Value: @@ -464,6 +526,12 @@ Outputs: - AuroraInstanceMysql - Endpoint.Address Description: Mysql Address + SqlServerAddress: + Value: + Fn::GetAtt: + - SqlServerInstance + - Endpoint.Address + Description: SQL Server Address DatabaseSecurityGroupId: Value: Fn::GetAtt: From 0e029731d0ed27f3576dd002b26f46b48218199f Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 12:33:27 +0100 Subject: [PATCH 04/19] WIP: Add docstrings and formatting --- .pylintrc | 2 +- awswrangler/_databases.py | 18 ++-- awswrangler/postgresql.py | 2 +- awswrangler/sqlserver.py | 184 +++++++++++++++++++++++++++++++++- cloudformation/databases.yaml | 7 +- tests/conftest.py | 29 +++--- 6 files changed, 211 insertions(+), 31 deletions(-) diff --git a/.pylintrc b/.pylintrc index daa1c3241..9585354aa 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,7 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. -extension-pkg-whitelist=pyarrow.lib +extension-pkg-whitelist=pyarrow.lib,pymssql # Specify a score threshold to be exceeded before program exits with error. fail-under=10 diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index cfef30fdf..7f327388e 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -204,15 +204,15 @@ def read_sql_query( dtype=dtype, safe=safe, ) - else: - return _iterate_results( - con=con, - cursor_args=args, - chunksize=chunksize, - index_col=index_col, - dtype=dtype, - safe=safe, - ) + + return _iterate_results( + con=con, + cursor_args=args, + chunksize=chunksize, + index_col=index_col, + dtype=dtype, + safe=safe, + ) except Exception as ex: con.rollback() _logger.error(ex) diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index d84dde3aa..1dd294f47 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -310,7 +310,7 @@ def to_sql( >>> import awswrangler as wr >>> con = wr.postgresql.connect("MY_GLUE_CONNECTION") >>> wr.postgresql.to_sql( - ... df=df + ... df=df, ... table="my_table", ... schema="public", ... con=con diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index 2e283dd78..77c5e81e4 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -26,7 +26,7 @@ def _validate_connection(con: pymssql.Connection) -> None: def _get_table_identifier(schema: Optional[str], table: str) -> str: - schema_str = f'"{schema}".' if schema else '' + schema_str = f'"{schema}".' if schema else "" table_identifier = f'{schema_str}"{table}"' return table_identifier @@ -68,7 +68,9 @@ def _create_table( ) cols_str: str = "".join([f"{k} {v},\n" for k, v in sqlserver_types.items()])[:-2] table_identifier = _get_table_identifier(schema, table) - sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NULL BEGIN CREATE TABLE {table_identifier} (\n{cols_str}); END;" + sql = ( + f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NULL BEGIN CREATE TABLE {table_identifier} (\n{cols_str}); END;" + ) _logger.debug("Create table query:\n%s", sql) cursor.execute(sql) @@ -82,6 +84,50 @@ def connect( timeout: Optional[int] = 0, login_timeout: Optional[int] = 60, ) -> pymssql.Connection: + """Return a pymssql connection from a Glue Catalog Connection. + + https://github.com/pymssql/pymssql + + Parameters + ---------- + connection : Optional[str] + Glue Catalog Connection name. + secret_id: Optional[str]: + Specifies the secret containing the version that you want to retrieve. + You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret. + catalog_id : str, optional + The ID of the Data Catalog. + If none is provided, the AWS account ID is used by default. + dbname: Optional[str] + Optional database name to overwrite the stored one. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + timeout: Optional[int] + This is the time in seconds before the connection to the server will time out. + The default is None which means no timeout. + This parameter is forwarded to pymssql. + https://pymssql.readthedocs.io/en/latest/ref/pymssql.html + login_timeout: Optional[int] + This is the time in seconds that the connection and login may take before it times out. + The default is 60 seconds. + This parameter is forwarded to pymssql. + https://pymssql.readthedocs.io/en/latest/ref/pymssql.html + + Returns + ------- + pymssql.Connection + pymssql connection. + + Examples + -------- + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> with con.cursor() as cursor: + >>> cursor.execute("SELECT 1") + >>> print(cursor.fetchall()) + >>> con.close() + + """ attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) @@ -107,6 +153,46 @@ def read_sql_query( dtype: Optional[Dict[str, pa.DataType]] = None, safe: bool = True, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Return a DataFrame corresponding to the result set of the query string. + + Parameters + ---------- + sql : str + SQL query. + con : pymssql.Connection + Use pymssql.connect() to use " + "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + index_col : Union[str, List[str]], optional + Column(s) to set as index(MultiIndex). + params : Union[List, Tuple, Dict], optional + List of parameters to pass to execute method. + The syntax used to pass parameters is database driver dependent. + Check your database driver documentation for which of the five syntax styles, + described in PEP 249’s paramstyle, is supported. + chunksize : int, optional + If specified, return an iterator where chunksize is the number of rows to include in each chunk. + dtype : Dict[str, pyarrow.DataType], optional + Specifying the datatype for columns. + The keys should be the column names and the values should be the PyArrow types. + safe : bool + Check for overflows or other unsafe data type conversions. + + Returns + ------- + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] + Result as Pandas DataFrame(s). + + Examples + -------- + Reading from Microsoft SQL Server using a Glue Catalog Connections + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> df = wr.sqlserver.read_sql_query( + ... sql="SELECT * FROM dbo.my_table", + ... con=con + ... ) + >>> con.close() + """ _validate_connection(con=con) return _db_utils.read_sql_query( sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe @@ -123,6 +209,51 @@ def read_sql_table( dtype: Optional[Dict[str, pa.DataType]] = None, safe: bool = True, ) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Return a DataFrame corresponding the table. + + Parameters + ---------- + table : str + Table name. + con : pymssql.Connection + Use pymssql.connect() to use " + "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + schema : str, optional + Name of SQL schema in database to query (if database flavor supports this). + Uses default schema if None (default). + index_col : Union[str, List[str]], optional + Column(s) to set as index(MultiIndex). + params : Union[List, Tuple, Dict], optional + List of parameters to pass to execute method. + The syntax used to pass parameters is database driver dependent. + Check your database driver documentation for which of the five syntax styles, + described in PEP 249’s paramstyle, is supported. + chunksize : int, optional + If specified, return an iterator where chunksize is the number of rows to include in each chunk. + dtype : Dict[str, pyarrow.DataType], optional + Specifying the datatype for columns. + The keys should be the column names and the values should be the PyArrow types. + safe : bool + Check for overflows or other unsafe data type conversions. + + Returns + ------- + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] + Result as Pandas DataFrame(s). + + Examples + -------- + Reading from Microsoft SQL Server using a Glue Catalog Connections + + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> df = wr.sqlserver.read_sql_table( + ... table="my_table", + ... schema="dbo", + ... con=con + ... ) + >>> con.close() + """ table_identifier = _get_table_identifier(schema, table) sql: str = f"SELECT * FROM {table_identifier}" return read_sql_query( @@ -140,6 +271,51 @@ def to_sql( dtype: Optional[Dict[str, str]] = None, varchar_lengths: Optional[Dict[str, int]] = None, ) -> None: + """Write records stored in a DataFrame into Microsoft SQL Server. + + Parameters + ---------- + df : pandas.DataFrame + Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html + con : pymssql.Connection + Use pymssql.connect() to use " + "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + table : str + Table name + schema : str + Schema name + mode : str + Append or overwrite. + index : bool + True to store the DataFrame index as a column in the table, + otherwise False to ignore it. + dtype: Dict[str, str], optional + Dictionary of columns names and Microsoft SQL Server types to be casted. + Useful when you have columns with undetermined or mixed data types. + (e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'}) + varchar_lengths : Dict[str, int], optional + Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + + Returns + ------- + None + None. + + Examples + -------- + Writing to Microsoft SQL Server using a Glue Catalog Connections + + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> wr.sqlserver.to_sql( + ... df=df, + ... table="table", + ... schema="dbo", + ... con=con + ... ) + >>> con.close() + + """ if df.empty is True: raise exceptions.EmptyDataFrame() _validate_connection(con=con) @@ -159,10 +335,10 @@ def to_sql( df.reset_index(level=df.index.names, inplace=True) placeholders: str = ", ".join(["%s"] * len(df.columns)) table_identifier = _get_table_identifier(schema, table) - sql: str = f'INSERT INTO {table_identifier} VALUES ({placeholders})' + sql: str = f"INSERT INTO {table_identifier} VALUES ({placeholders})" _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) - parameter_tuples: List[Tuple[Any]] = [tuple(parameter_set) for parameter_set in parameters] + parameter_tuples: List[Tuple[Any, ...]] = [tuple(parameter_set) for parameter_set in parameters] cursor.executemany(sql, parameter_tuples) con.commit() except Exception as ex: diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index b3a08f285..63e2bc423 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -297,14 +297,19 @@ Resources: EngineVersion: '15.00' DBInstanceIdentifier: sqlserver-instance-wrangler DBInstanceClass: db.t3.small + AllocatedStorage: '20' + MasterUsername: test + MasterUserPassword: + Ref: DatabasesPassword DBSubnetGroupName: Ref: RdsSubnetGroup - PubliclyAccessible: true + PubliclyAccessible: true AssociatedRoles: - RoleArn: Fn::GetAtt: - RdsRole - Arn + FeatureName: S3_INTEGRATION RedshiftGlueConnection: Type: AWS::Glue::Connection Properties: diff --git a/tests/conftest.py b/tests/conftest.py index a1ac8f9a7..ae01f97e9 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -130,27 +130,26 @@ def workgroup3(bucket, kms_key): @pytest.fixture(scope="session") def databases_parameters(): parameters = dict(postgresql={}, mysql={}, redshift={}, sqlserver={}) - # parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] - # parameters["postgresql"]["port"] = 3306 - # parameters["postgresql"]["schema"] = "public" - # parameters["postgresql"]["database"] = "postgres" - # parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] - # parameters["mysql"]["port"] = 3306 - # parameters["mysql"]["schema"] = "test" - # parameters["mysql"]["database"] = "test" + parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] + parameters["postgresql"]["port"] = 3306 + parameters["postgresql"]["schema"] = "public" + parameters["postgresql"]["database"] = "postgres" + parameters["mysql"]["host"] = cloudformation_outputs["MysqlAddress"] + parameters["mysql"]["port"] = 3306 + parameters["mysql"]["schema"] = "test" + parameters["mysql"]["database"] = "test" # parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] # parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] # parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] # parameters["redshift"]["schema"] = "public" # parameters["redshift"]["database"] = "test" # parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] - # parameters["password"] = cloudformation_outputs["DatabasesPassword"] - # parameters["user"] = "test" - parameters["sqlserver"]["host"] = "my-sql-server.cv9de6ia0cf2.us-east-1.rds.amazonaws.com" - parameters["sqlserver"]["port"] = "1433" - parameters["sqlserver"]["database"] = "TestDb" - parameters["user"] = "admin" - parameters["password"] = "123456Ab" + parameters["password"] = cloudformation_outputs["DatabasesPassword"] + parameters["user"] = "test" + parameters["sqlserver"]["host"] = cloudformation_outputs["SqlServerAddress"] + parameters["sqlserver"]["port"] = 1433 + parameters["sqlserver"]["schema"] = "dbo" + parameters["sqlserver"]["database"] = "test" return parameters From 2f749a9c21ad74669099c2d432fd0414c95015cf Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 13:23:53 +0100 Subject: [PATCH 05/19] Fix raising of exceptions --- awswrangler/_config.py | 4 +++- awswrangler/mysql.py | 2 +- awswrangler/postgresql.py | 4 +++- awswrangler/redshift.py | 4 +++- awswrangler/sqlserver.py | 4 +++- cloudformation/databases.yaml | 2 +- tests/conftest.py | 2 +- 7 files changed, 15 insertions(+), 7 deletions(-) diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 23703cdde..859d60e0a 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -154,7 +154,9 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla if _Config._is_null(value=value): if nullable is True: return None - exceptions.InvalidArgumentValue(f"{name} configuration does not accept a null value. Please pass {dtype}.") + raise exceptions.InvalidArgumentValue( + f"{name} configuration does not accept a null value. Please pass {dtype}." + ) try: return dtype(value) if isinstance(value, dtype) is False else value except ValueError as ex: diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 154aa762b..a3ff4c91b 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -127,7 +127,7 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "mysql": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)") + raise exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)") return pymysql.connect( user=attrs.user, database=attrs.database, diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index 1dd294f47..ffeb893f5 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -131,7 +131,9 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "postgresql": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)") + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)" + ) return pg8000.connect( user=attrs.user, database=attrs.database, diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 323fb6216..71d6f3336 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -386,7 +386,9 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "redshift": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)") + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)" + ) return redshift_connector.connect( user=attrs.user, database=attrs.database, diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index 77c5e81e4..43aa53ad7 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -132,7 +132,9 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "sqlserver": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)") + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)" + ) return pymssql.connect( user=attrs.user, database=attrs.database, diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index 63e2bc423..95f35a398 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -401,7 +401,7 @@ Resources: Fn::ImportValue: aws-data-wrangler-base-PrivateSubnet ConnectionProperties: JDBC_CONNECTION_URL: - Fn::Sub: jdbc:mysql://${SqlServerInstance.Endpoint.Address}:${SqlServerInstance.Endpoint.Port};databaseName=test + Fn::Sub: jdbc:sqlserver://${SqlServerInstance.Endpoint.Address}:${SqlServerInstance.Endpoint.Port};databaseName=test USERNAME: test PASSWORD: Ref: DatabasesPassword diff --git a/tests/conftest.py b/tests/conftest.py index ae01f97e9..26b1ad9ad 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -128,7 +128,7 @@ def workgroup3(bucket, kms_key): @pytest.fixture(scope="session") -def databases_parameters(): +def databases_parameters(cloudformation_outputs): parameters = dict(postgresql={}, mysql={}, redshift={}, sqlserver={}) parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] parameters["postgresql"]["port"] = 3306 From e58b00699a3eab70d1b3fd101d5cf8613ba44c6b Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 14:15:02 +0100 Subject: [PATCH 06/19] Adapt README and documentation --- README.md | 3 +- docs/source/api.rst | 14 ++++++++ ...hift, MySQL, PostgreSQL, SQL Server.ipynb} | 32 +++++++++---------- 3 files changed, 31 insertions(+), 18 deletions(-) rename tutorials/{007 - Redshift, MySQL, PostgreSQL.ipynb => 007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb} (85%) diff --git a/README.md b/README.md index 083c03dd7..fd09fda20 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3 - [004 - Parquet Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/004%20-%20Parquet%20Datasets.ipynb) - [005 - Glue Catalog](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/005%20-%20Glue%20Catalog.ipynb) - [006 - Amazon Athena](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/006%20-%20Amazon%20Athena.ipynb) - - [007 - Databases (Redshift, MySQL and PostgreSQL)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL.ipynb) + - [007 - Databases (Redshift, MySQL, PostgreSQL and SQL Server)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL%2C%20SQL%20Server.ipynb) - [008 - Redshift - Copy & Unload.ipynb](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/008%20-%20Redshift%20-%20Copy%20%26%20Unload.ipynb) - [009 - Redshift - Append, Overwrite and Upsert](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/009%20-%20Redshift%20-%20Append%2C%20Overwrite%2C%20Upsert.ipynb) - [010 - Parquet Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/010%20-%20Parquet%20Crawler.ipynb) @@ -134,6 +134,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3 - [Amazon Redshift](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-redshift) - [PostgreSQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#postgresql) - [MySQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#mysql) + - [SQL Server](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#sqlserver) - [DynamoDB](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#dynamodb) - [Amazon Timestream](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-timestream) - [Amazon EMR](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-emr) diff --git a/docs/source/api.rst b/docs/source/api.rst index d3bdb6df6..83f42d789 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -7,6 +7,7 @@ API Reference * `Amazon Redshift`_ * `PostgreSQL`_ * `MySQL`_ +* `Microsoft SQL Server`_ * `DynamoDB`_ * `Amazon Timestream`_ * `Amazon EMR`_ @@ -145,6 +146,19 @@ MySQL .. currentmodule:: awswrangler.mysql +.. autosummary:: + :toctree: stubs + + connect + read_sql_query + read_sql_table + to_sql + +Microsoft SQL Server +____________________ + +.. currentmodule:: awswrangler.sqlserver + .. autosummary:: :toctree: stubs diff --git a/tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb b/tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb similarity index 85% rename from tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb rename to tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb index 6da12b68f..8a5e36255 100644 --- a/tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb +++ b/tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb @@ -6,7 +6,7 @@ "source": [ "[![AWS Data Wrangler](_static/logo.png \"AWS Data Wrangler\")](https://github.com/awslabs/aws-data-wrangler)\n", "\n", - "# 7 - Redshift, MySQL and PostgreSQL\n", + "# 7 - Redshift, MySQL, PostgreSQL and SQL Server\n", "\n", "[Wrangler](https://github.com/awslabs/aws-data-wrangler)'s Redshift, MySQL and PostgreSQL have two basic function in common that tries to follow the Pandas conventions, but add more data type consistency.\n", "\n", @@ -15,7 +15,9 @@ "- [wr.mysql.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.to_sql.html)\n", "- [wr.mysql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.read_sql_query.html)\n", "- [wr.postgresql.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.to_sql.html)\n", - "- [wr.postgresql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.read_sql_query.html)" + "- [wr.postgresql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.read_sql_query.html)\n", + "- [wr.sqlserver.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.to_sql.html)\n", + "- [wr.sqlserver.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.read_sql_query.html)" ] }, { @@ -41,7 +43,8 @@ "\n", "- [wr.redshift.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.redshift.connect.html)\n", "- [wr.mysql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.connect.html)\n", - "- [wr.postgresql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.connect.html)" + "- [wr.postgresql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.connect.html)\n", + "- [wr.sqlserver.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.connect.html)" ] }, { @@ -52,7 +55,8 @@ "source": [ "con_redshift = wr.redshift.connect(\"aws-data-wrangler-redshift\")\n", "con_mysql = wr.mysql.connect(\"aws-data-wrangler-mysql\")\n", - "con_postgresql = wr.postgresql.connect(\"aws-data-wrangler-postgresql\")" + "con_postgresql = wr.postgresql.connect(\"aws-data-wrangler-postgresql\")\n", + "con_sqlserver = wr.sqlserver.connect(\"aws-data-wrangler-sqlserver\")" ] }, { @@ -96,7 +100,8 @@ "source": [ "wr.redshift.to_sql(df, con_redshift, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")\n", "wr.mysql.to_sql(df, con_mysql, schema=\"test\", table=\"tutorial\", mode=\"overwrite\")\n", - "wr.postgresql.to_sql(df, con_postgresql, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")" + "wr.postgresql.to_sql(df, con_postgresql, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")\n", + "wr.sqlserver.to_sql(df, con_sqlserver, schema=\"dbo\", table=\"tutorial\", mode=\"overwrite\")" ] }, { @@ -165,7 +170,8 @@ "source": [ "wr.redshift.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_redshift)\n", "wr.mysql.read_sql_query(\"SELECT * FROM test.tutorial\", con=con_mysql)\n", - "wr.postgresql.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_postgresql)" + "wr.postgresql.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_postgresql)\n", + "wr.sqlserver.read_sql_query(\"SELECT * FROM dbo.tutorial\", con=con_sqlserver)" ] }, { @@ -176,7 +182,8 @@ "source": [ "con_redshift.close()\n", "con_mysql.close()\n", - "con_postgresql.close()" + "con_postgresql.close()\n", + "con_sqlserver.close()" ] } ], @@ -197,17 +204,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file From ec4b7b1b006a9caf325ddf9e6ad5992ab26902bc Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 14:15:16 +0100 Subject: [PATCH 07/19] Decode password to string --- awswrangler/catalog/_get.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/catalog/_get.py b/awswrangler/catalog/_get.py index fa5471811..f6aab5132 100644 --- a/awswrangler/catalog/_get.py +++ b/awswrangler/catalog/_get.py @@ -524,7 +524,7 @@ def get_connection( client_kms = _utils.client(service_name="kms", session=boto3_session) pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(res["ConnectionProperties"]["ENCRYPTED_PASSWORD"]))[ "Plaintext" - ] + ].decode("utf-8") res["ConnectionProperties"]["PASSWORD"] = pwd return cast(Dict[str, Any], res) From 80afc70cbec48ae4ece53e545030afa7399bc096 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 15:07:09 +0100 Subject: [PATCH 08/19] WIP: Fix SQLServer tests --- cloudformation/databases.yaml | 10 ++++++---- tests/conftest.py | 21 +++++++++++++++++++++ tests/test_sqlserver.py | 2 +- 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index 95f35a398..e8d32ef57 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -1,6 +1,6 @@ AWSTemplateFormatVersion: 2010-09-09 Description: | - AWS Data Wrangler Development Databases Infrastructure Redshift, Aurora PostgreSQL, Aurora MySQL + AWS Data Wrangler Development Databases Infrastructure Redshift, Aurora PostgreSQL, Aurora MySQL, Microsoft SQL Server Parameters: DatabasesPassword: Type: String @@ -285,6 +285,8 @@ Resources: DBInstanceClass: db.t3.small DBSubnetGroupName: Ref: RdsSubnetGroup + VPCSecurityGroups: + - Ref: DatabaseSecurityGroup PubliclyAccessible: true SqlServerInstance: Type: AWS::RDS::DBInstance @@ -436,7 +438,7 @@ Resources: Tags: - Key: Env Value: aws-data-wrangler - postgresqlSecret: + PostgresqlSecret: Type: AWS::SecretsManager::Secret Properties: Name: aws-data-wrangler/postgresql @@ -446,7 +448,7 @@ Resources: { "username": "test", "password": "${DatabasesPassword}", - "engine": "postgres", + "engine": "postgresql", "host": "${AuroraInstancePostgresql.Endpoint.Address}", "port": ${AuroraInstancePostgresql.Endpoint.Port}, "dbClusterIdentifier": "${AuroraInstancePostgresql}", @@ -465,7 +467,7 @@ Resources: { "username": "test", "password": "${DatabasesPassword}", - "engine": "postgres", + "engine": "mysql", "host": "${AuroraInstanceMysql.Endpoint.Address}", "port": ${AuroraInstanceMysql.Endpoint.Port}, "dbClusterIdentifier": "${AuroraInstanceMysql}", diff --git a/tests/conftest.py b/tests/conftest.py index 26b1ad9ad..d2f544780 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,7 @@ from datetime import datetime import boto3 # type: ignore +import pymssql import pytest # type: ignore import awswrangler as wr @@ -153,6 +154,26 @@ def databases_parameters(cloudformation_outputs): return parameters +@pytest.fixture(scope="session", autouse=True) +def create_sql_server_database(databases_parameters): + con = pymssql.connect( + host=databases_parameters["sqlserver"]["host"], + port=int(databases_parameters["sqlserver"]["port"]), + user=databases_parameters["user"], + password=databases_parameters["password"], + ) + sql = ( + f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sql_server']['database']}') " + "BEGIN " + "CREATE DATABASE {databases_parameters['sql_server']['database']} " + "END" + ) + with con.cursor() as cursor: + cursor.execute(sql) + con.commit() + con.close() + + @pytest.fixture(scope="session") def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_database): region = cloudformation_outputs.get("Region") diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py index 226054c40..154824078 100644 --- a/tests/test_sqlserver.py +++ b/tests/test_sqlserver.py @@ -176,7 +176,7 @@ def test_table_name(): con.close() -@pytest.mark.parametrize("dbname", [None, "sqlserver"]) +@pytest.mark.parametrize("dbname", [None, "test"]) def test_connect_secret_manager(dbname): con = wr.sqlserver.connect(secret_id="aws-data-wrangler/sqlserver", dbname=dbname) df = wr.sqlserver.read_sql_query("SELECT 1", con=con) From 1ce7c3163f9499471eee32ba1d6ba0d0599aa820 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Thu, 31 Dec 2020 15:17:10 +0100 Subject: [PATCH 09/19] WIP: Fix cfn template --- cloudformation/databases.yaml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index e8d32ef57..c6351fe3a 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -285,8 +285,6 @@ Resources: DBInstanceClass: db.t3.small DBSubnetGroupName: Ref: RdsSubnetGroup - VPCSecurityGroups: - - Ref: DatabaseSecurityGroup PubliclyAccessible: true SqlServerInstance: Type: AWS::RDS::DBInstance @@ -305,6 +303,8 @@ Resources: Ref: DatabasesPassword DBSubnetGroupName: Ref: RdsSubnetGroup + VPCSecurityGroups: + - Ref: DatabaseSecurityGroup PubliclyAccessible: true AssociatedRoles: - RoleArn: From 256e8c09b11c50289bf3b221198fafb59d0a90f8 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Fri, 1 Jan 2021 15:57:06 +0100 Subject: [PATCH 10/19] Fix tests for Linux --- awswrangler/sqlserver.py | 4 ++ tests/conftest.py | 33 +++--------- tests/test_sqlserver.py | 109 +++++++++++++++++++++++---------------- 3 files changed, 74 insertions(+), 72 deletions(-) diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index 43aa53ad7..5b5db139a 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -135,6 +135,9 @@ def connect( raise exceptions.InvalidDatabaseType( f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)" ) + # Fix TDS version to 7.3 for enabling correct casting of DATE and TIME columns + # See: https://pymssql.readthedocs.io/en/latest/faq.html + # #pymssql-does-not-unserialize-date-and-time-columns-to-datetime-date-and-datetime-time-instances return pymssql.connect( user=attrs.user, database=attrs.database, @@ -143,6 +146,7 @@ def connect( host=attrs.host, timeout=timeout, login_timeout=login_timeout, + tds_version="7.3", ) diff --git a/tests/conftest.py b/tests/conftest.py index d2f544780..011fccfca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,7 +1,6 @@ from datetime import datetime import boto3 # type: ignore -import pymssql import pytest # type: ignore import awswrangler as wr @@ -139,12 +138,12 @@ def databases_parameters(cloudformation_outputs): parameters["mysql"]["port"] = 3306 parameters["mysql"]["schema"] = "test" parameters["mysql"]["database"] = "test" - # parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] - # parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] - # parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] - # parameters["redshift"]["schema"] = "public" - # parameters["redshift"]["database"] = "test" - # parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] + parameters["redshift"]["host"] = cloudformation_outputs["RedshiftAddress"] + parameters["redshift"]["port"] = cloudformation_outputs["RedshiftPort"] + parameters["redshift"]["identifier"] = cloudformation_outputs["RedshiftIdentifier"] + parameters["redshift"]["schema"] = "public" + parameters["redshift"]["database"] = "test" + parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] parameters["password"] = cloudformation_outputs["DatabasesPassword"] parameters["user"] = "test" parameters["sqlserver"]["host"] = cloudformation_outputs["SqlServerAddress"] @@ -154,26 +153,6 @@ def databases_parameters(cloudformation_outputs): return parameters -@pytest.fixture(scope="session", autouse=True) -def create_sql_server_database(databases_parameters): - con = pymssql.connect( - host=databases_parameters["sqlserver"]["host"], - port=int(databases_parameters["sqlserver"]["port"]), - user=databases_parameters["user"], - password=databases_parameters["password"], - ) - sql = ( - f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sql_server']['database']}') " - "BEGIN " - "CREATE DATABASE {databases_parameters['sql_server']['database']} " - "END" - ) - with con.cursor() as cursor: - cursor.execute(sql) - con.commit() - con.close() - - @pytest.fixture(scope="session") def redshift_external_schema(cloudformation_outputs, databases_parameters, glue_database): region = cloudformation_outputs.get("Region") diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py index 154824078..536f4e203 100644 --- a/tests/test_sqlserver.py +++ b/tests/test_sqlserver.py @@ -11,52 +11,81 @@ from ._utils import ensure_data_types, get_df logging.getLogger("awswrangler").setLevel(logging.DEBUG) -_GLUE_CONNECTION_NAME = "aws-data-wrangler-sqlserver" -def test_connection(): - wr.sqlserver.connect(_GLUE_CONNECTION_NAME, timeout=10).close() - - -def test_read_sql_query_simple(databases_parameters): +@pytest.fixture(scope="module", autouse=True) +def create_sql_server_database(databases_parameters): con = pymssql.connect( host=databases_parameters["sqlserver"]["host"], port=int(databases_parameters["sqlserver"]["port"]), - database=databases_parameters["sqlserver"]["database"], user=databases_parameters["user"], password=databases_parameters["password"], + autocommit=True, ) - df = wr.sqlserver.read_sql_query("SELECT 1", con=con) + sql_create_db = ( + f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " + "BEGIN " + f"CREATE DATABASE {databases_parameters['sqlserver']['database']} " + "END" + ) + with con.cursor() as cursor: + cursor.execute(sql_create_db) + con.commit() + + yield + + sql_drop_db = ( + f"IF EXISTS (SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " + "BEGIN " + f"DROP DATABASE {databases_parameters['sqlserver']['database']} " + "END" + ) + with con.cursor() as cursor: + cursor.execute(sql_drop_db) + con.commit() + + con.close() + + +@pytest.fixture(scope="function") +def sqlserver_con(): + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + yield con con.close() + + +def test_connection(): + wr.sqlserver.connect("aws-data-wrangler-sqlserver", timeout=10).close() + + +def test_read_sql_query_simple(databases_parameters, sqlserver_con): + df = wr.sqlserver.read_sql_query("SELECT 1", con=sqlserver_con) assert df.shape == (1, 1) -def test_to_sql_simple(sqlserver_table): - con = wr.sqlserver.connect(_GLUE_CONNECTION_NAME) +def test_to_sql_simple(sqlserver_table, sqlserver_con): df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) - wr.sqlserver.to_sql(df, con, sqlserver_table, "dbo", "overwrite", True) - con.close() + wr.sqlserver.to_sql(df, sqlserver_con, sqlserver_table, "dbo", "overwrite", True) -def test_sql_types(sqlserver_table): +def test_sql_types(sqlserver_table, sqlserver_con): table = sqlserver_table df = get_df() df.drop(["binary"], axis=1, inplace=True) - con = wr.sqlserver.connect(_GLUE_CONNECTION_NAME) wr.sqlserver.to_sql( df=df, - con=con, + con=sqlserver_con, table=table, schema="dbo", mode="overwrite", index=True, dtype={"iint32": "INTEGER"}, ) - df = wr.sqlserver.read_sql_query(f"SELECT * FROM dbo.{table}", con) + df = wr.sqlserver.read_sql_query(f"SELECT * FROM dbo.{table}", sqlserver_con) ensure_data_types(df, has_list=False) dfs = wr.sqlserver.read_sql_query( sql=f"SELECT * FROM dbo.{table}", - con=con, + con=sqlserver_con, chunksize=1, dtype={ "iint8": pa.int8(), @@ -78,7 +107,7 @@ def test_sql_types(sqlserver_table): ensure_data_types(df, has_list=False) -def test_to_sql_cast(sqlserver_table): +def test_to_sql_cast(sqlserver_table, sqlserver_con): table = sqlserver_table df = pd.DataFrame( { @@ -90,28 +119,25 @@ def test_to_sql_cast(sqlserver_table): }, dtype="string", ) - con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) wr.sqlserver.to_sql( df=df, - con=con, + con=sqlserver_con, table=table, schema="dbo", mode="overwrite", index=False, dtype={"col": "VARCHAR(1024)"}, ) - df2 = wr.sqlserver.read_sql_query(sql=f"SELECT * FROM dbo.{table}", con=con) + df2 = wr.sqlserver.read_sql_query(sql=f"SELECT * FROM dbo.{table}", con=sqlserver_con) assert df.equals(df2) - con.close() -def test_null(sqlserver_table): +def test_null(sqlserver_table, sqlserver_con): table = sqlserver_table - con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) df = pd.DataFrame({"id": [1, 2, 3], "nothing": [None, None, None]}) wr.sqlserver.to_sql( df=df, - con=con, + con=sqlserver_con, table=table, schema="dbo", mode="overwrite", @@ -120,19 +146,18 @@ def test_null(sqlserver_table): ) wr.sqlserver.to_sql( df=df, - con=con, + con=sqlserver_con, table=table, schema="dbo", mode="append", index=False, ) - df2 = wr.sqlserver.read_sql_table(table=table, schema="dbo", con=con) + df2 = wr.sqlserver.read_sql_table(table=table, schema="dbo", con=sqlserver_con) df["id"] = df["id"].astype("Int64") assert pd.concat(objs=[df, df], ignore_index=True).equals(df2) - con.close() -def test_decimal_cast(sqlserver_table): +def test_decimal_cast(sqlserver_table, sqlserver_con): table = sqlserver_table df = pd.DataFrame( { @@ -141,39 +166,33 @@ def test_decimal_cast(sqlserver_table): "col2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], } ) - con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) - wr.sqlserver.to_sql(df, con, table, "dbo") + wr.sqlserver.to_sql(df, sqlserver_con, table, "dbo") df2 = wr.sqlserver.read_sql_table( - schema="dbo", table=table, con=con, dtype={"col0": "float32", "col1": "float64", "col2": "Int64"} + schema="dbo", table=table, con=sqlserver_con, dtype={"col0": "float32", "col1": "float64", "col2": "Int64"} ) assert df2.dtypes.to_list() == ["float32", "float64", "Int64"] assert 3.88 <= df2.col0.sum() <= 3.89 assert 3.88 <= df2.col1.sum() <= 3.89 assert df2.col2.sum() == 2 - con.close() -def test_read_retry(): - con = wr.sqlserver.connect(connection=_GLUE_CONNECTION_NAME) +def test_read_retry(sqlserver_con): try: - wr.sqlserver.read_sql_query("ERROR", con) + wr.sqlserver.read_sql_query("ERROR", sqlserver_con) except: # noqa pass - df = wr.sqlserver.read_sql_query("SELECT 1", con) + df = wr.sqlserver.read_sql_query("SELECT 1", sqlserver_con) assert df.shape == (1, 1) - con.close() -def test_table_name(): +def test_table_name(sqlserver_con): df = pd.DataFrame({"col0": [1]}) - con = wr.sqlserver.connect(connection="aws-data-wrangler-sqlserver") - wr.sqlserver.to_sql(df, con, "Test Name", "dbo", mode="overwrite") - df = wr.sqlserver.read_sql_table(schema="dbo", con=con, table="Test Name") + wr.sqlserver.to_sql(df, sqlserver_con, "Test Name", "dbo", mode="overwrite") + df = wr.sqlserver.read_sql_table(schema="dbo", con=sqlserver_con, table="Test Name") assert df.shape == (1, 1) - with con.cursor() as cursor: + with sqlserver_con.cursor() as cursor: cursor.execute('DROP TABLE "Test Name"') - con.commit() - con.close() + sqlserver_con.commit() @pytest.mark.parametrize("dbname", [None, "test"]) From eb18b55ebff1ed54e86291c053df0e716f6b52bf Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Fri, 1 Jan 2021 16:01:00 +0100 Subject: [PATCH 11/19] Add missing ; --- awswrangler/_databases.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 7f327388e..119474835 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -36,7 +36,7 @@ def _get_connection_attributes_from_catalog( details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[ "ConnectionProperties" ] - if "databaseName=" in details["JDBC_CONNECTION_URL"]: + if ";databaseName=" in details["JDBC_CONNECTION_URL"]: database_sep = ";databaseName=" else: database_sep = "/" From 721baecd50f77d696c574d62b45fc501f7585b9e Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sat, 2 Jan 2021 17:49:14 +0100 Subject: [PATCH 12/19] Swap from pymssql to pyodbc --- .pylintrc | 2 +- awswrangler/sqlserver.py | 88 ++++++++++++++++++---------------------- requirements-dev.txt | 1 + requirements.txt | 1 - setup.py | 1 + tests/test_sqlserver.py | 18 ++++---- 6 files changed, 53 insertions(+), 58 deletions(-) diff --git a/.pylintrc b/.pylintrc index 9585354aa..eaa732f57 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,7 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. -extension-pkg-whitelist=pyarrow.lib,pymssql +extension-pkg-whitelist=pyarrow.lib,pyodbc # Specify a score threshold to be exceeded before program exits with error. fail-under=10 diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index 5b5db139a..e02ceb962 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -7,7 +7,7 @@ import boto3 import pandas as pd import pyarrow as pa -import pymssql +import pyodbc from awswrangler import _data_types from awswrangler import _databases as _db_utils @@ -16,11 +16,11 @@ _logger: logging.Logger = logging.getLogger(__name__) -def _validate_connection(con: pymssql.Connection) -> None: - if not isinstance(con, pymssql.Connection): +def _validate_connection(con: pyodbc.Connection) -> None: + if not isinstance(con, pyodbc.Connection): raise exceptions.InvalidConnection( "Invalid 'conn' argument, please pass a " - "pymssql.Connection object. Use pymssql.connect() to use " + "pyodbc.Connection object. Use pyodbc.connect() to use " "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog." ) @@ -31,14 +31,14 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str: return table_identifier -def _drop_table(cursor: pymssql.Cursor, schema: Optional[str], table: str) -> None: +def _drop_table(cursor: pyodbc.Cursor, schema: Optional[str], table: str) -> None: table_identifier = _get_table_identifier(schema, table) sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) -def _does_table_exist(cursor: pymssql.Cursor, schema: Optional[str], table: str) -> bool: +def _does_table_exist(cursor: pyodbc.Cursor, schema: Optional[str], table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") return len(cursor.fetchall()) > 0 @@ -46,7 +46,7 @@ def _does_table_exist(cursor: pymssql.Cursor, schema: Optional[str], table: str) def _create_table( df: pd.DataFrame, - cursor: pymssql.Cursor, + cursor: pyodbc.Cursor, table: str, schema: str, mode: str, @@ -80,13 +80,13 @@ def connect( secret_id: Optional[str] = None, catalog_id: Optional[str] = None, dbname: Optional[str] = None, + odbc_driver_version: int = 17, boto3_session: Optional[boto3.Session] = None, timeout: Optional[int] = 0, - login_timeout: Optional[int] = 60, -) -> pymssql.Connection: - """Return a pymssql connection from a Glue Catalog Connection. +) -> pyodbc.Connection: + """Return a pyodbc connection from a Glue Catalog Connection. - https://github.com/pymssql/pymssql + https://github.com/mkleehammer/pyodbc Parameters ---------- @@ -100,28 +100,25 @@ def connect( If none is provided, the AWS account ID is used by default. dbname: Optional[str] Optional database name to overwrite the stored one. + odbc_driver_version : int + Major version of the OBDC Driver version that is installed and should be used. boto3_session : boto3.Session(), optional Boto3 Session. The default boto3 session will be used if boto3_session receive None. timeout: Optional[int] This is the time in seconds before the connection to the server will time out. The default is None which means no timeout. - This parameter is forwarded to pymssql. - https://pymssql.readthedocs.io/en/latest/ref/pymssql.html - login_timeout: Optional[int] - This is the time in seconds that the connection and login may take before it times out. - The default is 60 seconds. - This parameter is forwarded to pymssql. - https://pymssql.readthedocs.io/en/latest/ref/pymssql.html + This parameter is forwarded to pyodbc. + https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#connect Returns ------- - pymssql.Connection - pymssql connection. + pyodbc.Connection + pyodbc connection. Examples -------- >>> import awswrangler as wr - >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) >>> with con.cursor() as cursor: >>> cursor.execute("SELECT 1") >>> print(cursor.fetchall()) @@ -135,24 +132,20 @@ def connect( raise exceptions.InvalidDatabaseType( f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)" ) - # Fix TDS version to 7.3 for enabling correct casting of DATE and TIME columns - # See: https://pymssql.readthedocs.io/en/latest/faq.html - # #pymssql-does-not-unserialize-date-and-time-columns-to-datetime-date-and-datetime-time-instances - return pymssql.connect( - user=attrs.user, - database=attrs.database, - password=attrs.password, - port=attrs.port, - host=attrs.host, - timeout=timeout, - login_timeout=login_timeout, - tds_version="7.3", + connection_str = ( + f"DRIVER={{ODBC Driver {odbc_driver_version} for SQL Server}};" + f"SERVER={attrs.host},{attrs.port};" + f"DATABASE={attrs.database};" + f"UID={attrs.user};" + f"PWD={attrs.password}" ) + return pyodbc.connect(connection_str, timeout=timeout) + def read_sql_query( sql: str, - con: pymssql.Connection, + con: pyodbc.Connection, index_col: Optional[Union[str, List[str]]] = None, params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, chunksize: Optional[int] = None, @@ -165,8 +158,8 @@ def read_sql_query( ---------- sql : str SQL query. - con : pymssql.Connection - Use pymssql.connect() to use " + con : pyodbc.Connection + Use pyodbc.connect() to use " "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). @@ -192,7 +185,7 @@ def read_sql_query( -------- Reading from Microsoft SQL Server using a Glue Catalog Connections >>> import awswrangler as wr - >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) >>> df = wr.sqlserver.read_sql_query( ... sql="SELECT * FROM dbo.my_table", ... con=con @@ -207,7 +200,7 @@ def read_sql_query( def read_sql_table( table: str, - con: pymssql.Connection, + con: pyodbc.Connection, schema: Optional[str] = None, index_col: Optional[Union[str, List[str]]] = None, params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, @@ -221,8 +214,8 @@ def read_sql_table( ---------- table : str Table name. - con : pymssql.Connection - Use pymssql.connect() to use " + con : pyodbc.Connection + Use pyodbc.connect() to use " "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query (if database flavor supports this). @@ -252,7 +245,7 @@ def read_sql_table( Reading from Microsoft SQL Server using a Glue Catalog Connections >>> import awswrangler as wr - >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) >>> df = wr.sqlserver.read_sql_table( ... table="my_table", ... schema="dbo", @@ -269,7 +262,7 @@ def read_sql_table( def to_sql( df: pd.DataFrame, - con: pymssql.Connection, + con: pyodbc.Connection, table: str, schema: str, mode: str = "append", @@ -283,8 +276,8 @@ def to_sql( ---------- df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html - con : pymssql.Connection - Use pymssql.connect() to use " + con : pyodbc.Connection + Use pyodbc.connect() to use " "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. table : str Table name @@ -312,7 +305,7 @@ def to_sql( Writing to Microsoft SQL Server using a Glue Catalog Connections >>> import awswrangler as wr - >>> con = wr.sqlserver.connect("MY_GLUE_CONNECTION") + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) >>> wr.sqlserver.to_sql( ... df=df, ... table="table", @@ -339,13 +332,12 @@ def to_sql( ) if index: df.reset_index(level=df.index.names, inplace=True) - placeholders: str = ", ".join(["%s"] * len(df.columns)) + placeholders: str = ", ".join(["?"] * len(df.columns)) table_identifier = _get_table_identifier(schema, table) sql: str = f"INSERT INTO {table_identifier} VALUES ({placeholders})" _logger.debug("sql: %s", sql) parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) - parameter_tuples: List[Tuple[Any, ...]] = [tuple(parameter_set) for parameter_set in parameters] - cursor.executemany(sql, parameter_tuples) + cursor.executemany(sql, parameters) con.commit() except Exception as ex: con.rollback() diff --git a/requirements-dev.txt b/requirements-dev.txt index 871204015..3af37aee9 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,4 +20,5 @@ moto==1.3.16 jupyterlab==3.0.0 jupyter==1.0.0 s3fs==0.4.2 +pyodbc~=4.0.30 -e . diff --git a/requirements.txt b/requirements.txt index 9bd3b99b0..6417ea8c0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,4 +6,3 @@ pyarrow~=2.0.0 redshift-connector~=2.0.0 pymysql>=0.9.0,<0.11.0 pg8000~=1.16.0 -pymssql~=2.1.5 diff --git a/setup.py b/setup.py index 004a746c1..911a5675e 100644 --- a/setup.py +++ b/setup.py @@ -38,4 +38,5 @@ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], + extras_require={"sqlserver": ["pyodbc~=4.0.30"]}, ) diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py index 536f4e203..26ee846c9 100644 --- a/tests/test_sqlserver.py +++ b/tests/test_sqlserver.py @@ -3,7 +3,7 @@ import pandas as pd import pyarrow as pa -import pymssql +import pyodbc import pytest import awswrangler as wr @@ -15,13 +15,13 @@ @pytest.fixture(scope="module", autouse=True) def create_sql_server_database(databases_parameters): - con = pymssql.connect( - host=databases_parameters["sqlserver"]["host"], - port=int(databases_parameters["sqlserver"]["port"]), - user=databases_parameters["user"], - password=databases_parameters["password"], - autocommit=True, + connection_str = ( + f"DRIVER={{ODBC Driver 17 for SQL Server}};" + f"SERVER={databases_parameters['sqlserver']['host']},{databases_parameters['sqlserver']['port']};" + f"UID={databases_parameters['user']};" + f"PWD={databases_parameters['password']}" ) + con = pyodbc.connect(connection_str, autocommit=True) sql_create_db = ( f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " "BEGIN " @@ -37,7 +37,9 @@ def create_sql_server_database(databases_parameters): sql_drop_db = ( f"IF EXISTS (SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " "BEGIN " - f"DROP DATABASE {databases_parameters['sqlserver']['database']} " + "USE master; " + f"ALTER DATABASE {databases_parameters['sqlserver']['database']} SET SINGLE_USER WITH ROLLBACK IMMEDIATE; " + f"DROP DATABASE {databases_parameters['sqlserver']['database']}; " "END" ) with con.cursor() as cursor: From f1f36944ea07eb4124c0acbf58baada29dd3743c Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sat, 2 Jan 2021 21:01:31 +0100 Subject: [PATCH 13/19] Dynamically import pyodbc --- awswrangler/sqlserver.py | 43 ++++++++++++++++++++++++++++++---------- 1 file changed, 33 insertions(+), 10 deletions(-) diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index e02ceb962..a88517bd9 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -1,22 +1,41 @@ """Amazon Microsoft SQL Server Module.""" +import importlib import logging -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union import boto3 import pandas as pd import pyarrow as pa -import pyodbc from awswrangler import _data_types from awswrangler import _databases as _db_utils from awswrangler import exceptions +__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"] + +_pyodbc_found = importlib.util.find_spec("pyodbc") +if _pyodbc_found: + import pyodbc + _logger: logging.Logger = logging.getLogger(__name__) +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + + +def _check_for_pyodbc(func: FuncT) -> FuncT: + def inner(*args: Any, **kwargs: Any) -> Any: + if not _pyodbc_found: + raise ModuleNotFoundError( + "You need to install pyodbc respectively the " + "AWS Data Wrangler package with the `sqlserver` extra for using the sqlserver module" + ) + return func(*args, **kwargs) + + return inner # type: ignore -def _validate_connection(con: pyodbc.Connection) -> None: +def _validate_connection(con: "pyodbc.Connection") -> None: if not isinstance(con, pyodbc.Connection): raise exceptions.InvalidConnection( "Invalid 'conn' argument, please pass a " @@ -31,14 +50,14 @@ def _get_table_identifier(schema: Optional[str], table: str) -> str: return table_identifier -def _drop_table(cursor: pyodbc.Cursor, schema: Optional[str], table: str) -> None: +def _drop_table(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> None: table_identifier = _get_table_identifier(schema, table) sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}" _logger.debug("Drop table query:\n%s", sql) cursor.execute(sql) -def _does_table_exist(cursor: pyodbc.Cursor, schema: Optional[str], table: str) -> bool: +def _does_table_exist(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> bool: schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") return len(cursor.fetchall()) > 0 @@ -46,7 +65,7 @@ def _does_table_exist(cursor: pyodbc.Cursor, schema: Optional[str], table: str) def _create_table( df: pd.DataFrame, - cursor: pyodbc.Cursor, + cursor: "pyodbc.Cursor", table: str, schema: str, mode: str, @@ -75,6 +94,7 @@ def _create_table( cursor.execute(sql) +@_check_for_pyodbc def connect( connection: Optional[str] = None, secret_id: Optional[str] = None, @@ -83,7 +103,7 @@ def connect( odbc_driver_version: int = 17, boto3_session: Optional[boto3.Session] = None, timeout: Optional[int] = 0, -) -> pyodbc.Connection: +) -> "pyodbc.Connection": """Return a pyodbc connection from a Glue Catalog Connection. https://github.com/mkleehammer/pyodbc @@ -143,9 +163,10 @@ def connect( return pyodbc.connect(connection_str, timeout=timeout) +@_check_for_pyodbc def read_sql_query( sql: str, - con: pyodbc.Connection, + con: "pyodbc.Connection", index_col: Optional[Union[str, List[str]]] = None, params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, chunksize: Optional[int] = None, @@ -198,9 +219,10 @@ def read_sql_query( ) +@_check_for_pyodbc def read_sql_table( table: str, - con: pyodbc.Connection, + con: "pyodbc.Connection", schema: Optional[str] = None, index_col: Optional[Union[str, List[str]]] = None, params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, @@ -260,9 +282,10 @@ def read_sql_table( ) +@_check_for_pyodbc def to_sql( df: pd.DataFrame, - con: pyodbc.Connection, + con: "pyodbc.Connection", table: str, schema: str, mode: str = "append", From 048099216ac9aa72bc45c56ec3da212ba558ad7f Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sun, 3 Jan 2021 12:09:56 +0100 Subject: [PATCH 14/19] Add pyodbc to Lambda layer --- building/lambda/Dockerfile | 23 +++++++++++++++++++++++ building/lambda/build-lambda-layer.sh | 9 ++++++--- 2 files changed, 29 insertions(+), 3 deletions(-) diff --git a/building/lambda/Dockerfile b/building/lambda/Dockerfile index 4ce0b4d47..f02fa7c5e 100644 --- a/building/lambda/Dockerfile +++ b/building/lambda/Dockerfile @@ -12,6 +12,29 @@ RUN yum install -y \ ninja-build \ ${py_dev} +# Based on https://gist.github.com/diriver63/b72a954fa0da4851d89e5086aa13c6e8 +RUN curl ftp://ftp.unixodbc.org/pub/unixODBC/unixODBC-2.3.9.tar.gz -O && \ + tar xvzf unixODBC-2.3.9.tar.gz && \ + cd unixODBC-2.3.9 && \ + mkdir /odbc-build && \ + ./configure --sysconfdir=/opt --disable-gui --disable-drivers --enable-iconv \ + --with-iconv-char-enc=UTF8 --with-iconv-ucode-enc=UTF16LE --prefix=/opt && \ + make && \ + make install && \ + cd .. && \ + rm -r unixODBC-2.3.9 && \ + rm unixODBC-2.3.9.tar.gz + +RUN curl https://packages.microsoft.com/config/rhel/7/prod.repo > /etc/yum.repos.d/mssql-release.repo && \ + ACCEPT_EULA=Y yum -y install msodbcsql17 unixODBC-devel && \ + export CFLAGS="-I/opt/include" && \ + export LDFLAGS="-L/opt/lib" && \ + cd /opt && \ + cp -r /opt/microsoft/msodbcsql17/ . && \ + rm -rf /opt/microsoft/ && \ + printf "[ODBC Driver 17 for SQL Server]\n\nDriver=/opt/msodbcsql17/lib64/libmsodbcsql-17.6.so.1.1\n\nUsageCount=1\n" > /opt/odbcinst.ini && \ + printf "[ODBC Driver 17 for SQL Server]\n\nDriver = ODBC Driver 17 for SQL Server\n\n Trace = No\n" > /opt/odbc.ini + RUN pip3 install --upgrade pip six cython cmake hypothesis ADD requirements.txt /root/ diff --git a/building/lambda/build-lambda-layer.sh b/building/lambda/build-lambda-layer.sh index f44f0ca5c..7d0f5dada 100644 --- a/building/lambda/build-lambda-layer.sh +++ b/building/lambda/build-lambda-layer.sh @@ -69,7 +69,7 @@ popd pushd /aws-data-wrangler -pip install . -t ./python +pip install .[sqlserver] -t ./python rm -rf python/pyarrow* rm -rf python/boto* @@ -80,12 +80,15 @@ rm -f /aws-data-wrangler/dist/pyarrow_files/pyarrow/libarrow_python.so cp -r /aws-data-wrangler/dist/pyarrow_files/pyarrow* python/ + find python -wholename "*/tests/*" -type f -delete -zip -r9 "${FILENAME}" ./python +cp -r /opt . +cd opt; zip -r9 ../"${FILENAME}" *; cd .. +zip -ur9 "${FILENAME}" ./python mv "${FILENAME}" dist/ -rm -rf python dist/pyarrow_files "${FILENAME}" +rm -rf python dist/pyarrow_files "${FILENAME}" opt popd From e892fc7199d66351c6cb21efcf279ae109430ddf Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sun, 3 Jan 2021 21:23:42 +0100 Subject: [PATCH 15/19] Fix for 3.6 and 3.7 --- building/lambda/Dockerfile | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/building/lambda/Dockerfile b/building/lambda/Dockerfile index f02fa7c5e..92b1daedf 100644 --- a/building/lambda/Dockerfile +++ b/building/lambda/Dockerfile @@ -26,14 +26,15 @@ RUN curl ftp://ftp.unixodbc.org/pub/unixODBC/unixODBC-2.3.9.tar.gz -O && \ rm unixODBC-2.3.9.tar.gz RUN curl https://packages.microsoft.com/config/rhel/7/prod.repo > /etc/yum.repos.d/mssql-release.repo && \ - ACCEPT_EULA=Y yum -y install msodbcsql17 unixODBC-devel && \ + yum install -y e2fsprogs.x86_64 0:1.43.5-2.43.amzn1 fuse-libs.x86_64 0:2.9.4-1.18.amzn1 libss.x86_64 0:1.43.5-2.43.amzn1 && \ + ACCEPT_EULA=Y yum -y install msodbcsql17 unixODBC-devel --disablerepo=amzn* && \ export CFLAGS="-I/opt/include" && \ export LDFLAGS="-L/opt/lib" && \ cd /opt && \ cp -r /opt/microsoft/msodbcsql17/ . && \ rm -rf /opt/microsoft/ && \ - printf "[ODBC Driver 17 for SQL Server]\n\nDriver=/opt/msodbcsql17/lib64/libmsodbcsql-17.6.so.1.1\n\nUsageCount=1\n" > /opt/odbcinst.ini && \ - printf "[ODBC Driver 17 for SQL Server]\n\nDriver = ODBC Driver 17 for SQL Server\n\n Trace = No\n" > /opt/odbc.ini + printf "[ODBC Driver 17 for SQL Server]\nDriver=/opt/msodbcsql17/lib64/libmsodbcsql-17.6.so.1.1\nUsageCount=1\n" > /opt/odbcinst.ini && \ + printf "[ODBC Driver 17 for SQL Server]\nDriver = ODBC Driver 17 for SQL Server\nTrace = No\n" > /opt/odbc.ini RUN pip3 install --upgrade pip six cython cmake hypothesis From 6955b54c3b9f3770b18e217f5cd1225d1bbfbff6 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Sun, 3 Jan 2021 23:33:47 +0100 Subject: [PATCH 16/19] Fix for 3.8 --- building/lambda/Dockerfile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/building/lambda/Dockerfile b/building/lambda/Dockerfile index 92b1daedf..dc53d4a37 100644 --- a/building/lambda/Dockerfile +++ b/building/lambda/Dockerfile @@ -26,7 +26,7 @@ RUN curl ftp://ftp.unixodbc.org/pub/unixODBC/unixODBC-2.3.9.tar.gz -O && \ rm unixODBC-2.3.9.tar.gz RUN curl https://packages.microsoft.com/config/rhel/7/prod.repo > /etc/yum.repos.d/mssql-release.repo && \ - yum install -y e2fsprogs.x86_64 0:1.43.5-2.43.amzn1 fuse-libs.x86_64 0:2.9.4-1.18.amzn1 libss.x86_64 0:1.43.5-2.43.amzn1 && \ + yum install -y e2fsprogs.x86_64 0:1.43.5-2.43.amzn1 fuse-libs.x86_64 0:2.9.4-1.18.amzn1 libss.x86_64 0:1.43.5-2.43.amzn1 openssl.x86_64 1:1.0.2k-19.amzn2.0.4 && \ ACCEPT_EULA=Y yum -y install msodbcsql17 unixODBC-devel --disablerepo=amzn* && \ export CFLAGS="-I/opt/include" && \ export LDFLAGS="-L/opt/lib" && \ From 4f7edd331d968ef3e098eedf10133b2ac5e13e94 Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Mon, 4 Jan 2021 20:50:37 +0100 Subject: [PATCH 17/19] Update documentation --- docs/source/install.rst | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/docs/source/install.rst b/docs/source/install.rst index dc54b5f0a..25674dab8 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -9,6 +9,8 @@ Some good practices for most of the methods bellow are: - Use new and individual Virtual Environments for each project (`venv `_). - On Notebooks, always restart your kernel after installations. +.. note:: If you want to use ``awswrangler`` for connecting to Microsoft SQL Server, some additional configuration is needed. Please have a look at the corresponding section below. + PyPI (pip) ---------- @@ -150,3 +152,34 @@ From Source >>> git clone https://github.com/awslabs/aws-data-wrangler.git >>> cd aws-data-wrangler >>> pip install . + + +Notes for Microsoft SQL Server +------------------------------ + +``awswrangler`` is using the `pyodbc `_ +for interacting with Microsoft SQL Server. For installing this package you need the ODBC header files, +which can be installed, for example, with the following commands: + + >>> sudo apt install unixodbc-dev + >>> yum install unixODBC-devel + +After installing these header files you can either just install ``pyodbc`` or +``awswrangler`` with the ``sqlserver`` extra, which will also install ``pyodbc``: + + >>> pip install pyodbc + >>> pip install awswrangler[sqlserver] + +Finally you also the need correct ODBC Driver for SQL Server. You can have a look at the +`documentation from Microsoft `_ +to see how they can be installed in your environment. + +If you want to connect to Microsoft SQL Server from AWS Lambda, you can build a separate Layer including the +needed OBDC drivers and `pyobdc` by following +`these instructions `_. + +If you maintain your own environment, you need to take care of the above steps. +Because of this limitation usage in combination with Glue jobs is limited and you need to rely on the +provided `functionality inside Glue itself `_. From 2bc1c8f910f21c4bac0526c28978fd039728fa8f Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Mon, 4 Jan 2021 22:49:48 +0100 Subject: [PATCH 18/19] Revert changes to Lambda layer build --- building/lambda/Dockerfile | 24 ------------------------ building/lambda/build-lambda-layer.sh | 9 +++------ 2 files changed, 3 insertions(+), 30 deletions(-) diff --git a/building/lambda/Dockerfile b/building/lambda/Dockerfile index dc53d4a37..4ce0b4d47 100644 --- a/building/lambda/Dockerfile +++ b/building/lambda/Dockerfile @@ -12,30 +12,6 @@ RUN yum install -y \ ninja-build \ ${py_dev} -# Based on https://gist.github.com/diriver63/b72a954fa0da4851d89e5086aa13c6e8 -RUN curl ftp://ftp.unixodbc.org/pub/unixODBC/unixODBC-2.3.9.tar.gz -O && \ - tar xvzf unixODBC-2.3.9.tar.gz && \ - cd unixODBC-2.3.9 && \ - mkdir /odbc-build && \ - ./configure --sysconfdir=/opt --disable-gui --disable-drivers --enable-iconv \ - --with-iconv-char-enc=UTF8 --with-iconv-ucode-enc=UTF16LE --prefix=/opt && \ - make && \ - make install && \ - cd .. && \ - rm -r unixODBC-2.3.9 && \ - rm unixODBC-2.3.9.tar.gz - -RUN curl https://packages.microsoft.com/config/rhel/7/prod.repo > /etc/yum.repos.d/mssql-release.repo && \ - yum install -y e2fsprogs.x86_64 0:1.43.5-2.43.amzn1 fuse-libs.x86_64 0:2.9.4-1.18.amzn1 libss.x86_64 0:1.43.5-2.43.amzn1 openssl.x86_64 1:1.0.2k-19.amzn2.0.4 && \ - ACCEPT_EULA=Y yum -y install msodbcsql17 unixODBC-devel --disablerepo=amzn* && \ - export CFLAGS="-I/opt/include" && \ - export LDFLAGS="-L/opt/lib" && \ - cd /opt && \ - cp -r /opt/microsoft/msodbcsql17/ . && \ - rm -rf /opt/microsoft/ && \ - printf "[ODBC Driver 17 for SQL Server]\nDriver=/opt/msodbcsql17/lib64/libmsodbcsql-17.6.so.1.1\nUsageCount=1\n" > /opt/odbcinst.ini && \ - printf "[ODBC Driver 17 for SQL Server]\nDriver = ODBC Driver 17 for SQL Server\nTrace = No\n" > /opt/odbc.ini - RUN pip3 install --upgrade pip six cython cmake hypothesis ADD requirements.txt /root/ diff --git a/building/lambda/build-lambda-layer.sh b/building/lambda/build-lambda-layer.sh index 7d0f5dada..f44f0ca5c 100644 --- a/building/lambda/build-lambda-layer.sh +++ b/building/lambda/build-lambda-layer.sh @@ -69,7 +69,7 @@ popd pushd /aws-data-wrangler -pip install .[sqlserver] -t ./python +pip install . -t ./python rm -rf python/pyarrow* rm -rf python/boto* @@ -80,15 +80,12 @@ rm -f /aws-data-wrangler/dist/pyarrow_files/pyarrow/libarrow_python.so cp -r /aws-data-wrangler/dist/pyarrow_files/pyarrow* python/ - find python -wholename "*/tests/*" -type f -delete -cp -r /opt . -cd opt; zip -r9 ../"${FILENAME}" *; cd .. -zip -ur9 "${FILENAME}" ./python +zip -r9 "${FILENAME}" ./python mv "${FILENAME}" dist/ -rm -rf python dist/pyarrow_files "${FILENAME}" opt +rm -rf python dist/pyarrow_files "${FILENAME}" popd From d869eeb1beb55e44e3e56f0e810d9835bf8fc13b Mon Sep 17 00:00:00 2001 From: Maximilian Speicher Date: Tue, 5 Jan 2021 00:36:36 +0100 Subject: [PATCH 19/19] Fix formatting in docstrings --- awswrangler/mysql.py | 9 +++------ awswrangler/postgresql.py | 9 +++------ awswrangler/sqlserver.py | 9 +++------ 3 files changed, 9 insertions(+), 18 deletions(-) diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index a3ff4c91b..1dc3d5be9 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -156,8 +156,7 @@ def read_sql_query( sql : str SQL query. con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). params : Union[List, Tuple, Dict], optional @@ -214,8 +213,7 @@ def read_sql_table( table : str Table name. con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query. Uses default schema if None. @@ -276,8 +274,7 @@ def to_sql( df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. table : str Table name schema : str diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index ffeb893f5..6a1461079 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -162,8 +162,7 @@ def read_sql_query( sql : str SQL query. con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). params : Union[List, Tuple, Dict], optional @@ -220,8 +219,7 @@ def read_sql_table( table : str Table name. con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query (if database flavor supports this). Uses default schema if None (default). @@ -282,8 +280,7 @@ def to_sql( df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. table : str Table name schema : str diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py index a88517bd9..35de8ed8b 100644 --- a/awswrangler/sqlserver.py +++ b/awswrangler/sqlserver.py @@ -180,8 +180,7 @@ def read_sql_query( sql : str SQL query. con : pyodbc.Connection - Use pyodbc.connect() to use " - "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). params : Union[List, Tuple, Dict], optional @@ -237,8 +236,7 @@ def read_sql_table( table : str Table name. con : pyodbc.Connection - Use pyodbc.connect() to use " - "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query (if database flavor supports this). Uses default schema if None (default). @@ -300,8 +298,7 @@ def to_sql( df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : pyodbc.Connection - Use pyodbc.connect() to use " - "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. table : str Table name schema : str