Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEAT] [New Query Planner] Add support for joins. #1260

Merged
merged 1 commit into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@
from daft.api_annotations import DataframePublicAPI
from daft.context import get_context
from daft.convert import InputListType
from daft.daft import FileFormat, PartitionScheme, PartitionSpec
from daft.daft import FileFormat, JoinType, PartitionScheme, PartitionSpec
from daft.dataframe.preview import DataFramePreview
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col, lit
from daft.logical.builder import JoinType, LogicalPlanBuilder
from daft.logical.builder import LogicalPlanBuilder
from daft.resource_request import ResourceRequest
from daft.runners.partitioning import PartitionCacheEntry, PartitionSet
from daft.runners.pyrunner import LocalPartitionSet
Expand Down Expand Up @@ -700,11 +700,13 @@
raise ValueError("If `on` is not None then both `left_on` and `right_on` must be None")
left_on = on
right_on = on
assert how == "inner", "only inner joins are currently supported"
join_type = JoinType.from_join_type_str(how)
if join_type != JoinType.Inner:
raise ValueError(f"Only inner joins are currently supported, but got: {how}")

Check warning on line 705 in daft/dataframe/dataframe.py

View check run for this annotation

Codecov / codecov/patch

daft/dataframe/dataframe.py#L705

Added line #L705 was not covered by tests

left_exprs = self.__column_input_to_expression(tuple(left_on) if isinstance(left_on, list) else (left_on,))
right_exprs = self.__column_input_to_expression(tuple(right_on) if isinstance(right_on, list) else (right_on,))
builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=JoinType.INNER)
builder = self._builder.join(other._builder, left_on=left_exprs, right_on=right_exprs, how=join_type)
return DataFrame(builder)

@DataframePublicAPI
Expand Down
4 changes: 2 additions & 2 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,11 @@
CsvSourceConfig,
FileFormat,
FileFormatConfig,
JoinType,
JsonSourceConfig,
ParquetSourceConfig,
)
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical.builder import JoinType
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
Expand Down Expand Up @@ -654,7 +654,7 @@ def _join(self, inputs: list[Table]) -> list[Table]:
left_on=self.left_on,
right_on=self.right_on,
output_projection=self.output_projection,
how=self.how.value,
how=self.how,
)
return [result]

Expand Down
3 changes: 1 addition & 2 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@

from loguru import logger

from daft.daft import FileFormat, FileFormatConfig
from daft.daft import FileFormat, FileFormatConfig, JoinType
from daft.execution import execution_step
from daft.execution.execution_step import (
Instruction,
Expand All @@ -34,7 +34,6 @@
SingleOutputPartitionTask,
)
from daft.expressions import ExpressionsProjection
from daft.logical.builder import JoinType
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
from daft.runners.partitioning import PartialPartitionMetadata
Expand Down
23 changes: 22 additions & 1 deletion daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Iterator, TypeVar, cast

from daft.context import get_context
from daft.daft import FileFormat, FileFormatConfig, PyExpr, PySchema, PyTable
from daft.daft import FileFormat, FileFormatConfig, JoinType, PyExpr, PySchema, PyTable
from daft.execution import execution_step, physical_plan
from daft.expressions import Expression, ExpressionsProjection
from daft.logical.map_partition_ops import MapPartitionOp
Expand Down Expand Up @@ -117,6 +117,27 @@ def reduce_merge(
return physical_plan.reduce(input, reduce_instruction)


def join(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
right: physical_plan.InProgressPhysicalPlan[PartitionT],
left_on: list[PyExpr],
right_on: list[PyExpr],
output_projection: list[PyExpr],
join_type: JoinType,
) -> physical_plan.InProgressPhysicalPlan[PartitionT]:
left_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in left_on])
right_on_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in right_on])
output_expr_proj = ExpressionsProjection([Expression._from_pyexpr(expr) for expr in output_projection])
return physical_plan.join(
left_plan=input,
right_plan=right,
left_on=left_on_expr_proj,
right_on=right_on_expr_proj,
output_projection=output_expr_proj,
how=join_type,
)


def write_file(
input: physical_plan.InProgressPhysicalPlan[PartitionT],
file_format: FileFormat,
Expand Down
17 changes: 8 additions & 9 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@

import pathlib
from abc import ABC, abstractmethod
from enum import Enum
from typing import TYPE_CHECKING

import fsspec

from daft.daft import FileFormat, FileFormatConfig, PartitionScheme, PartitionSpec
from daft.daft import (
FileFormat,
FileFormatConfig,
JoinType,
PartitionScheme,
PartitionSpec,
)
from daft.expressions.expressions import Expression, ExpressionsProjection
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
Expand All @@ -17,12 +22,6 @@
from daft.planner import QueryPlanner


class JoinType(Enum):
INNER = "inner"
LEFT = "left"
RIGHT = "right"


class LogicalPlanBuilder(ABC):
"""
An interface for building a logical plan for the Daft DataFrame.
Expand Down Expand Up @@ -149,7 +148,7 @@ def join(
right: LogicalPlanBuilder,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType = JoinType.INNER,
how: JoinType = JoinType.Inner,
) -> LogicalPlanBuilder:
pass

Expand Down
20 changes: 13 additions & 7 deletions daft/logical/logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,20 @@
import fsspec

from daft.context import get_context
from daft.daft import FileFormat, FileFormatConfig, PartitionScheme, PartitionSpec
from daft.daft import (
FileFormat,
FileFormatConfig,
JoinType,
PartitionScheme,
PartitionSpec,
)
from daft.datatype import DataType
from daft.errors import ExpressionTypeError
from daft.expressions import Expression, ExpressionsProjection, col
from daft.expressions.testing import expr_structurally_equal
from daft.internal.treenode import TreeNode
from daft.logical.aggregation_plan_builder import AggregationPlanBuilder
from daft.logical.builder import JoinType, LogicalPlanBuilder
from daft.logical.builder import LogicalPlanBuilder
from daft.logical.map_partition_ops import ExplodeOp, MapPartitionOp
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
Expand Down Expand Up @@ -198,7 +204,7 @@ def join( # type: ignore[override]
right: PyLogicalPlanBuilder,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType = JoinType.INNER,
how: JoinType = JoinType.Inner,
) -> PyLogicalPlanBuilder:
return Join(
self._plan,
Expand Down Expand Up @@ -1146,7 +1152,7 @@ def __init__(
right: LogicalPlan,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType = JoinType.INNER,
how: JoinType = JoinType.Inner,
) -> None:
assert len(left_on) == len(right_on), "left_on and right_on must match size"

Expand All @@ -1165,13 +1171,13 @@ def __init__(

self._how = how
output_schema: Schema
if how == JoinType.LEFT:
if how == JoinType.Left:
num_partitions = left.num_partitions()
raise NotImplementedError()
elif how == JoinType.RIGHT:
elif how == JoinType.Right:
num_partitions = right.num_partitions()
raise NotImplementedError()
elif how == JoinType.INNER:
elif how == JoinType.Inner:
num_partitions = max(left.num_partitions(), right.num_partitions())
right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()}
left_columns = ExpressionsProjection.from_schema(left.schema())
Expand Down
40 changes: 35 additions & 5 deletions daft/logical/rust_logical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import fsspec

from daft import DataType
from daft import DataType, col
from daft.context import get_context
from daft.daft import FileFormat, FileFormatConfig
from daft.daft import FileFormat, FileFormatConfig, JoinType
from daft.daft import LogicalPlanBuilder as _LogicalPlanBuilder
from daft.daft import PartitionScheme, PartitionSpec
from daft.errors import ExpressionTypeError
from daft.expressions.expressions import Expression, ExpressionsProjection
from daft.logical.builder import JoinType, LogicalPlanBuilder
from daft.logical.builder import LogicalPlanBuilder
from daft.logical.schema import Schema
from daft.resource_request import ResourceRequest
from daft.runners.partitioning import PartitionCacheEntry
Expand Down Expand Up @@ -192,9 +192,39 @@
right: RustLogicalPlanBuilder,
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
how: JoinType = JoinType.INNER,
how: JoinType = JoinType.Inner,
) -> RustLogicalPlanBuilder:
raise NotImplementedError("not implemented")
for schema, exprs in ((self.schema(), left_on), (right.schema(), right_on)):
resolved_schema = exprs.resolve_schema(schema)
for f, expr in zip(resolved_schema, exprs):
if f.dtype == DataType.null():
raise ExpressionTypeError(f"Cannot join on null type expression: {expr}")
if how == JoinType.Left:
raise NotImplementedError("Left join not implemented.")

Check warning on line 203 in daft/logical/rust_logical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/logical/rust_logical_plan.py#L203

Added line #L203 was not covered by tests
elif how == JoinType.Right:
raise NotImplementedError("Right join not implemented.")

Check warning on line 205 in daft/logical/rust_logical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/logical/rust_logical_plan.py#L205

Added line #L205 was not covered by tests
elif how == JoinType.Inner:
# TODO(Clark): Port this logic to Rust-side once ExpressionsProjection has been ported.
right_drop_set = {r.name() for l, r in zip(left_on, right_on) if l.name() == r.name()}
left_columns = ExpressionsProjection.from_schema(self.schema())
right_columns = ExpressionsProjection([col(f.name) for f in right.schema() if f.name not in right_drop_set])
output_projection = left_columns.union(right_columns, rename_dup="right.")
left_columns = left_columns
right_columns = ExpressionsProjection(list(output_projection)[len(left_columns) :])
output_schema = left_columns.resolve_schema(self.schema()).union(
right_columns.resolve_schema(right.schema())
)
builder = self._builder.join(
right._builder,
left_on.to_inner_py_exprs(),
right_on.to_inner_py_exprs(),
output_projection.to_inner_py_exprs(),
output_schema._schema,
how,
)
return RustLogicalPlanBuilder(builder)
else:
raise NotImplementedError(f"{how} join not implemented.")

Check warning on line 227 in daft/logical/rust_logical_plan.py

View check run for this annotation

Codecov / codecov/patch

daft/logical/rust_logical_plan.py#L227

Added line #L227 was not covered by tests

def concat(self, other: RustLogicalPlanBuilder) -> RustLogicalPlanBuilder: # type: ignore[override]
builder = self._builder.concat(other._builder)
Expand Down
5 changes: 3 additions & 2 deletions daft/table/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from loguru import logger

from daft.arrow_utils import ensure_table
from daft.daft import JoinType
from daft.daft import PyTable as _PyTable
from daft.daft import read_parquet as _read_parquet
from daft.daft import read_parquet_bulk as _read_parquet_bulk
Expand Down Expand Up @@ -284,9 +285,9 @@ def join(
left_on: ExpressionsProjection,
right_on: ExpressionsProjection,
output_projection: ExpressionsProjection | None = None,
how: str = "inner",
how: JoinType = JoinType.Inner,
) -> Table:
if how != "inner":
if how != JoinType.Inner:
raise NotImplementedError("TODO: [RUST] Implement Other Join types")
if len(left_on) != len(right_on):
raise ValueError(
Expand Down
37 changes: 36 additions & 1 deletion src/daft-plan/src/builder.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use std::sync::Arc;

use crate::logical_plan::LogicalPlan;
use crate::{logical_plan::LogicalPlan, JoinType};

#[cfg(feature = "python")]
use {
Expand Down Expand Up @@ -193,6 +193,41 @@ impl LogicalPlanBuilder {
Ok(logical_plan_builder)
}

pub fn join(
&self,
other: &Self,
left_on: Vec<PyExpr>,
right_on: Vec<PyExpr>,
output_projection: Vec<PyExpr>,
output_schema: &PySchema,
join_type: JoinType,
) -> PyResult<LogicalPlanBuilder> {
let left_on_exprs = left_on
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let right_on_exprs = right_on
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let output_projection_exprs = output_projection
.iter()
.map(|e| e.clone().into())
.collect::<Vec<Expr>>();
let logical_plan: LogicalPlan = ops::Join::new(
other.plan.clone(),
left_on_exprs,
right_on_exprs,
output_projection_exprs,
output_schema.clone().into(),
join_type,
self.plan.clone(),
)
.into();
let logical_plan_builder = LogicalPlanBuilder::new(logical_plan.into());
Ok(logical_plan_builder)
}

pub fn concat(&self, other: &Self) -> PyResult<LogicalPlanBuilder> {
let self_schema = self.plan.schema();
let other_schema = other.plan.schema();
Expand Down
Loading
Loading