diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index da0c615e3b23..03872147b797 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -363,6 +363,7 @@ enum JoinType { LEFT = 1; RIGHT = 2; FULL = 3; + SEMI = 4; } message JoinNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 10c4670e809a..48471263885f 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -265,6 +265,7 @@ impl TryInto for &protobuf::LogicalPlanNode { protobuf::JoinType::Left => JoinType::Left, protobuf::JoinType::Right => JoinType::Right, protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, }; LogicalPlanBuilder::from(&convert_box_required!(join.left)?) .join( diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index b630dfcc0d1b..e1c0c5e44df6 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -834,6 +834,7 @@ impl TryInto for &LogicalPlan { JoinType::Left => protobuf::JoinType::Left, JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, }; let left_join_column = on.iter().map(|on| on.0.to_owned()).collect(); let right_join_column = on.iter().map(|on| on.1.to_owned()).collect(); diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 2039def908bc..7f98a8378b0b 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -379,6 +379,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::JoinType::Left => JoinType::Left, protobuf::JoinType::Right => JoinType::Right, protobuf::JoinType::Full => JoinType::Full, + protobuf::JoinType::Semi => JoinType::Semi, }; Ok(Arc::new(HashJoinExec::try_new( left, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index 9571f3de2e76..c409f9474951 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -133,6 +133,7 @@ impl TryInto for Arc { JoinType::Left => protobuf::JoinType::Left, JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, + JoinType::Semi => protobuf::JoinType::Semi, }; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 9515ac2ff373..5e44a3e097f3 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -410,6 +410,10 @@ fn build_join_schema( // left then right left_fields.chain(right_fields).cloned().collect() } + JoinType::Semi => { + // Only use the left side for the schema + left.fields().clone() + } JoinType::Right => { // remove left-side join keys if they have the same names as the right-side let duplicate_keys = &on diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 4027916c8a7c..d10f8b573345 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -34,7 +34,7 @@ use std::{ }; /// Join type -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum JoinType { /// Inner Join Inner, @@ -44,6 +44,8 @@ pub enum JoinType { Right, /// Full Join Full, + /// Semi Join + Semi, } /// A LogicalPlan represents the different types of relational diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 100ae4fb09b7..86d38ef313ce 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -106,6 +106,13 @@ fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool { } } +fn supports_swap(join_type: JoinType) -> bool { + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true, + JoinType::Semi => false, + } +} + impl OptimizerRule for HashBuildProbeOrder { fn name(&self) -> &str { "hash_build_probe_order" @@ -128,7 +135,7 @@ impl OptimizerRule for HashBuildProbeOrder { } => { let left = self.optimize(left, execution_props)?; let right = self.optimize(right, execution_props)?; - if should_swap_join_order(&left, &right) { + if should_swap_join_order(&left, &right) && supports_swap(*join_type) { // Swap left and right, change join type and (equi-)join key order Ok(LogicalPlan::Join { left: Arc::new(right), @@ -216,6 +223,7 @@ fn swap_join_type(join_type: JoinType) -> JoinType { JoinType::Full => JoinType::Full, JoinType::Left => JoinType::Right, JoinType::Right => JoinType::Left, + _ => unreachable!(), } } diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 01551cd4daf4..6653b9a356a4 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -184,7 +184,7 @@ impl HashJoinExec { /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner | JoinType::Left | JoinType::Full => { + JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Semi => { (true, self.left.schema(), self.right.schema()) } JoinType::Right => (false, self.right.schema(), self.left.schema()), @@ -376,7 +376,7 @@ impl ExecutionPlan for HashJoinExec { let column_indices = self.column_indices_from_schema()?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { - JoinType::Left | JoinType::Full => vec![false; num_rows], + JoinType::Left | JoinType::Full | JoinType::Semi => vec![false; num_rows], JoinType::Inner | JoinType::Right => vec![], }; Ok(Box::pin(HashJoinStream { @@ -544,6 +544,13 @@ fn build_batch( ) .unwrap(); + if join_type == JoinType::Semi { + return Ok(( + RecordBatch::new_empty(Arc::new(schema.clone())), + left_indices, + )); + } + build_batch_from_indices( schema, &left_data.1, @@ -606,7 +613,7 @@ fn build_join_indexes( let left = &left_data.0; match join_type { - JoinType::Inner => { + JoinType::Inner | JoinType::Semi => { // Using a buffer builder to avoid slower normal builder let mut left_indices = UInt64BufferBuilder::new(0); let mut right_indices = UInt32BufferBuilder::new(0); @@ -1108,23 +1115,35 @@ pub fn create_hashes<'a>( Ok(hashes_buffer) } -// Produces a batch for left-side rows that are not marked as being visited during the whole join -fn produce_unmatched( +// Produces a batch for left-side rows that have/have not been matched during the whole join +fn produce_from_matched( visited_left_side: &[bool], schema: &SchemaRef, column_indices: &[ColumnIndex], left_data: &JoinLeftData, + unmatched: bool, ) -> ArrowResult { // Find indices which didn't match any right row (are false) - let unmatched_indices: Vec = visited_left_side - .iter() - .enumerate() - .filter(|&(_, &value)| !value) - .map(|(index, _)| index as u64) - .collect(); + let indices = if unmatched { + UInt64Array::from_iter_values( + visited_left_side + .iter() + .enumerate() + .filter(|&(_, &value)| !value) + .map(|(index, _)| index as u64), + ) + } else { + // produce those that did match + UInt64Array::from_iter_values( + visited_left_side + .iter() + .enumerate() + .filter(|&(_, &value)| value) + .map(|(index, _)| index as u64), + ) + }; // generate batches by taking values from the left side and generating columns filled with null on the right side - let indices = UInt64Array::from_iter_values(unmatched_indices); let num_rows = indices.len(); let mut columns: Vec> = Vec::with_capacity(schema.fields().len()); for (idx, column_index) in column_indices.iter().enumerate() { @@ -1171,7 +1190,7 @@ impl Stream for HashJoinStream { self.num_output_rows += batch.num_rows(); match self.join_type { - JoinType::Left | JoinType::Full => { + JoinType::Left | JoinType::Full | JoinType::Semi => { left_side.iter().flatten().for_each(|x| { self.visited_left_side[x as usize] = true; }); @@ -1185,12 +1204,15 @@ impl Stream for HashJoinStream { let start = Instant::now(); // For the left join, produce rows for unmatched rows match self.join_type { - JoinType::Left | JoinType::Full if !self.is_exhausted => { - let result = produce_unmatched( + JoinType::Left | JoinType::Full | JoinType::Semi + if !self.is_exhausted => + { + let result = produce_from_matched( &self.visited_left_side, &self.schema, &self.column_indices, &self.left_data, + self.join_type != JoinType::Semi, ); if let Ok(ref batch) = result { self.num_input_batches += 1; @@ -1207,6 +1229,7 @@ impl Stream for HashJoinStream { } JoinType::Left | JoinType::Full + | JoinType::Semi | JoinType::Inner | JoinType::Right => {} } @@ -1666,6 +1689,42 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_semi() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = &[("b1", "b1")]; + + let join = join(left, right, on, &JoinType::Semi)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + + Ok(()) + } + #[tokio::test] async fn join_right_one() -> Result<()> { let left = build_table( diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 7e030af3a124..110319e4bb6b 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -32,6 +32,8 @@ pub enum JoinType { Right, /// Full Join Full, + /// Semi Join + Semi, } /// The on clause of the join, as vector of (left, right) columns. @@ -130,6 +132,7 @@ pub fn build_join_schema( // left then right left_fields.chain(right_fields).cloned().collect() } + JoinType::Semi => left.fields().clone(), }; Schema::new(fields) } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 7ddfaf8f6897..4971a027ef1e 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -367,6 +367,7 @@ impl DefaultPhysicalPlanner { JoinType::Left => hash_utils::JoinType::Left, JoinType::Right => hash_utils::JoinType::Right, JoinType::Full => hash_utils::JoinType::Full, + JoinType::Semi => hash_utils::JoinType::Semi, }; if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins {