Skip to content

Commit

Permalink
more null checks, expose relid as label
Browse files Browse the repository at this point in the history
  • Loading branch information
bkietz committed Oct 28, 2021
1 parent 13bfd15 commit 1d4c37e
Show file tree
Hide file tree
Showing 3 changed files with 209 additions and 84 deletions.
15 changes: 15 additions & 0 deletions cpp/src/arrow/compute/exec/exec_plan.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,13 +351,28 @@ struct ARROW_EXPORT Declaration {
options{std::make_shared<Options>(std::move(options))},
label{this->factory_name} {}

template <typename Options>
Declaration(std::string factory_name, std::vector<Input> inputs, Options options,
std::string label)
: factory_name{std::move(factory_name)},
inputs{std::move(inputs)},
options{std::make_shared<Options>(std::move(options))},
label{std::move(label)} {}

template <typename Options>
Declaration(std::string factory_name, Options options)
: factory_name{std::move(factory_name)},
inputs{},
options{std::make_shared<Options>(std::move(options))},
label{this->factory_name} {}

template <typename Options>
Declaration(std::string factory_name, Options options, std::string label)
: factory_name{std::move(factory_name)},
inputs{},
options{std::make_shared<Options>(std::move(options))},
label{std::move(label)} {}

/// \brief Convenience factory for the common case of a simple sequence of nodes.
///
/// Each of decls will be appended to the inputs of the subsequent declaration,
Expand Down
139 changes: 110 additions & 29 deletions cpp/src/arrow/compute/exec/ir_consumer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ Result<std::shared_ptr<Field>> Convert(const flatbuf::Field& f) {
fields.resize(children->size());
int i = 0;
for (const flatbuf::Field* child : *children) {
if (child) return UnexpectedNullField("Field.children[i]");
ARROW_ASSIGN_OR_RAISE(fields[i++], Convert(*child));
}
}
Expand All @@ -65,12 +66,19 @@ Result<std::shared_ptr<Field>> Convert(const flatbuf::Field& f) {
return field(std::move(name), std::move(type), f.nullable(), std::move(metadata));
}

std::string LabelFromRelId(const ir::RelId* id) {
return id ? std::to_string(id->id()) : "";
}

Result<std::shared_ptr<Buffer>> BufferFromFlatbufferByteVector(
const flatbuffers::Vector<int8_t>* vec) {
if (!vec) return nullptr;

ARROW_ASSIGN_OR_RAISE(auto buf, AllocateBuffer(vec->size()));

if (!vec->data()) return UnexpectedNullField("Vector<int8_t>.data");
std::memcpy(buf->mutable_data(), vec->data(), vec->size());

return buf;
}

Expand Down Expand Up @@ -104,8 +112,8 @@ struct ConvertLiteralImpl {

Result<Datum> Convert(const IntervalType& t) {
ARROW_ASSIGN_OR_RAISE(auto lit, GetLiteral<ir::IntervalLiteral>());
if (!lit->value()) return UnexpectedNullField("IntervalLiteral.value");

if (!lit->value()) return UnexpectedNullField("IntervalLiteral.value");
switch (t.interval_type()) {
case IntervalType::MONTHS:
if (auto value = lit->value_as<ir::IntervalLiteralMonths>()) {
Expand Down Expand Up @@ -133,8 +141,8 @@ struct ConvertLiteralImpl {

Result<Datum> Convert(const DecimalType& t) {
ARROW_ASSIGN_OR_RAISE(auto lit, GetLiteral<ir::DecimalLiteral>());
if (!lit->value()) return UnexpectedNullField("DecimalLiteral.value");

if (!lit->value()) return UnexpectedNullField("DecimalLiteral.value");
if (static_cast<int>(lit->value()->size()) != t.byte_width()) {
return Status::IOError("DecimalLiteral.type was ", t.ToString(),
" (expected byte width ", t.byte_width(), ")",
Expand Down Expand Up @@ -165,11 +173,13 @@ struct ConvertLiteralImpl {

Result<Datum> Convert(const ListType&) {
ARROW_ASSIGN_OR_RAISE(auto lit, GetLiteral<ir::ListLiteral>());
if (!lit->values()) return UnexpectedNullField("ListLiteral.values");

if (!lit->values()) return UnexpectedNullField("ListLiteral.values");
ScalarVector values{lit->values()->size()};

int i = 0;
for (const ir::Literal* v : *lit->values()) {
if (!v) return UnexpectedNullField("ListLiteral.values[i]");
ARROW_ASSIGN_OR_RAISE(Datum value, arrow::compute::Convert(*v));
values[i++] = value.scalar();
}
Expand All @@ -183,11 +193,13 @@ struct ConvertLiteralImpl {

Result<Datum> Convert(const MapType& t) {
ARROW_ASSIGN_OR_RAISE(auto lit, GetLiteral<ir::MapLiteral>());
if (!lit->values()) return UnexpectedNullField("MapLiteral.values");

if (!lit->values()) return UnexpectedNullField("MapLiteral.values");
ScalarVector keys{lit->values()->size()}, values{lit->values()->size()};

int i = 0;
for (const ir::KeyValue* kv : *lit->values()) {
if (!kv) return UnexpectedNullField("MapLiteral.values[i]");
ARROW_ASSIGN_OR_RAISE(Datum key, arrow::compute::Convert(*kv->value()));
ARROW_ASSIGN_OR_RAISE(Datum value, arrow::compute::Convert(*kv->value()));
keys[i] = key.scalar();
Expand Down Expand Up @@ -222,6 +234,7 @@ struct ConvertLiteralImpl {
ScalarVector values{lit->values()->size()};
int i = 0;
for (const ir::Literal* v : *lit->values()) {
if (!v) return UnexpectedNullField("StructLiteral.values[i]");
ARROW_ASSIGN_OR_RAISE(Datum value, arrow::compute::Convert(*v));
if (!value.type()->Equals(*t.field(i)->type())) {
return Status::IOError("StructLiteral.type was ", t.ToString(), " but value ", i,
Expand Down Expand Up @@ -350,6 +363,7 @@ Result<std::pair<std::vector<Expression>, std::vector<Expression>>> Convert(

int i = 0;
for (const ir::CaseFragment* c : cases) {
if (!c) return UnexpectedNullField("Vector<CaseFragment>[i]");
ARROW_ASSIGN_OR_RAISE(conditions[i], Convert(*c->match()));
ARROW_ASSIGN_OR_RAISE(arguments[i], Convert(*c->result()));
++i;
Expand Down Expand Up @@ -380,11 +394,15 @@ Result<Expression> Convert(const ir::Expression& expr) {
case ir::ExpressionImpl::Call: {
auto call = expr.impl_as<ir::Call>();

auto name = call->name()->str();
if (!call->name()) return UnexpectedNullField("Call.name");
auto name = ipc::internal::StringFromFlatbuffers(call->name());

if (!call->arguments()) return UnexpectedNullField("Call.arguments");
std::vector<Expression> arguments(call->arguments()->size());

int i = 0;
for (const ir::Expression* a : *call->arguments()) {
if (!a) return UnexpectedNullField("Call.arguments[i]");
ARROW_ASSIGN_OR_RAISE(arguments[i++], Convert(*a));
}

Expand All @@ -395,28 +413,50 @@ Result<Expression> Convert(const ir::Expression& expr) {
case ir::ExpressionImpl::Cast: {
auto cast = expr.impl_as<ir::Cast>();

if (!cast->operand()) return UnexpectedNullField("Cast.operand");
ARROW_ASSIGN_OR_RAISE(Expression arg, Convert(*cast->operand()));

if (!cast->to()) return UnexpectedNullField("Cast.to");
ARROW_ASSIGN_OR_RAISE(auto to, Convert(*cast->to()));

return call("cast", {std::move(arg)}, CastOptions::Safe(to->type()));
}

case ir::ExpressionImpl::ConditionalCase: {
auto conditional_case = expr.impl_as<ir::ConditionalCase>();

if (!conditional_case->conditions()) {
return UnexpectedNullField("ConditionalCase.conditions");
}
ARROW_ASSIGN_OR_RAISE(auto cases, Convert(*conditional_case->conditions()));

if (!conditional_case->else_()) return UnexpectedNullField("ConditionalCase.else");
ARROW_ASSIGN_OR_RAISE(auto default_value, Convert(*conditional_case->else_()));

return CaseWhen(std::move(cases.first), std::move(cases.second),
std::move(default_value));
}

case ir::ExpressionImpl::SimpleCase: {
auto simple_case = expr.impl_as<ir::SimpleCase>();
ARROW_ASSIGN_OR_RAISE(auto expression, Convert(*simple_case->expression()));
auto expression = simple_case->expression();
auto matches = simple_case->matches();
auto else_ = simple_case->else_();

if (!expression) return UnexpectedNullField("SimpleCase.expression");
ARROW_ASSIGN_OR_RAISE(auto rhs, Convert(*expression));

if (!matches) return UnexpectedNullField("SimpleCase.matches");
ARROW_ASSIGN_OR_RAISE(auto cases, Convert(*simple_case->matches()));

// replace each condition with an equality expression with the rhs
for (auto& condition : cases.first) {
condition = equal(std::move(condition), expression);
condition = equal(std::move(condition), rhs);
}

if (!else_) return UnexpectedNullField("SimpleCase.else");
ARROW_ASSIGN_OR_RAISE(auto default_value, Convert(*simple_case->else_()));

return CaseWhen(std::move(cases.first), std::move(cases.second),
std::move(default_value));
}
Expand All @@ -434,49 +474,73 @@ Result<Declaration> Convert(const ir::Relation& rel) {
switch (rel.impl_type()) {
case ir::RelationImpl::Source: {
auto source = rel.impl_as<ir::Source>();

if (!source->name()) return UnexpectedNullField("Source.name");
auto name = ipc::internal::StringFromFlatbuffers(source->name());
ipc::DictionaryMemo ignore;

std::shared_ptr<Schema> schema;
RETURN_NOT_OK(ipc::internal::GetSchema(source->schema(), &ignore, &schema));
if (source->schema()) {
ipc::DictionaryMemo ignore;
RETURN_NOT_OK(ipc::internal::GetSchema(source->schema(), &ignore, &schema));
}

return Declaration{"catalog_source",
{},
CatalogSourceNodeOptions{std::move(name), std::move(schema)}};
CatalogSourceNodeOptions{std::move(name), std::move(schema)},
LabelFromRelId(source->id())};
}

case ir::RelationImpl::Filter: {
auto filter = rel.impl_as<ir::Filter>();
ARROW_ASSIGN_OR_RAISE(auto arg, Convert(*filter->rel()).As<Declaration::Input>());

if (!filter->predicate()) return UnexpectedNullField("Filter.predicate");
ARROW_ASSIGN_OR_RAISE(auto predicate, Convert(*filter->predicate()));
return Declaration{
"filter", {std::move(arg)}, FilterNodeOptions{std::move(predicate)}};

if (!filter->rel()) return UnexpectedNullField("Filter.rel");
ARROW_ASSIGN_OR_RAISE(auto arg, Convert(*filter->rel()).As<Declaration::Input>());

return Declaration{"filter",
{std::move(arg)},
FilterNodeOptions{std::move(predicate)},
LabelFromRelId(filter->id())};
}

case ir::RelationImpl::Project: {
auto project = rel.impl_as<ir::Project>();

if (!project->rel()) return UnexpectedNullField("Project.rel");
ARROW_ASSIGN_OR_RAISE(auto arg, Convert(*project->rel()).As<Declaration::Input>());

ProjectNodeOptions opts{{}};

if (!project->expressions()) return UnexpectedNullField("Project.expressions");
for (const ir::Expression* expression : *project->expressions()) {
if (!expression) return UnexpectedNullField("Project.expressions[i]");
ARROW_ASSIGN_OR_RAISE(auto expr, Convert(*expression));
opts.expressions.push_back(std::move(expr));
}

return Declaration{"project", {std::move(arg)}, std::move(opts)};
return Declaration{
"project", {std::move(arg)}, std::move(opts), LabelFromRelId(project->id())};
}

case ir::RelationImpl::Aggregate: {
auto aggregate = rel.impl_as<ir::Aggregate>();

if (!aggregate->rel()) return UnexpectedNullField("Aggregate.rel");
ARROW_ASSIGN_OR_RAISE(auto arg,
Convert(*aggregate->rel()).As<Declaration::Input>());

AggregateNodeOptions opts{{}, {}, {}};

for (const ir::Expression* measure : *aggregate->measures()) {
ARROW_ASSIGN_OR_RAISE(auto expr, Convert(*measure));
auto call = expr.call();
if (!aggregate->measures()) return UnexpectedNullField("Aggregate.measures");
for (const ir::Expression* m : *aggregate->measures()) {
if (!m) return UnexpectedNullField("Aggregate.measures[i]");
ARROW_ASSIGN_OR_RAISE(auto measure, Convert(*m));

auto call = measure.call();
if (!call || call->arguments.size() != 1) {
return Status::IOError("One of Aggregate.measures was ", expr.ToString(),
return Status::IOError("One of Aggregate.measures was ", measure.ToString(),
" (expected Expression::Call with one argument)");
}

Expand All @@ -491,13 +555,24 @@ Result<Declaration> Convert(const ir::Relation& rel) {
opts.names.push_back(call->function_name + " " + target->ToString());
}

if (!aggregate->groupings()) return UnexpectedNullField("Aggregate.groupings");
if (aggregate->groupings()->size() > 1) {
return Status::NotImplemented("Support for multiple grouping sets");
}

if (aggregate->groupings()->size() == 1) {
if (!aggregate->groupings()->Get(0)) {
return UnexpectedNullField("Aggregate.groupings[0]");
}

if (!aggregate->groupings()->Get(0)->keys()) {
return UnexpectedNullField("Grouping.keys");
}

for (const ir::Expression* key : *aggregate->groupings()->Get(0)->keys()) {
if (!key) return UnexpectedNullField("Grouping.keys[i]");
ARROW_ASSIGN_OR_RAISE(auto key_expr, Convert(*key));

auto key_ref = key_expr.field_ref();
if (!key_ref) {
return Status::NotImplemented("Support for non-FieldRef grouping keys");
Expand All @@ -506,22 +581,28 @@ Result<Declaration> Convert(const ir::Relation& rel) {
}
}

return Declaration{"aggregate", {std::move(arg)}, std::move(opts)};
return Declaration{"aggregate",
{std::move(arg)},
std::move(opts),
LabelFromRelId(aggregate->id())};
}

case ir::RelationImpl::OrderBy: {
auto aggregate = rel.impl_as<ir::OrderBy>();
ARROW_ASSIGN_OR_RAISE(auto arg,
Convert(*aggregate->rel()).As<Declaration::Input>());
auto order_by = rel.impl_as<ir::OrderBy>();

if (!order_by->rel()) return UnexpectedNullField("OrderBy.rel");
ARROW_ASSIGN_OR_RAISE(auto arg, Convert(*order_by->rel()).As<Declaration::Input>());

if (aggregate->keys()->size() == 0) {
if (!order_by->keys()) return UnexpectedNullField("OrderBy.keys");
if (order_by->keys()->size() == 0) {
return Status::NotImplemented("Empty sort key list");
}

util::optional<NullPlacement> null_placement;
std::vector<SortKey> sort_keys;

for (const ir::SortKey* key : *aggregate->keys()) {
for (const ir::SortKey* key : *order_by->keys()) {
if (!key) return UnexpectedNullField("OrderBy.keys[i]");
ARROW_ASSIGN_OR_RAISE(auto expr, Convert(*key->expression()));

auto target = expr.field_ref();
Expand Down Expand Up @@ -562,11 +643,11 @@ Result<Declaration> Convert(const ir::Relation& rel) {
null_placement = key_null_placement;
}

return Declaration{
"order_by_sink",
{std::move(arg)},
OrderBySinkNodeOptions{SortOptions{std::move(sort_keys), *null_placement},
nullptr}};
return Declaration{"order_by_sink",
{std::move(arg)},
OrderBySinkNodeOptions{
SortOptions{std::move(sort_keys), *null_placement}, nullptr},
LabelFromRelId(order_by->id())};
}

default:
Expand Down

0 comments on commit 1d4c37e

Please sign in to comment.