diff --git a/.pylintrc b/.pylintrc index daa1c3241..eaa732f57 100644 --- a/.pylintrc +++ b/.pylintrc @@ -3,7 +3,7 @@ # A comma-separated list of package or module names from where C extensions may # be loaded. Extensions are loading into the active Python interpreter and may # run arbitrary code. -extension-pkg-whitelist=pyarrow.lib +extension-pkg-whitelist=pyarrow.lib,pyodbc # Specify a score threshold to be exceeded before program exits with error. fail-under=10 diff --git a/README.md b/README.md index b3897b11d..1bdb4a636 100644 --- a/README.md +++ b/README.md @@ -105,7 +105,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3 - [004 - Parquet Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/004%20-%20Parquet%20Datasets.ipynb) - [005 - Glue Catalog](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/005%20-%20Glue%20Catalog.ipynb) - [006 - Amazon Athena](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/006%20-%20Amazon%20Athena.ipynb) - - [007 - Databases (Redshift, MySQL and PostgreSQL)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL.ipynb) + - [007 - Databases (Redshift, MySQL, PostgreSQL and SQL Server)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL%2C%20SQL%20Server.ipynb) - [008 - Redshift - Copy & Unload.ipynb](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/008%20-%20Redshift%20-%20Copy%20%26%20Unload.ipynb) - [009 - Redshift - Append, Overwrite and Upsert](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/009%20-%20Redshift%20-%20Append%2C%20Overwrite%2C%20Upsert.ipynb) - [010 - Parquet Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/010%20-%20Parquet%20Crawler.ipynb) @@ -134,6 +134,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3 - [Amazon Redshift](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-redshift) - [PostgreSQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#postgresql) - [MySQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#mysql) + - [SQL Server](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#sqlserver) - [DynamoDB](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#dynamodb) - [Amazon Timestream](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-timestream) - [Amazon EMR](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-emr) diff --git a/awswrangler/__init__.py b/awswrangler/__init__.py index 00e6f7315..25785e433 100644 --- a/awswrangler/__init__.py +++ b/awswrangler/__init__.py @@ -21,6 +21,7 @@ redshift, s3, secretsmanager, + sqlserver, sts, timestream, ) @@ -42,6 +43,7 @@ "mysql", "postgresql", "secretsmanager", + "sqlserver", "config", "timestream", "__description__", diff --git a/awswrangler/_config.py b/awswrangler/_config.py index 23703cdde..859d60e0a 100644 --- a/awswrangler/_config.py +++ b/awswrangler/_config.py @@ -154,7 +154,9 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla if _Config._is_null(value=value): if nullable is True: return None - exceptions.InvalidArgumentValue(f"{name} configuration does not accept a null value. Please pass {dtype}.") + raise exceptions.InvalidArgumentValue( + f"{name} configuration does not accept a null value. Please pass {dtype}." + ) try: return dtype(value) if isinstance(value, dtype) is False else value except ValueError as ex: diff --git a/awswrangler/_data_types.py b/awswrangler/_data_types.py index 9596f33bc..728926478 100644 --- a/awswrangler/_data_types.py +++ b/awswrangler/_data_types.py @@ -166,6 +166,41 @@ def pyarrow2postgresql( # pylint: disable=too-many-branches,too-many-return-sta raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}") +def pyarrow2sqlserver( # pylint: disable=too-many-branches,too-many-return-statements + dtype: pa.DataType, string_type: str +) -> str: + """Pyarrow to Microsoft SQL Server data types conversion.""" + if pa.types.is_int8(dtype): + return "SMALLINT" + if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype): + return "SMALLINT" + if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype): + return "INT" + if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype): + return "BIGINT" + if pa.types.is_uint64(dtype): + raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.") + if pa.types.is_float32(dtype): + return "FLOAT(24)" + if pa.types.is_float64(dtype): + return "FLOAT" + if pa.types.is_boolean(dtype): + return "BIT" + if pa.types.is_string(dtype): + return string_type + if pa.types.is_timestamp(dtype): + return "DATETIME2" + if pa.types.is_date(dtype): + return "DATE" + if pa.types.is_decimal(dtype): + return f"DECIMAL({dtype.precision},{dtype.scale})" + if pa.types.is_dictionary(dtype): + return pyarrow2sqlserver(dtype=dtype.value_type, string_type=string_type) + if pa.types.is_binary(dtype): + return "VARBINARY" + raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}") + + def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements """Pyarrow to Amazon Timestream data types conversion.""" if pa.types.is_int8(dtype): diff --git a/awswrangler/_databases.py b/awswrangler/_databases.py index 5a1bbb1da..119474835 100644 --- a/awswrangler/_databases.py +++ b/awswrangler/_databases.py @@ -36,7 +36,11 @@ def _get_connection_attributes_from_catalog( details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[ "ConnectionProperties" ] - port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split("/") + if ";databaseName=" in details["JDBC_CONNECTION_URL"]: + database_sep = ";databaseName=" + else: + database_sep = "/" + port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split(database_sep) return ConnectionAttributes( kind=details["JDBC_CONNECTION_URL"].split(":")[1].lower(), user=details["USERNAME"], @@ -136,19 +140,48 @@ def _records2df( return df -def _iterate_cursor( - cursor: Any, +def _get_cols_names(cursor_description: Any) -> List[str]: + cols_names = [col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor_description] + _logger.debug("cols_names: %s", cols_names) + + return cols_names + + +def _iterate_results( + con: Any, + cursor_args: List[Any], chunksize: int, - cols_names: List[str], - index: Optional[Union[str, List[str]]], + index_col: Optional[Union[str, List[str]]], safe: bool, dtype: Optional[Dict[str, pa.DataType]], ) -> Iterator[pd.DataFrame]: - while True: - records = cursor.fetchmany(chunksize) - if not records: - break - yield _records2df(records=records, cols_names=cols_names, index=index, safe=safe, dtype=dtype) + with con.cursor() as cursor: + cursor.execute(*cursor_args) + cols_names = _get_cols_names(cursor.description) + while True: + records = cursor.fetchmany(chunksize) + if not records: + break + yield _records2df(records=records, cols_names=cols_names, index=index_col, safe=safe, dtype=dtype) + + +def _fetch_all_results( + con: Any, + cursor_args: List[Any], + index_col: Optional[Union[str, List[str]]] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> pd.DataFrame: + with con.cursor() as cursor: + cursor.execute(*cursor_args) + cols_names = _get_cols_names(cursor.description) + return _records2df( + records=cast(List[Tuple[Any]], cursor.fetchall()), + cols_names=cols_names, + index=index_col, + dtype=dtype, + safe=safe, + ) def read_sql_query( @@ -163,23 +196,23 @@ def read_sql_query( """Read SQL Query (generic).""" args = _convert_params(sql, params) try: - with con.cursor() as cursor: - cursor.execute(*args) - cols_names: List[str] = [ - col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description - ] - _logger.debug("cols_names: %s", cols_names) - if chunksize is None: - return _records2df( - records=cast(List[Tuple[Any]], cursor.fetchall()), - cols_names=cols_names, - index=index_col, - dtype=dtype, - safe=safe, - ) - return _iterate_cursor( - cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe + if chunksize is None: + return _fetch_all_results( + con=con, + cursor_args=args, + index_col=index_col, + dtype=dtype, + safe=safe, ) + + return _iterate_results( + con=con, + cursor_args=args, + chunksize=chunksize, + index_col=index_col, + dtype=dtype, + safe=safe, + ) except Exception as ex: con.rollback() _logger.error(ex) diff --git a/awswrangler/catalog/_get.py b/awswrangler/catalog/_get.py index fa5471811..f6aab5132 100644 --- a/awswrangler/catalog/_get.py +++ b/awswrangler/catalog/_get.py @@ -524,7 +524,7 @@ def get_connection( client_kms = _utils.client(service_name="kms", session=boto3_session) pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(res["ConnectionProperties"]["ENCRYPTED_PASSWORD"]))[ "Plaintext" - ] + ].decode("utf-8") res["ConnectionProperties"]["PASSWORD"] = pwd return cast(Dict[str, Any], res) diff --git a/awswrangler/mysql.py b/awswrangler/mysql.py index 154aa762b..1dc3d5be9 100644 --- a/awswrangler/mysql.py +++ b/awswrangler/mysql.py @@ -127,7 +127,7 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "mysql": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)") + raise exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)") return pymysql.connect( user=attrs.user, database=attrs.database, @@ -156,8 +156,7 @@ def read_sql_query( sql : str SQL query. con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). params : Union[List, Tuple, Dict], optional @@ -214,8 +213,7 @@ def read_sql_table( table : str Table name. con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query. Uses default schema if None. @@ -276,8 +274,7 @@ def to_sql( df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : pymysql.connections.Connection - Use pymysql.connect() to use " - "credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. + Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog. table : str Table name schema : str diff --git a/awswrangler/postgresql.py b/awswrangler/postgresql.py index d84dde3aa..6a1461079 100644 --- a/awswrangler/postgresql.py +++ b/awswrangler/postgresql.py @@ -131,7 +131,9 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "postgresql": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)") + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)" + ) return pg8000.connect( user=attrs.user, database=attrs.database, @@ -160,8 +162,7 @@ def read_sql_query( sql : str SQL query. con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. index_col : Union[str, List[str]], optional Column(s) to set as index(MultiIndex). params : Union[List, Tuple, Dict], optional @@ -218,8 +219,7 @@ def read_sql_table( table : str Table name. con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. schema : str, optional Name of SQL schema in database to query (if database flavor supports this). Uses default schema if None (default). @@ -280,8 +280,7 @@ def to_sql( df : pandas.DataFrame Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html con : pg8000.Connection - Use pg8000.connect() to use " - "credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. + Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog. table : str Table name schema : str @@ -310,7 +309,7 @@ def to_sql( >>> import awswrangler as wr >>> con = wr.postgresql.connect("MY_GLUE_CONNECTION") >>> wr.postgresql.to_sql( - ... df=df + ... df=df, ... table="my_table", ... schema="public", ... con=con diff --git a/awswrangler/redshift.py b/awswrangler/redshift.py index 323fb6216..71d6f3336 100644 --- a/awswrangler/redshift.py +++ b/awswrangler/redshift.py @@ -386,7 +386,9 @@ def connect( connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session ) if attrs.kind != "redshift": - exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)") + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)" + ) return redshift_connector.connect( user=attrs.user, database=attrs.database, diff --git a/awswrangler/sqlserver.py b/awswrangler/sqlserver.py new file mode 100644 index 000000000..35de8ed8b --- /dev/null +++ b/awswrangler/sqlserver.py @@ -0,0 +1,365 @@ +"""Amazon Microsoft SQL Server Module.""" + + +import importlib +import logging +from typing import Any, Callable, Dict, Iterator, List, Optional, Tuple, TypeVar, Union + +import boto3 +import pandas as pd +import pyarrow as pa + +from awswrangler import _data_types +from awswrangler import _databases as _db_utils +from awswrangler import exceptions + +__all__ = ["connect", "read_sql_query", "read_sql_table", "to_sql"] + +_pyodbc_found = importlib.util.find_spec("pyodbc") +if _pyodbc_found: + import pyodbc + +_logger: logging.Logger = logging.getLogger(__name__) +FuncT = TypeVar("FuncT", bound=Callable[..., Any]) + + +def _check_for_pyodbc(func: FuncT) -> FuncT: + def inner(*args: Any, **kwargs: Any) -> Any: + if not _pyodbc_found: + raise ModuleNotFoundError( + "You need to install pyodbc respectively the " + "AWS Data Wrangler package with the `sqlserver` extra for using the sqlserver module" + ) + return func(*args, **kwargs) + + return inner # type: ignore + + +def _validate_connection(con: "pyodbc.Connection") -> None: + if not isinstance(con, pyodbc.Connection): + raise exceptions.InvalidConnection( + "Invalid 'conn' argument, please pass a " + "pyodbc.Connection object. Use pyodbc.connect() to use " + "credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog." + ) + + +def _get_table_identifier(schema: Optional[str], table: str) -> str: + schema_str = f'"{schema}".' if schema else "" + table_identifier = f'{schema_str}"{table}"' + return table_identifier + + +def _drop_table(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> None: + table_identifier = _get_table_identifier(schema, table) + sql = f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NOT NULL DROP TABLE {table_identifier}" + _logger.debug("Drop table query:\n%s", sql) + cursor.execute(sql) + + +def _does_table_exist(cursor: "pyodbc.Cursor", schema: Optional[str], table: str) -> bool: + schema_str = f"TABLE_SCHEMA = '{schema}' AND" if schema else "" + cursor.execute(f"SELECT * FROM INFORMATION_SCHEMA.TABLES WHERE " f"{schema_str} TABLE_NAME = '{table}'") + return len(cursor.fetchall()) > 0 + + +def _create_table( + df: pd.DataFrame, + cursor: "pyodbc.Cursor", + table: str, + schema: str, + mode: str, + index: bool, + dtype: Optional[Dict[str, str]], + varchar_lengths: Optional[Dict[str, int]], +) -> None: + if mode == "overwrite": + _drop_table(cursor=cursor, schema=schema, table=table) + elif _does_table_exist(cursor=cursor, schema=schema, table=table): + return + sqlserver_types: Dict[str, str] = _data_types.database_types_from_pandas( + df=df, + index=index, + dtype=dtype, + varchar_lengths_default="VARCHAR(MAX)", + varchar_lengths=varchar_lengths, + converter_func=_data_types.pyarrow2sqlserver, + ) + cols_str: str = "".join([f"{k} {v},\n" for k, v in sqlserver_types.items()])[:-2] + table_identifier = _get_table_identifier(schema, table) + sql = ( + f"IF OBJECT_ID(N'{table_identifier}', N'U') IS NULL BEGIN CREATE TABLE {table_identifier} (\n{cols_str}); END;" + ) + _logger.debug("Create table query:\n%s", sql) + cursor.execute(sql) + + +@_check_for_pyodbc +def connect( + connection: Optional[str] = None, + secret_id: Optional[str] = None, + catalog_id: Optional[str] = None, + dbname: Optional[str] = None, + odbc_driver_version: int = 17, + boto3_session: Optional[boto3.Session] = None, + timeout: Optional[int] = 0, +) -> "pyodbc.Connection": + """Return a pyodbc connection from a Glue Catalog Connection. + + https://github.com/mkleehammer/pyodbc + + Parameters + ---------- + connection : Optional[str] + Glue Catalog Connection name. + secret_id: Optional[str]: + Specifies the secret containing the version that you want to retrieve. + You can specify either the Amazon Resource Name (ARN) or the friendly name of the secret. + catalog_id : str, optional + The ID of the Data Catalog. + If none is provided, the AWS account ID is used by default. + dbname: Optional[str] + Optional database name to overwrite the stored one. + odbc_driver_version : int + Major version of the OBDC Driver version that is installed and should be used. + boto3_session : boto3.Session(), optional + Boto3 Session. The default boto3 session will be used if boto3_session receive None. + timeout: Optional[int] + This is the time in seconds before the connection to the server will time out. + The default is None which means no timeout. + This parameter is forwarded to pyodbc. + https://github.com/mkleehammer/pyodbc/wiki/The-pyodbc-Module#connect + + Returns + ------- + pyodbc.Connection + pyodbc connection. + + Examples + -------- + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) + >>> with con.cursor() as cursor: + >>> cursor.execute("SELECT 1") + >>> print(cursor.fetchall()) + >>> con.close() + + """ + attrs: _db_utils.ConnectionAttributes = _db_utils.get_connection_attributes( + connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session + ) + if attrs.kind != "sqlserver": + raise exceptions.InvalidDatabaseType( + f"Invalid connection type ({attrs.kind}. It must be a sqlserver connection.)" + ) + connection_str = ( + f"DRIVER={{ODBC Driver {odbc_driver_version} for SQL Server}};" + f"SERVER={attrs.host},{attrs.port};" + f"DATABASE={attrs.database};" + f"UID={attrs.user};" + f"PWD={attrs.password}" + ) + + return pyodbc.connect(connection_str, timeout=timeout) + + +@_check_for_pyodbc +def read_sql_query( + sql: str, + con: "pyodbc.Connection", + index_col: Optional[Union[str, List[str]]] = None, + params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, + chunksize: Optional[int] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Return a DataFrame corresponding to the result set of the query string. + + Parameters + ---------- + sql : str + SQL query. + con : pyodbc.Connection + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + index_col : Union[str, List[str]], optional + Column(s) to set as index(MultiIndex). + params : Union[List, Tuple, Dict], optional + List of parameters to pass to execute method. + The syntax used to pass parameters is database driver dependent. + Check your database driver documentation for which of the five syntax styles, + described in PEP 249’s paramstyle, is supported. + chunksize : int, optional + If specified, return an iterator where chunksize is the number of rows to include in each chunk. + dtype : Dict[str, pyarrow.DataType], optional + Specifying the datatype for columns. + The keys should be the column names and the values should be the PyArrow types. + safe : bool + Check for overflows or other unsafe data type conversions. + + Returns + ------- + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] + Result as Pandas DataFrame(s). + + Examples + -------- + Reading from Microsoft SQL Server using a Glue Catalog Connections + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) + >>> df = wr.sqlserver.read_sql_query( + ... sql="SELECT * FROM dbo.my_table", + ... con=con + ... ) + >>> con.close() + """ + _validate_connection(con=con) + return _db_utils.read_sql_query( + sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe + ) + + +@_check_for_pyodbc +def read_sql_table( + table: str, + con: "pyodbc.Connection", + schema: Optional[str] = None, + index_col: Optional[Union[str, List[str]]] = None, + params: Optional[Union[List[Any], Tuple[Any, ...], Dict[Any, Any]]] = None, + chunksize: Optional[int] = None, + dtype: Optional[Dict[str, pa.DataType]] = None, + safe: bool = True, +) -> Union[pd.DataFrame, Iterator[pd.DataFrame]]: + """Return a DataFrame corresponding the table. + + Parameters + ---------- + table : str + Table name. + con : pyodbc.Connection + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + schema : str, optional + Name of SQL schema in database to query (if database flavor supports this). + Uses default schema if None (default). + index_col : Union[str, List[str]], optional + Column(s) to set as index(MultiIndex). + params : Union[List, Tuple, Dict], optional + List of parameters to pass to execute method. + The syntax used to pass parameters is database driver dependent. + Check your database driver documentation for which of the five syntax styles, + described in PEP 249’s paramstyle, is supported. + chunksize : int, optional + If specified, return an iterator where chunksize is the number of rows to include in each chunk. + dtype : Dict[str, pyarrow.DataType], optional + Specifying the datatype for columns. + The keys should be the column names and the values should be the PyArrow types. + safe : bool + Check for overflows or other unsafe data type conversions. + + Returns + ------- + Union[pandas.DataFrame, Iterator[pandas.DataFrame]] + Result as Pandas DataFrame(s). + + Examples + -------- + Reading from Microsoft SQL Server using a Glue Catalog Connections + + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) + >>> df = wr.sqlserver.read_sql_table( + ... table="my_table", + ... schema="dbo", + ... con=con + ... ) + >>> con.close() + """ + table_identifier = _get_table_identifier(schema, table) + sql: str = f"SELECT * FROM {table_identifier}" + return read_sql_query( + sql=sql, con=con, index_col=index_col, params=params, chunksize=chunksize, dtype=dtype, safe=safe + ) + + +@_check_for_pyodbc +def to_sql( + df: pd.DataFrame, + con: "pyodbc.Connection", + table: str, + schema: str, + mode: str = "append", + index: bool = False, + dtype: Optional[Dict[str, str]] = None, + varchar_lengths: Optional[Dict[str, int]] = None, +) -> None: + """Write records stored in a DataFrame into Microsoft SQL Server. + + Parameters + ---------- + df : pandas.DataFrame + Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html + con : pyodbc.Connection + Use pyodbc.connect() to use credentials directly or wr.sqlserver.connect() to fetch it from the Glue Catalog. + table : str + Table name + schema : str + Schema name + mode : str + Append or overwrite. + index : bool + True to store the DataFrame index as a column in the table, + otherwise False to ignore it. + dtype: Dict[str, str], optional + Dictionary of columns names and Microsoft SQL Server types to be casted. + Useful when you have columns with undetermined or mixed data types. + (e.g. {'col name': 'TEXT', 'col2 name': 'FLOAT'}) + varchar_lengths : Dict[str, int], optional + Dict of VARCHAR length by columns. (e.g. {"col1": 10, "col5": 200}). + + Returns + ------- + None + None. + + Examples + -------- + Writing to Microsoft SQL Server using a Glue Catalog Connections + + >>> import awswrangler as wr + >>> con = wr.sqlserver.connect(connection="MY_GLUE_CONNECTION", odbc_driver_version=17) + >>> wr.sqlserver.to_sql( + ... df=df, + ... table="table", + ... schema="dbo", + ... con=con + ... ) + >>> con.close() + + """ + if df.empty is True: + raise exceptions.EmptyDataFrame() + _validate_connection(con=con) + try: + with con.cursor() as cursor: + _create_table( + df=df, + cursor=cursor, + table=table, + schema=schema, + mode=mode, + index=index, + dtype=dtype, + varchar_lengths=varchar_lengths, + ) + if index: + df.reset_index(level=df.index.names, inplace=True) + placeholders: str = ", ".join(["?"] * len(df.columns)) + table_identifier = _get_table_identifier(schema, table) + sql: str = f"INSERT INTO {table_identifier} VALUES ({placeholders})" + _logger.debug("sql: %s", sql) + parameters: List[List[Any]] = _db_utils.extract_parameters(df=df) + cursor.executemany(sql, parameters) + con.commit() + except Exception as ex: + con.rollback() + _logger.error(ex) + raise diff --git a/cloudformation/databases.yaml b/cloudformation/databases.yaml index 0590c622b..c6351fe3a 100644 --- a/cloudformation/databases.yaml +++ b/cloudformation/databases.yaml @@ -1,6 +1,6 @@ AWSTemplateFormatVersion: 2010-09-09 Description: | - AWS Data Wrangler Development Databases Infrastructure Redshift, Aurora PostgreSQL, Aurora MySQL + AWS Data Wrangler Development Databases Infrastructure Redshift, Aurora PostgreSQL, Aurora MySQL, Microsoft SQL Server Parameters: DatabasesPassword: Type: String @@ -136,7 +136,7 @@ Resources: SubnetIds: - Fn::ImportValue: aws-data-wrangler-base-PublicSubnet1 - Fn::ImportValue: aws-data-wrangler-base-PublicSubnet2 - AuroraRole: + RdsRole: Type: AWS::IAM::Role Properties: Tags: @@ -205,7 +205,7 @@ Resources: - FeatureName: s3Import RoleArn: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraInstancePostgresql: Type: AWS::RDS::DBInstance @@ -234,15 +234,15 @@ Resources: Parameters: aurora_load_from_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn aws_default_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn aurora_select_into_s3_role: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraClusterMysql: Type: AWS::RDS::DBCluster @@ -268,7 +268,7 @@ Resources: AssociatedRoles: - RoleArn: Fn::GetAtt: - - AuroraRole + - RdsRole - Arn AuroraInstanceMysql: Type: AWS::RDS::DBInstance @@ -286,6 +286,32 @@ Resources: DBSubnetGroupName: Ref: RdsSubnetGroup PubliclyAccessible: true + SqlServerInstance: + Type: AWS::RDS::DBInstance + DeletionPolicy: Delete + Properties: + Tags: + - Key: Env + Value: aws-data-wrangler + Engine: sqlserver-ex + EngineVersion: '15.00' + DBInstanceIdentifier: sqlserver-instance-wrangler + DBInstanceClass: db.t3.small + AllocatedStorage: '20' + MasterUsername: test + MasterUserPassword: + Ref: DatabasesPassword + DBSubnetGroupName: + Ref: RdsSubnetGroup + VPCSecurityGroups: + - Ref: DatabaseSecurityGroup + PubliclyAccessible: true + AssociatedRoles: + - RoleArn: + Fn::GetAtt: + - RdsRole + - Arn + FeatureName: S3_INTEGRATION RedshiftGlueConnection: Type: AWS::Glue::Connection Properties: @@ -358,6 +384,30 @@ Resources: PASSWORD: Ref: DatabasesPassword Name: aws-data-wrangler-mysql + SqlServerGlueConnection: + Type: AWS::Glue::Connection + Properties: + CatalogId: + Ref: AWS::AccountId + ConnectionInput: + Description: Connect to SQL Server. + ConnectionType: JDBC + PhysicalConnectionRequirements: + AvailabilityZone: + Fn::Select: + - 0 + - Fn::GetAZs: '' + SecurityGroupIdList: + - Ref: DatabaseSecurityGroup + SubnetId: + Fn::ImportValue: aws-data-wrangler-base-PrivateSubnet + ConnectionProperties: + JDBC_CONNECTION_URL: + Fn::Sub: jdbc:sqlserver://${SqlServerInstance.Endpoint.Address}:${SqlServerInstance.Endpoint.Port};databaseName=test + USERNAME: test + PASSWORD: + Ref: DatabasesPassword + Name: aws-data-wrangler-sqlserver GlueCatalogSettings: Type: AWS::Glue::DataCatalogEncryptionSettings Properties: @@ -388,7 +438,7 @@ Resources: Tags: - Key: Env Value: aws-data-wrangler - postgresqlSecret: + PostgresqlSecret: Type: AWS::SecretsManager::Secret Properties: Name: aws-data-wrangler/postgresql @@ -398,7 +448,7 @@ Resources: { "username": "test", "password": "${DatabasesPassword}", - "engine": "postgres", + "engine": "postgresql", "host": "${AuroraInstancePostgresql.Endpoint.Address}", "port": ${AuroraInstancePostgresql.Endpoint.Port}, "dbClusterIdentifier": "${AuroraInstancePostgresql}", @@ -417,7 +467,7 @@ Resources: { "username": "test", "password": "${DatabasesPassword}", - "engine": "postgres", + "engine": "mysql", "host": "${AuroraInstanceMysql.Endpoint.Address}", "port": ${AuroraInstanceMysql.Endpoint.Port}, "dbClusterIdentifier": "${AuroraInstanceMysql}", @@ -426,6 +476,25 @@ Resources: Tags: - Key: Env Value: aws-data-wrangler + SqlServerSecret: + Type: AWS::SecretsManager::Secret + Properties: + Name: aws-data-wrangler/sqlserver + Description: SQL Server credentials + SecretString: + Fn::Sub: | + { + "username": "test", + "password": "${DatabasesPassword}", + "engine": "sqlserver", + "host": "${SqlServerInstance.Endpoint.Address}", + "port": ${SqlServerInstance.Endpoint.Port}, + "dbClusterIdentifier": "${SqlServerInstance}", + "dbname": "test" + } + Tags: + - Key: Env + Value: aws-data-wrangler Outputs: DatabasesPassword: Value: @@ -464,6 +533,12 @@ Outputs: - AuroraInstanceMysql - Endpoint.Address Description: Mysql Address + SqlServerAddress: + Value: + Fn::GetAtt: + - SqlServerInstance + - Endpoint.Address + Description: SQL Server Address DatabaseSecurityGroupId: Value: Fn::GetAtt: diff --git a/docs/source/api.rst b/docs/source/api.rst index 443a3c2a0..d9af29819 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -7,6 +7,7 @@ API Reference * `Amazon Redshift`_ * `PostgreSQL`_ * `MySQL`_ +* `Microsoft SQL Server`_ * `DynamoDB`_ * `Amazon Timestream`_ * `Amazon EMR`_ @@ -145,6 +146,19 @@ MySQL .. currentmodule:: awswrangler.mysql +.. autosummary:: + :toctree: stubs + + connect + read_sql_query + read_sql_table + to_sql + +Microsoft SQL Server +____________________ + +.. currentmodule:: awswrangler.sqlserver + .. autosummary:: :toctree: stubs diff --git a/docs/source/install.rst b/docs/source/install.rst index dc54b5f0a..25674dab8 100644 --- a/docs/source/install.rst +++ b/docs/source/install.rst @@ -9,6 +9,8 @@ Some good practices for most of the methods bellow are: - Use new and individual Virtual Environments for each project (`venv `_). - On Notebooks, always restart your kernel after installations. +.. note:: If you want to use ``awswrangler`` for connecting to Microsoft SQL Server, some additional configuration is needed. Please have a look at the corresponding section below. + PyPI (pip) ---------- @@ -150,3 +152,34 @@ From Source >>> git clone https://github.com/awslabs/aws-data-wrangler.git >>> cd aws-data-wrangler >>> pip install . + + +Notes for Microsoft SQL Server +------------------------------ + +``awswrangler`` is using the `pyodbc `_ +for interacting with Microsoft SQL Server. For installing this package you need the ODBC header files, +which can be installed, for example, with the following commands: + + >>> sudo apt install unixodbc-dev + >>> yum install unixODBC-devel + +After installing these header files you can either just install ``pyodbc`` or +``awswrangler`` with the ``sqlserver`` extra, which will also install ``pyodbc``: + + >>> pip install pyodbc + >>> pip install awswrangler[sqlserver] + +Finally you also the need correct ODBC Driver for SQL Server. You can have a look at the +`documentation from Microsoft `_ +to see how they can be installed in your environment. + +If you want to connect to Microsoft SQL Server from AWS Lambda, you can build a separate Layer including the +needed OBDC drivers and `pyobdc` by following +`these instructions `_. + +If you maintain your own environment, you need to take care of the above steps. +Because of this limitation usage in combination with Glue jobs is limited and you need to rely on the +provided `functionality inside Glue itself `_. diff --git a/requirements-dev.txt b/requirements-dev.txt index 75caad733..f87389ca7 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -20,4 +20,5 @@ moto==1.3.16 jupyterlab==3.0.0 jupyter==1.0.0 s3fs==0.4.2 +pyodbc~=4.0.30 -e . diff --git a/setup.py b/setup.py index 6bdfe1b97..4159fff33 100644 --- a/setup.py +++ b/setup.py @@ -31,4 +31,5 @@ "Programming Language :: Python :: 3.7", "Programming Language :: Python :: 3.8", ], + extras_require={"sqlserver": ["pyodbc~=4.0.30"]}, ) diff --git a/tests/_utils.py b/tests/_utils.py index b51dbcbd4..85df69484 100644 --- a/tests/_utils.py +++ b/tests/_utils.py @@ -25,7 +25,7 @@ def get_df(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -56,7 +56,7 @@ def get_df_list(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -90,7 +90,7 @@ def get_df_cast(): "iint32": [None, None, None], "iint64": [None, None, None], "float": [None, None, None], - "double": [None, None, None], + "ddouble": [None, None, None], "decimal": [None, None, None], "string": [None, None, None], "date": [None, None, dt("2020-01-02")], @@ -201,7 +201,7 @@ def get_df_quicksight(): "iint32": [1, None, 2], "iint64": [1, None, 2], "float": [0.0, None, 1.1], - "double": [0.0, None, 1.1], + "ddouble": [0.0, None, 1.1], "decimal": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], "string_object": ["foo", None, "boo"], "string": ["foo", None, "boo"], @@ -425,7 +425,7 @@ def ensure_data_types(df, has_list=False): assert str(df["iint32"].dtype).startswith("Int") assert str(df["iint64"].dtype) == "Int64" assert str(df["float"].dtype).startswith("float") - assert str(df["double"].dtype) == "float64" + assert str(df["ddouble"].dtype) == "float64" assert str(df["decimal"].dtype) == "object" if "string_object" in df.columns: assert str(df["string_object"].dtype) == "string" diff --git a/tests/conftest.py b/tests/conftest.py index 39eaef776..011fccfca 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -129,7 +129,7 @@ def workgroup3(bucket, kms_key): @pytest.fixture(scope="session") def databases_parameters(cloudformation_outputs): - parameters = dict(postgresql={}, mysql={}, redshift={}) + parameters = dict(postgresql={}, mysql={}, redshift={}, sqlserver={}) parameters["postgresql"]["host"] = cloudformation_outputs["PostgresqlAddress"] parameters["postgresql"]["port"] = 3306 parameters["postgresql"]["schema"] = "public" @@ -146,6 +146,10 @@ def databases_parameters(cloudformation_outputs): parameters["redshift"]["role"] = cloudformation_outputs["RedshiftRole"] parameters["password"] = cloudformation_outputs["DatabasesPassword"] parameters["user"] = "test" + parameters["sqlserver"]["host"] = cloudformation_outputs["SqlServerAddress"] + parameters["sqlserver"]["port"] = 1433 + parameters["sqlserver"]["schema"] = "dbo" + parameters["sqlserver"]["database"] = "test" return parameters @@ -236,6 +240,18 @@ def mysql_table(): con.close() +@pytest.fixture(scope="function") +def sqlserver_table(): + name = f"tbl_{get_time_str_with_random_suffix()}" + print(f"Table name: {name}") + yield name + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + with con.cursor() as cursor: + cursor.execute(f"IF OBJECT_ID(N'dbo.{name}', N'U') IS NOT NULL DROP TABLE dbo.{name}") + con.commit() + con.close() + + @pytest.fixture(scope="function") def timestream_database_and_table(): name = f"tbl_{get_time_str_with_random_suffix()}" diff --git a/tests/test_athena.py b/tests/test_athena.py index a8657569e..8aedc8bb5 100644 --- a/tests/test_athena.py +++ b/tests/test_athena.py @@ -163,7 +163,7 @@ def test_athena(path, glue_database, glue_table, kms_key, workgroup0, workgroup1 f" `iint32` int," f" `iint64` bigint," f" `float` float," - f" `double` double," + f" `ddouble` double," f" `decimal` decimal(3,2)," f" `string_object` string," f" `string` string," diff --git a/tests/test_athena_parquet.py b/tests/test_athena_parquet.py index 348492b34..cfde1c760 100644 --- a/tests/test_athena_parquet.py +++ b/tests/test_athena_parquet.py @@ -112,7 +112,7 @@ def test_parquet_catalog_casting(path, glue_database): "iint32": "int", "iint64": "bigint", "float": "float", - "double": "double", + "ddouble": "double", "decimal": "decimal(3,2)", "string": "string", "date": "date", @@ -408,7 +408,7 @@ def test_parquet_catalog_casting_to_string(path, glue_table, glue_database): "iint32": "string", "iint64": "string", "float": "string", - "double": "string", + "ddouble": "string", "decimal": "string", "string": "string", "date": "string", diff --git a/tests/test_mysql.py b/tests/test_mysql.py index b7746e354..6cccbcfb6 100644 --- a/tests/test_mysql.py +++ b/tests/test_mysql.py @@ -63,7 +63,7 @@ def test_sql_types(mysql_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_postgresql.py b/tests/test_postgresql.py index 0d00e0a4e..c697a3685 100644 --- a/tests/test_postgresql.py +++ b/tests/test_postgresql.py @@ -63,7 +63,7 @@ def test_sql_types(postgresql_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_redshift.py b/tests/test_redshift.py index 528ac38ad..160fa42dc 100644 --- a/tests/test_redshift.py +++ b/tests/test_redshift.py @@ -67,7 +67,7 @@ def test_sql_types(redshift_table): "iint32": pa.int32(), "iint64": pa.int64(), "float": pa.float32(), - "double": pa.float64(), + "ddouble": pa.float64(), "decimal": pa.decimal128(3, 2), "string_object": pa.string(), "string": pa.string(), diff --git a/tests/test_sqlserver.py b/tests/test_sqlserver.py new file mode 100644 index 000000000..26ee846c9 --- /dev/null +++ b/tests/test_sqlserver.py @@ -0,0 +1,205 @@ +import logging +from decimal import Decimal + +import pandas as pd +import pyarrow as pa +import pyodbc +import pytest + +import awswrangler as wr + +from ._utils import ensure_data_types, get_df + +logging.getLogger("awswrangler").setLevel(logging.DEBUG) + + +@pytest.fixture(scope="module", autouse=True) +def create_sql_server_database(databases_parameters): + connection_str = ( + f"DRIVER={{ODBC Driver 17 for SQL Server}};" + f"SERVER={databases_parameters['sqlserver']['host']},{databases_parameters['sqlserver']['port']};" + f"UID={databases_parameters['user']};" + f"PWD={databases_parameters['password']}" + ) + con = pyodbc.connect(connection_str, autocommit=True) + sql_create_db = ( + f"IF NOT EXISTS(SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " + "BEGIN " + f"CREATE DATABASE {databases_parameters['sqlserver']['database']} " + "END" + ) + with con.cursor() as cursor: + cursor.execute(sql_create_db) + con.commit() + + yield + + sql_drop_db = ( + f"IF EXISTS (SELECT * FROM sys.databases WHERE name = '{databases_parameters['sqlserver']['database']}') " + "BEGIN " + "USE master; " + f"ALTER DATABASE {databases_parameters['sqlserver']['database']} SET SINGLE_USER WITH ROLLBACK IMMEDIATE; " + f"DROP DATABASE {databases_parameters['sqlserver']['database']}; " + "END" + ) + with con.cursor() as cursor: + cursor.execute(sql_drop_db) + con.commit() + + con.close() + + +@pytest.fixture(scope="function") +def sqlserver_con(): + con = wr.sqlserver.connect("aws-data-wrangler-sqlserver") + yield con + con.close() + + +def test_connection(): + wr.sqlserver.connect("aws-data-wrangler-sqlserver", timeout=10).close() + + +def test_read_sql_query_simple(databases_parameters, sqlserver_con): + df = wr.sqlserver.read_sql_query("SELECT 1", con=sqlserver_con) + assert df.shape == (1, 1) + + +def test_to_sql_simple(sqlserver_table, sqlserver_con): + df = pd.DataFrame({"c0": [1, 2, 3], "c1": ["foo", "boo", "bar"]}) + wr.sqlserver.to_sql(df, sqlserver_con, sqlserver_table, "dbo", "overwrite", True) + + +def test_sql_types(sqlserver_table, sqlserver_con): + table = sqlserver_table + df = get_df() + df.drop(["binary"], axis=1, inplace=True) + wr.sqlserver.to_sql( + df=df, + con=sqlserver_con, + table=table, + schema="dbo", + mode="overwrite", + index=True, + dtype={"iint32": "INTEGER"}, + ) + df = wr.sqlserver.read_sql_query(f"SELECT * FROM dbo.{table}", sqlserver_con) + ensure_data_types(df, has_list=False) + dfs = wr.sqlserver.read_sql_query( + sql=f"SELECT * FROM dbo.{table}", + con=sqlserver_con, + chunksize=1, + dtype={ + "iint8": pa.int8(), + "iint16": pa.int16(), + "iint32": pa.int32(), + "iint64": pa.int64(), + "float": pa.float32(), + "ddouble": pa.float64(), + "decimal": pa.decimal128(3, 2), + "string_object": pa.string(), + "string": pa.string(), + "date": pa.date32(), + "timestamp": pa.timestamp(unit="ns"), + "binary": pa.binary(), + "category": pa.float64(), + }, + ) + for df in dfs: + ensure_data_types(df, has_list=False) + + +def test_to_sql_cast(sqlserver_table, sqlserver_con): + table = sqlserver_table + df = pd.DataFrame( + { + "col": [ + "".join([str(i)[-1] for i in range(1_024)]), + "".join([str(i)[-1] for i in range(1_024)]), + "".join([str(i)[-1] for i in range(1_024)]), + ] + }, + dtype="string", + ) + wr.sqlserver.to_sql( + df=df, + con=sqlserver_con, + table=table, + schema="dbo", + mode="overwrite", + index=False, + dtype={"col": "VARCHAR(1024)"}, + ) + df2 = wr.sqlserver.read_sql_query(sql=f"SELECT * FROM dbo.{table}", con=sqlserver_con) + assert df.equals(df2) + + +def test_null(sqlserver_table, sqlserver_con): + table = sqlserver_table + df = pd.DataFrame({"id": [1, 2, 3], "nothing": [None, None, None]}) + wr.sqlserver.to_sql( + df=df, + con=sqlserver_con, + table=table, + schema="dbo", + mode="overwrite", + index=False, + dtype={"nothing": "INTEGER"}, + ) + wr.sqlserver.to_sql( + df=df, + con=sqlserver_con, + table=table, + schema="dbo", + mode="append", + index=False, + ) + df2 = wr.sqlserver.read_sql_table(table=table, schema="dbo", con=sqlserver_con) + df["id"] = df["id"].astype("Int64") + assert pd.concat(objs=[df, df], ignore_index=True).equals(df2) + + +def test_decimal_cast(sqlserver_table, sqlserver_con): + table = sqlserver_table + df = pd.DataFrame( + { + "col0": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + "col1": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + "col2": [Decimal((0, (1, 9, 9), -2)), None, Decimal((0, (1, 9, 0), -2))], + } + ) + wr.sqlserver.to_sql(df, sqlserver_con, table, "dbo") + df2 = wr.sqlserver.read_sql_table( + schema="dbo", table=table, con=sqlserver_con, dtype={"col0": "float32", "col1": "float64", "col2": "Int64"} + ) + assert df2.dtypes.to_list() == ["float32", "float64", "Int64"] + assert 3.88 <= df2.col0.sum() <= 3.89 + assert 3.88 <= df2.col1.sum() <= 3.89 + assert df2.col2.sum() == 2 + + +def test_read_retry(sqlserver_con): + try: + wr.sqlserver.read_sql_query("ERROR", sqlserver_con) + except: # noqa + pass + df = wr.sqlserver.read_sql_query("SELECT 1", sqlserver_con) + assert df.shape == (1, 1) + + +def test_table_name(sqlserver_con): + df = pd.DataFrame({"col0": [1]}) + wr.sqlserver.to_sql(df, sqlserver_con, "Test Name", "dbo", mode="overwrite") + df = wr.sqlserver.read_sql_table(schema="dbo", con=sqlserver_con, table="Test Name") + assert df.shape == (1, 1) + with sqlserver_con.cursor() as cursor: + cursor.execute('DROP TABLE "Test Name"') + sqlserver_con.commit() + + +@pytest.mark.parametrize("dbname", [None, "test"]) +def test_connect_secret_manager(dbname): + con = wr.sqlserver.connect(secret_id="aws-data-wrangler/sqlserver", dbname=dbname) + df = wr.sqlserver.read_sql_query("SELECT 1", con=con) + con.close() + assert df.shape == (1, 1) diff --git a/tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb b/tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb similarity index 85% rename from tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb rename to tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb index 6da12b68f..8a5e36255 100644 --- a/tutorials/007 - Redshift, MySQL, PostgreSQL.ipynb +++ b/tutorials/007 - Redshift, MySQL, PostgreSQL, SQL Server.ipynb @@ -6,7 +6,7 @@ "source": [ "[![AWS Data Wrangler](_static/logo.png \"AWS Data Wrangler\")](https://github.com/awslabs/aws-data-wrangler)\n", "\n", - "# 7 - Redshift, MySQL and PostgreSQL\n", + "# 7 - Redshift, MySQL, PostgreSQL and SQL Server\n", "\n", "[Wrangler](https://github.com/awslabs/aws-data-wrangler)'s Redshift, MySQL and PostgreSQL have two basic function in common that tries to follow the Pandas conventions, but add more data type consistency.\n", "\n", @@ -15,7 +15,9 @@ "- [wr.mysql.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.to_sql.html)\n", "- [wr.mysql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.read_sql_query.html)\n", "- [wr.postgresql.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.to_sql.html)\n", - "- [wr.postgresql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.read_sql_query.html)" + "- [wr.postgresql.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.read_sql_query.html)\n", + "- [wr.sqlserver.to_sql()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.to_sql.html)\n", + "- [wr.sqlserver.read_sql_query()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.read_sql_query.html)" ] }, { @@ -41,7 +43,8 @@ "\n", "- [wr.redshift.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.redshift.connect.html)\n", "- [wr.mysql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.mysql.connect.html)\n", - "- [wr.postgresql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.connect.html)" + "- [wr.postgresql.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.postgresql.connect.html)\n", + "- [wr.sqlserver.connect()](https://aws-data-wrangler.readthedocs.io/en/stable/stubs/awswrangler.sqlserver.connect.html)" ] }, { @@ -52,7 +55,8 @@ "source": [ "con_redshift = wr.redshift.connect(\"aws-data-wrangler-redshift\")\n", "con_mysql = wr.mysql.connect(\"aws-data-wrangler-mysql\")\n", - "con_postgresql = wr.postgresql.connect(\"aws-data-wrangler-postgresql\")" + "con_postgresql = wr.postgresql.connect(\"aws-data-wrangler-postgresql\")\n", + "con_sqlserver = wr.sqlserver.connect(\"aws-data-wrangler-sqlserver\")" ] }, { @@ -96,7 +100,8 @@ "source": [ "wr.redshift.to_sql(df, con_redshift, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")\n", "wr.mysql.to_sql(df, con_mysql, schema=\"test\", table=\"tutorial\", mode=\"overwrite\")\n", - "wr.postgresql.to_sql(df, con_postgresql, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")" + "wr.postgresql.to_sql(df, con_postgresql, schema=\"public\", table=\"tutorial\", mode=\"overwrite\")\n", + "wr.sqlserver.to_sql(df, con_sqlserver, schema=\"dbo\", table=\"tutorial\", mode=\"overwrite\")" ] }, { @@ -165,7 +170,8 @@ "source": [ "wr.redshift.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_redshift)\n", "wr.mysql.read_sql_query(\"SELECT * FROM test.tutorial\", con=con_mysql)\n", - "wr.postgresql.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_postgresql)" + "wr.postgresql.read_sql_query(\"SELECT * FROM public.tutorial\", con=con_postgresql)\n", + "wr.sqlserver.read_sql_query(\"SELECT * FROM dbo.tutorial\", con=con_sqlserver)" ] }, { @@ -176,7 +182,8 @@ "source": [ "con_redshift.close()\n", "con_mysql.close()\n", - "con_postgresql.close()" + "con_postgresql.close()\n", + "con_sqlserver.close()" ] } ], @@ -197,17 +204,8 @@ "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.10" - }, - "pycharm": { - "stem_cell": { - "cell_type": "raw", - "metadata": { - "collapsed": false - }, - "source": [] - } } }, "nbformat": 4, "nbformat_minor": 4 -} +} \ No newline at end of file