From 5292c859548ab68be621f84b2eb03d1760718ee8 Mon Sep 17 00:00:00 2001 From: francis-du Date: Fri, 9 Sep 2022 14:17:59 +0800 Subject: [PATCH 1/4] fix: conflicting --- datafusion/tests/test_dataframe.py | 50 ++++++++++++++++++++++++++++++ src/dataframe.rs | 20 ++++++++++++ 2 files changed, 70 insertions(+) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 760c37610..0ef91353f 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -22,6 +22,11 @@ from datafusion import DataFrame, SessionContext, column, literal, udf +@pytest.fixture +def ctx(): + return SessionContext() + + @pytest.fixture def df(): ctx = SessionContext() @@ -323,3 +328,48 @@ def test_collect_partitioned(): ) assert [[batch]] == ctx.create_dataframe([[batch]]).collect_partitioned() + + +def test_union(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a.union(df_b).sort(column("a").sort(ascending=True)).collect() + + +def test_union_distinct(ctx): + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df_a = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([3, 4, 5]), pa.array([6, 7, 8])], + names=["a", "b"], + ) + df_b = ctx.create_dataframe([[batch]]) + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], + names=["a", "b"], + ) + df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a.union(df_b, True).sort(column("a").sort(ascending=True)).collect() + assert df_c.collect() == df_a.union_distinct(df_b).sort(column("a").sort(ascending=True)).collect() diff --git a/src/dataframe.rs b/src/dataframe.rs index 4ae0160a9..4d8c0a3fb 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -204,6 +204,26 @@ impl PyDataFrame { Ok(Self::new(new_df)) } + /// Calculate the union of two `DataFrame`s, preserving duplicate rows.The + /// two `DataFrame`s must have exactly the same schema + #[args(distinct = false)] + fn union(&self, py_df: PyDataFrame, distinct: bool) -> PyResult { + let new_df = if distinct { + self.df.union_distinct(py_df.df)? + } else { + self.df.union(py_df.df)? + }; + + Ok(Self::new(new_df)) + } + + /// Calculate the distinct union of two `DataFrame`s. The + /// two `DataFrame`s must have exactly the same schema + fn union_distinct(&self, py_df: PyDataFrame) -> PyResult { + let new_df = self.df.union_distinct(py_df.df)?; + Ok(Self::new(new_df)) + } + /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s must have exactly the same schema fn intersect(&self, py_df: PyDataFrame) -> PyResult { let new_df = self.df.intersect(py_df.df)?; From 933130122f2319fdf961da7eb9a8ec62bc6a751d Mon Sep 17 00:00:00 2001 From: francis-du Date: Tue, 6 Sep 2022 11:41:48 +0800 Subject: [PATCH 2/4] fix: python linter --- datafusion/tests/test_dataframe.py | 27 ++++++++++++++++++++++----- 1 file changed, 22 insertions(+), 5 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 0ef91353f..9bbc8030c 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -347,9 +347,14 @@ def test_union(ctx): [pa.array([1, 2, 3, 3, 4, 5]), pa.array([4, 5, 6, 6, 7, 8])], names=["a", "b"], ) - df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) - assert df_c.collect() == df_a.union(df_b).sort(column("a").sort(ascending=True)).collect() + assert ( + df_c.collect() + == df_a.union(df_b).sort(column("a").sort(ascending=True)).collect() + ) def test_union_distinct(ctx): @@ -369,7 +374,19 @@ def test_union_distinct(ctx): [pa.array([1, 2, 3, 4, 5]), pa.array([4, 5, 6, 7, 8])], names=["a", "b"], ) - df_c = ctx.create_dataframe([[batch]]).sort(column("a").sort(ascending=True)) + df_c = ctx.create_dataframe([[batch]]).sort( + column("a").sort(ascending=True) + ) - assert df_c.collect() == df_a.union(df_b, True).sort(column("a").sort(ascending=True)).collect() - assert df_c.collect() == df_a.union_distinct(df_b).sort(column("a").sort(ascending=True)).collect() + assert ( + df_c.collect() + == df_a.union(df_b, True) + .sort(column("a").sort(ascending=True)) + .collect() + ) + assert ( + df_c.collect() + == df_a.union_distinct(df_b) + .sort(column("a").sort(ascending=True)) + .collect() + ) From 2be0f794589b8c728ac1b8cd5a6fe565b4a3c944 Mon Sep 17 00:00:00 2001 From: francis-du Date: Thu, 8 Sep 2022 16:00:57 +0800 Subject: [PATCH 3/4] fix: flake8 W503 isssue --- datafusion/tests/test_dataframe.py | 19 +++++-------------- 1 file changed, 5 insertions(+), 14 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 9bbc8030c..4587175d7 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -351,10 +351,7 @@ def test_union(ctx): column("a").sort(ascending=True) ) - assert ( - df_c.collect() - == df_a.union(df_b).sort(column("a").sort(ascending=True)).collect() - ) + assert df_c == df_a.union(df_b).sort(column("a").sort(ascending=True)) def test_union_distinct(ctx): @@ -378,15 +375,9 @@ def test_union_distinct(ctx): column("a").sort(ascending=True) ) - assert ( - df_c.collect() - == df_a.union(df_b, True) - .sort(column("a").sort(ascending=True)) - .collect() + assert df_c == df_a.union(df_b, True).sort( + column("a").sort(ascending=True) ) - assert ( - df_c.collect() - == df_a.union_distinct(df_b) - .sort(column("a").sort(ascending=True)) - .collect() + assert df_c == df_a.union_distinct(df_b).sort( + column("a").sort(ascending=True) ) From b7d87b98aadbbfbc28e58c3e4b88c3334abc0930 Mon Sep 17 00:00:00 2001 From: francis-du Date: Thu, 8 Sep 2022 16:20:40 +0800 Subject: [PATCH 4/4] fix: test error --- datafusion/tests/test_dataframe.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/datafusion/tests/test_dataframe.py b/datafusion/tests/test_dataframe.py index 4587175d7..9880b6d33 100644 --- a/datafusion/tests/test_dataframe.py +++ b/datafusion/tests/test_dataframe.py @@ -351,7 +351,9 @@ def test_union(ctx): column("a").sort(ascending=True) ) - assert df_c == df_a.union(df_b).sort(column("a").sort(ascending=True)) + df_a_u_b = df_a.union(df_b).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_u_b.collect() def test_union_distinct(ctx): @@ -375,9 +377,7 @@ def test_union_distinct(ctx): column("a").sort(ascending=True) ) - assert df_c == df_a.union(df_b, True).sort( - column("a").sort(ascending=True) - ) - assert df_c == df_a.union_distinct(df_b).sort( - column("a").sort(ascending=True) - ) + df_a_u_b = df_a.union(df_b, True).sort(column("a").sort(ascending=True)) + + assert df_c.collect() == df_a_u_b.collect() + assert df_c.collect() == df_a_u_b.collect()