Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-427: [C++] Implement dictionary array type #268

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/src/arrow/CMakeLists.txt
Expand Up @@ -47,6 +47,7 @@ install(

ADD_ARROW_TEST(array-test)
ADD_ARROW_TEST(array-decimal-test)
ADD_ARROW_TEST(array-dictionary-test)
ADD_ARROW_TEST(array-list-test)
ADD_ARROW_TEST(array-primitive-test)
ADD_ARROW_TEST(array-string-test)
Expand Down
128 changes: 128 additions & 0 deletions cpp/src/arrow/array-dictionary-test.cc
@@ -0,0 +1,128 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <cstdint>
#include <cstdlib>
#include <memory>
#include <numeric>
#include <vector>

#include "gtest/gtest.h"

#include "arrow/array.h"
#include "arrow/buffer.h"
#include "arrow/memory_pool.h"
#include "arrow/test-util.h"
#include "arrow/type.h"

namespace arrow {

TEST(TestDictionary, Basics) {
std::vector<int32_t> values = {100, 1000, 10000, 100000};
std::shared_ptr<Array> dict;
ArrayFromVector<Int32Type, int32_t>(int32(), values, &dict);

std::shared_ptr<DictionaryType> type1 =
std::dynamic_pointer_cast<DictionaryType>(dictionary(int16(), dict));
DictionaryType type2(int16(), dict);

ASSERT_TRUE(int16()->Equals(type1->index_type()));
ASSERT_TRUE(type1->dictionary()->Equals(dict));

ASSERT_TRUE(int16()->Equals(type2.index_type()));
ASSERT_TRUE(type2.dictionary()->Equals(dict));

ASSERT_EQ("dictionary<int32, int16>", type1->ToString());
}

TEST(TestDictionary, Equals) {
std::vector<bool> is_valid = {true, true, false, true, true, true};

std::shared_ptr<Array> dict;
std::vector<std::string> dict_values = {"foo", "bar", "baz"};
ArrayFromVector<StringType, std::string>(utf8(), dict_values, &dict);
std::shared_ptr<DataType> dict_type = dictionary(int16(), dict);

std::shared_ptr<Array> dict2;
std::vector<std::string> dict2_values = {"foo", "bar", "baz", "qux"};
ArrayFromVector<StringType, std::string>(utf8(), dict2_values, &dict2);
std::shared_ptr<DataType> dict2_type = dictionary(int16(), dict2);

std::shared_ptr<Array> indices;
std::vector<int16_t> indices_values = {1, 2, -1, 0, 2, 0};
ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices_values, &indices);

std::shared_ptr<Array> indices2;
std::vector<int16_t> indices2_values = {1, 2, 0, 0, 2, 0};
ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices2_values, &indices2);

std::shared_ptr<Array> indices3;
std::vector<int16_t> indices3_values = {1, 1, 0, 0, 2, 0};
ArrayFromVector<Int16Type, int16_t>(int16(), is_valid, indices3_values, &indices3);

auto arr = std::make_shared<DictionaryArray>(dict_type, indices);
auto arr2 = std::make_shared<DictionaryArray>(dict_type, indices2);
auto arr3 = std::make_shared<DictionaryArray>(dict2_type, indices);
auto arr4 = std::make_shared<DictionaryArray>(dict_type, indices3);

ASSERT_TRUE(arr->Equals(arr));

// Equal, because the unequal index is masked by null
ASSERT_TRUE(arr->Equals(arr2));

// Unequal dictionaries
ASSERT_FALSE(arr->Equals(arr3));

// Unequal indices
ASSERT_FALSE(arr->Equals(arr4));

// RangeEquals
ASSERT_TRUE(arr->RangeEquals(3, 6, 3, arr4));
ASSERT_FALSE(arr->RangeEquals(1, 3, 1, arr4));
}

TEST(TestDictionary, Validate) {
std::vector<bool> is_valid = {true, true, false, true, true, true};

std::shared_ptr<Array> dict;
std::vector<std::string> dict_values = {"foo", "bar", "baz"};
ArrayFromVector<StringType, std::string>(utf8(), dict_values, &dict);
std::shared_ptr<DataType> dict_type = dictionary(int16(), dict);

std::shared_ptr<Array> indices;
std::vector<uint8_t> indices_values = {1, 2, 0, 0, 2, 0};
ArrayFromVector<UInt8Type, uint8_t>(uint8(), is_valid, indices_values, &indices);

std::shared_ptr<Array> indices2;
std::vector<float> indices2_values = {1., 2., 0., 0., 2., 0.};
ArrayFromVector<FloatType, float>(float32(), is_valid, indices2_values, &indices2);

std::shared_ptr<Array> indices3;
std::vector<int64_t> indices3_values = {1, 2, 0, 0, 2, 0};
ArrayFromVector<Int64Type, int64_t>(int64(), is_valid, indices3_values, &indices3);

std::shared_ptr<Array> arr = std::make_shared<DictionaryArray>(dict_type, indices);
std::shared_ptr<Array> arr2 = std::make_shared<DictionaryArray>(dict_type, indices2);
std::shared_ptr<Array> arr3 = std::make_shared<DictionaryArray>(dict_type, indices3);

// Only checking index type for now
ASSERT_OK(arr->Validate());
ASSERT_RAISES(Invalid, arr2->Validate());
ASSERT_OK(arr3->Validate());
}

} // namespace arrow
4 changes: 2 additions & 2 deletions cpp/src/arrow/array-string-test.cc
Expand Up @@ -36,8 +36,8 @@ TEST(TypesTest, BinaryType) {
BinaryType t1;
BinaryType e1;
StringType t2;
EXPECT_TRUE(t1.Equals(&e1));
EXPECT_FALSE(t1.Equals(&t2));
EXPECT_TRUE(t1.Equals(e1));
EXPECT_FALSE(t1.Equals(t2));
ASSERT_EQ(t1.type, Type::BINARY);
ASSERT_EQ(t1.ToString(), std::string("binary"));
}
Expand Down
94 changes: 77 additions & 17 deletions cpp/src/arrow/array.cc
Expand Up @@ -42,7 +42,7 @@ Status GetEmptyBitmap(
// ----------------------------------------------------------------------
// Base array class

Array::Array(const TypePtr& type, int32_t length, int32_t null_count,
Array::Array(const std::shared_ptr<DataType>& type, int32_t length, int32_t null_count,
const std::shared_ptr<Buffer>& null_bitmap) {
type_ = type;
length_ = length;
Expand All @@ -51,6 +51,12 @@ Array::Array(const TypePtr& type, int32_t length, int32_t null_count,
if (null_bitmap_) { null_bitmap_data_ = null_bitmap_->data(); }
}

bool Array::BaseEquals(const std::shared_ptr<Array>& other) const {
if (this == other.get()) { return true; }
if (!other) { return false; }
return EqualsExact(*other.get());
}

bool Array::EqualsExact(const Array& other) const {
if (this == &other) { return true; }
if (length_ != other.length_ || null_count_ != other.null_count_ ||
Expand Down Expand Up @@ -91,7 +97,7 @@ Status NullArray::Accept(ArrayVisitor* visitor) const {
// ----------------------------------------------------------------------
// Primitive array base

PrimitiveArray::PrimitiveArray(const TypePtr& type, int32_t length,
PrimitiveArray::PrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length,
const std::shared_ptr<Buffer>& data, int32_t null_count,
const std::shared_ptr<Buffer>& null_bitmap)
: Array(type, length, null_count, null_bitmap) {
Expand All @@ -100,14 +106,9 @@ PrimitiveArray::PrimitiveArray(const TypePtr& type, int32_t length,
}

bool PrimitiveArray::EqualsExact(const PrimitiveArray& other) const {
if (this == &other) { return true; }
if (null_count_ != other.null_count_) { return false; }
if (!Array::EqualsExact(other)) { return false; }

if (null_count_ > 0) {
bool equal_bitmap =
null_bitmap_->Equals(*other.null_bitmap_, BitUtil::CeilByte(length_) / 8);
if (!equal_bitmap) { return false; }

const uint8_t* this_data = raw_data_;
const uint8_t* other_data = other.raw_data_;

Expand All @@ -131,7 +132,7 @@ bool PrimitiveArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (!arr) { return false; }
if (this->type_enum() != arr->type_enum()) { return false; }
return EqualsExact(*static_cast<const PrimitiveArray*>(arr.get()));
return EqualsExact(static_cast<const PrimitiveArray&>(*arr.get()));
}

template <typename T>
Expand Down Expand Up @@ -161,7 +162,7 @@ BooleanArray::BooleanArray(int32_t length, const std::shared_ptr<Buffer>& data,
: PrimitiveArray(
std::make_shared<BooleanType>(), length, data, null_count, null_bitmap) {}

BooleanArray::BooleanArray(const TypePtr& type, int32_t length,
BooleanArray::BooleanArray(const std::shared_ptr<DataType>& type, int32_t length,
const std::shared_ptr<Buffer>& data, int32_t null_count,
const std::shared_ptr<Buffer>& null_bitmap)
: PrimitiveArray(type, length, data, null_count, null_bitmap) {}
Expand Down Expand Up @@ -192,7 +193,7 @@ bool BooleanArray::EqualsExact(const BooleanArray& other) const {
bool BooleanArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) return true;
if (Type::BOOL != arr->type_enum()) { return false; }
return EqualsExact(*static_cast<const BooleanArray*>(arr.get()));
return EqualsExact(static_cast<const BooleanArray&>(*arr.get()));
}

bool BooleanArray::RangeEquals(int32_t start_idx, int32_t end_idx,
Expand Down Expand Up @@ -238,7 +239,7 @@ bool ListArray::EqualsExact(const ListArray& other) const {
bool ListArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (this->type_enum() != arr->type_enum()) { return false; }
return EqualsExact(*static_cast<const ListArray*>(arr.get()));
return EqualsExact(static_cast<const ListArray&>(*arr.get()));
}

bool ListArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
Expand Down Expand Up @@ -333,7 +334,7 @@ BinaryArray::BinaryArray(int32_t length, const std::shared_ptr<Buffer>& offsets,
const std::shared_ptr<Buffer>& null_bitmap)
: BinaryArray(kBinary, length, offsets, data, null_count, null_bitmap) {}

BinaryArray::BinaryArray(const TypePtr& type, int32_t length,
BinaryArray::BinaryArray(const std::shared_ptr<DataType>& type, int32_t length,
const std::shared_ptr<Buffer>& offsets, const std::shared_ptr<Buffer>& data,
int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap)
: Array(type, length, null_count, null_bitmap),
Expand Down Expand Up @@ -364,7 +365,7 @@ bool BinaryArray::EqualsExact(const BinaryArray& other) const {
bool BinaryArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (this->type_enum() != arr->type_enum()) { return false; }
return EqualsExact(*static_cast<const BinaryArray*>(arr.get()));
return EqualsExact(static_cast<const BinaryArray&>(*arr.get()));
}

bool BinaryArray::RangeEquals(int32_t start_idx, int32_t end_idx, int32_t other_start_idx,
Expand Down Expand Up @@ -493,7 +494,7 @@ Status StructArray::Accept(ArrayVisitor* visitor) const {
// ----------------------------------------------------------------------
// UnionArray

UnionArray::UnionArray(const TypePtr& type, int32_t length,
UnionArray::UnionArray(const std::shared_ptr<DataType>& type, int32_t length,
const std::vector<std::shared_ptr<Array>>& children,
const std::shared_ptr<Buffer>& type_ids, const std::shared_ptr<Buffer>& offsets,
int32_t null_count, const std::shared_ptr<Buffer>& null_bitmap)
Expand Down Expand Up @@ -586,14 +587,74 @@ Status UnionArray::Accept(ArrayVisitor* visitor) const {
return visitor->Visit(*this);
}

// ----------------------------------------------------------------------
// DictionaryArray

Status DictionaryArray::FromBuffer(const std::shared_ptr<DataType>& type, int32_t length,
const std::shared_ptr<Buffer>& indices, int32_t null_count,
const std::shared_ptr<Buffer>& null_bitmap, std::shared_ptr<DictionaryArray>* out) {
DCHECK_EQ(type->type, Type::DICTIONARY);
const auto& dict_type = static_cast<const DictionaryType*>(type.get());

std::shared_ptr<Array> boxed_indices;
RETURN_NOT_OK(MakePrimitiveArray(
dict_type->index_type(), length, indices, null_count, null_bitmap, &boxed_indices));

*out = std::make_shared<DictionaryArray>(type, boxed_indices);
return Status::OK();
}

DictionaryArray::DictionaryArray(
const std::shared_ptr<DataType>& type, const std::shared_ptr<Array>& indices)
: Array(type, indices->length(), indices->null_count(), indices->null_bitmap()),
dict_type_(static_cast<const DictionaryType*>(type.get())),
indices_(indices) {
DCHECK_EQ(type->type, Type::DICTIONARY);
}

Status DictionaryArray::Validate() const {
Type::type index_type_id = indices_->type()->type;
if (!is_integer(index_type_id)) {
return Status::Invalid("Dictionary indices must be integer type");
}
return Status::OK();
}

std::shared_ptr<Array> DictionaryArray::dictionary() const {
return dict_type_->dictionary();
}

bool DictionaryArray::EqualsExact(const DictionaryArray& other) const {
if (!dictionary()->Equals(other.dictionary())) { return false; }
return indices_->Equals(other.indices());
}

bool DictionaryArray::Equals(const std::shared_ptr<Array>& arr) const {
if (this == arr.get()) { return true; }
if (Type::DICTIONARY != arr->type_enum()) { return false; }
return EqualsExact(static_cast<const DictionaryArray&>(*arr.get()));
}

bool DictionaryArray::RangeEquals(int32_t start_idx, int32_t end_idx,
int32_t other_start_idx, const std::shared_ptr<Array>& arr) const {
if (Type::DICTIONARY != arr->type_enum()) { return false; }
const auto& dict_other = static_cast<const DictionaryArray&>(*arr.get());
if (!dictionary()->Equals(dict_other.dictionary())) { return false; }
return indices_->RangeEquals(start_idx, end_idx, other_start_idx, dict_other.indices());
}

Status DictionaryArray::Accept(ArrayVisitor* visitor) const {
return visitor->Visit(*this);
}

// ----------------------------------------------------------------------

#define MAKE_PRIMITIVE_ARRAY_CASE(ENUM, ArrayType) \
case Type::ENUM: \
out->reset(new ArrayType(type, length, data, null_count, null_bitmap)); \
break;

Status MakePrimitiveArray(const TypePtr& type, int32_t length,
Status MakePrimitiveArray(const std::shared_ptr<DataType>& type, int32_t length,
const std::shared_ptr<Buffer>& data, int32_t null_count,
const std::shared_ptr<Buffer>& null_bitmap, std::shared_ptr<Array>* out) {
switch (type->type) {
Expand All @@ -610,7 +671,6 @@ Status MakePrimitiveArray(const TypePtr& type, int32_t length,
MAKE_PRIMITIVE_ARRAY_CASE(DOUBLE, DoubleArray);
MAKE_PRIMITIVE_ARRAY_CASE(TIME, Int64Array);
MAKE_PRIMITIVE_ARRAY_CASE(TIMESTAMP, TimestampArray);
MAKE_PRIMITIVE_ARRAY_CASE(TIMESTAMP_DOUBLE, DoubleArray);
default:
return Status::NotImplemented(type->ToString());
}
Expand Down