Skip to content

Commit

Permalink
Fixes #299 to work without relying on expression IDs (#510)
Browse files Browse the repository at this point in the history
* #299 fixes sorts by avoiding recomputations on Expressions that have
IDs that already exist in the current vPartition
* However, moving forward Expressions will be ID-less
* This PR fixes sorts in a different way, no-opping expression
evaluation in `quantile_reduce_func` during sort sampling
  • Loading branch information
jaychia committed Jan 26, 2023
1 parent 301982f commit 1ec273f
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 13 deletions.
5 changes: 4 additions & 1 deletion daft/execution/logical_op_runners.py
Expand Up @@ -359,7 +359,10 @@ def sample_map_func(part: vPartition) -> vPartition:

def quantile_reduce_func(to_reduce: list[vPartition]) -> vPartition:
merged = vPartition.merge_partitions(to_reduce, verify_partition_id=False)
merged_sorted = merged.sort(exprs, descending=descending)

# Skip evaluation of expressions by converting to ColumnExpression, since evaluation was done in sample_map_func
merged_sorted = merged.sort(exprs.to_column_expressions(), descending=descending)

return merged_sorted.quantiles(num_partitions)

prev_part = inputs[child_id]
Expand Down
4 changes: 0 additions & 4 deletions daft/runners/partitioning.py
Expand Up @@ -208,10 +208,6 @@ def get_unresolved_col_expressions(self) -> ExpressionList:
def eval_expression(self, expr: Expression) -> PyListTile:
expr_col_id = expr.get_id()

# Avoid recomputing expressions that have been computed before
if expr_col_id in self.columns and self.columns[expr_col_id].column_name == expr.name():
return self.columns[expr_col_id]

expr_name = expr.name()

assert expr_col_id is not None
Expand Down
16 changes: 13 additions & 3 deletions tests/dataframe_cookbook/test_sorting.py
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
import pytest

from daft.expressions import col
Expand All @@ -14,12 +15,21 @@
def test_sorted_by_expr(daft_df, service_requests_csv_pd_df, repartition_nparts):
"""Sort by a column that undergoes an expression"""
daft_df = daft_df.repartition(repartition_nparts)
daft_sorted_df = daft_df.sort(col("Unique Key") + 1)

daft_sorted_df = daft_df.sort(((col("Unique Key") % 2) == 0).if_else(col("Unique Key"), col("Unique Key") * -1))
daft_sorted_pd_df = daft_sorted_df.to_pandas()

service_requests_csv_pd_df["tmp"] = service_requests_csv_pd_df["Unique Key"]
service_requests_csv_pd_df["tmp"] = np.where(
service_requests_csv_pd_df["tmp"] % 2 == 0,
service_requests_csv_pd_df["tmp"],
service_requests_csv_pd_df["tmp"] * -1,
)
service_requests_csv_pd_df = service_requests_csv_pd_df.sort_values("tmp", ascending=True)
service_requests_csv_pd_df = service_requests_csv_pd_df.drop(["tmp"], axis=1)

assert_df_equals(
daft_sorted_pd_df,
service_requests_csv_pd_df.sort_values(by=["Unique Key"]),
service_requests_csv_pd_df,
assert_ordering=True,
)

Expand Down
8 changes: 3 additions & 5 deletions tests/runners/test_partitioning.py
Expand Up @@ -149,18 +149,16 @@ def test_vpartition_filter() -> None:
expr = col("x") < 4
expr = resolve_expr(expr)

# Need to make sure there are no conflicts with column ID of `expr`
col_ids = [expr.required_columns()[0].get_id()] + list(range(expr.get_id() + 1, expr.get_id() + 4))

tiles = {}
for i in col_ids:
col_id = expr.required_columns()[0].get_id()
for i in range(col_id, col_id + 4):
block = DataBlock.make_block(np.arange(0, 10, 1))
tiles[i] = PyListTile(column_id=i, column_name=f"col_{i}", partition_id=0, block=block)

part = vPartition(columns=tiles, partition_id=0)
part = part.filter(ExpressionList([expr]))
arrow_table = pa.Table.from_pandas(part.to_pandas())
assert arrow_table.column_names == [f"col_{i}" for i in col_ids]
assert arrow_table.column_names == [f"col_{i}" for i in range(col_id, col_id + 4)]

for i in range(4):
assert np.all(arrow_table[i].to_numpy() < 4)
Expand Down

0 comments on commit 1ec273f

Please sign in to comment.