Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-39865: [C++] Strip extension metadata when importing a registered extension #39866

Merged
merged 2 commits into from
Feb 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
6 changes: 6 additions & 0 deletions cpp/src/arrow/c/bridge.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,6 +914,8 @@ struct DecodedMetadata {
std::shared_ptr<KeyValueMetadata> metadata;
std::string extension_name;
std::string extension_serialized;
int extension_name_index = -1; // index of extension_name in metadata
int extension_serialized_index = -1; // index of extension_serialized in metadata
};

Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
Expand Down Expand Up @@ -956,8 +958,10 @@ Result<DecodedMetadata> DecodeMetadata(const char* metadata) {
RETURN_NOT_OK(read_string(&values[i]));
if (keys[i] == kExtensionTypeKeyName) {
decoded.extension_name = values[i];
decoded.extension_name_index = i;
} else if (keys[i] == kExtensionMetadataKeyName) {
decoded.extension_serialized = values[i];
decoded.extension_serialized_index = i;
}
}
decoded.metadata = key_value_metadata(std::move(keys), std::move(values));
Expand Down Expand Up @@ -1046,6 +1050,8 @@ struct SchemaImporter {
ARROW_ASSIGN_OR_RAISE(
type_, registered_ext_type->Deserialize(std::move(type_),
metadata_.extension_serialized));
RETURN_NOT_OK(metadata_.metadata->DeleteMany(
{metadata_.extension_name_index, metadata_.extension_serialized_index}));
}
}

Expand Down
48 changes: 32 additions & 16 deletions cpp/src/arrow/c/bridge_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1872,7 +1872,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
AssertTypeEqual(*expected, *type);
AssertTypeEqual(*expected, *type, /*check_metadata=*/true);
}

void CheckImport(const std::shared_ptr<Field>& expected) {
Expand All @@ -1892,7 +1892,7 @@ class TestSchemaImport : public ::testing::Test, public SchemaStructBuilder {
ASSERT_TRUE(ArrowSchemaIsReleased(&c_struct_));
Reset(); // for further tests
cb.AssertCalled(); // was released
AssertSchemaEqual(*expected, *schema);
AssertSchemaEqual(*expected, *schema, /*check_metadata=*/true);
}

void CheckImportError() {
Expand Down Expand Up @@ -3571,7 +3571,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the type
ASSERT_OK_AND_ASSIGN(actual, ImportType(&c_schema));
type = factory_expected();
AssertTypeEqual(*type, *actual);
AssertTypeEqual(*type, *actual, /*check_metadata=*/true);
type.reset();
actual.reset();

Expand Down Expand Up @@ -3602,7 +3602,7 @@ class TestSchemaRoundtrip : public ::testing::Test {
// Recreate the schema
ASSERT_OK_AND_ASSIGN(actual, ImportSchema(&c_schema));
schema = factory();
AssertSchemaEqual(*schema, *actual);
AssertSchemaEqual(*schema, *actual, /*check_metadata=*/true);
schema.reset();
actual.reset();

Expand Down Expand Up @@ -3695,13 +3695,27 @@ TEST_F(TestSchemaRoundtrip, Dictionary) {
}
}

// Given an extension type, return a field of its storage type + the
// serialized extension metadata.
std::shared_ptr<Field> GetStorageWithMetadata(const std::string& field_name,
const std::shared_ptr<DataType>& type) {
const auto& ext_type = checked_cast<const ExtensionType&>(*type);
auto storage_type = ext_type.storage_type();
auto md = KeyValueMetadata::Make({kExtensionTypeKeyName, kExtensionMetadataKeyName},
{ext_type.extension_name(), ext_type.Serialize()});
return field(field_name, storage_type, /*nullable=*/true, md);
}

TEST_F(TestSchemaRoundtrip, UnregisteredExtension) {
TestWithTypeFactory(uuid, []() { return fixed_size_binary(16); });
TestWithTypeFactory(dict_extension_type, []() { return dictionary(int8(), utf8()); });

// Inside nested type
TestWithTypeFactory([]() { return list(dict_extension_type()); },
[]() { return list(dictionary(int8(), utf8())); });
// Inside nested type.
// When an extension type is not known by the importer, it is imported
// as its storage type and the extension metadata is preserved on the field.
TestWithTypeFactory(
[]() { return list(dict_extension_type()); },
[]() { return list(GetStorageWithMetadata("item", dict_extension_type())); });
pitrou marked this conversation as resolved.
Show resolved Hide resolved
}

TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
Expand All @@ -3710,7 +3724,9 @@ TEST_F(TestSchemaRoundtrip, RegisteredExtension) {
TestWithTypeFactory(dict_extension_type);
TestWithTypeFactory(complex128);

// Inside nested type
// Inside nested type.
// When the extension type is registered, the extension metadata is removed
// from the storage type's field to ensure roundtripping (GH-39865).
TestWithTypeFactory([]() { return list(uuid()); });
TestWithTypeFactory([]() { return list(dict_extension_type()); });
TestWithTypeFactory([]() { return list(complex128()); });
Expand Down Expand Up @@ -3810,7 +3826,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
AssertTypeEqual(*expected->type(), *array->type());
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
Expand Down Expand Up @@ -3850,7 +3866,7 @@ class TestArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
AssertSchemaEqual(*expected->schema(), *batch->schema());
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
Expand Down Expand Up @@ -4230,7 +4246,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<Array> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory_expected()));
AssertTypeEqual(*expected->type(), *array->type());
AssertTypeEqual(*expected->type(), *array->type(), /*check_metadata=*/true);
AssertArraysEqual(*expected, *array, true);
}
array.reset();
Expand Down Expand Up @@ -4276,7 +4292,7 @@ class TestDeviceArrayRoundtrip : public ::testing::Test {
{
std::shared_ptr<RecordBatch> expected;
ASSERT_OK_AND_ASSIGN(expected, ToResult(factory()));
AssertSchemaEqual(*expected->schema(), *batch->schema());
AssertSchemaEqual(*expected->schema(), *batch->schema(), /*check_metadata=*/true);
AssertBatchesEqual(*expected, *batch);
}
batch.reset();
Expand Down Expand Up @@ -4353,7 +4369,7 @@ class TestArrayStreamExport : public BaseArrayStreamTest {
SchemaExportGuard schema_guard(&c_schema);
ASSERT_FALSE(ArrowSchemaIsReleased(&c_schema));
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
AssertSchemaEqual(expected, *schema);
AssertSchemaEqual(expected, *schema, /*check_metadata=*/true);
}

void AssertStreamEnd(struct ArrowArrayStream* c_stream) {
Expand Down Expand Up @@ -4437,7 +4453,7 @@ TEST_F(TestArrayStreamExport, ArrayLifetime) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto got_schema, ImportSchema(&c_schema));
AssertSchemaEqual(*schema, *got_schema);
AssertSchemaEqual(*schema, *got_schema, /*check_metadata=*/true);
}

ASSERT_GT(pool_->bytes_allocated(), orig_allocated_);
Expand All @@ -4462,7 +4478,7 @@ TEST_F(TestArrayStreamExport, Errors) {
{
SchemaExportGuard schema_guard(&c_schema);
ASSERT_OK_AND_ASSIGN(auto schema, ImportSchema(&c_schema));
AssertSchemaEqual(schema, arrow::schema({}));
AssertSchemaEqual(schema, arrow::schema({}), /*check_metadata=*/true);
}

struct ArrowArray c_array;
Expand Down Expand Up @@ -4539,7 +4555,7 @@ TEST_F(TestArrayStreamRoundtrip, Simple) {
ASSERT_OK_AND_ASSIGN(auto reader, RecordBatchReader::Make(batches, orig_schema));

Roundtrip(std::move(reader), [&](const std::shared_ptr<RecordBatchReader>& reader) {
AssertSchemaEqual(*orig_schema, *reader->schema());
AssertSchemaEqual(*orig_schema, *reader->schema(), /*check_metadata=*/true);
AssertReaderNext(reader, *batches[0]);
AssertReaderNext(reader, *batches[1]);
AssertReaderEnd(reader);
Expand Down
18 changes: 8 additions & 10 deletions cpp/src/arrow/util/key_value_metadata.cc
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ void KeyValueMetadata::Append(std::string key, std::string value) {
values_.push_back(std::move(value));
}

Result<std::string> KeyValueMetadata::Get(const std::string& key) const {
Result<std::string> KeyValueMetadata::Get(std::string_view key) const {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
Expand Down Expand Up @@ -129,7 +129,7 @@ Status KeyValueMetadata::DeleteMany(std::vector<int64_t> indices) {
return Status::OK();
}

Status KeyValueMetadata::Delete(const std::string& key) {
Status KeyValueMetadata::Delete(std::string_view key) {
auto index = FindKey(key);
if (index < 0) {
return Status::KeyError(key);
Expand All @@ -138,20 +138,18 @@ Status KeyValueMetadata::Delete(const std::string& key) {
}
}

Status KeyValueMetadata::Set(const std::string& key, const std::string& value) {
Status KeyValueMetadata::Set(std::string key, std::string value) {
auto index = FindKey(key);
if (index < 0) {
Append(key, value);
Append(std::move(key), std::move(value));
} else {
keys_[index] = key;
values_[index] = value;
keys_[index] = std::move(key);
values_[index] = std::move(value);
}
return Status::OK();
}

bool KeyValueMetadata::Contains(const std::string& key) const {
return FindKey(key) >= 0;
}
bool KeyValueMetadata::Contains(std::string_view key) const { return FindKey(key) >= 0; }

void KeyValueMetadata::reserve(int64_t n) {
DCHECK_GE(n, 0);
Expand Down Expand Up @@ -188,7 +186,7 @@ std::vector<std::pair<std::string, std::string>> KeyValueMetadata::sorted_pairs(
return pairs;
}

int KeyValueMetadata::FindKey(const std::string& key) const {
int KeyValueMetadata::FindKey(std::string_view key) const {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

👍

for (size_t i = 0; i < keys_.size(); ++i) {
if (keys_[i] == key) {
return static_cast<int>(i);
Expand Down
11 changes: 6 additions & 5 deletions cpp/src/arrow/util/key_value_metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
#include <cstdint>
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <utility>
#include <vector>
Expand All @@ -44,13 +45,13 @@ class ARROW_EXPORT KeyValueMetadata {
void ToUnorderedMap(std::unordered_map<std::string, std::string>* out) const;
void Append(std::string key, std::string value);

Result<std::string> Get(const std::string& key) const;
bool Contains(const std::string& key) const;
Result<std::string> Get(std::string_view key) const;
bool Contains(std::string_view key) const;
// Note that deleting may invalidate known indices
Status Delete(const std::string& key);
Status Delete(std::string_view key);
Status Delete(int64_t index);
Status DeleteMany(std::vector<int64_t> indices);
Status Set(const std::string& key, const std::string& value);
Status Set(std::string key, std::string value);

void reserve(int64_t n);

Expand All @@ -63,7 +64,7 @@ class ARROW_EXPORT KeyValueMetadata {
std::vector<std::pair<std::string, std::string>> sorted_pairs() const;

/// \brief Perform linear search for key, returning -1 if not found
int FindKey(const std::string& key) const;
int FindKey(std::string_view key) const;

std::shared_ptr<KeyValueMetadata> Copy() const;

Expand Down