From 664ae2fe48a9a374b389e76b4e305d41c0bb8f79 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 12:27:43 +0200 Subject: [PATCH 1/2] feat: expose join_on method --- python/datafusion/dataframe.py | 25 ++++++++++++++++++++++- python/tests/test_dataframe.py | 37 ++++++++++++++++++++++++++++++++++ src/dataframe.rs | 25 +++++++++++++++++++++++ 3 files changed, 86 insertions(+), 1 deletion(-) diff --git a/python/datafusion/dataframe.py b/python/datafusion/dataframe.py index c5ac0bb89..f8ea2d521 100644 --- a/python/datafusion/dataframe.py +++ b/python/datafusion/dataframe.py @@ -21,7 +21,7 @@ from __future__ import annotations -from typing import Any, List, TYPE_CHECKING +from typing import Any, List, TYPE_CHECKING, Literal from datafusion.record_batch import RecordBatchStream from typing_extensions import deprecated from datafusion.plan import LogicalPlan, ExecutionPlan @@ -293,6 +293,29 @@ def join( """ return DataFrame(self.df.join(right.df, join_keys, how)) + def join_on( + self, + right: DataFrame, + *on_exprs: Expr, + how: Literal["inner", "left", "right", "full", "semi", "anti"] = "inner", + ) -> DataFrame: + """Join two :py:class:`DataFrame`using the specified expressions. + + On expressions are used to support in-equality predicates. Equality + predicates are correctly optimized + + Args: + right: Other DataFrame to join with. + on_exprs: single or multiple (in)-equality predicates. + how: Type of join to perform. Supported types are "inner", "left", + "right", "full", "semi", "anti". + + Returns: + DataFrame after join. + """ + exprs = [expr.expr for expr in on_exprs] + return DataFrame(self.df.join_on(right.df, exprs, how)) + def explain(self, verbose: bool = False, analyze: bool = False) -> DataFrame: """Return a DataFrame with the explanation of its plan so far. diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index e89c57159..77f40bb6f 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -259,6 +259,43 @@ def test_join(): assert table.to_pydict() == expected +def test_join_on(): + ctx = SessionContext() + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2, 3]), pa.array([4, 5, 6])], + names=["a", "b"], + ) + df = ctx.create_dataframe([[batch]], "l") + + batch = pa.RecordBatch.from_arrays( + [pa.array([1, 2]), pa.array([8, 10])], + names=["a", "c"], + ) + df1 = ctx.create_dataframe([[batch]], "r") + + df2 = df.join_on(df1, column("l.a").__eq__(column("r.a")), how="inner") + df2.show() + df2 = df2.sort(column("l.a")) + table = pa.Table.from_batches(df2.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + df3 = df.join_on( + df1, + column("l.a").__eq__(column("r.a")), + column("l.a").__lt__(column("r.c")), + how="inner", + ) + df3.show() + df3 = df3.sort(column("l.a")) + table = pa.Table.from_batches(df3.collect()) + + expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + assert table.to_pydict() == expected + + def test_distinct(): ctx = SessionContext() diff --git a/src/dataframe.rs b/src/dataframe.rs index e77ca8425..57263a481 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -293,6 +293,31 @@ impl PyDataFrame { Ok(Self::new(df)) } + fn join_on(&self, right: PyDataFrame, on_exprs: Vec, how: &str) -> PyResult { + let join_type = match how { + "inner" => JoinType::Inner, + "left" => JoinType::Left, + "right" => JoinType::Right, + "full" => JoinType::Full, + "semi" => JoinType::LeftSemi, + "anti" => JoinType::LeftAnti, + how => { + return Err(DataFusionError::Common(format!( + "The join type {how} does not exist or is not implemented" + )) + .into()); + } + }; + let exprs: Vec = on_exprs.into_iter().map(|e| e.into()).collect(); + + let df = self + .df + .as_ref() + .clone() + .join_on(right.df.as_ref().clone(), join_type, exprs)?; + Ok(Self::new(df)) + } + /// Print the query plan #[pyo3(signature = (verbose=false, analyze=false))] fn explain(&self, py: Python, verbose: bool, analyze: bool) -> PyResult<()> { From c34b34d6d0e01085e3a6f2a95fb21709bca74561 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 13 Oct 2024 16:55:09 +0200 Subject: [PATCH 2/2] test: improve join_on case --- python/tests/test_dataframe.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/python/tests/test_dataframe.py b/python/tests/test_dataframe.py index 77f40bb6f..40943d2af 100644 --- a/python/tests/test_dataframe.py +++ b/python/tests/test_dataframe.py @@ -269,7 +269,7 @@ def test_join_on(): df = ctx.create_dataframe([[batch]], "l") batch = pa.RecordBatch.from_arrays( - [pa.array([1, 2]), pa.array([8, 10])], + [pa.array([1, 2]), pa.array([-8, 10])], names=["a", "c"], ) df1 = ctx.create_dataframe([[batch]], "r") @@ -279,7 +279,7 @@ def test_join_on(): df2 = df2.sort(column("l.a")) table = pa.Table.from_batches(df2.collect()) - expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + expected = {"a": [1, 2], "c": [-8, 10], "b": [4, 5]} assert table.to_pydict() == expected df3 = df.join_on( @@ -291,8 +291,7 @@ def test_join_on(): df3.show() df3 = df3.sort(column("l.a")) table = pa.Table.from_batches(df3.collect()) - - expected = {"a": [1, 2], "c": [8, 10], "b": [4, 5]} + expected = {"a": [2], "c": [10], "b": [5]} assert table.to_pydict() == expected