Skip to content

Commit

Permalink
Merge pull request #115 from awslabs/categorical-partitions
Browse files Browse the repository at this point in the history
Handling categorical partitions
  • Loading branch information
igorborgest committed Jan 21, 2020
2 parents 5cc6360 + a81769c commit 5d9c525
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 3 deletions.
6 changes: 3 additions & 3 deletions awswrangler/pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1004,7 +1004,7 @@ def _data_to_s3_dataset_writer(dataframe: pd.DataFrame,
objects_paths.append(object_path)
else:
dataframe = Pandas._cast_pandas(dataframe=dataframe, cast_columns=cast_columns)
for keys, subgroup in dataframe.groupby(partition_cols):
for keys, subgroup in dataframe.groupby(by=partition_cols, observed=True):
subgroup = subgroup.drop(partition_cols, axis="columns")
if not isinstance(keys, tuple):
keys = (keys, )
Expand Down Expand Up @@ -1407,7 +1407,7 @@ def read_parquet(self,
if len(dfs) == 1:
df: pd.DataFrame = dfs[0]
else:
df = pd.concat(objs=dfs, ignore_index=True)
df = pd.concat(objs=dfs, ignore_index=True, sort=False)
return df

@staticmethod
Expand Down Expand Up @@ -1870,7 +1870,7 @@ def read_csv_list(
logger.debug(f"Closing proc number: {i}")
receive_pipes[i].close()
logger.debug(f"Concatenating all {len(paths)} DataFrames...")
df = pd.concat(objs=dfs, ignore_index=True)
df = pd.concat(objs=dfs, ignore_index=True, sort=False)
return df

def _read_csv_list_iterator(
Expand Down
13 changes: 13 additions & 0 deletions testing/test_awswrangler/test_pandas.py
Original file line number Diff line number Diff line change
Expand Up @@ -2173,3 +2173,16 @@ def test_aurora_mysql_load_special2(bucket, mysql_parameters):
assert rows[0][2] is None
assert rows[1][3] is None
conn.close()


def test_to_parquet_categorical_partitions(bucket):
path = f"s3://{bucket}/test_to_parquet_categorical_partitions"
wr.s3.delete_objects(path=path)
d = pd.date_range("1990-01-01", freq="D", periods=10000)
vals = pd.np.random.randn(len(d), 4)
x = pd.DataFrame(vals, index=d, columns=["A", "B", "C", "D"])
x['Year'] = x.index.year
x['Year'] = x['Year'].astype('category')
wr.pandas.to_parquet(x[x.Year == 1990], path=path, partition_cols=["Year"])
y = wr.pandas.read_parquet(path=path)
assert len(x[x.Year == 1990].index) == len(y.index)

0 comments on commit 5d9c525

Please sign in to comment.