Skip to content

Commit

Permalink
GH-35004: [C++] Remove RelationInfo (#35005)
Browse files Browse the repository at this point in the history
See #35004
* Closes: #35004

Lead-authored-by: Yaron Gvili <rtpsw@hotmail.com>
Co-authored-by: rtpsw <rtpsw@hotmail.com>
Signed-off-by: Li Jin <ice.xelloss@gmail.com>
  • Loading branch information
rtpsw committed Apr 20, 2023
1 parent 0bf777a commit d5866ec
Show file tree
Hide file tree
Showing 7 changed files with 61 additions and 131 deletions.
22 changes: 4 additions & 18 deletions cpp/src/arrow/acero/asof_join_node.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1315,25 +1315,19 @@ class AsofJoinNode : public ExecNode {

/// \brief Make the output schema of an as-of-join node
///
/// Optionally, also provides the field output indices for this node.
/// \see arrow::engine::RelationInfo
///
/// \param[in] input_schema the schema of each input to the node
/// \param[in] indices_of_on_key the on-key index of each input to the node
/// \param[in] indices_of_by_key the by-key indices of each input to the node
/// \param[out] field_output_indices the output index of each field
static arrow::Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>> input_schema,
const std::vector<col_index_t>& indices_of_on_key,
const std::vector<std::vector<col_index_t>>& indices_of_by_key,
std::vector<int>* field_output_indices = nullptr) {
const std::vector<std::vector<col_index_t>>& indices_of_by_key) {
std::vector<std::shared_ptr<arrow::Field>> fields;

size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size();
const DataType* on_key_type = NULLPTR;
std::vector<const DataType*> by_key_type(n_by, NULLPTR);
// Take all non-key, non-time RHS fields
int output_field_idx = 0;
for (size_t j = 0; j < input_schema.size(); ++j) {
const auto& on_field_ix = indices_of_on_key[j];
const auto& by_field_ix = indices_of_by_key[j];
Expand Down Expand Up @@ -1367,30 +1361,22 @@ class AsofJoinNode : public ExecNode {

for (int i = 0; i < input_schema[j]->num_fields(); ++i) {
const auto field = input_schema[j]->field(i);
bool as_output; // true if the field appears as an output
int final_output_idx; // the final output index for the field
bool as_output; // true if the field appears as an output
if (i == on_field_ix) {
ARROW_RETURN_NOT_OK(is_valid_on_field(field));
// Only add on field from the left table
as_output = (j == 0);
final_output_idx = as_output ? output_field_idx++ : indices_of_on_key[0];
} else if (std_has(by_field_ix, i)) {
ARROW_RETURN_NOT_OK(is_valid_by_field(field));
// Only add by field from the left table
as_output = (j == 0);
final_output_idx = as_output ? output_field_idx++
: indices_of_by_key[0][std_index(by_field_ix, i)];
} else {
ARROW_RETURN_NOT_OK(is_valid_data_field(field));
as_output = true;
final_output_idx = output_field_idx++;
}
if (as_output) {
fields.push_back(field);
}
if (field_output_indices) {
field_output_indices->push_back(final_output_idx);
}
}
}
return std::make_shared<arrow::Schema>(fields);
Expand Down Expand Up @@ -1604,13 +1590,13 @@ namespace asofjoin {

Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<AsofJoinKeys>& input_keys, std::vector<int>* field_output_indices) {
const std::vector<AsofJoinKeys>& input_keys) {
ARROW_ASSIGN_OR_RAISE(std::vector<col_index_t> indices_of_on_key,
AsofJoinNode::GetIndicesOfOnKey(input_schema, input_keys));
ARROW_ASSIGN_OR_RAISE(std::vector<std::vector<col_index_t>> indices_of_by_key,
AsofJoinNode::GetIndicesOfByKey(input_schema, input_keys));
return AsofJoinNode::MakeOutputSchema(input_schema, indices_of_on_key,
indices_of_by_key, field_output_indices);
indices_of_by_key);
}

} // namespace asofjoin
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/arrow/acero/asof_join_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,11 @@ using AsofJoinKeys = AsofJoinNodeOptions::Keys;

/// \brief Make the output schema of an as-of-join node
///
/// Optionally, also provides the field output indices for this node.
/// \see arrow::engine::RelationInfo
///
/// \param[in] input_schema the schema of each input to the node
/// \param[in] input_keys the key of each input to the node
/// \param[out] field_output_indices the output index of each field
ARROW_ACERO_EXPORT Result<std::shared_ptr<Schema>> MakeOutputSchema(
const std::vector<std::shared_ptr<Schema>>& input_schema,
const std::vector<AsofJoinKeys>& input_keys,
std::vector<int>* field_output_indices = NULLPTR);
const std::vector<AsofJoinKeys>& input_keys);

} // namespace asofjoin
} // namespace acero
Expand Down
75 changes: 32 additions & 43 deletions cpp/src/arrow/engine/substrait/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,26 +46,26 @@ std::vector<acero::Declaration::Input> MakeDeclarationInputs(

class BaseExtensionProvider : public ExtensionProvider {
public:
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) override {
Result<DeclarationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) override {
auto details = dynamic_cast<const DefaultExtensionDetails&>(ext_details);
return MakeRel(conv_opts, inputs, details.rel, ext_set);
}

virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) = 0;
virtual Result<DeclarationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) = 0;
};

class DefaultExtensionProvider : public BaseExtensionProvider {
public:
Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) override {
Result<DeclarationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const google::protobuf::Any& rel,
const ExtensionSet& ext_set) override {
if (rel.Is<substrait_ext::AsOfJoinRel>()) {
substrait_ext::AsOfJoinRel as_of_join_rel;
rel.UnpackTo(&as_of_join_rel);
Expand All @@ -86,9 +86,9 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
}

private:
Result<RelationInfo> MakeAsOfJoinRel(const std::vector<DeclarationInfo>& inputs,
const substrait_ext::AsOfJoinRel& as_of_join_rel,
const ExtensionSet& ext_set) {
Result<DeclarationInfo> MakeAsOfJoinRel(
const std::vector<DeclarationInfo>& inputs,
const substrait_ext::AsOfJoinRel& as_of_join_rel, const ExtensionSet& ext_set) {
if (inputs.size() < 2) {
return Status::Invalid("substrait_ext::AsOfJoinNode too few input tables: ",
inputs.size());
Expand Down Expand Up @@ -133,24 +133,21 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
for (size_t i = 0; i < inputs.size(); i++) {
input_schema[i] = inputs[i].output_schema;
}
std::vector<int> field_output_indices;
ARROW_ASSIGN_OR_RAISE(auto schema,
acero::asofjoin::MakeOutputSchema(input_schema, input_keys,
&field_output_indices));
acero::asofjoin::MakeOutputSchema(input_schema, input_keys));
acero::AsofJoinNodeOptions asofjoin_node_opts{std::move(input_keys), tolerance};

// declaration
auto input_decls = MakeDeclarationInputs(inputs);
return RelationInfo{
{acero::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)),
std::move(schema)},
std::move(field_output_indices)};
return DeclarationInfo{
acero::Declaration("asofjoin", input_decls, std::move(asofjoin_node_opts)),
std::move(schema)};
}

Result<RelationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const substrait_ext::NamedTapRel& named_tap_rel,
const ExtensionSet& ext_set) {
Result<DeclarationInfo> MakeNamedTapRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const substrait_ext::NamedTapRel& named_tap_rel,
const ExtensionSet& ext_set) {
if (inputs.size() != 1) {
return Status::Invalid(
"substrait_ext::NamedTapRel requires a single input but got: ", inputs.size());
Expand All @@ -169,10 +166,10 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
ARROW_ASSIGN_OR_RAISE(
auto decl, conv_opts.named_tap_provider(named_tap_rel.kind(), input_decls,
named_tap_rel.name(), renamed_schema));
return RelationInfo{{std::move(decl), std::move(renamed_schema)}, std::nullopt};
return DeclarationInfo{std::move(decl), std::move(renamed_schema)};
}

Result<RelationInfo> MakeSegmentedAggregateRel(
Result<DeclarationInfo> MakeSegmentedAggregateRel(
const ConversionOptions& conv_opts, const std::vector<DeclarationInfo>& inputs,
const substrait_ext::SegmentedAggregateRel& seg_agg_rel,
const ExtensionSet& ext_set) {
Expand Down Expand Up @@ -211,21 +208,13 @@ class DefaultExtensionProvider : public BaseExtensionProvider {
aggregates.push_back(std::move(aggregate));
}

ARROW_ASSIGN_OR_RAISE(
auto output_schema,
acero::aggregate::MakeOutputSchema(input_schema, keys, segment_keys, aggregates));

ARROW_ASSIGN_OR_RAISE(auto decl_info, internal::MakeAggregateDeclaration(
std::move(inputs[0].declaration),
output_schema, std::move(aggregates),
std::move(keys), std::move(segment_keys)));

size_t out_size = output_schema->num_fields();
std::vector<int> field_output_indices(out_size);
for (int i = 0; i < static_cast<int>(out_size); i++) {
field_output_indices[i] = i;
}
return RelationInfo{decl_info, std::move(field_output_indices)};
ARROW_ASSIGN_OR_RAISE(auto aggregate_schema,
acero::aggregate::MakeOutputSchema(
input_schema, keys, /*segment_keys=*/{}, aggregates));

return internal::MakeAggregateDeclaration(
std::move(inputs[0].declaration), std::move(aggregate_schema),
std::move(aggregates), std::move(keys), std::move(segment_keys));
}
};

Expand Down
8 changes: 4 additions & 4 deletions cpp/src/arrow/engine/substrait/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,10 +81,10 @@ class ARROW_ENGINE_EXPORT ExtensionDetails {
class ARROW_ENGINE_EXPORT ExtensionProvider {
public:
virtual ~ExtensionProvider() = default;
virtual Result<RelationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) = 0;
virtual Result<DeclarationInfo> MakeRel(const ConversionOptions& conv_opts,
const std::vector<DeclarationInfo>& inputs,
const ExtensionDetails& ext_details,
const ExtensionSet& ext_set) = 0;
};

/// \brief Get the default extension provider
Expand Down
16 changes: 0 additions & 16 deletions cpp/src/arrow/engine/substrait/relation.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,6 @@ struct ARROW_ENGINE_EXPORT DeclarationInfo {
std::shared_ptr<Schema> output_schema;
};

/// Information resulting from converting a Substrait relation.
///
/// RelationInfo adds the "output indices" field for the extension to define how the
/// fields should be mapped to get the standard indices expected by Substrait.
struct ARROW_ENGINE_EXPORT RelationInfo {
/// The execution information produced thus far.
DeclarationInfo decl_info;
/// A vector of indices, one per input field per input in order, each index referring
/// to the corresponding field within the output schema, if it is in the output, or -1
/// otherwise. Each location in this vector is a field input index. This vector is
/// useful for translating selected field input indices (often from an output mapping in
/// a Substrait plan) of a join-type relation to their locations in the output schema of
/// the relation. This vector is undefined if the translation is unsupported.
std::optional<std::vector<int>> field_output_indices;
};

/// Information resulting from converting a Substrait plan
struct ARROW_ENGINE_EXPORT PlanInfo {
/// The root declaration.
Expand Down
52 changes: 15 additions & 37 deletions cpp/src/arrow/engine/substrait/relation_internal.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,25 +162,20 @@ Result<DeclarationInfo> ProcessEmit(const substrait::ProjectRel& rel,
no_emit_declr, schema);
}

Result<DeclarationInfo> ProcessExtensionEmit(
const DeclarationInfo& no_emit_declr, const std::vector<int>& emit_order,
const std::vector<int>& field_output_indices) {
Result<DeclarationInfo> ProcessExtensionEmit(const DeclarationInfo& no_emit_declr,
const std::vector<int>& emit_order) {
const std::shared_ptr<Schema>& input_schema = no_emit_declr.output_schema;
std::vector<compute::Expression> proj_field_refs;
proj_field_refs.reserve(emit_order.size());
FieldVector emit_fields;
emit_fields.reserve(emit_order.size());

for (int emit_idx : emit_order) {
if (emit_idx < 0 || static_cast<size_t>(emit_idx) >= field_output_indices.size()) {
if (emit_idx < 0 || emit_idx >= input_schema->num_fields()) {
return Status::Invalid("Out of bounds emit index ", emit_idx);
}
int field_idx = field_output_indices[emit_idx];
if (field_idx < 0) {
return Status::Invalid("Non-output emit index ", emit_idx);
}
proj_field_refs.push_back(compute::field_ref(FieldRef(field_idx)));
emit_fields.push_back(input_schema->field(field_idx));
proj_field_refs.push_back(compute::field_ref(FieldRef(emit_idx)));
emit_fields.push_back(input_schema->field(emit_idx));
}

std::shared_ptr<Schema> emit_schema = schema(std::move(emit_fields));
Expand All @@ -192,13 +187,13 @@ Result<DeclarationInfo> ProcessExtensionEmit(
std::move(emit_schema)};
}

Result<RelationInfo> GetExtensionRelationInfo(const substrait::Rel& rel,
const ExtensionSet& ext_set,
const ConversionOptions& conv_opts,
std::vector<DeclarationInfo>* inputs_arg) {
Result<DeclarationInfo> GetExtensionInfo(const substrait::Rel& rel,
const ExtensionSet& ext_set,
const ConversionOptions& conv_opts,
std::vector<DeclarationInfo>* inputs_arg) {
if (inputs_arg == nullptr) {
std::vector<DeclarationInfo> inputs_tmp;
return GetExtensionRelationInfo(rel, ext_set, conv_opts, &inputs_tmp);
return GetExtensionInfo(rel, ext_set, conv_opts, &inputs_tmp);
}
std::vector<DeclarationInfo>& inputs = *inputs_arg;
inputs.clear();
Expand Down Expand Up @@ -788,44 +783,27 @@ Result<DeclarationInfo> FromProto(const substrait::Rel& rel, const ExtensionSet&
case substrait::Rel::RelTypeCase::kExtensionMulti: {
std::vector<DeclarationInfo> ext_rel_inputs;
ARROW_ASSIGN_OR_RAISE(
auto ext_rel_info,
GetExtensionRelationInfo(rel, ext_set, conversion_options, &ext_rel_inputs));
const auto& ext_decl_info = ext_rel_info.decl_info;
auto ext_decl_info,
GetExtensionInfo(rel, ext_set, conversion_options, &ext_rel_inputs));
auto ext_common_opt = GetExtensionRelCommon(rel);
bool has_emit = ext_common_opt && ext_common_opt->emit_kind_case() ==
substrait::RelCommon::EmitKindCase::kEmit;
if (!ext_rel_info.field_output_indices) {
if (!has_emit) {
return ext_decl_info;
}
return Status::NotImplemented("Emit not supported by ",
ext_decl_info.declaration.factory_name);
}
// Set up the emit order - an ordered list of indices that specifies an output
// mapping as expected by Substrait. This is a sublist of [0..N), where N is the
// total number of input fields across all inputs of the relation, that selects
// from these input fields.
std::vector<int> emit_order;
if (has_emit) {
std::vector<int> emit_order;
// the emit order is defined in the Substrait plan - pick it up
const auto& emit_info = ext_common_opt->emit();
emit_order.reserve(emit_info.output_mapping_size());
for (const auto& emit_idx : emit_info.output_mapping()) {
emit_order.push_back(emit_idx);
}
return ProcessExtensionEmit(std::move(ext_decl_info), emit_order);
} else {
// the emit order is the default output mapping [0..N)
int emit_size = 0;
for (const auto& input : ext_rel_inputs) {
emit_size += input.output_schema->num_fields();
}
emit_order.reserve(emit_size);
for (int emit_idx = 0; emit_idx < emit_size; emit_idx++) {
emit_order.push_back(emit_idx);
}
return ext_decl_info;
}
return ProcessExtensionEmit(ext_decl_info, emit_order,
*ext_rel_info.field_output_indices);
}

case substrait::Rel::RelTypeCase::kSet: {
Expand Down
12 changes: 5 additions & 7 deletions cpp/src/arrow/engine/substrait/serde_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4356,7 +4356,7 @@ TEST(Substrait, PlanWithAsOfJoinExtension) {
"extension_multi": {
"common": {
"emit": {
"outputMapping": [0, 1, 2, 5]
"outputMapping": [0, 1, 2, 3]
}
},
"inputs": [
Expand Down Expand Up @@ -5156,7 +5156,7 @@ TEST(Substrait, PlanWithExtension) {
"extension_multi": {
"common": {
"emit": {
"outputMapping": [0, 1, 2, 5]
"outputMapping": [0, 1, 2, 3]
}
},
"inputs": [
Expand Down Expand Up @@ -5475,7 +5475,7 @@ TEST(Substrait, AsOfJoinDefaultEmit) {
}
}
},
"names": ["time", "key", "value1", "time2", "key2", "value2"]
"names": ["time", "key", "value1", "value2"]
}
}],
"expectedTypeUrls": []
Expand Down Expand Up @@ -5507,12 +5507,10 @@ TEST(Substrait, AsOfJoinDefaultEmit) {
ASSERT_OK_AND_ASSIGN(auto buf, internal::SubstraitFromJSON("Plan", substrait_json));

auto out_schema = schema({field("time", int32()), field("key", int32()),
field("value1", float64()), field("time2", int32()),
field("key2", int32()), field("value2", float64())});
field("value1", float64()), field("value2", float64())});

auto expected_table = TableFromJSON(
out_schema,
{"[[2, 1, 1.1, 2, 1, 1.2], [4, 1, 2.1, 4, 1, 1.2], [6, 2, 3.1, 6, 2, 3.2]]"});
out_schema, {"[[2, 1, 1.1, 1.2], [4, 1, 2.1, 1.2], [6, 2, 3.1, 3.2]]"});
CheckRoundTripResult(std::move(expected_table), buf, {}, conversion_options);
}

Expand Down

0 comments on commit d5866ec

Please sign in to comment.