Skip to content

Commit

Permalink
apacheGH-15483: [C++] Add a Fixed Shape Tensor canonical ExtensionType (
Browse files Browse the repository at this point in the history
apache#8510)

> [ARROW-1614](https://issues.apache.org/jira/browse/ARROW-1614): In an Arrow table, we would like to add support for a column that has values cells each containing a tensor value, with all tensors having the same dimensions. These would be stored as a binary value, plus some metadata to store type and shape/strides.
* Closes: apache#15483

Lead-authored-by: Rok Mihevc <rok@mihevc.org>
Co-authored-by: Rok <rok@mihevc.org>
Co-authored-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
Co-authored-by: Ben Harkins <60872452+benibus@users.noreply.github.com>
Signed-off-by: Joris Van den Bossche <jorisvandenbossche@gmail.com>
  • Loading branch information
3 people authored and ArgusLi committed May 15, 2023
1 parent f3d84f5 commit bfb2a05
Show file tree
Hide file tree
Showing 6 changed files with 515 additions and 0 deletions.
2 changes: 2 additions & 0 deletions cpp/src/arrow/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -520,6 +520,7 @@ endif()
if(ARROW_JSON)
list(APPEND
ARROW_SRCS
extension/fixed_shape_tensor.cc
json/options.cc
json/chunked_builder.cc
json/chunker.cc
Expand Down Expand Up @@ -856,6 +857,7 @@ endif()

if(ARROW_JSON)
add_subdirectory(json)
add_subdirectory(extension)
endif()

if(ARROW_ORC)
Expand Down
24 changes: 24 additions & 0 deletions cpp/src/arrow/extension/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

add_arrow_test(test
SOURCES
fixed_shape_tensor_test.cc
PREFIX
"arrow-fixed-shape-tensor")

arrow_install_all_headers("arrow/extension")
170 changes: 170 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include <numeric>
#include <sstream>

#include "arrow/extension/fixed_shape_tensor.h"

#include "arrow/array/array_nested.h"
#include "arrow/array/array_primitive.h"
#include "arrow/json/rapidjson_defs.h" // IWYU pragma: keep
#include "arrow/util/int_util_overflow.h"
#include "arrow/util/logging.h"
#include "arrow/util/sort.h"

#include <rapidjson/document.h>
#include <rapidjson/writer.h>

namespace rj = arrow::rapidjson;

namespace arrow {
namespace extension {

bool FixedShapeTensorType::ExtensionEquals(const ExtensionType& other) const {
if (extension_name() != other.extension_name()) {
return false;
}
const auto& other_ext = static_cast<const FixedShapeTensorType&>(other);

auto is_permutation_trivial = [](const std::vector<int64_t>& permutation) {
for (size_t i = 1; i < permutation.size(); ++i) {
if (permutation[i - 1] + 1 != permutation[i]) {
return false;
}
}
return true;
};
const bool permutation_equivalent =
((permutation_ == other_ext.permutation()) ||
(permutation_.empty() && is_permutation_trivial(other_ext.permutation())) ||
(is_permutation_trivial(permutation_) && other_ext.permutation().empty()));

return (storage_type()->Equals(other_ext.storage_type())) &&
(this->shape() == other_ext.shape()) && (dim_names_ == other_ext.dim_names()) &&
permutation_equivalent;
}

std::string FixedShapeTensorType::Serialize() const {
rj::Document document;
document.SetObject();
rj::Document::AllocatorType& allocator = document.GetAllocator();

rj::Value shape(rj::kArrayType);
for (auto v : shape_) {
shape.PushBack(v, allocator);
}
document.AddMember(rj::Value("shape", allocator), shape, allocator);

if (!permutation_.empty()) {
rj::Value permutation(rj::kArrayType);
for (auto v : permutation_) {
permutation.PushBack(v, allocator);
}
document.AddMember(rj::Value("permutation", allocator), permutation, allocator);
}

if (!dim_names_.empty()) {
rj::Value dim_names(rj::kArrayType);
for (std::string v : dim_names_) {
dim_names.PushBack(rj::Value{}.SetString(v.c_str(), allocator), allocator);
}
document.AddMember(rj::Value("dim_names", allocator), dim_names, allocator);
}

rj::StringBuffer buffer;
rj::Writer<rj::StringBuffer> writer(buffer);
document.Accept(writer);
return buffer.GetString();
}

Result<std::shared_ptr<DataType>> FixedShapeTensorType::Deserialize(
std::shared_ptr<DataType> storage_type, const std::string& serialized_data) const {
if (storage_type->id() != Type::FIXED_SIZE_LIST) {
return Status::Invalid("Expected FixedSizeList storage type, got ",
storage_type->ToString());
}
auto value_type =
internal::checked_pointer_cast<FixedSizeListType>(storage_type)->value_type();
rj::Document document;
if (document.Parse(serialized_data.data(), serialized_data.length()).HasParseError() ||
!document.HasMember("shape") || !document["shape"].IsArray()) {
return Status::Invalid("Invalid serialized JSON data: ", serialized_data);
}

std::vector<int64_t> shape;
for (auto& x : document["shape"].GetArray()) {
shape.emplace_back(x.GetInt64());
}
std::vector<int64_t> permutation;
if (document.HasMember("permutation")) {
for (auto& x : document["permutation"].GetArray()) {
permutation.emplace_back(x.GetInt64());
}
if (shape.size() != permutation.size()) {
return Status::Invalid("Invalid permutation");
}
}
std::vector<std::string> dim_names;
if (document.HasMember("dim_names")) {
for (auto& x : document["dim_names"].GetArray()) {
dim_names.emplace_back(x.GetString());
}
if (shape.size() != dim_names.size()) {
return Status::Invalid("Invalid dim_names");
}
}

return fixed_shape_tensor(value_type, shape, permutation, dim_names);
}

std::shared_ptr<Array> FixedShapeTensorType::MakeArray(
std::shared_ptr<ArrayData> data) const {
DCHECK_EQ(data->type->id(), Type::EXTENSION);
DCHECK_EQ("arrow.fixed_shape_tensor",
static_cast<const ExtensionType&>(*data->type).extension_name());
return std::make_shared<ExtensionArray>(data);
}

Result<std::shared_ptr<DataType>> FixedShapeTensorType::Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation, const std::vector<std::string>& dim_names) {
if (!permutation.empty() && shape.size() != permutation.size()) {
return Status::Invalid("permutation size must match shape size. Expected: ",
shape.size(), " Got: ", permutation.size());
}
if (!dim_names.empty() && shape.size() != dim_names.size()) {
return Status::Invalid("dim_names size must match shape size. Expected: ",
shape.size(), " Got: ", dim_names.size());
}
const auto size = std::accumulate(shape.begin(), shape.end(), static_cast<int64_t>(1),
std::multiplies<>());
return std::make_shared<FixedShapeTensorType>(value_type, static_cast<int32_t>(size),
shape, permutation, dim_names);
}

std::shared_ptr<DataType> fixed_shape_tensor(const std::shared_ptr<DataType>& value_type,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation,
const std::vector<std::string>& dim_names) {
auto maybe_type = FixedShapeTensorType::Make(value_type, shape, permutation, dim_names);
ARROW_DCHECK_OK(maybe_type.status());
return maybe_type.MoveValueUnsafe();
}

} // namespace extension
} // namespace arrow
92 changes: 92 additions & 0 deletions cpp/src/arrow/extension/fixed_shape_tensor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "arrow/extension_type.h"

namespace arrow {
namespace extension {

class ARROW_EXPORT FixedShapeTensorArray : public ExtensionArray {
public:
using ExtensionArray::ExtensionArray;
};

/// \brief Concrete type class for constant-size Tensor data.
/// This is a canonical arrow extension type.
/// See: https://arrow.apache.org/docs/format/CanonicalExtensions.html
class ARROW_EXPORT FixedShapeTensorType : public ExtensionType {
public:
FixedShapeTensorType(const std::shared_ptr<DataType>& value_type, const int32_t& size,
const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {})
: ExtensionType(fixed_size_list(value_type, size)),
value_type_(value_type),
shape_(shape),
permutation_(permutation),
dim_names_(dim_names) {}

std::string extension_name() const override { return "arrow.fixed_shape_tensor"; }

/// Number of dimensions of tensor elements
size_t ndim() { return shape_.size(); }

/// Shape of tensor elements
const std::vector<int64_t> shape() const { return shape_; }

/// Value type of tensor elements
const std::shared_ptr<DataType> value_type() const { return value_type_; }

/// Permutation mapping from logical to physical memory layout of tensor elements
const std::vector<int64_t>& permutation() const { return permutation_; }

/// Dimension names of tensor elements. Dimensions are ordered physically.
const std::vector<std::string>& dim_names() const { return dim_names_; }

bool ExtensionEquals(const ExtensionType& other) const override;

std::string Serialize() const override;

Result<std::shared_ptr<DataType>> Deserialize(
std::shared_ptr<DataType> storage_type,
const std::string& serialized_data) const override;

/// Create a FixedShapeTensorArray from ArrayData
std::shared_ptr<Array> MakeArray(std::shared_ptr<ArrayData> data) const override;

/// \brief Create a FixedShapeTensorType instance
static Result<std::shared_ptr<DataType>> Make(
const std::shared_ptr<DataType>& value_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

private:
std::shared_ptr<DataType> storage_type_;
std::shared_ptr<DataType> value_type_;
std::vector<int64_t> shape_;
std::vector<int64_t> permutation_;
std::vector<std::string> dim_names_;
};

/// \brief Return a FixedShapeTensorType instance.
ARROW_EXPORT std::shared_ptr<DataType> fixed_shape_tensor(
const std::shared_ptr<DataType>& storage_type, const std::vector<int64_t>& shape,
const std::vector<int64_t>& permutation = {},
const std::vector<std::string>& dim_names = {});

} // namespace extension
} // namespace arrow
Loading

0 comments on commit bfb2a05

Please sign in to comment.