Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 40 additions & 46 deletions cpp/src/arrow/compute/expression.cc
Original file line number Diff line number Diff line change
Expand Up @@ -545,67 +545,61 @@ Result<Expression> BindNonRecursive(Expression::Call call, bool insert_implicit_
std::vector<TypeHolder> types = GetTypes(call.arguments);
ARROW_ASSIGN_OR_RAISE(call.function, GetFunction(call, exec_context));

auto FinishBind = [&] {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is basically reverting the change of #40223. Because we've reworked the fix in #47297. Besides, the original control flow (prior to #40223) is much more straightforward.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for working on this!

compute::KernelContext kernel_context(exec_context, call.kernel);
if (call.kernel->init) {
const FunctionOptions* options =
call.options ? call.options.get() : call.function->default_options();
ARROW_ASSIGN_OR_RAISE(
call.kernel_state,
call.kernel->init(&kernel_context, {call.kernel, types, options}));

kernel_context.SetState(call.kernel_state.get());
}

ARROW_ASSIGN_OR_RAISE(
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));
return Status::OK();
};

// First try and bind exactly
Result<const Kernel*> maybe_exact_match = call.function->DispatchExact(types);
if (maybe_exact_match.ok()) {
call.kernel = *maybe_exact_match;
if (FinishBind().ok()) {
return Expression(std::move(call));
} else {
if (!insert_implicit_casts) {
return maybe_exact_match.status();
}
}

if (!insert_implicit_casts) {
return maybe_exact_match.status();
}
// If exact binding fails, and we are allowed to cast, then prefer casting literals
// first. Since DispatchBest generally prefers up-casting the best way to do this is
// first down-cast the literals as much as possible
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));

// If exact binding fails, and we are allowed to cast, then prefer casting literals
// first. Since DispatchBest generally prefers up-casting the best way to do this is
// first down-cast the literals as much as possible
types = GetTypesWithSmallestLiteralRepresentation(call.arguments);
ARROW_ASSIGN_OR_RAISE(call.kernel, call.function->DispatchBest(&types));
for (size_t i = 0; i < types.size(); ++i) {
if (types[i] == call.arguments[i].type()) continue;

for (size_t i = 0; i < types.size(); ++i) {
if (types[i] == call.arguments[i].type()) continue;
if (const Datum* lit = call.arguments[i].literal()) {
ARROW_ASSIGN_OR_RAISE(Datum new_lit,
compute::Cast(*lit, types[i].GetSharedPtr()));
call.arguments[i] = literal(std::move(new_lit));
continue;
}

if (const Datum* lit = call.arguments[i].literal()) {
ARROW_ASSIGN_OR_RAISE(Datum new_lit, compute::Cast(*lit, types[i].GetSharedPtr()));
call.arguments[i] = literal(std::move(new_lit));
continue;
}
// construct an implicit cast Expression with which to replace this argument
Expression::Call implicit_cast;
implicit_cast.function_name = "cast";
implicit_cast.arguments = {std::move(call.arguments[i])};

// construct an implicit cast Expression with which to replace this argument
Expression::Call implicit_cast;
implicit_cast.function_name = "cast";
implicit_cast.arguments = {std::move(call.arguments[i])};
// TODO(wesm): Use TypeHolder in options
implicit_cast.options = std::make_shared<compute::CastOptions>(
compute::CastOptions::Safe(types[i].GetSharedPtr()));

// TODO(wesm): Use TypeHolder in options
implicit_cast.options = std::make_shared<compute::CastOptions>(
compute::CastOptions::Safe(types[i].GetSharedPtr()));
ARROW_ASSIGN_OR_RAISE(
call.arguments[i],
BindNonRecursive(std::move(implicit_cast),
/*insert_implicit_casts=*/false, exec_context));
}
}

compute::KernelContext kernel_context(exec_context, call.kernel);
if (call.kernel->init) {
const FunctionOptions* options =
call.options ? call.options.get() : call.function->default_options();
ARROW_ASSIGN_OR_RAISE(
call.arguments[i],
BindNonRecursive(std::move(implicit_cast),
/*insert_implicit_casts=*/false, exec_context));
call.kernel_state,
call.kernel->init(&kernel_context, {call.kernel, types, options}));

kernel_context.SetState(call.kernel_state.get());
}

RETURN_NOT_OK(FinishBind());
ARROW_ASSIGN_OR_RAISE(
call.type, call.kernel->signature->out_type().Resolve(&kernel_context, types));

return Expression(std::move(call));
}

Expand Down
39 changes: 39 additions & 0 deletions cpp/src/arrow/compute/expression_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -622,6 +622,45 @@ TEST(Expression, BindCall) {
add(cast(field_ref("i32"), float32()), literal(3.5F)));
}

static Status RegisterInvalidInit() {
const std::string name = "invalid_init";
struct CastableFunction : public ScalarFunction {
using ScalarFunction::ScalarFunction;

Result<const Kernel*> DispatchBest(std::vector<TypeHolder>* types) const override {
return Status::Invalid("Shouldn't call DispatchBest on this function");
}
};
auto func =
std::make_shared<CastableFunction>(name, Arity::Unary(), FunctionDoc::Empty());

auto func_exec = [](KernelContext*, const ExecSpan&, ExecResult*) -> Status {
return Status::OK();
};
auto func_init = [](KernelContext*,
const KernelInitArgs&) -> Result<std::unique_ptr<KernelState>> {
return Status::Invalid("Invalid Init");
};

ScalarKernel kernel({int64()}, int64(), func_exec, func_init);
ARROW_RETURN_NOT_OK(func->AddKernel(kernel));

auto registry = GetFunctionRegistry();
ARROW_RETURN_NOT_OK(registry->AddFunction(std::move(func)));

return Status::OK();
}

// GH-47268: The bad status in call binding is discarded.
TEST(Expression, BindCallError) {
ASSERT_OK(RegisterInvalidInit());
auto expr = call("invalid_init", {field_ref("i64")});
EXPECT_FALSE(expr.IsBound());

ASSERT_RAISES_WITH_MESSAGE(Invalid, "Invalid: Invalid Init",
expr.Bind(*kBoringSchema).status());
}

TEST(Expression, BindWithAliasCasts) {
auto fm = GetFunctionRegistry();
EXPECT_OK(fm->AddAlias("alias_cast", "cast"));
Expand Down
Loading