Skip to content

Commit

Permalink
SqlToS3Operator: feat/ add max_rows_per_file parameter (#37055)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: Selim Chergui <selim.chergui@setec.com>
Co-authored-by: Jarek Potiuk <jarek@potiuk.com>
  • Loading branch information
3 people committed Jan 30, 2024
1 parent c36c4db commit 8914e49
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 3 deletions.
34 changes: 31 additions & 3 deletions airflow/providers/amazon/aws/transfers/sql_to_s3.py
Expand Up @@ -81,6 +81,9 @@ class SqlToS3Operator(BaseOperator):
You can specify this argument if you want to use a different
CA cert bundle than the one used by botocore.
:param file_format: the destination file format, only string 'csv', 'json' or 'parquet' is accepted.
:param max_rows_per_file: (optional) argument to set destination file number of rows limit, if source data
is larger than that, it will be dispatched into multiple files.
Will be ignored if ``groupby_kwargs`` argument is specified.
:param pd_kwargs: arguments to include in DataFrame ``.to_parquet()``, ``.to_json()`` or ``.to_csv()``.
:param groupby_kwargs: argument to include in DataFrame ``groupby()``.
"""
Expand Down Expand Up @@ -110,6 +113,7 @@ def __init__(
aws_conn_id: str = "aws_default",
verify: bool | str | None = None,
file_format: Literal["csv", "json", "parquet"] = "csv",
max_rows_per_file: int = 0,
pd_kwargs: dict | None = None,
groupby_kwargs: dict | None = None,
**kwargs,
Expand All @@ -124,12 +128,19 @@ def __init__(
self.replace = replace
self.pd_kwargs = pd_kwargs or {}
self.parameters = parameters
self.max_rows_per_file = max_rows_per_file
self.groupby_kwargs = groupby_kwargs or {}
self.sql_hook_params = sql_hook_params

if "path_or_buf" in self.pd_kwargs:
raise AirflowException("The argument path_or_buf is not allowed, please remove it")

if self.max_rows_per_file and self.groupby_kwargs:
raise AirflowException(
"SqlToS3Operator arguments max_rows_per_file and groupby_kwargs "
"can not be both specified. Please choose one."
)

try:
self.file_format = FILE_FORMAT[file_format.upper()]
except KeyError:
Expand Down Expand Up @@ -177,10 +188,8 @@ def execute(self, context: Context) -> None:
s3_conn = S3Hook(aws_conn_id=self.aws_conn_id, verify=self.verify)
data_df = sql_hook.get_pandas_df(sql=self.query, parameters=self.parameters)
self.log.info("Data from SQL obtained")

self._fix_dtypes(data_df, self.file_format)
file_options = FILE_OPTIONS_MAP[self.file_format]

for group_name, df in self._partition_dataframe(df=data_df):
with NamedTemporaryFile(mode=file_options.mode, suffix=file_options.suffix) as tmp_file:
self.log.info("Writing data to temp file")
Expand All @@ -194,13 +203,32 @@ def execute(self, context: Context) -> None:

def _partition_dataframe(self, df: pd.DataFrame) -> Iterable[tuple[str, pd.DataFrame]]:
"""Partition dataframe using pandas groupby() method."""
try:
import secrets
import string

import numpy as np
except ImportError:
pass
# if max_rows_per_file argument is specified, a temporary column with a random unusual name will be
# added to the dataframe. This column is used to dispatch the dataframe into smaller ones using groupby()
random_column_name = ""
if self.max_rows_per_file and not self.groupby_kwargs:
random_column_name = "".join(secrets.choice(string.ascii_letters) for _ in range(20))
df[random_column_name] = np.arange(len(df)) // self.max_rows_per_file
self.groupby_kwargs = {"by": random_column_name}
if not self.groupby_kwargs:
yield "", df
return
for group_label in (grouped_df := df.groupby(**self.groupby_kwargs)).groups:
yield (
cast(str, group_label),
cast("pd.DataFrame", grouped_df.get_group(group_label).reset_index(drop=True)),
cast(
"pd.DataFrame",
grouped_df.get_group(group_label)
.drop(random_column_name, axis=1, errors="ignore")
.reset_index(drop=True),
),
)

def _get_hook(self) -> DbApiHook:
Expand Down
52 changes: 52 additions & 0 deletions tests/providers/amazon/aws/transfers/test_sql_to_s3.py
Expand Up @@ -271,6 +271,58 @@ def test_without_groupby_kwarg(self):
)
)

def test_with_max_rows_per_file(self):
"""
Test operator when the max_rows_per_file is specified
"""
query = "query"
s3_bucket = "bucket"
s3_key = "key"

op = SqlToS3Operator(
query=query,
s3_bucket=s3_bucket,
s3_key=s3_key,
sql_conn_id="mysql_conn_id",
aws_conn_id="aws_conn_id",
task_id="task_id",
replace=True,
pd_kwargs={"index": False, "header": False},
max_rows_per_file=3,
dag=None,
)
example = {
"Team": ["Australia", "Australia", "India", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli", "Rohit Sharma"],
"Runs": [345, 490, 672, 560],
}

df = pd.DataFrame(example)
data = []
for group_name, df in op._partition_dataframe(df):
data.append((group_name, df))
data.sort(key=lambda d: d[0])
team, df = data[0]
assert df.equals(
pd.DataFrame(
{
"Team": ["Australia", "Australia", "India"],
"Player": ["Ricky", "David Warner", "Virat Kohli"],
"Runs": [345, 490, 672],
}
)
)
team, df = data[1]
assert df.equals(
pd.DataFrame(
{
"Team": ["India"],
"Player": ["Rohit Sharma"],
"Runs": [560],
}
)
)

@mock.patch("airflow.providers.common.sql.operators.sql.BaseHook.get_connection")
def test_hook_params(self, mock_get_conn):
mock_get_conn.return_value = Connection(conn_id="postgres_test", conn_type="postgres")
Expand Down

0 comments on commit 8914e49

Please sign in to comment.