From 9067b521aa6f113c99528942deb23b92c4a8e5f8 Mon Sep 17 00:00:00 2001 From: Lantao Jin Date: Wed, 20 May 2026 09:20:11 +0000 Subject: [PATCH] feat(dataframe): add join, joinOn, and JoinType --- .../java/org/apache/datafusion/DataFrame.java | 126 ++++++++ .../java/org/apache/datafusion/JoinType.java | 61 ++++ .../apache/datafusion/DataFrameJoinTest.java | 294 ++++++++++++++++++ native/src/lib.rs | 134 +++++++- 4 files changed, 613 insertions(+), 2 deletions(-) create mode 100644 core/src/main/java/org/apache/datafusion/JoinType.java create mode 100644 core/src/test/java/org/apache/datafusion/DataFrameJoinTest.java diff --git a/core/src/main/java/org/apache/datafusion/DataFrame.java b/core/src/main/java/org/apache/datafusion/DataFrame.java index 86dd523..5f89a8b 100644 --- a/core/src/main/java/org/apache/datafusion/DataFrame.java +++ b/core/src/main/java/org/apache/datafusion/DataFrame.java @@ -256,6 +256,121 @@ public DataFrame unnestColumns(UnnestOptions options, String... columns) { return new DataFrame(unnestColumns(nativeHandle, columns, options.preserveNulls())); } + /** + * Equi-join this DataFrame with {@code right} on the named columns, using the given {@link + * JoinType}. The receiver and {@code right} both remain usable and must still be closed + * independently. + * + *

Equivalent to SQL {@code left JOIN right ON l.leftCols[0] = r.rightCols[0] AND ...}. + * {@code leftCols} and {@code rightCols} must have the same length. + * + * @throws IllegalArgumentException if any argument is {@code null} or {@code leftCols.length != + * rightCols.length}. + * @throws IllegalStateException if either DataFrame is closed or already collected. + * @throws RuntimeException if join planning fails (column collision in the combined schema, + * unknown column names, etc.). + */ + public DataFrame join(DataFrame right, JoinType type, String[] leftCols, String[] rightCols) { + checkJoinArgs(right, type, leftCols, rightCols); + return new DataFrame( + joinDataFrame(nativeHandle, right.nativeHandle, type.code(), leftCols, rightCols, null)); + } + + /** + * Equi-join this DataFrame with {@code right}, restricting the result with a residual SQL filter + * parsed against the combined schema (left columns followed by right columns; columns + * may be qualified with the relation alias when ambiguous). The receiver and {@code right} both + * remain usable and must still be closed independently. + * + *

For outer joins, the filter is applied only to matched rows; unmatched rows are passed + * through with nulls on the unmatched side, matching DataFusion's semantics. + * + * @throws IllegalArgumentException if any argument is {@code null} or {@code leftCols.length != + * rightCols.length}. + * @throws IllegalStateException if either DataFrame is closed or already collected. + * @throws RuntimeException if join planning or filter parsing fails. + */ + public DataFrame join( + DataFrame right, JoinType type, String[] leftCols, String[] rightCols, String filter) { + checkJoinArgs(right, type, leftCols, rightCols); + if (filter == null) { + throw new IllegalArgumentException("join filter must be non-null"); + } + return new DataFrame( + joinDataFrame(nativeHandle, right.nativeHandle, type.code(), leftCols, rightCols, filter)); + } + + /** + * Join this DataFrame with {@code right} using arbitrary SQL predicates parsed against the + * combined schema. Each predicate is parsed independently and the join evaluates their + * conjunction. Predicates may reference columns from either side and may be qualified with the + * relation alias when ambiguous (e.g. {@code "left.x = right.x"}). The receiver and {@code right} + * both remain usable and must still be closed independently. + * + *

DataFusion's optimiser identifies and rewrites equality predicates into hash-join keys + * automatically, so {@code joinOn(right, INNER, "l.id = r.id")} plans equivalently to {@link + * #join(DataFrame, JoinType, String[], String[])} with a single key. Use {@code joinOn} when the + * predicate is not a simple equality, e.g. inequality joins or range conditions. + * + * @throws IllegalArgumentException if {@code right} or {@code type} is {@code null}, or {@code + * predicates} is {@code null} or empty, or any predicate is {@code null}. + * @throws IllegalStateException if either DataFrame is closed or already collected. + * @throws RuntimeException if predicate parsing or join planning fails. + */ + public DataFrame joinOn(DataFrame right, JoinType type, String... predicates) { + if (right == null) { + throw new IllegalArgumentException("joinOn right must be non-null"); + } + if (type == null) { + throw new IllegalArgumentException("joinOn type must be non-null"); + } + if (predicates == null || predicates.length == 0) { + throw new IllegalArgumentException("joinOn predicates must be non-null and non-empty"); + } + for (String p : predicates) { + if (p == null) { + throw new IllegalArgumentException("joinOn predicates must not contain null"); + } + } + if (nativeHandle == 0) { + throw new IllegalStateException("DataFrame is closed or already collected"); + } + if (right.nativeHandle == 0) { + throw new IllegalStateException("right DataFrame is closed or already collected"); + } + return new DataFrame( + joinOnDataFrame(nativeHandle, right.nativeHandle, type.code(), predicates)); + } + + private void checkJoinArgs( + DataFrame right, JoinType type, String[] leftCols, String[] rightCols) { + if (right == null) { + throw new IllegalArgumentException("join right must be non-null"); + } + if (type == null) { + throw new IllegalArgumentException("join type must be non-null"); + } + if (leftCols == null) { + throw new IllegalArgumentException("join leftCols must be non-null"); + } + if (rightCols == null) { + throw new IllegalArgumentException("join rightCols must be non-null"); + } + if (leftCols.length != rightCols.length) { + throw new IllegalArgumentException( + "join leftCols and rightCols must have the same length, got " + + leftCols.length + + " and " + + rightCols.length); + } + if (nativeHandle == 0) { + throw new IllegalStateException("DataFrame is closed or already collected"); + } + if (right.nativeHandle == 0) { + throw new IllegalStateException("right DataFrame is closed or already collected"); + } + } + /** * Materialize this DataFrame as Parquet at {@code path}. The path is treated as a directory * unless overridden via {@link ParquetWriteOptions#singleFileOutput(boolean)}. The receiver @@ -386,6 +501,17 @@ public void close() { private static native long unnestColumns(long handle, String[] columns, boolean preserveNulls); + private static native long joinDataFrame( + long leftHandle, + long rightHandle, + byte joinType, + String[] leftCols, + String[] rightCols, + String filter); + + private static native long joinOnDataFrame( + long leftHandle, long rightHandle, byte joinType, String[] predicates); + private static native void writeParquetWithOptions( long handle, String path, diff --git a/core/src/main/java/org/apache/datafusion/JoinType.java b/core/src/main/java/org/apache/datafusion/JoinType.java new file mode 100644 index 0000000..3c3334a --- /dev/null +++ b/core/src/main/java/org/apache/datafusion/JoinType.java @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +/** + * Join algorithm requested for {@link DataFrame#join} / {@link DataFrame#joinOn}. Mirrors + * DataFusion's {@code JoinType} enum one-to-one. + * + *

+ */ +public enum JoinType { + INNER((byte) 0), + LEFT((byte) 1), + RIGHT((byte) 2), + FULL((byte) 3), + LEFT_SEMI((byte) 4), + RIGHT_SEMI((byte) 5), + LEFT_ANTI((byte) 6), + RIGHT_ANTI((byte) 7), + LEFT_MARK((byte) 8), + RIGHT_MARK((byte) 9); + + private final byte code; + + JoinType(byte code) { + this.code = code; + } + + /** Stable byte code for FFI. */ + public byte code() { + return code; + } +} diff --git a/core/src/test/java/org/apache/datafusion/DataFrameJoinTest.java b/core/src/test/java/org/apache/datafusion/DataFrameJoinTest.java new file mode 100644 index 0000000..b118ea2 --- /dev/null +++ b/core/src/test/java/org/apache/datafusion/DataFrameJoinTest.java @@ -0,0 +1,294 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.datafusion; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import org.junit.jupiter.api.Test; + +class DataFrameJoinTest { + + // Two relations with one matching key (1, 2) and a few unmatched rows on each side. + // left: (1,'a'), (2,'b'), (3,'c'); right: (1,10), (2,20), (4,40). + private static final String LEFT_SQL = + "SELECT * FROM (VALUES (1, 'a'), (2, 'b'), (3, 'c')) AS l(id, s)"; + private static final String RIGHT_SQL = + "SELECT * FROM (VALUES (1, 10), (2, 20), (4, 40)) AS r(id, v)"; + + @Test + void innerJoinOnSingleColumn() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.INNER, new String[] {"id"}, new String[] {"id"})) { + assertEquals(2L, joined.count()); // (1,'a',1,10) and (2,'b',2,20) + } + } + + @Test + void innerJoinOnMultipleColumns() { + try (SessionContext ctx = new SessionContext(); + DataFrame l = + ctx.sql("SELECT * FROM (VALUES (1, 'x', 100), (2, 'y', 200)) AS t(a, b, l_other)"); + DataFrame r = + ctx.sql("SELECT * FROM (VALUES (1, 'x', 'p'), (2, 'z', 'q')) AS t(a2, b2, r_other)"); + DataFrame joined = + l.join(r, JoinType.INNER, new String[] {"a", "b"}, new String[] {"a2", "b2"})) { + assertEquals(1L, joined.count()); // only (1,'x') matches on both keys + } + } + + @Test + void leftJoinPreservesUnmatchedLeft() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.LEFT, new String[] {"id"}, new String[] {"id"})) { + // 3 left rows; unmatched (3,'c') gets nulls on the right side. + assertEquals(3L, joined.count()); + } + } + + @Test + void rightJoinPreservesUnmatchedRight() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.RIGHT, new String[] {"id"}, new String[] {"id"})) { + // 3 right rows; unmatched (4,40) gets nulls on the left side. + assertEquals(3L, joined.count()); + } + } + + @Test + void fullJoinPreservesBothSides() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.FULL, new String[] {"id"}, new String[] {"id"})) { + // 2 matched rows + 1 unmatched-left + 1 unmatched-right = 4. + assertEquals(4L, joined.count()); + } + } + + @Test + void leftSemiJoinReturnsLeftMatchedOnly() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.LEFT_SEMI, new String[] {"id"}, new String[] {"id"})) { + // Only the 2 left rows that have a matching right row. + // Output projects left side only (id, s) — right columns dropped. + assertEquals(2L, joined.count()); + } + } + + @Test + void leftAntiJoinReturnsLeftUnmatchedOnly() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.LEFT_ANTI, new String[] {"id"}, new String[] {"id"})) { + // Only the 1 left row (3,'c') with no right match. Output projects left side only. + assertEquals(1L, joined.count()); + } + } + + @Test + void rightSemiJoinReturnsRightMatchedOnly() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.RIGHT_SEMI, new String[] {"id"}, new String[] {"id"})) { + // Output projects right side only (id, v) — left columns dropped. + assertEquals(2L, joined.count()); + } + } + + @Test + void rightAntiJoinReturnsRightUnmatchedOnly() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.RIGHT_ANTI, new String[] {"id"}, new String[] {"id"})) { + assertEquals(1L, joined.count()); // (4, 40) + } + } + + @Test + void leftMarkJoinAddsMarkColumn() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.LEFT_MARK, new String[] {"id"}, new String[] {"id"})) { + // One row per left row, plus a 'mark' boolean column. + assertEquals(3L, joined.count()); + } + } + + @Test + void joinWithResidualFilter() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = + left.join(right, JoinType.INNER, new String[] {"id"}, new String[] {"id"}, "v >= 20")) { + // Without the filter: 2 matched rows. With v >= 20: only (2,'b',2,20). + assertEquals(1L, joined.count()); + } + } + + @Test + void joinOnSingleEqualityPredicate() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = left.joinOn(right, JoinType.INNER, "l.id = r.id")) { + assertEquals(2L, joined.count()); + } + } + + @Test + void joinOnInequalityPredicate() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = left.joinOn(right, JoinType.INNER, "l.id < r.id")) { + // Pairs (1,'a')<(2,20), (1,'a')<(4,40), (2,'b')<(4,40), (3,'c')<(4,40) = 4. + assertEquals(4L, joined.count()); + } + } + + @Test + void joinOnMultiplePredicates() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL); + DataFrame joined = left.joinOn(right, JoinType.INNER, "l.id = r.id", "r.v > 15")) { + // Equality narrows to (1,'a',1,10) and (2,'b',2,20); v > 15 leaves only the second. + assertEquals(1L, joined.count()); + } + } + + @Test + void semiJoinWithFilterToleratesSharedUnqualifiedColumn() { + // Regression for issue surfaced in code review: when both inputs carry an unqualified + // column with the same name (here, `tag`) that the residual filter does NOT reference, + // the join must still plan. Earlier the Rust side merged the schemas via + // DFSchema::join, whose check_names rejected the duplicate before parsing the filter. + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql("SELECT 1 AS id, 'l' AS tag"); + DataFrame right = ctx.sql("SELECT 1 AS rid, 99 AS rv, 'r' AS tag"); + DataFrame joined = + left.join( + right, JoinType.LEFT_SEMI, new String[] {"id"}, new String[] {"rid"}, "rv > 0")) { + assertEquals(1L, joined.count()); + } + } + + @Test + void joinOnToleratesSharedUnqualifiedColumn() { + // Same regression as the previous test, but exercises the joinOn predicate path. + // Uses LEFT_SEMI so the output schema is one-sided -- INNER joins on inputs that share + // an unqualified column name are genuinely ambiguous in the result and rejected by + // upstream's build_join_schema, which is not specific to our code. + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql("SELECT 1 AS id, 'l' AS tag"); + DataFrame right = ctx.sql("SELECT 1 AS rid, 99 AS rv, 'r' AS tag"); + DataFrame joined = left.joinOn(right, JoinType.LEFT_SEMI, "id = rid", "rv > 0")) { + assertEquals(1L, joined.count()); + } + } + + @Test + void joinPreservesReceivers() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL)) { + try (DataFrame joined = + left.join(right, JoinType.INNER, new String[] {"id"}, new String[] {"id"})) { + assertEquals(2L, joined.count()); + } + // Both receivers still usable after join(). + assertEquals(3L, left.count()); + assertEquals(3L, right.count()); + } + } + + @Test + void joinThrowsWhenLeftClosed() { + try (SessionContext ctx = new SessionContext(); + DataFrame right = ctx.sql(RIGHT_SQL)) { + DataFrame left = ctx.sql(LEFT_SQL); + left.close(); + assertThrows( + IllegalStateException.class, + () -> left.join(right, JoinType.INNER, new String[] {"id"}, new String[] {"id"})); + } + } + + @Test + void joinThrowsWhenRightClosed() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL)) { + DataFrame right = ctx.sql(RIGHT_SQL); + right.close(); + assertThrows( + IllegalStateException.class, + () -> left.join(right, JoinType.INNER, new String[] {"id"}, new String[] {"id"})); + } + } + + @Test + void joinNullArgumentValidation() { + try (SessionContext ctx = new SessionContext(); + DataFrame left = ctx.sql(LEFT_SQL); + DataFrame right = ctx.sql(RIGHT_SQL)) { + String[] cols = new String[] {"id"}; + assertThrows( + IllegalArgumentException.class, () -> left.join(null, JoinType.INNER, cols, cols)); + assertThrows(IllegalArgumentException.class, () -> left.join(right, null, cols, cols)); + assertThrows( + IllegalArgumentException.class, () -> left.join(right, JoinType.INNER, null, cols)); + assertThrows( + IllegalArgumentException.class, () -> left.join(right, JoinType.INNER, cols, null)); + assertThrows( + IllegalArgumentException.class, + () -> left.join(right, JoinType.INNER, new String[] {"id", "id"}, cols)); + assertThrows( + IllegalArgumentException.class, () -> left.join(right, JoinType.INNER, cols, cols, null)); + assertThrows(IllegalArgumentException.class, () -> left.joinOn(null, JoinType.INNER, "1=1")); + assertThrows(IllegalArgumentException.class, () -> left.joinOn(right, null, "1=1")); + assertThrows(IllegalArgumentException.class, () -> left.joinOn(right, JoinType.INNER)); + assertThrows( + IllegalArgumentException.class, () -> left.joinOn(right, JoinType.INNER, (String) null)); + } + } +} diff --git a/native/src/lib.rs b/native/src/lib.rs index a235cd3..9d9343e 100644 --- a/native/src/lib.rs +++ b/native/src/lib.rs @@ -39,18 +39,19 @@ use datafusion::arrow::datatypes::SchemaRef; use datafusion::arrow::error::ArrowError; use datafusion::arrow::ffi_stream::FFI_ArrowArrayStream; use datafusion::arrow::record_batch::{RecordBatchIterator, RecordBatchReader}; -use datafusion::common::UnnestOptions; +use datafusion::common::{JoinType, UnnestOptions}; use datafusion::config::TableParquetOptions; use datafusion::dataframe::DataFrame; use datafusion::dataframe::DataFrameWriteOptions; use datafusion::error::DataFusionError; use datafusion::execution::runtime_env::RuntimeEnvBuilder; use datafusion::execution::SendableRecordBatchStream; +use datafusion::logical_expr::Expr; use datafusion::logical_expr::{ScalarUDF, Signature}; use datafusion::prelude::{ParquetReadOptions, SessionConfig, SessionContext}; use futures::StreamExt; use jni::objects::{JByteArray, JClass, JObject, JObjectArray, JString}; -use jni::sys::{jboolean, jint, jlong}; +use jni::sys::{jboolean, jbyte, jint, jlong}; use jni::JNIEnv; use jni::JavaVM; use prost::Message; @@ -511,6 +512,135 @@ pub extern "system" fn Java_org_apache_datafusion_DataFrame_unnestColumns<'local }) } +/// Map a Java {@code JoinType.code()} byte back to upstream's enum. +fn join_type_from_byte(byte: u8) -> JniResult { + match byte { + 0 => Ok(JoinType::Inner), + 1 => Ok(JoinType::Left), + 2 => Ok(JoinType::Right), + 3 => Ok(JoinType::Full), + 4 => Ok(JoinType::LeftSemi), + 5 => Ok(JoinType::RightSemi), + 6 => Ok(JoinType::LeftAnti), + 7 => Ok(JoinType::RightAnti), + 8 => Ok(JoinType::LeftMark), + 9 => Ok(JoinType::RightMark), + other => Err(format!("unknown join type byte: {other}").into()), + } +} + +/// Build a combined DFSchema for SQL parsing of a join filter or `joinOn` predicate. +/// Mirrors how upstream's `LogicalPlanBuilder::join_detailed` normalises the parsed Expr +/// against `&[&[left_schema, right_schema]]`: tolerate unrelated duplicate-named columns +/// rather than rejecting them via `DFSchema::join`'s `check_names`. `DFSchema::merge` +/// skips duplicates (left side wins for unqualified collisions), which is fine for the +/// SQL-to-Expr step -- the subsequent join planner runs the real ambiguity check. +fn combine_schemas( + left: &datafusion::common::DFSchema, + right: &datafusion::common::DFSchema, +) -> datafusion::common::DFSchema { + let mut combined = left.clone(); + combined.merge(right); + combined +} + +/// Drain a Java {@code String[]} into an owned {@code Vec}. +fn collect_jstring_array(env: &mut JNIEnv, arr: &JObjectArray) -> JniResult> { + let len = env.get_array_length(arr)?; + let mut owned: Vec = Vec::with_capacity(len as usize); + for i in 0..len { + let elem = env.get_object_array_element(arr, i)?; + let jstr: JString = elem.into(); + owned.push(env.get_string(&jstr)?.into()); + } + Ok(owned) +} + +#[no_mangle] +#[allow(clippy::too_many_arguments)] +pub extern "system" fn Java_org_apache_datafusion_DataFrame_joinDataFrame<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + left_handle: jlong, + right_handle: jlong, + join_type: jbyte, + left_cols: JObjectArray<'local>, + right_cols: JObjectArray<'local>, + filter: JString<'local>, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |env| -> JniResult { + if left_handle == 0 { + return Err("left DataFrame handle is null".into()); + } + if right_handle == 0 { + return Err("right DataFrame handle is null".into()); + } + let left = unsafe { &*(left_handle as *const DataFrame) }.clone(); + let right = unsafe { &*(right_handle as *const DataFrame) }.clone(); + let join_type = join_type_from_byte(join_type as u8)?; + + let left_owned: Vec = collect_jstring_array(env, &left_cols)?; + let right_owned: Vec = collect_jstring_array(env, &right_cols)?; + let left_refs: Vec<&str> = left_owned.iter().map(String::as_str).collect(); + let right_refs: Vec<&str> = right_owned.iter().map(String::as_str).collect(); + + // The optional residual filter spans both sides and must be parsed against the + // combined schema. parse_sql_expr only sees one DataFrame's schema, so reach into + // the SessionState via into_parts() on a clone. Use DFSchema::merge rather than + // DFSchema::join so the parser tolerates unrelated duplicate unqualified columns + // shared by both sides (e.g. both inputs carrying a `created_at` field) -- merge + // skips the duplicates while join's check_names rejects them. Upstream's join + // path normalises the parsed Expr against both schemas as a precedence list, so + // ambiguous references genuinely used in the filter are still surfaced after + // parsing. + let filter_expr: Option = if filter.is_null() { + None + } else { + let filter_sql: String = env.get_string(&filter)?.into(); + let combined = combine_schemas(left.schema(), right.schema()); + let (state, _plan) = left.clone().into_parts(); + Some(state.create_logical_expr(&filter_sql, &combined)?) + }; + + let new_df = left.join(right, join_type, &left_refs, &right_refs, filter_expr)?; + Ok(Box::into_raw(Box::new(new_df)) as jlong) + }) +} + +#[no_mangle] +pub extern "system" fn Java_org_apache_datafusion_DataFrame_joinOnDataFrame<'local>( + mut env: JNIEnv<'local>, + _class: JClass<'local>, + left_handle: jlong, + right_handle: jlong, + join_type: jbyte, + predicates: JObjectArray<'local>, +) -> jlong { + try_unwrap_or_throw(&mut env, 0, |env| -> JniResult { + if left_handle == 0 { + return Err("left DataFrame handle is null".into()); + } + if right_handle == 0 { + return Err("right DataFrame handle is null".into()); + } + let left = unsafe { &*(left_handle as *const DataFrame) }.clone(); + let right = unsafe { &*(right_handle as *const DataFrame) }.clone(); + let join_type = join_type_from_byte(join_type as u8)?; + + let predicates_owned: Vec = collect_jstring_array(env, &predicates)?; + // See joinDataFrame for the rationale behind combine_schemas vs DFSchema::join. + let combined = combine_schemas(left.schema(), right.schema()); + let (state, _plan) = left.clone().into_parts(); + let exprs: Vec = predicates_owned + .iter() + .map(|sql| state.create_logical_expr(sql, &combined)) + .collect::>>()?; + + let new_df = left.join_on(right, join_type, exprs)?; + Ok(Box::into_raw(Box::new(new_df)) as jlong) + }) +} + #[no_mangle] pub extern "system" fn Java_org_apache_datafusion_DataFrame_writeParquetWithOptions<'local>( mut env: JNIEnv<'local>,