Skip to content

Commit

Permalink
ARROW-7412: [C++][Dataset] Provide FieldRef to disambiguate field ref…
Browse files Browse the repository at this point in the history
…erences

`FieldRef` is a new utility class which represents a reference to a field. It is intended to replace parameters like `int field_index` and `const std::string& name`; it can be implicitly constructed from either a field index or a name.

Nested fields can be referenced as well:
```C++
// the following all indicate schema->GetFieldByName("alpha")->type()->child(0)
FieldRef ref1({FieldRef("alpha"), FieldRef(0)});
FieldRef ref2("alpha", 0);
ARROW_ASSIGN_OR_RAISE(FieldRef ref3,
                      FieldRef::FromDotPath(".alpha[0]"));
```

FieldRefs provide a number of accessors for drilling down to potentially nested children. They are overloaded for convenience to support  Schema (returns a field), DataType (returns a child field), Field (returns a child field of this field's type) Array (returns a child array), RecordBatch (returns a column), ChunkedArray (returns a ChunkedArray where each chunk is a child array of the corresponding original chunk) and Table (returns a column).

```C++
// Field names can match multiple fields in a Schema
Schema a_is_ambiguous({field("a", null()), field("a", null())});
auto matches = FieldRef("a").FindAll(a_is_ambiguous);
assert(matches.size() == 2);
assert_ok_and_eq(FieldRef::Get(match, a_is_ambiguous), a_is_ambiguous.field(0));

// Convenience accessor raises a helpful error if the field is not found or ambiguous
ARROW_ASSIGN_OR_RAISE(auto column, FieldRef("struct", "field_i32").GetOne(some_table));
```

Closes #6545 from bkietz/7412-Dataset-Ensure-that-datas

Authored-by: Benjamin Kietzman <bengilgit@gmail.com>
Signed-off-by: Benjamin Kietzman <bengilgit@gmail.com>
  • Loading branch information
bkietz committed Mar 13, 2020
1 parent 84651cb commit 6d3c085
Show file tree
Hide file tree
Showing 18 changed files with 809 additions and 58 deletions.
2 changes: 0 additions & 2 deletions cpp/src/arrow/array.h
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,6 @@ class ARROW_EXPORT Array {
ARROW_DISALLOW_COPY_AND_ASSIGN(Array);
};

using ArrayVector = std::vector<std::shared_ptr<Array>>;

namespace internal {

/// Given a number of ArrayVectors, treat each ArrayVector as the
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/arrow/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -332,8 +332,6 @@ class ARROW_EXPORT Buffer {
ARROW_DISALLOW_COPY_AND_ASSIGN(Buffer);
};

using BufferVector = std::vector<std::shared_ptr<Buffer>>;

/// \defgroup buffer-slicing-functions Functions for slicing buffers
///
/// @{
Expand Down
7 changes: 4 additions & 3 deletions cpp/src/arrow/dataset/dataset_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,10 @@ static inline FragmentIterator GetFragmentsFromDatasets(
inline std::shared_ptr<Schema> SchemaFromColumnNames(
const std::shared_ptr<Schema>& input, const std::vector<std::string>& column_names) {
std::vector<std::shared_ptr<Field>> columns;
for (const auto& name : column_names) {
if (auto field = input->GetFieldByName(name)) {
columns.push_back(std::move(field));
for (FieldRef ref : column_names) {
auto maybe_field = ref.GetOne(*input);
if (maybe_field.ok()) {
columns.push_back(std::move(maybe_field).ValueOrDie());
}
}

Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/dataset/filter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -898,7 +898,8 @@ Result<std::shared_ptr<DataType>> ScalarExpression::Validate(const Schema& schem
}

Result<std::shared_ptr<DataType>> FieldExpression::Validate(const Schema& schema) const {
if (auto field = schema.GetFieldByName(name_)) {
ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(name_).GetOneOrNone(schema));
if (field != nullptr) {
return field->type();
}
return null();
Expand Down
8 changes: 3 additions & 5 deletions cpp/src/arrow/dataset/partition.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ Result<std::shared_ptr<Expression>> SegmentDictionaryPartitioning::Parse(

Result<std::shared_ptr<Expression>> KeyValuePartitioning::ConvertKey(
const Key& key, const Schema& schema) {
auto field = schema.GetFieldByName(key.name);
ARROW_ASSIGN_OR_RAISE(auto field, FieldRef(key.name).GetOneOrNone(schema));
if (field == nullptr) {
return scalar(true);
}
Expand Down Expand Up @@ -141,10 +141,8 @@ class DirectoryPartitioningFactory : public PartitioningFactory {

Result<std::shared_ptr<Partitioning>> Finish(
const std::shared_ptr<Schema>& schema) const override {
for (const auto& field_name : field_names_) {
if (schema->GetFieldIndex(field_name) == -1) {
return Status::TypeError("no field named '", field_name, "' in schema", *schema);
}
for (FieldRef ref : field_names_) {
RETURN_NOT_OK(ref.FindOne(*schema).status());
}

// drop fields which aren't in field_names_
Expand Down
31 changes: 21 additions & 10 deletions cpp/src/arrow/dataset/projector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,15 @@ RecordBatchProjector::RecordBatchProjector(std::shared_ptr<Schema> to)
column_indices_(to_->num_fields(), kNoMatch),
scalars_(to_->num_fields(), nullptr) {}

Status RecordBatchProjector::SetDefaultValue(int index, std::shared_ptr<Scalar> scalar) {
Status RecordBatchProjector::SetDefaultValue(FieldRef ref,
std::shared_ptr<Scalar> scalar) {
DCHECK_NE(scalar, nullptr);
if (ref.IsNested()) {
return Status::NotImplemented("setting default values for nested columns");
}

ARROW_ASSIGN_OR_RAISE(auto match, ref.FindOne(*to_));
auto index = match[0];

auto field_type = to_->field(index)->type();
if (!field_type->Equals(scalar->type)) {
Expand Down Expand Up @@ -83,9 +90,19 @@ Status RecordBatchProjector::SetInputSchema(std::shared_ptr<Schema> from,

for (int i = 0; i < to_->num_fields(); ++i) {
const auto& field = to_->field(i);
int matching_index = from_->GetFieldIndex(field->name());
FieldRef ref(field->name());
auto matches = ref.FindAll(*from_);

if (matches.empty()) {
// Mark column i as missing by setting missing_columns_[i]
// to a non-null placeholder.
RETURN_NOT_OK(
MakeArrayOfNull(pool, to_->field(i)->type(), 0, &missing_columns_[i]));
column_indices_[i] = kNoMatch;
} else {
RETURN_NOT_OK(ref.CheckNonMultiple(matches, *from_));
int matching_index = matches[0][0];

if (matching_index != kNoMatch) {
if (!from_->field(matching_index)->Equals(field)) {
return Status::TypeError("fields had matching names but were not equivalent ",
from_->field(matching_index)->ToString(), " vs ",
Expand All @@ -94,14 +111,8 @@ Status RecordBatchProjector::SetInputSchema(std::shared_ptr<Schema> from,

// Mark column i as not missing by setting missing_columns_[i] to nullptr
missing_columns_[i] = nullptr;
} else {
// Mark column i as missing by setting missing_columns_[i]
// to a non-null placeholder.
RETURN_NOT_OK(
MakeArrayOfNull(pool, to_->field(i)->type(), 0, &missing_columns_[i]));
column_indices_[i] = matching_index;
}

column_indices_[i] = matching_index;
}
return Status::OK();
}
Expand Down
5 changes: 2 additions & 3 deletions cpp/src/arrow/dataset/projector.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
#pragma once

#include <memory>
#include <string>
#include <utility>
#include <vector>

#include "arrow/dataset/type_fwd.h"
Expand Down Expand Up @@ -48,7 +46,7 @@ class ARROW_DS_EXPORT RecordBatchProjector {

/// If the indexed field is absent from a record batch it will be added to the projected
/// record batch with all its slots equal to the provided scalar (instead of null).
Status SetDefaultValue(int index, std::shared_ptr<Scalar> scalar);
Status SetDefaultValue(FieldRef ref, std::shared_ptr<Scalar> scalar);

Result<std::shared_ptr<RecordBatch>> Project(const RecordBatch& batch,
MemoryPool* pool = default_memory_pool());
Expand All @@ -63,6 +61,7 @@ class ARROW_DS_EXPORT RecordBatchProjector {

std::shared_ptr<Schema> from_, to_;
int64_t missing_columns_length_ = 0;
// these vectors are indexed parallel to to_->fields()
std::vector<std::shared_ptr<Array>> missing_columns_;
std::vector<int> column_indices_;
std::vector<std::shared_ptr<Scalar>> scalars_;
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/arrow/flight/perf_server.cc
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@ namespace flight {
} \
} while (0)

using ArrayVector = std::vector<std::shared_ptr<Array>>;

// Create record batches with a unique "a" column so we can verify on the
// client side that the results are correct
class PerfDataStream : public FlightDataStream {
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/record_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class SimpleRecordBatch : public RecordBatch {

std::shared_ptr<ArrayData> column_data(int i) const override { return columns_[i]; }

ArrayDataVector column_data() const override { return columns_; }

Status AddColumn(int i, const std::shared_ptr<Field>& field,
const std::shared_ptr<Array>& column,
std::shared_ptr<RecordBatch>* out) const override {
Expand Down
5 changes: 4 additions & 1 deletion cpp/src/arrow/record_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,14 @@ class ARROW_EXPORT RecordBatch {
/// \return an Array or null if no field was found
std::shared_ptr<Array> GetColumnByName(const std::string& name) const;

/// \brief Retrieve an array's internaldata from the record batch
/// \brief Retrieve an array's internal data from the record batch
/// \param[in] i field index, does not boundscheck
/// \return an internal ArrayData object
virtual std::shared_ptr<ArrayData> column_data(int i) const = 0;

/// \brief Retrieve all arrays' internal data from the record batch.
virtual ArrayDataVector column_data() const = 0;

/// \brief Add column to the record batch, producing a new RecordBatch
///
/// \param[in] i field index, which will be boundschecked
Expand Down
15 changes: 7 additions & 8 deletions cpp/src/arrow/table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,25 +41,24 @@ using internal::checked_cast;
// ----------------------------------------------------------------------
// ChunkedArray methods

ChunkedArray::ChunkedArray(const ArrayVector& chunks) : chunks_(chunks) {
ChunkedArray::ChunkedArray(ArrayVector chunks) : chunks_(std::move(chunks)) {
length_ = 0;
null_count_ = 0;

ARROW_CHECK_GT(chunks.size(), 0)
ARROW_CHECK_GT(chunks_.size(), 0)
<< "cannot construct ChunkedArray from empty vector and omitted type";
type_ = chunks[0]->type();
for (const std::shared_ptr<Array>& chunk : chunks) {
type_ = chunks_[0]->type();
for (const std::shared_ptr<Array>& chunk : chunks_) {
length_ += chunk->length();
null_count_ += chunk->null_count();
}
}

ChunkedArray::ChunkedArray(const ArrayVector& chunks,
const std::shared_ptr<DataType>& type)
: chunks_(chunks), type_(type) {
ChunkedArray::ChunkedArray(ArrayVector chunks, std::shared_ptr<DataType> type)
: chunks_(std::move(chunks)), type_(std::move(type)) {
length_ = 0;
null_count_ = 0;
for (const std::shared_ptr<Array>& chunk : chunks) {
for (const std::shared_ptr<Array>& chunk : chunks_) {
length_ += chunk->length();
null_count_ += chunk->null_count();
}
Expand Down
6 changes: 2 additions & 4 deletions cpp/src/arrow/table.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,6 @@

namespace arrow {

using ArrayVector = std::vector<std::shared_ptr<Array>>;

/// \class ChunkedArray
/// \brief A data structure managing a list of primitive Arrow arrays logically
/// as one large array
Expand All @@ -41,7 +39,7 @@ class ARROW_EXPORT ChunkedArray {
///
/// The vector must be non-empty and all its elements must have the same
/// data type.
explicit ChunkedArray(const ArrayVector& chunks);
explicit ChunkedArray(ArrayVector chunks);

/// \brief Construct a chunked array from a single Array
explicit ChunkedArray(const std::shared_ptr<Array>& chunk)
Expand All @@ -50,7 +48,7 @@ class ARROW_EXPORT ChunkedArray {
/// \brief Construct a chunked array from a vector of arrays and a data type
///
/// As the data type is passed explicitly, the vector may be empty.
ChunkedArray(const ArrayVector& chunks, const std::shared_ptr<DataType>& type);
ChunkedArray(ArrayVector chunks, std::shared_ptr<DataType> type);

/// \return the total length of the chunked array; computed on construction
int64_t length() const { return length_; }
Expand Down
2 changes: 0 additions & 2 deletions cpp/src/arrow/testing/gtest_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,8 +164,6 @@ struct Datum;

using Datum = compute::Datum;

using ArrayVector = std::vector<std::shared_ptr<Array>>;

#define ASSERT_ARRAYS_EQUAL(lhs, rhs) AssertArraysEqual((lhs), (rhs))
#define ASSERT_BATCHES_EQUAL(lhs, rhs) AssertBatchesEqual((lhs), (rhs))
#define ASSERT_TABLES_EQUAL(lhs, rhs) AssertTablesEqual((lhs), (rhs))
Expand Down
8 changes: 0 additions & 8 deletions cpp/src/arrow/testing/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,6 @@

namespace arrow {

class Array;
class ChunkedArray;
class MemoryPool;
class RecordBatch;
class Table;

using ArrayVector = std::vector<std::shared_ptr<Array>>;

template <typename T>
Status CopyBufferFromVector(const std::vector<T>& values, MemoryPool* pool,
std::shared_ptr<Buffer>* result) {
Expand Down
Loading

0 comments on commit 6d3c085

Please sign in to comment.