diff --git a/ballista/rust/core/proto/ballista.proto b/ballista/rust/core/proto/ballista.proto index 03872147b797..8d5a9df0fc08 100644 --- a/ballista/rust/core/proto/ballista.proto +++ b/ballista/rust/core/proto/ballista.proto @@ -364,6 +364,7 @@ enum JoinType { RIGHT = 2; FULL = 3; SEMI = 4; + ANTI = 5; } message JoinNode { diff --git a/ballista/rust/core/src/serde/logical_plan/from_proto.rs b/ballista/rust/core/src/serde/logical_plan/from_proto.rs index 48471263885f..ca201a7db7b0 100644 --- a/ballista/rust/core/src/serde/logical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/from_proto.rs @@ -266,6 +266,7 @@ impl TryInto for &protobuf::LogicalPlanNode { protobuf::JoinType::Right => JoinType::Right, protobuf::JoinType::Full => JoinType::Full, protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, }; LogicalPlanBuilder::from(&convert_box_required!(join.left)?) .join( diff --git a/ballista/rust/core/src/serde/logical_plan/to_proto.rs b/ballista/rust/core/src/serde/logical_plan/to_proto.rs index e1c0c5e44df6..1cd886b175cf 100644 --- a/ballista/rust/core/src/serde/logical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/logical_plan/to_proto.rs @@ -835,6 +835,7 @@ impl TryInto for &LogicalPlan { JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, }; let left_join_column = on.iter().map(|on| on.0.to_owned()).collect(); let right_join_column = on.iter().map(|on| on.1.to_owned()).collect(); diff --git a/ballista/rust/core/src/serde/physical_plan/from_proto.rs b/ballista/rust/core/src/serde/physical_plan/from_proto.rs index 7f98a8378b0b..89307027d701 100644 --- a/ballista/rust/core/src/serde/physical_plan/from_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/from_proto.rs @@ -380,6 +380,7 @@ impl TryInto> for &protobuf::PhysicalPlanNode { protobuf::JoinType::Right => JoinType::Right, protobuf::JoinType::Full => JoinType::Full, protobuf::JoinType::Semi => JoinType::Semi, + protobuf::JoinType::Anti => JoinType::Anti, }; Ok(Arc::new(HashJoinExec::try_new( left, diff --git a/ballista/rust/core/src/serde/physical_plan/to_proto.rs b/ballista/rust/core/src/serde/physical_plan/to_proto.rs index c409f9474951..26092e74a096 100644 --- a/ballista/rust/core/src/serde/physical_plan/to_proto.rs +++ b/ballista/rust/core/src/serde/physical_plan/to_proto.rs @@ -134,6 +134,7 @@ impl TryInto for Arc { JoinType::Right => protobuf::JoinType::Right, JoinType::Full => protobuf::JoinType::Full, JoinType::Semi => protobuf::JoinType::Semi, + JoinType::Anti => protobuf::JoinType::Anti, }; Ok(protobuf::PhysicalPlanNode { physical_plan_type: Some(PhysicalPlanType::HashJoin(Box::new( diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs index 71de48cdb8f8..fe4ee65fad8f 100644 --- a/datafusion/src/logical_plan/builder.rs +++ b/datafusion/src/logical_plan/builder.rs @@ -410,7 +410,7 @@ fn build_join_schema( // left then right left_fields.chain(right_fields).cloned().collect() } - JoinType::Semi => { + JoinType::Semi | JoinType::Anti => { // Only use the left side for the schema left.fields().clone() } diff --git a/datafusion/src/logical_plan/plan.rs b/datafusion/src/logical_plan/plan.rs index 5cb94be405e7..5391e76e7576 100644 --- a/datafusion/src/logical_plan/plan.rs +++ b/datafusion/src/logical_plan/plan.rs @@ -46,6 +46,8 @@ pub enum JoinType { Full, /// Semi Join Semi, + /// Anti Join + Anti, } /// A LogicalPlan represents the different types of relational diff --git a/datafusion/src/optimizer/hash_build_probe_order.rs b/datafusion/src/optimizer/hash_build_probe_order.rs index 86d38ef313ce..74d2b0090194 100644 --- a/datafusion/src/optimizer/hash_build_probe_order.rs +++ b/datafusion/src/optimizer/hash_build_probe_order.rs @@ -109,7 +109,7 @@ fn should_swap_join_order(left: &LogicalPlan, right: &LogicalPlan) -> bool { fn supports_swap(join_type: JoinType) -> bool { match join_type { JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => true, - JoinType::Semi => false, + JoinType::Semi | JoinType::Anti => false, } } diff --git a/datafusion/src/physical_plan/hash_join.rs b/datafusion/src/physical_plan/hash_join.rs index 6653b9a356a4..d12e249cbe34 100644 --- a/datafusion/src/physical_plan/hash_join.rs +++ b/datafusion/src/physical_plan/hash_join.rs @@ -184,9 +184,11 @@ impl HashJoinExec { /// Calculates column indices and left/right placement on input / output schemas and jointype fn column_indices_from_schema(&self) -> ArrowResult> { let (primary_is_left, primary_schema, secondary_schema) = match self.join_type { - JoinType::Inner | JoinType::Left | JoinType::Full | JoinType::Semi => { - (true, self.left.schema(), self.right.schema()) - } + JoinType::Inner + | JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => (true, self.left.schema(), self.right.schema()), JoinType::Right => (false, self.right.schema(), self.left.schema()), }; let mut column_indices = Vec::with_capacity(self.schema.fields().len()); @@ -376,7 +378,9 @@ impl ExecutionPlan for HashJoinExec { let column_indices = self.column_indices_from_schema()?; let num_rows = left_data.1.num_rows(); let visited_left_side = match self.join_type { - JoinType::Left | JoinType::Full | JoinType::Semi => vec![false; num_rows], + JoinType::Left | JoinType::Full | JoinType::Semi | JoinType::Anti => { + vec![false; num_rows] + } JoinType::Inner | JoinType::Right => vec![], }; Ok(Box::pin(HashJoinStream { @@ -544,7 +548,7 @@ fn build_batch( ) .unwrap(); - if join_type == JoinType::Semi { + if matches!(join_type, JoinType::Semi | JoinType::Anti) { return Ok(( RecordBatch::new_empty(Arc::new(schema.clone())), left_indices, @@ -613,7 +617,7 @@ fn build_join_indexes( let left = &left_data.0; match join_type { - JoinType::Inner | JoinType::Semi => { + JoinType::Inner | JoinType::Semi | JoinType::Anti => { // Using a buffer builder to avoid slower normal builder let mut left_indices = UInt64BufferBuilder::new(0); let mut right_indices = UInt32BufferBuilder::new(0); @@ -1190,7 +1194,10 @@ impl Stream for HashJoinStream { self.num_output_rows += batch.num_rows(); match self.join_type { - JoinType::Left | JoinType::Full | JoinType::Semi => { + JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti => { left_side.iter().flatten().for_each(|x| { self.visited_left_side[x as usize] = true; }); @@ -1204,7 +1211,10 @@ impl Stream for HashJoinStream { let start = Instant::now(); // For the left join, produce rows for unmatched rows match self.join_type { - JoinType::Left | JoinType::Full | JoinType::Semi + JoinType::Left + | JoinType::Full + | JoinType::Semi + | JoinType::Anti if !self.is_exhausted => { let result = produce_from_matched( @@ -1230,6 +1240,7 @@ impl Stream for HashJoinStream { JoinType::Left | JoinType::Full | JoinType::Semi + | JoinType::Anti | JoinType::Inner | JoinType::Right => {} } @@ -1725,6 +1736,40 @@ mod tests { Ok(()) } + #[tokio::test] + async fn join_anti() -> Result<()> { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30, 40]), + ("b1", &vec![4, 5, 6, 5]), // 5 is double on the right + ("c2", &vec![70, 80, 90, 100]), + ); + let on = &[("b1", "b1")]; + + let join = join(left, right, on, &JoinType::Anti)?; + + let columns = columns(&join.schema()); + assert_eq!(columns, vec!["a1", "b1", "c1"]); + + let stream = join.execute(0).await?; + let batches = common::collect(stream).await?; + + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn join_right_one() -> Result<()> { let left = build_table( diff --git a/datafusion/src/physical_plan/hash_utils.rs b/datafusion/src/physical_plan/hash_utils.rs index 110319e4bb6b..a48710bfbfc3 100644 --- a/datafusion/src/physical_plan/hash_utils.rs +++ b/datafusion/src/physical_plan/hash_utils.rs @@ -34,6 +34,8 @@ pub enum JoinType { Full, /// Semi Join Semi, + /// Anti Join + Anti, } /// The on clause of the join, as vector of (left, right) columns. @@ -132,7 +134,7 @@ pub fn build_join_schema( // left then right left_fields.chain(right_fields).cloned().collect() } - JoinType::Semi => left.fields().clone(), + JoinType::Semi | JoinType::Anti => left.fields().clone(), }; Schema::new(fields) } diff --git a/datafusion/src/physical_plan/planner.rs b/datafusion/src/physical_plan/planner.rs index 4971a027ef1e..9d86f67cb2e1 100644 --- a/datafusion/src/physical_plan/planner.rs +++ b/datafusion/src/physical_plan/planner.rs @@ -368,6 +368,7 @@ impl DefaultPhysicalPlanner { JoinType::Right => hash_utils::JoinType::Right, JoinType::Full => hash_utils::JoinType::Full, JoinType::Semi => hash_utils::JoinType::Semi, + JoinType::Anti => hash_utils::JoinType::Anti, }; if ctx_state.config.concurrency > 1 && ctx_state.config.repartition_joins {