Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 52 additions & 17 deletions awswrangler/aurora.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any
from typing import TYPE_CHECKING, Union, List, Dict, Tuple, Any, Optional
from logging import getLogger, Logger, INFO
import json
import warnings
Expand Down Expand Up @@ -137,6 +137,7 @@ def load_table(dataframe: pd.DataFrame,
table_name: str,
connection: Any,
num_files: int,
columns: Optional[List[str]] = None,
mode: str = "append",
preserve_index: bool = False,
engine: str = "mysql",
Expand All @@ -152,6 +153,7 @@ def load_table(dataframe: pd.DataFrame,
:param table_name: Aurora table name
:param connection: A PEP 249 compatible connection (Can be generated with Aurora.generate_connection())
:param num_files: Number of files to be loaded
:param columns: List of columns to load
:param mode: append or overwrite
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:param engine: "mysql" or "postgres"
Expand All @@ -167,7 +169,8 @@ def load_table(dataframe: pd.DataFrame,
connection=connection,
mode=mode,
preserve_index=preserve_index,
region=region)
region=region,
columns=columns)
elif "mysql" in engine.lower():
Aurora.load_table_mysql(dataframe=dataframe,
dataframe_type=dataframe_type,
Expand All @@ -177,7 +180,8 @@ def load_table(dataframe: pd.DataFrame,
connection=connection,
mode=mode,
preserve_index=preserve_index,
num_files=num_files)
num_files=num_files,
columns=columns)
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")

Expand All @@ -190,7 +194,8 @@ def load_table_postgres(dataframe: pd.DataFrame,
connection: Any,
mode: str = "append",
preserve_index: bool = False,
region: str = "us-east-1"):
region: str = "us-east-1",
columns: Optional[List[str]] = None):
"""
Load text/CSV files into a Aurora table using a manifest file.
Creates the table if necessary.
Expand All @@ -204,6 +209,7 @@ def load_table_postgres(dataframe: pd.DataFrame,
:param mode: append or overwrite
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:param region: AWS S3 bucket region (Required only for postgres engine)
:param columns: List of columns to load
:return: None
"""
with connection.cursor() as cursor:
Expand All @@ -214,15 +220,17 @@ def load_table_postgres(dataframe: pd.DataFrame,
schema_name=schema_name,
table_name=table_name,
preserve_index=preserve_index,
engine="postgres")
engine="postgres",
columns=columns)
connection.commit()
logger.debug("CREATE TABLE committed.")
for path in load_paths:
sql = Aurora._get_load_sql(path=path,
schema_name=schema_name,
table_name=table_name,
engine="postgres",
region=region)
region=region,
columns=columns)
Aurora._load_object_postgres_with_retry(connection=connection, sql=sql)
logger.debug(f"Load committed for: {path}.")

Expand Down Expand Up @@ -257,7 +265,8 @@ def load_table_mysql(dataframe: pd.DataFrame,
connection: Any,
num_files: int,
mode: str = "append",
preserve_index: bool = False):
preserve_index: bool = False,
columns: Optional[List[str]] = None):
"""
Load text/CSV files into a Aurora table using a manifest file.
Creates the table if necessary.
Expand All @@ -271,6 +280,7 @@ def load_table_mysql(dataframe: pd.DataFrame,
:param num_files: Number of files to be loaded
:param mode: append or overwrite
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:param columns: List of columns to load
:return: None
"""
with connection.cursor() as cursor:
Expand All @@ -281,11 +291,13 @@ def load_table_mysql(dataframe: pd.DataFrame,
schema_name=schema_name,
table_name=table_name,
preserve_index=preserve_index,
engine="mysql")
engine="mysql",
columns=columns)
sql = Aurora._get_load_sql(path=manifest_path,
schema_name=schema_name,
table_name=table_name,
engine="mysql")
engine="mysql",
columns=columns)
logger.debug(sql)
cursor.execute(sql)
logger.debug(f"Load done for: {manifest_path}")
Expand All @@ -310,22 +322,40 @@ def _parse_path(path):
return parts[0], parts[2]

@staticmethod
def _get_load_sql(path: str, schema_name: str, table_name: str, engine: str, region: str = "us-east-1") -> str:
def _get_load_sql(path: str,
schema_name: str,
table_name: str,
engine: str,
region: str = "us-east-1",
columns: Optional[List[str]] = None) -> str:
if "postgres" in engine.lower():
bucket, key = Aurora._parse_path(path=path)
if columns is None:
cols_str: str = ""
else:
cols_str = ",".join(columns)
sql: str = ("-- AWS DATA WRANGLER\n"
"SELECT aws_s3.table_import_from_s3(\n"
f"'{schema_name}.{table_name}',\n"
"'',\n"
f"'{cols_str}',\n"
"'(FORMAT CSV, DELIMITER '','', QUOTE ''\"'', ESCAPE ''\"'')',\n"
f"'({bucket},{key},{region})')")
elif "mysql" in engine.lower():
if columns is None:
cols_str = ""
else:
# building something like: (@col1,@col2) set col1=@col1,col2=@col2
col_str = [f"@{x}" for x in columns]
set_str = [f"{x}=@{x}" for x in columns]
cols_str = f"({','.join(col_str)}) SET {','.join(set_str)}"
logger.debug(f"cols_str: {cols_str}")
sql = ("-- AWS DATA WRANGLER\n"
f"LOAD DATA FROM S3 MANIFEST '{path}'\n"
"REPLACE\n"
f"INTO TABLE {schema_name}.{table_name}\n"
"FIELDS TERMINATED BY ',' OPTIONALLY ENCLOSED BY '\"' ESCAPED BY '\"'\n"
"LINES TERMINATED BY '\\n'")
"LINES TERMINATED BY '\\n'"
f"{cols_str}")
else:
raise InvalidEngine(f"{engine} is not a valid engine. Please use 'mysql' or 'postgres'!")
return sql
Expand All @@ -337,7 +367,8 @@ def _create_table(cursor,
schema_name,
table_name,
preserve_index=False,
engine: str = "mysql"):
engine: str = "mysql",
columns: Optional[List[str]] = None):
"""
Creates Aurora table.

Expand All @@ -348,6 +379,7 @@ def _create_table(cursor,
:param table_name: Redshift table name
:param preserve_index: Should we preserve the Dataframe index? (ONLY for Pandas Dataframe)
:param engine: "mysql" or "postgres"
:param columns: List of columns to load
:return: None
"""
sql: str = f"-- AWS DATA WRANGLER\n" \
Expand All @@ -364,7 +396,8 @@ def _create_table(cursor,
schema = Aurora._get_schema(dataframe=dataframe,
dataframe_type=dataframe_type,
preserve_index=preserve_index,
engine=engine)
engine=engine,
columns=columns)
cols_str: str = "".join([f"{col[0]} {col[1]},\n" for col in schema])[:-2]
sql = f"-- AWS DATA WRANGLER\n" f"CREATE TABLE IF NOT EXISTS {schema_name}.{table_name} (\n" f"{cols_str})"
logger.debug(f"Create table query:\n{sql}")
Expand All @@ -374,7 +407,8 @@ def _create_table(cursor,
def _get_schema(dataframe,
dataframe_type: str,
preserve_index: bool,
engine: str = "mysql") -> List[Tuple[str, str]]:
engine: str = "mysql",
columns: Optional[List[str]] = None) -> List[Tuple[str, str]]:
schema_built: List[Tuple[str, str]] = []
if "postgres" in engine.lower():
convert_func = data_types.pyarrow2postgres
Expand All @@ -386,8 +420,9 @@ def _get_schema(dataframe,
pyarrow_schema: List[Tuple[str, str]] = data_types.extract_pyarrow_schema_from_pandas(
dataframe=dataframe, preserve_index=preserve_index, indexes_position="right")
for name, dtype in pyarrow_schema:
aurora_type: str = convert_func(dtype)
schema_built.append((name, aurora_type))
if columns is None or name in columns:
aurora_type: str = convert_func(dtype)
schema_built.append((name, aurora_type))
else:
raise InvalidDataframeType(f"{dataframe_type} is not a valid DataFrame type. Please use 'pandas'!")
return schema_built
Expand Down
8 changes: 4 additions & 4 deletions awswrangler/data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
:param indexes_position: "right" or "left"
:return: Pyarrow schema (e.g. [("col name": "bigint"), ("col2 name": "int")]
"""
cols = []
cols_dtypes = {}
cols: List[str] = []
cols_dtypes: Dict[str, str] = {}
if indexes_position not in ("right", "left"):
raise ValueError(f"indexes_position must be \"right\" or \"left\"")

Expand All @@ -384,10 +384,10 @@ def extract_pyarrow_schema_from_pandas(dataframe: pd.DataFrame,
cols.append(name)

# Filling cols_dtypes and indexes
indexes = []
indexes: List[str] = []
for field in pa.Schema.from_pandas(df=dataframe[cols], preserve_index=preserve_index):
name = str(field.name)
dtype = field.type
dtype = str(field.type)
cols_dtypes[name] = dtype
if name not in dataframe.columns:
indexes.append(name)
Expand Down
Loading