From f930c78fc812189108b7a14bf7c1a0dab8b70be1 Mon Sep 17 00:00:00 2001 From: clarkzinzow Date: Fri, 11 Aug 2023 11:35:18 -0700 Subject: [PATCH] Add support for joins. --- daft/dataframe/dataframe.py | 10 ++- daft/execution/execution_step.py | 4 +- daft/execution/physical_plan.py | 3 +- daft/execution/rust_physical_plan_shim.py | 23 +++++- daft/logical/builder.py | 17 ++--- daft/logical/logical_plan.py | 20 +++-- daft/logical/rust_logical_plan.py | 40 ++++++++-- daft/table/table.py | 5 +- src/daft-plan/src/builder.rs | 37 ++++++++- src/daft-plan/src/join.rs | 91 +++++++++++++++++++++++ src/daft-plan/src/lib.rs | 3 + src/daft-plan/src/logical_plan.rs | 28 ++++++- src/daft-plan/src/ops/join.rs | 56 ++++++++++++++ src/daft-plan/src/ops/mod.rs | 2 + src/daft-plan/src/physical_ops/join.rs | 36 +++++++++ src/daft-plan/src/physical_ops/mod.rs | 2 + src/daft-plan/src/physical_plan.rs | 36 +++++++++ src/daft-plan/src/planner.rs | 60 ++++++++++++++- tests/cookbook/test_joins.py | 8 +- tests/dataframe/test_joins.py | 12 +-- tests/table/test_joins.py | 6 +- 21 files changed, 449 insertions(+), 50 deletions(-) create mode 100644 src/daft-plan/src/join.rs create mode 100644 src/daft-plan/src/ops/join.rs create mode 100644 src/daft-plan/src/physical_ops/join.rs diff --git a/daft/dataframe/dataframe.py b/daft/dataframe/dataframe.py index d4e682e175..384a486e9c 100644 --- a/daft/dataframe/dataframe.py +++ b/daft/dataframe/dataframe.py @@ -24,12 +24,12 @@ from daft.api_annotations import DataframePublicAPI from daft.context import get_context from daft.convert import InputListType -from daft.daft import FileFormat, PartitionScheme, PartitionSpec +from daft.daft import FileFormat, JoinType, PartitionScheme, PartitionSpec from daft.dataframe.preview import DataFramePreview from daft.datatype import DataType from daft.errors import ExpressionTypeError from daft.expressions import Expression, ExpressionsProjection, col, lit -from daft.logical.builder import JoinType, LogicalPlanBuilder +from daft.logical.builder import LogicalPlanBuilder from daft.resource_request import ResourceRequest from daft.runners.partitioning import PartitionCacheEntry, PartitionSet from daft.runners.pyrunner import LocalPartitionSet @@ -700,11 +700,13 @@ def join( raise ValueError("If `on` is not None then both `left_on` and `right_on` must be None") left_on = on right_on = on - assert how == "inner", "only inner joins are currently supported" + join_type = JoinType.from_join_type_str(how) + if join_type != JoinType.Inner: + raise ValueError(f"Only inner joins are currently supported, but got: {how}") left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,)) right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,)) - builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=JoinType.INNER) + builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type) return DataFrame(builder) @DataframePublicAPI diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 005d4a0ec7..deadc561f7 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -18,11 +18,11 @@ CsvSourceConfig, FileFormat, FileFormatConfig, + JoinType, JsonSourceConfig, ParquetSourceConfig, ) from daft.expressions import Expression, ExpressionsProjection, col -from daft.logical.builder import JoinType from daft.logical.map_partition_ops import MapPartitionOp from daft.logical.schema import Schema from daft.resource_request import ResourceRequest @@ -654,7 +654,7 @@ def _join(self, inputs: list[Table]) -> list[Table]: left_on=self.left_on, right_on=self.right_on, output_projection=self.output_projection, - how=self.how.value, + how=self.how, ) return [result] diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 7e2176d563..e90985af26 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -23,7 +23,7 @@ from loguru import logger -from daft.daft import FileFormat, FileFormatConfig +from daft.daft import FileFormat, FileFormatConfig, JoinType from daft.execution import execution_step from daft.execution.execution_step import ( Instruction, @@ -34,7 +34,6 @@ SingleOutputPartitionTask, ) from daft.expressions import ExpressionsProjection -from daft.logical.builder import JoinType from daft.logical.schema import Schema from daft.resource_request import ResourceRequest from daft.runners.partitioning import PartialPartitionMetadata diff --git a/daft/execution/rust_physical_plan_shim.py b/daft/execution/rust_physical_plan_shim.py index bc5fa03cf6..690117802f 100644 --- a/daft/execution/rust_physical_plan_shim.py +++ b/daft/execution/rust_physical_plan_shim.py @@ -3,7 +3,7 @@ from typing import Iterator, TypeVar, cast from daft.context import get_context -from daft.daft import FileFormat, FileFormatConfig, PyExpr, PySchema, PyTable +from daft.daft import FileFormat, FileFormatConfig, JoinType, PyExpr, PySchema, PyTable from daft.execution import execution_step, physical_plan from daft.expressions import Expression, ExpressionsProjection from daft.logical.map_partition_ops import MapPartitionOp @@ -117,6 +117,27 @@ def reduce_merge( return physical_plan.reduce(input, reduce_instruction) +def join( + input: physical_plan.InProgressPhysicalPlan[PartitionT], + right: physical_plan.InProgressPhysicalPlan[PartitionT], + left_on: list[PyExpr], + right_on: list[PyExpr], + output_projection: list[PyExpr], + join_type: JoinType, +) -> physical_plan.InProgressPhysicalPlan[PartitionT]: + left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on]) + right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on]) + output_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in output_projection]) + return physical_plan.join( + left_plan=input, + right_plan=right, + left_on=left_on_expr_proj, + right_on=right_on_expr_proj, + output_projection=output_expr_proj, + how=join_type, + ) + + def write_file( input: physical_plan.InProgressPhysicalPlan[PartitionT], file_format: FileFormat, diff --git a/daft/logical/builder.py b/daft/logical/builder.py index 644d49aa85..f1e085de28 100644 --- a/daft/logical/builder.py +++ b/daft/logical/builder.py @@ -2,12 +2,17 @@ import pathlib from abc import ABC, abstractmethod -from enum import Enum from typing import TYPE_CHECKING import fsspec -from daft.daft import FileFormat, FileFormatConfig, PartitionScheme, PartitionSpec +from daft.daft import ( + FileFormat, + FileFormatConfig, + JoinType, + PartitionScheme, + PartitionSpec, +) from daft.expressions.expressions import Expression, ExpressionsProjection from daft.logical.schema import Schema from daft.resource_request import ResourceRequest @@ -17,12 +22,6 @@ from daft.planner import QueryPlanner -class JoinType(Enum): - INNER = "inner" - LEFT = "left" - RIGHT = "right" - - class LogicalPlanBuilder(ABC): """ An interface for building a logical plan for the Daft DataFrame. @@ -149,7 +148,7 @@ def join( right: LogicalPlanBuilder, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - how: JoinType = JoinType.INNER, + how: JoinType = JoinType.Inner, ) -> LogicalPlanBuilder: pass diff --git a/daft/logical/logical_plan.py b/daft/logical/logical_plan.py index 0e06341c57..9f2fcc4616 100644 --- a/daft/logical/logical_plan.py +++ b/daft/logical/logical_plan.py @@ -10,14 +10,20 @@ import fsspec from daft.context import get_context -from daft.daft import FileFormat, FileFormatConfig, PartitionScheme, PartitionSpec +from daft.daft import ( + FileFormat, + FileFormatConfig, + JoinType, + PartitionScheme, + PartitionSpec, +) from daft.datatype import DataType from daft.errors import ExpressionTypeError from daft.expressions import Expression, ExpressionsProjection, col from daft.expressions.testing import expr_structurally_equal from daft.internal.treenode import TreeNode from daft.logical.aggregation_plan_builder import AggregationPlanBuilder -from daft.logical.builder import JoinType, LogicalPlanBuilder +from daft.logical.builder import LogicalPlanBuilder from daft.logical.map_partition_ops import ExplodeOp, MapPartitionOp from daft.logical.schema import Schema from daft.resource_request import ResourceRequest @@ -198,7 +204,7 @@ def join( # type: ignore[override] right: PyLogicalPlanBuilder, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - how: JoinType = JoinType.INNER, + how: JoinType = JoinType.Inner, ) -> PyLogicalPlanBuilder: return Join( self._plan, @@ -1146,7 +1152,7 @@ def __init__( right: LogicalPlan, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - how: JoinType = JoinType.INNER, + how: JoinType = JoinType.Inner, ) -> None: assert len(left_on) == len(right_on), "left_on and right_on must match size" @@ -1165,13 +1171,13 @@ def __init__( self._how = how output_schema: Schema - if how == JoinType.LEFT: + if how == JoinType.Left: num_partitions = left.num_partitions() raise NotImplementedError() - elif how == JoinType.RIGHT: + elif how == JoinType.Right: num_partitions = right.num_partitions() raise NotImplementedError() - elif how == JoinType.INNER: + elif how == JoinType.Inner: num_partitions = max(left.num_partitions(), right.num_partitions()) right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()} left_columns = ExpressionsProjection.from_schema(left.schema()) diff --git a/daft/logical/rust_logical_plan.py b/daft/logical/rust_logical_plan.py index 9f723949df..f48db728d1 100644 --- a/daft/logical/rust_logical_plan.py +++ b/daft/logical/rust_logical_plan.py @@ -5,14 +5,14 @@ import fsspec -from daft import DataType +from daft import DataType, col from daft.context import get_context -from daft.daft import FileFormat, FileFormatConfig +from daft.daft import FileFormat, FileFormatConfig, JoinType from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder from daft.daft import PartitionScheme, PartitionSpec from daft.errors import ExpressionTypeError from daft.expressions.expressions import Expression, ExpressionsProjection -from daft.logical.builder import JoinType, LogicalPlanBuilder +from daft.logical.builder import LogicalPlanBuilder from daft.logical.schema import Schema from daft.resource_request import ResourceRequest from daft.runners.partitioning import PartitionCacheEntry @@ -192,9 +192,39 @@ def join( # type: ignore[override] right: RustLogicalPlanBuilder, left_on: ExpressionsProjection, right_on: ExpressionsProjection, - how: JoinType = JoinType.INNER, + how: JoinType = JoinType.Inner, ) -> RustLogicalPlanBuilder: - raise NotImplementedError("not implemented") + for schema, exprs in ((self.schema(), left_on), (right.schema(), right_on)): + resolved_schema = exprs.resolve_schema(schema) + for f, expr in zip(resolved_schema, exprs): + if f.dtype == DataType.null(): + raise ExpressionTypeError(f"Cannot join on null type expression: {expr}") + if how == JoinType.Left: + raise NotImplementedError("Left join not implemented.") + elif how == JoinType.Right: + raise NotImplementedError("Right join not implemented.") + elif how == JoinType.Inner: + # TODO(Clark): Port this logic to Rust-side once ExpressionsProjection has been ported. + right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()} + left_columns = ExpressionsProjection.from_schema(self.schema()) + right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set]) + output_projection = left_columns.union(right_columns, rename_dup="right.") + left_columns = left_columns + right_columns = ExpressionsProjection(list(output_projection)[len(left_columns) :]) + output_schema = left_columns.resolve_schema(self.schema()).union( + right_columns.resolve_schema(right.schema()) + ) + builder = self._builder.join( + right._builder, + left_on.to_inner_py_exprs(), + right_on.to_inner_py_exprs(), + output_projection.to_inner_py_exprs(), + output_schema._schema, + how, + ) + return RustLogicalPlanBuilder(builder) + else: + raise NotImplementedError(f"{how} join not implemented.") def concat(self, other: RustLogicalPlanBuilder) -> RustLogicalPlanBuilder: # type: ignore[override] builder = self._builder.concat(other._builder) diff --git a/daft/table/table.py b/daft/table/table.py index 78eec94404..acc80cb54f 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -7,6 +7,7 @@ from loguru import logger from daft.arrow_utils import ensure_table +from daft.daft import JoinType from daft.daft import PyTable as _PyTable from daft.daft import read_parquet as _read_parquet from daft.daft import read_parquet_bulk as _read_parquet_bulk @@ -284,9 +285,9 @@ def join( left_on: ExpressionsProjection, right_on: ExpressionsProjection, output_projection: ExpressionsProjection | None = None, - how: str = "inner", + how: JoinType = JoinType.Inner, ) -> Table: - if how != "inner": + if how != JoinType.Inner: raise NotImplementedError("TODO: [RUST] Implement Other Join types") if len(left_on) != len(right_on): raise ValueError( diff --git a/src/daft-plan/src/builder.rs b/src/daft-plan/src/builder.rs index ea4cece56b..9e756ad3a2 100644 --- a/src/daft-plan/src/builder.rs +++ b/src/daft-plan/src/builder.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use crate::logical_plan::LogicalPlan; +use crate::{logical_plan::LogicalPlan, JoinType}; #[cfg(feature = "python")] use { @@ -193,6 +193,41 @@ impl LogicalPlanBuilder { Ok(logical_plan_builder) } + pub fn join( + &self, + other: &Self, + left_on: Vec, + right_on: Vec, + output_projection: Vec, + output_schema: &PySchema, + join_type: JoinType, + ) -> PyResult { + let left_on_exprs = left_on + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let right_on_exprs = right_on + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let output_projection_exprs = output_projection + .iter() + .map(|e| e.clone().into()) + .collect::>(); + let logical_plan: LogicalPlan = ops::Join::new( + other.plan.clone(), + left_on_exprs, + right_on_exprs, + output_projection_exprs, + output_schema.clone().into(), + join_type, + self.plan.clone(), + ) + .into(); + let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into()); + Ok(logical_plan_builder) + } + pub fn concat(&self, other: &Self) -> PyResult { let self_schema = self.plan.schema(); let other_schema = other.plan.schema(); diff --git a/src/daft-plan/src/join.rs b/src/daft-plan/src/join.rs new file mode 100644 index 0000000000..ef049d6ff1 --- /dev/null +++ b/src/daft-plan/src/join.rs @@ -0,0 +1,91 @@ +use std::{ + fmt::{Display, Formatter, Result}, + str::FromStr, +}; + +use common_error::{DaftError, DaftResult}; +use daft_core::impl_bincode_py_state_serialization; +#[cfg(feature = "python")] +use pyo3::{ + exceptions::PyValueError, + pyclass, pymethods, + types::{PyBytes, PyTuple}, + PyResult, Python, +}; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Copy, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] +#[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] +pub enum JoinType { + Inner, + Left, + Right, +} + +#[cfg(feature = "python")] +#[pymethods] +impl JoinType { + #[new] + #[pyo3(signature = (*args))] + pub fn new(args: &PyTuple) -> PyResult { + match args.len() { + // Create dummy variant, to be overridden by __setstate__. + 0 => Ok(Self::Inner), + _ => Err(PyValueError::new_err(format!( + "expected no arguments to make new JoinType, got : {}", + args.len() + ))), + } + } + + /// Create a JoinType from its string representation. + /// + /// Args: + /// join_type: String representation of the join type, e.g. "inner", "left", or "right". + #[staticmethod] + pub fn from_join_type_str(join_type: &str) -> PyResult { + Self::from_str(join_type).map_err(|e| PyValueError::new_err(e.to_string())) + } + + pub fn __str__(&self) -> PyResult { + Ok(self.to_string()) + } +} + +impl_bincode_py_state_serialization!(JoinType); + +impl JoinType { + pub fn iterator() -> std::slice::Iter<'static, JoinType> { + use JoinType::*; + + static JOIN_TYPES: [JoinType; 3] = [Inner, Left, Right]; + JOIN_TYPES.iter() + } +} + +impl FromStr for JoinType { + type Err = DaftError; + + fn from_str(join_type: &str) -> DaftResult { + use JoinType::*; + + match join_type { + "inner" => Ok(Inner), + "left" => Ok(Left), + "right" => Ok(Right), + _ => Err(DaftError::TypeError(format!( + "Join type {} is not supported; only the following modes are supported: {:?}", + join_type, + JoinType::iterator().as_slice() + ))), + } + } +} + +impl Display for JoinType { + fn fmt(&self, f: &mut Formatter) -> Result { + // Leverage Debug trait implementation, which will already return the enum variant as a string. + write!(f, "{:?}", self) + } +} diff --git a/src/daft-plan/src/lib.rs b/src/daft-plan/src/lib.rs index 41f34ac156..aa1bbc1603 100644 --- a/src/daft-plan/src/lib.rs +++ b/src/daft-plan/src/lib.rs @@ -1,5 +1,6 @@ mod builder; mod display; +mod join; mod logical_plan; mod ops; mod partitioning; @@ -10,6 +11,7 @@ mod sink_info; mod source_info; pub use builder::LogicalPlanBuilder; +pub use join::JoinType; pub use logical_plan::LogicalPlan; pub use partitioning::{PartitionScheme, PartitionSpec}; pub use source_info::{ @@ -29,6 +31,7 @@ pub fn register_modules(_py: Python, parent: &PyModule) -> PyResult<()> { parent.add_class::()?; parent.add_class::()?; parent.add_class::()?; + parent.add_class::()?; Ok(()) } diff --git a/src/daft-plan/src/logical_plan.rs b/src/daft-plan/src/logical_plan.rs index 6cceaf312e..ff2ebddfb2 100644 --- a/src/daft-plan/src/logical_plan.rs +++ b/src/daft-plan/src/logical_plan.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::{cmp::max, sync::Arc}; use daft_core::schema::SchemaRef; @@ -17,6 +17,7 @@ pub enum LogicalPlan { Distinct(Distinct), Aggregate(Aggregate), Concat(Concat), + Join(Join), Sink(Sink), } @@ -38,6 +39,7 @@ impl LogicalPlan { Self::Distinct(Distinct { input, .. }) => input.schema(), Self::Aggregate(aggregate) => aggregate.schema(), Self::Concat(Concat { input, .. }) => input.schema(), + Self::Join(Join { output_schema, .. }) => output_schema.clone(), Self::Sink(Sink { schema, .. }) => schema.clone(), } } @@ -77,6 +79,27 @@ impl LogicalPlan { None, ) .into(), + Self::Join(Join { + input, + right, + left_on, + .. + }) => match max( + input.partition_spec().num_partitions, + right.partition_spec().num_partitions, + ) { + // NOTE: This duplicates the repartitioning logic in the planner, where we + // conditionally repartition the left and right tables. + // TODO(Clark): Consolidate this logic with the planner logic when we push the partition spec + // to be an entirely planner-side concept. + 1 => input.partition_spec(), + num_partitions => PartitionSpec::new_internal( + PartitionScheme::Hash, + num_partitions, + Some(left_on.clone()), + ) + .into(), + }, Self::Sink(Sink { input, .. }) => input.partition_spec(), } } @@ -94,6 +117,7 @@ impl LogicalPlan { Self::Distinct(Distinct { input, .. }) => vec![input], Self::Aggregate(Aggregate { input, .. }) => vec![input], Self::Concat(Concat { input, other }) => vec![input, other], + Self::Join(Join { input, right, .. }) => vec![input, right], Self::Sink(Sink { input, .. }) => vec![input], } } @@ -113,6 +137,7 @@ impl LogicalPlan { Self::Distinct(_) => vec!["Distinct".to_string()], Self::Aggregate(aggregate) => aggregate.multiline_display(), Self::Concat(_) => vec!["Concat".to_string()], + Self::Join(join) => join.multiline_display(), Self::Sink(sink) => sink.multiline_display(), } } @@ -145,4 +170,5 @@ impl_from_data_struct_for_logical_plan!(Coalesce); impl_from_data_struct_for_logical_plan!(Distinct); impl_from_data_struct_for_logical_plan!(Aggregate); impl_from_data_struct_for_logical_plan!(Concat); +impl_from_data_struct_for_logical_plan!(Join); impl_from_data_struct_for_logical_plan!(Sink); diff --git a/src/daft-plan/src/ops/join.rs b/src/daft-plan/src/ops/join.rs new file mode 100644 index 0000000000..f8f1f69867 --- /dev/null +++ b/src/daft-plan/src/ops/join.rs @@ -0,0 +1,56 @@ +use std::sync::Arc; + +use daft_core::schema::SchemaRef; +use daft_dsl::Expr; + +use crate::{JoinType, LogicalPlan}; + +#[derive(Clone, Debug)] +pub struct Join { + pub right: Arc, + pub left_on: Vec, + pub right_on: Vec, + pub output_projection: Vec, + pub output_schema: SchemaRef, + pub join_type: JoinType, + // Upstream node. + pub input: Arc, +} + +impl Join { + pub(crate) fn new( + right: Arc, + left_on: Vec, + right_on: Vec, + output_projection: Vec, + output_schema: SchemaRef, + join_type: JoinType, + input: Arc, + ) -> Self { + Self { + right, + left_on, + right_on, + output_projection, + output_schema, + join_type, + input, + } + } + + pub fn multiline_display(&self) -> Vec { + let mut res = vec![]; + res.push(format!("Join ({}):", self.join_type)); + if !self.left_on.is_empty() { + res.push(format!(" Left on: {:?}", self.left_on)); + } + if !self.right_on.is_empty() { + res.push(format!(" Right on: {:?}", self.left_on)); + } + res.push(format!( + " Output schema: {}", + self.output_schema.short_string() + )); + res + } +} diff --git a/src/daft-plan/src/ops/mod.rs b/src/daft-plan/src/ops/mod.rs index c140e9bcf8..ad8e4130cf 100644 --- a/src/daft-plan/src/ops/mod.rs +++ b/src/daft-plan/src/ops/mod.rs @@ -4,6 +4,7 @@ mod concat; mod distinct; mod explode; mod filter; +mod join; mod limit; mod project; mod repartition; @@ -17,6 +18,7 @@ pub use concat::Concat; pub use distinct::Distinct; pub use explode::Explode; pub use filter::Filter; +pub use join::Join; pub use limit::Limit; pub use project::Project; pub use repartition::Repartition; diff --git a/src/daft-plan/src/physical_ops/join.rs b/src/daft-plan/src/physical_ops/join.rs new file mode 100644 index 0000000000..2e4896244e --- /dev/null +++ b/src/daft-plan/src/physical_ops/join.rs @@ -0,0 +1,36 @@ +use std::sync::Arc; + +use daft_dsl::Expr; + +use crate::{physical_plan::PhysicalPlan, JoinType}; + +#[derive(Clone, Debug)] +pub struct Join { + pub right: Arc, + pub left_on: Vec, + pub right_on: Vec, + pub output_projection: Vec, + pub join_type: JoinType, + // Upstream node. + pub input: Arc, +} + +impl Join { + pub(crate) fn new( + right: Arc, + left_on: Vec, + right_on: Vec, + output_projection: Vec, + join_type: JoinType, + input: Arc, + ) -> Self { + Self { + right, + left_on, + right_on, + output_projection, + join_type, + input, + } + } +} diff --git a/src/daft-plan/src/physical_ops/mod.rs b/src/daft-plan/src/physical_ops/mod.rs index cbce384e61..d1e12d6a69 100644 --- a/src/daft-plan/src/physical_ops/mod.rs +++ b/src/daft-plan/src/physical_ops/mod.rs @@ -8,6 +8,7 @@ mod filter; mod flatten; #[cfg(feature = "python")] mod in_memory; +mod join; mod json; mod limit; mod parquet; @@ -26,6 +27,7 @@ pub use filter::Filter; pub use flatten::Flatten; #[cfg(feature = "python")] pub use in_memory::InMemoryScan; +pub use join::Join; pub use json::{TabularScanJson, TabularWriteJson}; pub use limit::Limit; pub use parquet::{TabularScanParquet, TabularWriteParquet}; diff --git a/src/daft-plan/src/physical_plan.rs b/src/daft-plan/src/physical_plan.rs index 61bb8981a2..2310d8bd34 100644 --- a/src/daft-plan/src/physical_plan.rs +++ b/src/daft-plan/src/physical_plan.rs @@ -40,6 +40,7 @@ pub enum PhysicalPlan { Aggregate(Aggregate), Coalesce(Coalesce), Concat(Concat), + Join(Join), TabularWriteParquet(TabularWriteParquet), TabularWriteJson(TabularWriteJson), TabularWriteCsv(TabularWriteCsv), @@ -360,6 +361,41 @@ impl PhysicalPlan { .call1((upstream_input_iter, upstream_other_iter))?; Ok(py_iter.into()) } + PhysicalPlan::Join(Join { + right, + left_on, + right_on, + output_projection, + join_type, + input, + }) => { + let upstream_input_iter = input.to_partition_tasks(py, psets)?; + let upstream_right_iter = right.to_partition_tasks(py, psets)?; + let left_on_pyexprs: Vec = left_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let right_on_pyexprs: Vec = right_on + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let output_projection_pyexprs: Vec = output_projection + .iter() + .map(|expr| PyExpr::from(expr.clone())) + .collect(); + let py_iter = py + .import(pyo3::intern!(py, "daft.execution.rust_physical_plan_shim"))? + .getattr(pyo3::intern!(py, "join"))? + .call1(( + upstream_input_iter, + upstream_right_iter, + left_on_pyexprs, + right_on_pyexprs, + output_projection_pyexprs, + *join_type, + ))?; + Ok(py_iter.into()) + } PhysicalPlan::TabularWriteParquet(TabularWriteParquet { schema, file_info: diff --git a/src/daft-plan/src/planner.rs b/src/daft-plan/src/planner.rs index 14421e6096..bb2bb10d8c 100644 --- a/src/daft-plan/src/planner.rs +++ b/src/daft-plan/src/planner.rs @@ -1,3 +1,4 @@ +use std::cmp::max; use std::sync::Arc; use common_error::DaftResult; @@ -7,13 +8,13 @@ use crate::logical_plan::LogicalPlan; use crate::ops::{ Aggregate as LogicalAggregate, Coalesce as LogicalCoalesce, Concat as LogicalConcat, Distinct as LogicalDistinct, Explode as LogicalExplode, Filter as LogicalFilter, - Limit as LogicalLimit, Project as LogicalProject, Repartition as LogicalRepartition, - Sink as LogicalSink, Sort as LogicalSort, Source, + Join as LogicalJoin, Limit as LogicalLimit, Project as LogicalProject, + Repartition as LogicalRepartition, Sink as LogicalSink, Sort as LogicalSort, Source, }; -use crate::physical_ops::*; use crate::physical_plan::PhysicalPlan; use crate::sink_info::{OutputFileInfo, SinkInfo}; use crate::source_info::{ExternalInfo as ExternalSourceInfo, FileFormatConfig, SourceInfo}; +use crate::{physical_ops::*, PartitionSpec}; use crate::{FileFormat, PartitionScheme}; #[cfg(feature = "python")] @@ -292,6 +293,59 @@ pub fn plan(logical_plan: &LogicalPlan) -> DaftResult { input_physical.into(), ))) } + LogicalPlan::Join(LogicalJoin { + right, + input, + left_on, + right_on, + output_projection, + join_type, + .. + }) => { + let mut left_physical = plan(input)?; + let mut right_physical = plan(right)?; + let left_pspec = input.partition_spec(); + let right_pspec = right.partition_spec(); + let num_partitions = max(left_pspec.num_partitions, right_pspec.num_partitions); + let new_left_pspec = Arc::new(PartitionSpec::new_internal( + PartitionScheme::Hash, + num_partitions, + Some(left_on.clone()), + )); + let new_right_pspec = Arc::new(PartitionSpec::new_internal( + PartitionScheme::Hash, + num_partitions, + Some(right_on.clone()), + )); + if (num_partitions > 1 || left_pspec.num_partitions != num_partitions) + && left_pspec != new_left_pspec + { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + num_partitions, + left_on.clone(), + left_physical.into(), + )); + left_physical = PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); + } + if (num_partitions > 1 || right_pspec.num_partitions != num_partitions) + && right_pspec != new_right_pspec + { + let split_op = PhysicalPlan::FanoutByHash(FanoutByHash::new( + num_partitions, + right_on.clone(), + right_physical.into(), + )); + right_physical = PhysicalPlan::ReduceMerge(ReduceMerge::new(split_op.into())); + } + Ok(PhysicalPlan::Join(Join::new( + right_physical.into(), + left_on.clone(), + right_on.clone(), + output_projection.clone(), + *join_type, + left_physical.into(), + ))) + } LogicalPlan::Sink(LogicalSink { schema, sink_info, diff --git a/tests/cookbook/test_joins.py b/tests/cookbook/test_joins.py index 4ec962e328..61a555902f 100644 --- a/tests/cookbook/test_joins.py +++ b/tests/cookbook/test_joins.py @@ -4,7 +4,7 @@ from tests.conftest import assert_df_equals -def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): daft_df = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.select(col("Unique Key"), col("Borough")) daft_df_right = daft_df.select(col("Unique Key"), col("Created Date")) @@ -21,7 +21,7 @@ def test_simple_join(daft_df, service_requests_csv_pd_df, repartition_nparts): assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): daft_df = daft_df.repartition(repartition_nparts) daft_df = daft_df.select(col("Unique Key"), col("Borough")) @@ -38,7 +38,7 @@ def test_simple_self_join(daft_df, service_requests_csv_pd_df, repartition_npart assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): daft_df_right = daft_df.sort("Unique Key").limit(25).repartition(repartition_nparts) daft_df_left = daft_df.repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) @@ -58,7 +58,7 @@ def test_simple_join_missing_rvalues(daft_df, service_requests_csv_pd_df, repart assert_df_equals(daft_pd_df, service_requests_csv_pd_df) -def test_simple_join_missing_lvalues(daft_df, service_requests_csv_pd_df, repartition_nparts): +def test_simple_join_missing_lvalues(daft_df, service_requests_csv_pd_df, repartition_nparts, use_new_planner): daft_df_right = daft_df.repartition(repartition_nparts) daft_df_left = daft_df.sort(col("Unique Key")).limit(25).repartition(repartition_nparts) daft_df_left = daft_df_left.select(col("Unique Key"), col("Borough")) diff --git a/tests/dataframe/test_joins.py b/tests/dataframe/test_joins.py index 55797792ea..933b737c94 100644 --- a/tests/dataframe/test_joins.py +++ b/tests/dataframe/test_joins.py @@ -10,7 +10,7 @@ @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_multicol_joins(n_partitions: int): +def test_multicol_joins(n_partitions: int, use_new_planner): df = daft.from_pydict( { "A": [1, 2, 3], @@ -31,7 +31,7 @@ def test_multicol_joins(n_partitions: int): @pytest.mark.parametrize("n_partitions", [1, 2, 4]) -def test_limit_after_join(n_partitions: int): +def test_limit_after_join(n_partitions: int, use_new_planner): data = { "A": [1, 2, 3], } @@ -50,7 +50,7 @@ def test_limit_after_join(n_partitions: int): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join(repartition_nparts): +def test_inner_join(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [1, None, 3], @@ -76,7 +76,7 @@ def test_inner_join(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_multikey(repartition_nparts): +def test_inner_join_multikey(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [1, None, None], @@ -105,7 +105,7 @@ def test_inner_join_multikey(repartition_nparts): @pytest.mark.parametrize("repartition_nparts", [1, 2, 4]) -def test_inner_join_all_null(repartition_nparts): +def test_inner_join_all_null(repartition_nparts, use_new_planner): daft_df = daft.from_pydict( { "id": [None, None, None], @@ -130,7 +130,7 @@ def test_inner_join_all_null(repartition_nparts): ) -def test_inner_join_null_type_column(): +def test_inner_join_null_type_column(use_new_planner): daft_df = daft.from_pydict( { "id": [None, None, None], diff --git a/tests/table/test_joins.py b/tests/table/test_joins.py index 0e25f63899..b12c15eb3c 100644 --- a/tests/table/test_joins.py +++ b/tests/table/test_joins.py @@ -5,6 +5,7 @@ import pytest from daft import utils +from daft.daft import JoinType from daft.datatype import DataType from daft.expressions import col from daft.series import Series @@ -45,7 +46,7 @@ def test_table_join_single_column(dtype, data) -> None: [col("x").cast(dtype), col("x_ind")] ) right_table = Table.from_pydict({"y": r, "y_ind": list(range(len(r)))}) - result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("y")], how="inner") + result_table = left_table.join(right_table, left_on=[col("x")], right_on=[col("y")], how=JoinType.Inner) assert result_table.column_names() == ["x", "x_ind", "y", "y_ind"] @@ -60,7 +61,7 @@ def test_table_join_single_column(dtype, data) -> None: assert result_table.get_column("y").to_pylist() == result_r # make sure the result is the same with right table on left - result_table = right_table.join(left_table, right_on=[col("x")], left_on=[col("y")], how="inner") + result_table = right_table.join(left_table, right_on=[col("x")], left_on=[col("y")], how=JoinType.Inner) assert result_table.column_names() == ["y", "y_ind", "x", "x_ind"] @@ -233,7 +234,6 @@ def test_table_join_single_column_name_conflicts_different_named_join() -> None: def test_table_join_single_column_name_multiple_conflicts() -> None: - left_table = Table.from_pydict({"x": [0, 1, 2, 3], "y": [2, 3, 4, 5], "right.y": [6, 7, 8, 9]}) right_table = Table.from_pydict({"x": [3, 2, 1, 0], "y": [10, 11, 12, 13]})