Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
838634c
WIP: Add support for SQL Server
maxispeicher Dec 31, 2020
30c1f03
WIP: SQL Server feature complete
maxispeicher Dec 31, 2020
0ecadbc
WIP: Adapt databases cfn template
maxispeicher Dec 31, 2020
0e02973
WIP: Add docstrings and formatting
maxispeicher Dec 31, 2020
2f749a9
Fix raising of exceptions
maxispeicher Dec 31, 2020
e58b006
Adapt README and documentation
maxispeicher Dec 31, 2020
ec4b7b1
Decode password to string
maxispeicher Dec 31, 2020
80afc70
WIP: Fix SQLServer tests
maxispeicher Dec 31, 2020
1ce7c31
WIP: Fix cfn template
maxispeicher Dec 31, 2020
256e8c0
Fix tests for Linux
maxispeicher Jan 1, 2021
eb18b55
Add missing ;
maxispeicher Jan 1, 2021
878ee71
Merge branch 'master' into add-sql-server-support
maxispeicher Jan 2, 2021
721baec
Swap from pymssql to pyodbc
maxispeicher Jan 2, 2021
f1f3694
Dynamically import pyodbc
maxispeicher Jan 2, 2021
c5ea2ef
Merge branch 'master' into add-sql-server-support
maxispeicher Jan 3, 2021
0480992
Add pyodbc to Lambda layer
maxispeicher Jan 3, 2021
ad3253e
Merge branch 'master' into swap-to-pyodbc
maxispeicher Jan 3, 2021
e892fc7
Fix for 3.6 and 3.7
maxispeicher Jan 3, 2021
6955b54
Fix for 3.8
maxispeicher Jan 3, 2021
4f7edd3
Update documentation
maxispeicher Jan 4, 2021
2bc1c8f
Revert changes to Lambda layer build
maxispeicher Jan 4, 2021
c2bde90
Merge branch 'swap-to-pyodbc' into add-sql-server-support
maxispeicher Jan 4, 2021
d4b1d89
Merge branch 'master' into add-sql-server-support
maxispeicher Jan 4, 2021
d869eeb
Fix formatting in docstrings
maxispeicher Jan 4, 2021
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# A comma-separated list of package or module names from where C extensions may
# be loaded. Extensions are loading into the active Python interpreter and may
# run arbitrary code.
extension-pkg-whitelist=pyarrow.lib
extension-pkg-whitelist=pyarrow.lib,pyodbc

# Specify a score threshold to be exceeded before program exits with error.
fail-under=10
Expand Down
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3
- [004 - Parquet Datasets](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/004%20-%20Parquet%20Datasets.ipynb)
- [005 - Glue Catalog](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/005%20-%20Glue%20Catalog.ipynb)
- [006 - Amazon Athena](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/006%20-%20Amazon%20Athena.ipynb)
- [007 - Databases (Redshift, MySQL and PostgreSQL)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL.ipynb)
- [007 - Databases (Redshift, MySQL, PostgreSQL and SQL Server)](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/007%20-%20Redshift%2C%20MySQL%2C%20PostgreSQL%2C%20SQL%20Server.ipynb)
- [008 - Redshift - Copy & Unload.ipynb](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/008%20-%20Redshift%20-%20Copy%20%26%20Unload.ipynb)
- [009 - Redshift - Append, Overwrite and Upsert](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/009%20-%20Redshift%20-%20Append%2C%20Overwrite%2C%20Upsert.ipynb)
- [010 - Parquet Crawler](https://github.com/awslabs/aws-data-wrangler/blob/master/tutorials/010%20-%20Parquet%20Crawler.ipynb)
Expand Down Expand Up @@ -134,6 +134,7 @@ FROM "sampleDB"."sampleTable" ORDER BY time DESC LIMIT 3
- [Amazon Redshift](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-redshift)
- [PostgreSQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#postgresql)
- [MySQL](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#mysql)
- [SQL Server](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#sqlserver)
- [DynamoDB](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#dynamodb)
- [Amazon Timestream](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-timestream)
- [Amazon EMR](https://aws-data-wrangler.readthedocs.io/en/stable/api.html#amazon-emr)
Expand Down
2 changes: 2 additions & 0 deletions awswrangler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
redshift,
s3,
secretsmanager,
sqlserver,
sts,
timestream,
)
Expand All @@ -42,6 +43,7 @@
"mysql",
"postgresql",
"secretsmanager",
"sqlserver",
"config",
"timestream",
"__description__",
Expand Down
4 changes: 3 additions & 1 deletion awswrangler/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,9 @@ def _apply_type(name: str, value: Any, dtype: Type[Union[str, bool, int]], nulla
if _Config._is_null(value=value):
if nullable is True:
return None
exceptions.InvalidArgumentValue(f"{name} configuration does not accept a null value. Please pass {dtype}.")
raise exceptions.InvalidArgumentValue(
f"{name} configuration does not accept a null value. Please pass {dtype}."
)
try:
return dtype(value) if isinstance(value, dtype) is False else value
except ValueError as ex:
Expand Down
35 changes: 35 additions & 0 deletions awswrangler/_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,41 @@ def pyarrow2postgresql( # pylint: disable=too-many-branches,too-many-return-sta
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")


def pyarrow2sqlserver( # pylint: disable=too-many-branches,too-many-return-statements
dtype: pa.DataType, string_type: str
) -> str:
"""Pyarrow to Microsoft SQL Server data types conversion."""
if pa.types.is_int8(dtype):
return "SMALLINT"
if pa.types.is_int16(dtype) or pa.types.is_uint8(dtype):
return "SMALLINT"
if pa.types.is_int32(dtype) or pa.types.is_uint16(dtype):
return "INT"
if pa.types.is_int64(dtype) or pa.types.is_uint32(dtype):
return "BIGINT"
if pa.types.is_uint64(dtype):
raise exceptions.UnsupportedType("There is no support for uint64, please consider int64 or uint32.")
if pa.types.is_float32(dtype):
return "FLOAT(24)"
if pa.types.is_float64(dtype):
return "FLOAT"
if pa.types.is_boolean(dtype):
return "BIT"
if pa.types.is_string(dtype):
return string_type
if pa.types.is_timestamp(dtype):
return "DATETIME2"
if pa.types.is_date(dtype):
return "DATE"
if pa.types.is_decimal(dtype):
return f"DECIMAL({dtype.precision},{dtype.scale})"
if pa.types.is_dictionary(dtype):
return pyarrow2sqlserver(dtype=dtype.value_type, string_type=string_type)
if pa.types.is_binary(dtype):
return "VARBINARY"
raise exceptions.UnsupportedType(f"Unsupported PostgreSQL type: {dtype}")


def pyarrow2timestream(dtype: pa.DataType) -> str: # pylint: disable=too-many-branches,too-many-return-statements
"""Pyarrow to Amazon Timestream data types conversion."""
if pa.types.is_int8(dtype):
Expand Down
85 changes: 59 additions & 26 deletions awswrangler/_databases.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,11 @@ def _get_connection_attributes_from_catalog(
details: Dict[str, Any] = get_connection(name=connection, catalog_id=catalog_id, boto3_session=boto3_session)[
"ConnectionProperties"
]
port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split("/")
if ";databaseName=" in details["JDBC_CONNECTION_URL"]:
database_sep = ";databaseName="
else:
database_sep = "/"
port, database = details["JDBC_CONNECTION_URL"].split(":")[3].split(database_sep)
return ConnectionAttributes(
kind=details["JDBC_CONNECTION_URL"].split(":")[1].lower(),
user=details["USERNAME"],
Expand Down Expand Up @@ -136,19 +140,48 @@ def _records2df(
return df


def _iterate_cursor(
cursor: Any,
def _get_cols_names(cursor_description: Any) -> List[str]:
cols_names = [col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor_description]
_logger.debug("cols_names: %s", cols_names)

return cols_names


def _iterate_results(
con: Any,
cursor_args: List[Any],
chunksize: int,
cols_names: List[str],
index: Optional[Union[str, List[str]]],
index_col: Optional[Union[str, List[str]]],
safe: bool,
dtype: Optional[Dict[str, pa.DataType]],
) -> Iterator[pd.DataFrame]:
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
yield _records2df(records=records, cols_names=cols_names, index=index, safe=safe, dtype=dtype)
with con.cursor() as cursor:
cursor.execute(*cursor_args)
cols_names = _get_cols_names(cursor.description)
while True:
records = cursor.fetchmany(chunksize)
if not records:
break
yield _records2df(records=records, cols_names=cols_names, index=index_col, safe=safe, dtype=dtype)


def _fetch_all_results(
con: Any,
cursor_args: List[Any],
index_col: Optional[Union[str, List[str]]] = None,
dtype: Optional[Dict[str, pa.DataType]] = None,
safe: bool = True,
) -> pd.DataFrame:
with con.cursor() as cursor:
cursor.execute(*cursor_args)
cols_names = _get_cols_names(cursor.description)
return _records2df(
records=cast(List[Tuple[Any]], cursor.fetchall()),
cols_names=cols_names,
index=index_col,
dtype=dtype,
safe=safe,
)


def read_sql_query(
Expand All @@ -163,23 +196,23 @@ def read_sql_query(
"""Read SQL Query (generic)."""
args = _convert_params(sql, params)
try:
with con.cursor() as cursor:
cursor.execute(*args)
cols_names: List[str] = [
col[0].decode("utf-8") if isinstance(col[0], bytes) else col[0] for col in cursor.description
]
_logger.debug("cols_names: %s", cols_names)
if chunksize is None:
return _records2df(
records=cast(List[Tuple[Any]], cursor.fetchall()),
cols_names=cols_names,
index=index_col,
dtype=dtype,
safe=safe,
)
return _iterate_cursor(
cursor=cursor, chunksize=chunksize, cols_names=cols_names, index=index_col, dtype=dtype, safe=safe
if chunksize is None:
return _fetch_all_results(
con=con,
cursor_args=args,
index_col=index_col,
dtype=dtype,
safe=safe,
)

return _iterate_results(
con=con,
cursor_args=args,
chunksize=chunksize,
index_col=index_col,
dtype=dtype,
safe=safe,
)
except Exception as ex:
con.rollback()
_logger.error(ex)
Expand Down
2 changes: 1 addition & 1 deletion awswrangler/catalog/_get.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,7 @@ def get_connection(
client_kms = _utils.client(service_name="kms", session=boto3_session)
pwd = client_kms.decrypt(CiphertextBlob=base64.b64decode(res["ConnectionProperties"]["ENCRYPTED_PASSWORD"]))[
"Plaintext"
]
].decode("utf-8")
res["ConnectionProperties"]["PASSWORD"] = pwd
return cast(Dict[str, Any], res)

Expand Down
11 changes: 4 additions & 7 deletions awswrangler/mysql.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def connect(
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
)
if attrs.kind != "mysql":
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")
raise exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a MySQL connection.)")
return pymysql.connect(
user=attrs.user,
database=attrs.database,
Expand Down Expand Up @@ -156,8 +156,7 @@ def read_sql_query(
sql : str
SQL query.
con : pymysql.connections.Connection
Use pymysql.connect() to use "
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
index_col : Union[str, List[str]], optional
Column(s) to set as index(MultiIndex).
params : Union[List, Tuple, Dict], optional
Expand Down Expand Up @@ -214,8 +213,7 @@ def read_sql_table(
table : str
Table name.
con : pymysql.connections.Connection
Use pymysql.connect() to use "
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
schema : str, optional
Name of SQL schema in database to query.
Uses default schema if None.
Expand Down Expand Up @@ -276,8 +274,7 @@ def to_sql(
df : pandas.DataFrame
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
con : pymysql.connections.Connection
Use pymysql.connect() to use "
"credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
Use pymysql.connect() to use credentials directly or wr.mysql.connect() to fetch it from the Glue Catalog.
table : str
Table name
schema : str
Expand Down
15 changes: 7 additions & 8 deletions awswrangler/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def connect(
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
)
if attrs.kind != "postgresql":
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)")
raise exceptions.InvalidDatabaseType(
f"Invalid connection type ({attrs.kind}. It must be a postgresql connection.)"
)
return pg8000.connect(
user=attrs.user,
database=attrs.database,
Expand Down Expand Up @@ -160,8 +162,7 @@ def read_sql_query(
sql : str
SQL query.
con : pg8000.Connection
Use pg8000.connect() to use "
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
index_col : Union[str, List[str]], optional
Column(s) to set as index(MultiIndex).
params : Union[List, Tuple, Dict], optional
Expand Down Expand Up @@ -218,8 +219,7 @@ def read_sql_table(
table : str
Table name.
con : pg8000.Connection
Use pg8000.connect() to use "
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
schema : str, optional
Name of SQL schema in database to query (if database flavor supports this).
Uses default schema if None (default).
Expand Down Expand Up @@ -280,8 +280,7 @@ def to_sql(
df : pandas.DataFrame
Pandas DataFrame https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html
con : pg8000.Connection
Use pg8000.connect() to use "
"credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
Use pg8000.connect() to use credentials directly or wr.postgresql.connect() to fetch it from the Glue Catalog.
table : str
Table name
schema : str
Expand Down Expand Up @@ -310,7 +309,7 @@ def to_sql(
>>> import awswrangler as wr
>>> con = wr.postgresql.connect("MY_GLUE_CONNECTION")
>>> wr.postgresql.to_sql(
... df=df
... df=df,
... table="my_table",
... schema="public",
... con=con
Expand Down
4 changes: 3 additions & 1 deletion awswrangler/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,9 @@ def connect(
connection=connection, secret_id=secret_id, catalog_id=catalog_id, dbname=dbname, boto3_session=boto3_session
)
if attrs.kind != "redshift":
exceptions.InvalidDatabaseType(f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)")
raise exceptions.InvalidDatabaseType(
f"Invalid connection type ({attrs.kind}. It must be a redshift connection.)"
)
return redshift_connector.connect(
user=attrs.user,
database=attrs.database,
Expand Down
Loading