diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 59ee4c1e737..26456a2fd47 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -1160,6 +1160,7 @@ cc_library( tf_gen_op_libs( is_external = False, op_lib_names = [ + "parquet_ops", "batch_ops", "bitwise_ops", "boosted_trees_ops", @@ -1415,6 +1416,7 @@ cc_library( visibility = ["//visibility:public"], deps = [ ":array_ops_op_lib", + ":parquet_ops_op_lib", ":audio_ops_op_lib", ":batch_ops_op_lib", ":bitwise_ops_op_lib", @@ -1619,6 +1621,7 @@ cc_library( "//tensorflow/core/kernels:function_ops", "//tensorflow/core/kernels:functional_ops", "//tensorflow/core/kernels:fused_embedding_ops", + "//tensorflow/core/kernels/data:parquet_dataset_ops", "//tensorflow/core/kernels:grappler", "//tensorflow/core/kernels:hash_ops", "//tensorflow/core/kernels:histogram_op", diff --git a/tensorflow/core/api_def/base_api/api_def_ParquetTabularDatasetV1.pbtxt b/tensorflow/core/api_def/base_api/api_def_ParquetTabularDatasetV1.pbtxt new file mode 100644 index 00000000000..acc51c40081 --- /dev/null +++ b/tensorflow/core/api_def/base_api/api_def_ParquetTabularDatasetV1.pbtxt @@ -0,0 +1,3 @@ +op { + graph_op_name: "ParquetTabularDatasetV1" +} diff --git a/tensorflow/core/kernels/data/BUILD b/tensorflow/core/kernels/data/BUILD index 18d9ab0be83..3ba229c47e5 100644 --- a/tensorflow/core/kernels/data/BUILD +++ b/tensorflow/core/kernels/data/BUILD @@ -5,6 +5,7 @@ load( "//tensorflow:tensorflow.bzl", "tf_cc_test", "tf_kernel_library", + "pybind_extension", ) package( @@ -1263,6 +1264,7 @@ tf_kernel_library( tf_kernel_library( name = "data", deps = [ + ":parquet_dataset_ops", ":batch_dataset_op", ":cache_dataset_ops", ":concatenate_dataset_op", @@ -1365,3 +1367,47 @@ tf_kernel_library( "//tensorflow/core:lib_internal", ], ) + +cc_library( + name = "parquet_dataset_ops", + srcs = [ + "parquet_dataset_ops.cc", + "parquet_batch_reader.h", + "parquet_batch_reader.cc", + ], + hdrs = ["parquet_dataset_ops.h"], + deps = [ + ":arrow_util", + ":dataset_ops", + "//tensorflow/core:framework", + ], +) + +pybind_extension( + name = "_parquet_pybind", + srcs = ["parquet_pybind.cc"], + copts = ["-fexceptions"], + features = ["-use_header_modules"], + module_name = "_parquet_pybind", + deps = [ + ":arrow_util", + "@pybind11", + ], +) + +cc_library( + name = "arrow_util", + srcs = ["arrow_util.cc", + "eigen.h"], + hdrs = ["arrow_util.h"], + deps = [ + "@arrow", + "//third_party/eigen3", + "//tensorflow/core:framework", + ], + defines = [ + "DEEPREC_ARROW_HDFS", + "DEEPREC_ARROW_S3", + "DEEPREC_ARROW_ZEROCOPY", + ] +) diff --git a/tensorflow/core/kernels/data/arrow_util.cc b/tensorflow/core/kernels/data/arrow_util.cc new file mode 100644 index 00000000000..6c5b414ea0e --- /dev/null +++ b/tensorflow/core/kernels/data/arrow_util.cc @@ -0,0 +1,389 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/arrow_util.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "arrow/array.h" +#include "arrow/util/thread_pool.h" +#include "tensorflow/core/kernels/data/eigen.h" +#include "tensorflow/core/framework/allocation_description.pb.h" + +namespace tensorflow { +namespace data { +namespace ArrowUtil { + +namespace { + +int EnvGetInt(const std::string& env_var, int default_val) { + const char* env_var_val = getenv(env_var.c_str()); + if (env_var_val == nullptr) { + return default_val; + } + std::string env_var_val_str(env_var_val); + std::istringstream ss(env_var_val_str); + int result; + if (!(ss >> result)) { + result = default_val; + } + return result; +} + +int SetArrowCpuThreadPoolCapacityFromEnv() { + int arrow_threads = EnvGetInt("ARROW_NUM_THREADS", 0); + if (arrow_threads > 0) { // Set from environment variable + auto s = ::arrow::SetCpuThreadPoolCapacity(arrow_threads); + if (ARROW_PREDICT_FALSE(!s.ok())) { + return 0; + } + } + return arrow_threads; +} + +::arrow::Status MakeNumpyDtypeAndRaggedRankFromArrowDataType( + std::string* numpy_dtype, int* ragged_rank, + const std::shared_ptr<::arrow::DataType>& arrow_dtype) { + if (arrow_dtype->id() == ::arrow::Type::LIST) { + ++(*ragged_rank); + return MakeNumpyDtypeAndRaggedRankFromArrowDataType( + numpy_dtype, ragged_rank, arrow_dtype->field(0)->type()); + } + + switch (arrow_dtype->id()) { + case ::arrow::Type::INT8: + case ::arrow::Type::UINT8: + case ::arrow::Type::INT32: + case ::arrow::Type::INT64: + case ::arrow::Type::UINT64: + *numpy_dtype = arrow_dtype->name(); + break; + case ::arrow::Type::HALF_FLOAT: + *numpy_dtype = "float16"; + break; + case ::arrow::Type::FLOAT: + *numpy_dtype = "float32"; + break; + case ::arrow::Type::DOUBLE: + *numpy_dtype = "float64"; + break; + case ::arrow::Type::STRING: + *numpy_dtype = "O"; + break; + default: + return ::arrow::Status::Invalid( + "Arrow data type ", arrow_dtype->ToString(), " not supported."); + } + return ::arrow::Status::OK(); +} + +#if DEEPREC_ARROW_ZEROCOPY +class ArrowPrimitiveTensorBuffer : public TensorBuffer { + public: + ArrowPrimitiveTensorBuffer() = delete; + + explicit ArrowPrimitiveTensorBuffer( + const std::shared_ptr& arrow_buffer) + : TensorBuffer(const_cast(arrow_buffer->data())), + arrow_buffer_(arrow_buffer) {} + + size_t size() const override { return arrow_buffer_->size(); } + + TensorBuffer* root_buffer() override { return this; } + + void FillAllocationDescription(AllocationDescription* proto) const override { + proto->set_requested_bytes(size()); + proto->set_allocator_name(::tensorflow::cpu_allocator()->Name()); + } + + bool OwnsMemory() const override { return false; } + + private: + std::shared_ptr<::arrow::Buffer> arrow_buffer_; +}; +#endif + +::arrow::Status MakeTensorFromArrowBuffer( + DataType dtype, const std::shared_ptr<::arrow::Buffer>& arrow_buffer, + Tensor* tensor) { + const TensorShape shape = {arrow_buffer->size() / DataTypeSize(dtype)}; + +#if DEEPREC_ARROW_ZEROCOPY + // NOTE: Alignment is 64 in Arrow 4.x, same to EIGEN_MAX_ALIGN_BYTES. See: + // https://github.com/apache/arrow/blob/apache-arrow-4.0.1/cpp/src/arrow/memory_pool.cc#L97 + if (TF_PREDICT_FALSE(!CHECK_EIGEN_ALIGN(arrow_buffer->data()))) { + *tensor = Tensor(dtype, shape); + std::memcpy(const_cast(tensor->tensor_data().data()), + arrow_buffer->data(), arrow_buffer->size()); + return ::arrow::Status::OK(); + } + + ArrowPrimitiveTensorBuffer* tensor_buffer = + new ArrowPrimitiveTensorBuffer(arrow_buffer); + core::ScopedUnref unref(tensor_buffer); + *tensor = Tensor(dtype, shape, tensor_buffer); + return ::arrow::Status::OK(); +#else + *tensor = Tensor(dtype, shape); + std::memcpy(const_cast(tensor->tensor_data().data()), + arrow_buffer->data(), arrow_buffer->size()); + return ::arrow::Status::OK(); +#endif +} + +::arrow::Status MakeStringTensorFromArrowArray( + const ::arrow::StringArray& array, Tensor* tensor) { + if (array.null_count() != 0) { + return ::arrow::Status::Invalid("Null elements not supported"); + } + + const auto num_strings = array.length(); + + *tensor = Tensor(DT_STRING, TensorShape({num_strings})); + auto tensor_vec = tensor->vec(); + + for (auto i = 0; i < num_strings; ++i) { + int string_size; + auto string_data = array.GetValue(i, &string_size); + tensor_vec(i).assign(reinterpret_cast(string_data), + string_size); + } + return ::arrow::Status::OK(); +} + +// Primitive Arrow arrays have validity and value buffers. +#define RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(ARRAY_CLASS) \ + ::arrow::Status Visit(const ARRAY_CLASS& array) override { \ + if (TF_PREDICT_FALSE(ragged_rank_ != 0)) { \ + return ::arrow::Status::Invalid("Inconsistent ragged rank"); \ + } \ + Tensor tensor; \ + auto st = \ + MakeTensorFromArrowBuffer(dtype_, array.data()->buffers[1], &tensor); \ + if (!st.ok()) { \ + return st; \ + } \ + ragged_tensor_.push_front(std::move(tensor)); \ + return ::arrow::Status::OK(); \ + } + +#define RAGGED_TENSOR_BUILDER_STRING_VISIT(ARRAY_CLASS) \ + ::arrow::Status Visit(const ARRAY_CLASS& array) override { \ + if (TF_PREDICT_FALSE(ragged_rank_ != 0)) { \ + return ::arrow::Status::Invalid("Inconsistent ragged rank"); \ + } \ + Tensor tensor; \ + auto st = MakeStringTensorFromArrowArray(array, &tensor); \ + if (!st.ok()) { \ + return st; \ + } \ + ragged_tensor_.push_front(std::move(tensor)); \ + return ::arrow::Status::OK(); \ + } + +class RaggedTensorBuilder : public ::arrow::ArrayVisitor { + public: + RaggedTensorBuilder(DataType dtype, int32 ragged_rank) + : dtype_(dtype), ragged_rank_(ragged_rank) {} + + ::arrow::Status Build(const std::shared_ptr<::arrow::Array>& array, + std::vector* output_tensors) { + auto st = array->Accept(this); + if (!st.ok()) { + return st; + } + output_tensors->insert(output_tensors->end(), ragged_tensor_.begin(), + ragged_tensor_.end()); + return ::arrow::Status::OK(); + } + + ::arrow::Status Visit(const ::arrow::ListArray& array) override { + --ragged_rank_; + Tensor tensor; + auto st = + MakeTensorFromArrowBuffer(DT_INT32, array.value_offsets(), &tensor); + if (!st.ok()) { + return st; + } + ragged_tensor_.push_front(std::move(tensor)); + return array.values()->Accept(this); + } + + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::Int8Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::UInt8Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::Int32Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::UInt32Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::Int64Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::UInt64Array); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::HalfFloatArray); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::FloatArray); + RAGGED_TENSOR_BUILDER_PRIMITIVE_VISIT(::arrow::DoubleArray); + + RAGGED_TENSOR_BUILDER_STRING_VISIT(::arrow::StringArray); + + private: + const DataType dtype_; + int32 ragged_rank_; + std::deque ragged_tensor_; +}; + +} // namespace + +#define CASE_ARROW_ENUM_SET_DTYPE(PTR, ENUM) \ + case ENUM: { \ + *PTR = DataTypeToEnum::Type>::value; \ + return Status::OK(); \ + } + +Status MakeDataTypeAndRaggedRankFromArrowDataType( + const std::shared_ptr<::arrow::DataType>& arrow_dtype, DataType* dtype, + int32* ragged_rank) { + if (arrow_dtype->id() == ::arrow::Type::LIST) { + ++(*ragged_rank); + return MakeDataTypeAndRaggedRankFromArrowDataType( + arrow_dtype->field(0)->type(), dtype, ragged_rank); + } + + switch (arrow_dtype->id()) { + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::INT8); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::UINT8); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::INT32); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::UINT32); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::INT64); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::UINT64); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::HALF_FLOAT); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::FLOAT); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::DOUBLE); + CASE_ARROW_ENUM_SET_DTYPE(dtype, ::arrow::Type::STRING); + default: + return errors::Unimplemented("Arrow data type ", arrow_dtype->ToString(), + " not supported."); + } + return Status::OK(); +} + +Status MakeTensorsFromArrowArray( + DataType dtype, int32 ragged_rank, + const std::shared_ptr<::arrow::Array>& arrow_array, + std::vector* output_tensors) { + if (TF_PREDICT_FALSE(arrow_array->null_count() != 0)) { + return errors::Internal("Arrow array with null values not supported"); + } + + if (TF_PREDICT_FALSE(arrow_array->data()->offset != 0)) { + return errors::Internal("Arrow array has zero non-offset not supported"); + } + + RaggedTensorBuilder builder(dtype, ragged_rank); + TF_RETURN_IF_ARROW_ERROR(builder.Build(arrow_array, output_tensors)); + return Status::OK(); +} + +int UpdateArrowCpuThreadPoolCapacityFromEnv() { + static int arrow_threads = SetArrowCpuThreadPoolCapacityFromEnv(); + return arrow_threads; +} + +int GetArrowFileBufferSizeFromEnv() { + static int buffer_size = EnvGetInt("ARROW_FILE_BUFFER_SIZE", 4096 * 4); + return buffer_size; +} + +::arrow::Status OpenArrowFile( + std::shared_ptr<::arrow::io::RandomAccessFile>* file, + const std::string& filename) { +#if DEEPREC_ARROW_HDFS + if (filename.rfind("hdfs://", 0) == 0) { + ::arrow::internal::Uri uri; + ARROW_RETURN_NOT_OK(uri.Parse(filename)); + ARROW_ASSIGN_OR_RAISE(auto options, ::arrow::fs::HdfsOptions::FromUri(uri)); + std::shared_ptr<::arrow::io::HadoopFileSystem> fs; + ARROW_RETURN_NOT_OK(::arrow::io::HadoopFileSystem::Connect( + &options.connection_config, &fs)); + std::shared_ptr<::arrow::io::HdfsReadableFile> hdfs_file; + ARROW_RETURN_NOT_OK(fs->OpenReadable(uri.path(), &hdfs_file)); + *file = hdfs_file; + return ::arrow::Status::OK(); + } +#endif +#if DEEPREC_ARROW_S3 + if (filename.rfind("s3://", 0) == 0 || filename.rfind("oss://", 0) == 0) { + ARROW_RETURN_NOT_OK(::arrow::fs::EnsureS3Initialized()); + ::arrow::internal::Uri uri; + ARROW_RETURN_NOT_OK(uri.Parse(filename)); + std::string path; + ARROW_ASSIGN_OR_RAISE(auto options, + ::arrow::fs::S3Options::FromUri(uri, &path)); + ARROW_ASSIGN_OR_RAISE(auto fs, ::arrow::fs::S3FileSystem::Make(options)); + ARROW_ASSIGN_OR_RAISE(*file, fs->OpenInputFile(path)); + return ::arrow::Status::OK(); + } +#endif + auto fs = std::make_shared<::arrow::fs::LocalFileSystem>(); + ARROW_ASSIGN_OR_RAISE(*file, fs->OpenInputFile(filename)); + return ::arrow::Status::OK(); +} + +::arrow::Status OpenParquetReader( + std::unique_ptr<::parquet::arrow::FileReader>* reader, + const std::shared_ptr<::arrow::io::RandomAccessFile>& file) { + auto config = ::parquet::ReaderProperties(); + config.enable_buffered_stream(); + config.set_buffer_size(GetArrowFileBufferSizeFromEnv()); + ARROW_RETURN_NOT_OK(::parquet::arrow::FileReader::Make( + ::arrow::default_memory_pool(), + ::parquet::ParquetFileReader::Open(file, config), reader)); + // If ARROW_NUM_THREADS > 0, specified number of threads will be used. + // If ARROW_NUM_THREADS = 0, no threads will be used. + // If ARROW_NUM_THREADS < 0, all threads will be used. + (*reader)->set_use_threads(UpdateArrowCpuThreadPoolCapacityFromEnv() != 0); + return ::arrow::Status::OK(); +} + +::arrow::Status GetParquetDataFrameFields( + std::vector* field_names, + std::vector* field_dtypes, + std::vector* field_ragged_ranks, const std::string& filename) { + std::shared_ptr<::arrow::io::RandomAccessFile> file; + ARROW_RETURN_NOT_OK(OpenArrowFile(&file, filename)); + std::unique_ptr<::parquet::arrow::FileReader> reader; + ARROW_RETURN_NOT_OK(OpenParquetReader(&reader, file)); + + std::shared_ptr<::arrow::Schema> schema; + ARROW_RETURN_NOT_OK(reader->GetSchema(&schema)); + if (ARROW_PREDICT_FALSE(!schema->HasDistinctFieldNames())) { + return ::arrow::Status::Invalid(filename, + " must has distinct column names"); + } + for (const auto& field : schema->fields()) { + field_names->push_back(field->name()); + std::string dtype; + int ragged_rank = 0; + ARROW_RETURN_NOT_OK(MakeNumpyDtypeAndRaggedRankFromArrowDataType( + &dtype, &ragged_rank, field->type())); + field_dtypes->push_back(dtype); + field_ragged_ranks->push_back(ragged_rank); + } + return ::arrow::Status::OK(); +} + +} // namespace ArrowUtil +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/arrow_util.h b/tensorflow/core/kernels/data/arrow_util.h new file mode 100644 index 00000000000..87d527f8934 --- /dev/null +++ b/tensorflow/core/kernels/data/arrow_util.h @@ -0,0 +1,116 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_ARROW_UTIL_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_ARROW_UTIL_H_ + +#include +#include + +#include "arrow/dataset/api.h" +#include "arrow/record_batch.h" +#include "parquet/arrow/reader.h" +#include "parquet/properties.h" +#include "arrow/filesystem/localfs.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" + +#if DEEPREC_ARROW_HDFS +#include +#endif +#if DEEPREC_ARROW_S3 +#include +#endif + +#define TF_RETURN_IF_ARROW_ERROR(...) \ + do { \ + const ::arrow::Status _status = (__VA_ARGS__); \ + if (TF_PREDICT_FALSE(!_status.ok())) \ + return errors::Internal(_status.ToString()); \ + } while (0) + +#define TF_CHECKED_ARROW_ASSIGN(lhs, rexpr) \ + do { \ + auto&& _result = (rexpr); \ + if (TF_PREDICT_FALSE(!_result.ok())) \ + return errors::Internal(_result.status().ToString()); \ + lhs = std::move(_result).ValueUnsafe(); \ + } while (0) + +namespace tensorflow { +namespace data { +namespace ArrowUtil { + +int UpdateArrowCpuThreadPoolCapacityFromEnv(); + +int GetArrowFileBufferSizeFromEnv(); + +::arrow::Status OpenArrowFile( + std::shared_ptr<::arrow::io::RandomAccessFile>* file, + const std::string& filename); + +::arrow::Status OpenParquetReader( + std::unique_ptr<::parquet::arrow::FileReader>* reader, + const std::shared_ptr<::arrow::io::RandomAccessFile>& file); + +::arrow::Status GetParquetDataFrameFields( + std::vector* field_names, + std::vector* field_dtypes, + std::vector* field_ragged_ranks, const std::string& filename); + +template +struct DataTypeToArrowEnum { + static constexpr ::arrow::Type::type value = ::arrow::Type::NA; +}; + +template <::arrow::Type::type VALUE> +struct ArrowEnumToDataType { + typedef uint8 Type; +}; + +#define MATCH_TYPE_AND_ARROW_ENUM(TYPE, ENUM) \ + template <> \ + struct DataTypeToArrowEnum { \ + static constexpr ::arrow::Type::type value = ENUM; \ + }; \ + template <> \ + struct ArrowEnumToDataType { \ + typedef TYPE Type; \ + } + +MATCH_TYPE_AND_ARROW_ENUM(int8, ::arrow::Type::INT8); +MATCH_TYPE_AND_ARROW_ENUM(uint8, ::arrow::Type::UINT8); +MATCH_TYPE_AND_ARROW_ENUM(int32, ::arrow::Type::INT32); +MATCH_TYPE_AND_ARROW_ENUM(uint32, ::arrow::Type::UINT32); +MATCH_TYPE_AND_ARROW_ENUM(int64, ::arrow::Type::INT64); +MATCH_TYPE_AND_ARROW_ENUM(uint64, ::arrow::Type::UINT64); +MATCH_TYPE_AND_ARROW_ENUM(Eigen::half, ::arrow::Type::HALF_FLOAT); +MATCH_TYPE_AND_ARROW_ENUM(float, ::arrow::Type::FLOAT); +MATCH_TYPE_AND_ARROW_ENUM(double, ::arrow::Type::DOUBLE); +MATCH_TYPE_AND_ARROW_ENUM(string, ::arrow::Type::STRING); + +Status MakeDataTypeAndRaggedRankFromArrowDataType( + const std::shared_ptr<::arrow::DataType>& arrow_dtype, DataType* dtype, + int32* ragged_rank); + +Status MakeTensorsFromArrowArray( + DataType type, int32 ragged_rank, + const std::shared_ptr<::arrow::Array>& arrow_array, + std::vector* output_tensors); + +} // namespace ArrowUtil +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_ARROW_UTIL_H_ diff --git a/tensorflow/core/kernels/data/eigen.h b/tensorflow/core/kernels/data/eigen.h new file mode 100644 index 00000000000..f84dc9f3937 --- /dev/null +++ b/tensorflow/core/kernels/data/eigen.h @@ -0,0 +1,32 @@ +/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_EIGEN_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_EIGEN_H_ + +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/public/version.h" +#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor" + +// NOTE: EIGEN_MAX_ALIGN_BYTES is 64 in TF 1.x. See: +// DeepRec/third_party/eigen.BUILD#L67 +#if EIGEN_MAX_ALIGN_BYTES == 0 +#define CHECK_EIGEN_ALIGN(...) (true) +#else +#define CHECK_EIGEN_ALIGN(...) \ + (0 == reinterpret_cast(__VA_ARGS__) % EIGEN_MAX_ALIGN_BYTES) +#endif + +#endif // TENSORFLOW_CORE_KERNELS_DATA_EIGEN_H_ diff --git a/tensorflow/core/kernels/data/parquet_batch_reader.cc b/tensorflow/core/kernels/data/parquet_batch_reader.cc new file mode 100644 index 00000000000..f0a56094415 --- /dev/null +++ b/tensorflow/core/kernels/data/parquet_batch_reader.cc @@ -0,0 +1,162 @@ +/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/parquet_batch_reader.h" + +#include +#include +#include +#include + +#include "absl/strings/match.h" +#include "tensorflow/core/kernels/data/arrow_util.h" + +namespace tensorflow { +namespace data { + +class ParquetBatchReader::Impl { + public: + Impl(const string& filename, const int64 batch_size, + const std::vector& field_names, + const DataTypeVector& field_dtypes, + const std::vector& field_ragged_ranks, + const int64 partition_count, const int64 partition_index, + const bool drop_remainder) + : filename_(filename), + batch_size_(batch_size), + field_names_(field_names), + field_dtypes_(field_dtypes), + field_ragged_ranks_(field_ragged_ranks), + partition_count_(partition_count), + partition_index_(partition_index), + drop_remainder_(drop_remainder) {} + + Status Open() { + if (TF_PREDICT_TRUE(batch_reader_)) { + return Status::OK(); + } + if (TF_PREDICT_FALSE(partition_index_ >= partition_count_)) { + return errors::InvalidArgument("Partition index ", partition_index_, + " must be smaller than partition count ", + partition_count_); + } + if (TF_PREDICT_FALSE(partition_index_ < 0)) { + return errors::InvalidArgument("Partition index ", partition_index_, + "must be greater than 0"); + } + + std::shared_ptr<::arrow::io::RandomAccessFile> file; + TF_RETURN_IF_ARROW_ERROR(ArrowUtil::OpenArrowFile(&file, filename_)); + TF_RETURN_IF_ARROW_ERROR(ArrowUtil::OpenParquetReader(&reader_, file)); + + int num_row_groups = reader_->num_row_groups(); + for (int g = partition_index_; g < num_row_groups; g += partition_count_) { + row_group_indices_.push_back(g); + } + std::shared_ptr<::arrow::Schema> schema; + TF_RETURN_IF_ARROW_ERROR(reader_->GetSchema(&schema)); + if (TF_PREDICT_FALSE(!schema->HasDistinctFieldNames())) { + return errors::InvalidArgument(filename_, + " must has distinct column names"); + } + for (size_t i = 0; i < field_names_.size(); ++i) { + auto& cname = field_names_[i]; + int column_index = schema->GetFieldIndex(cname); + if (TF_PREDICT_FALSE(column_index < 0)) { + return errors::NotFound("No column called `", cname, "` found in ", + filename_); + } + column_indices_.push_back(column_index); + const auto& expected_dtype = field_dtypes_[i]; + const auto& expected_ragged_rank = field_ragged_ranks_[i]; + DataType actual_dtype; + int32 actual_ragged_rank = 0; + TF_RETURN_IF_ERROR(ArrowUtil::MakeDataTypeAndRaggedRankFromArrowDataType( + schema->field(column_index)->type(), &actual_dtype, + &actual_ragged_rank)); + if (TF_PREDICT_FALSE(actual_dtype != expected_dtype)) { + return errors::InvalidArgument( + "Field ", cname, " in ", filename_, " has unexpected data type ", + DataTypeString(actual_dtype), ", which should be ", + DataTypeString(expected_dtype)); + } + if (TF_PREDICT_FALSE(actual_ragged_rank != expected_ragged_rank)) { + return errors::InvalidArgument( + "Field ", cname, " in ", filename_, " has unexpected ragged rank ", + actual_ragged_rank, ", which should be ", expected_ragged_rank); + } + } + reader_->set_batch_size(batch_size_); + + TF_RETURN_IF_ARROW_ERROR(reader_->GetRecordBatchReader( + row_group_indices_, column_indices_, &batch_reader_)); + return Status::OK(); + } + + Status Read(std::vector* output_tensors) { + // Read next batch from parquet file. + std::shared_ptr<::arrow::RecordBatch> batch; + TF_RETURN_IF_ARROW_ERROR(batch_reader_->ReadNext(&batch)); + if (TF_PREDICT_FALSE(!batch)) { + return errors::OutOfRange("Reached end of parquet file ", filename_); + } + if (TF_PREDICT_FALSE(drop_remainder_ && batch->num_rows() < batch_size_)) { + return errors::OutOfRange("Reached end of parquet file ", filename_, + " after dropping reminder batch"); + } + + // Populate tensors from record batch. + auto arrays = batch->columns(); + for (size_t i = 0; i < arrays.size(); ++i) { + TF_RETURN_IF_ERROR(ArrowUtil::MakeTensorsFromArrowArray( + field_dtypes_[i], field_ragged_ranks_[i], arrays[i], output_tensors)); + } + + return Status::OK(); + } + + private: + const string filename_; + const int64 batch_size_; + std::vector field_names_; + DataTypeVector field_dtypes_; + std::vector field_ragged_ranks_; + int64 partition_count_; + int64 partition_index_; + bool drop_remainder_; + std::unique_ptr<::parquet::arrow::FileReader> reader_; + std::unique_ptr<::arrow::RecordBatchReader> batch_reader_; + std::vector row_group_indices_; + std::vector column_indices_; +}; + +ParquetBatchReader::ParquetBatchReader( + const string& filename, const int64 batch_size, + const std::vector& field_names, const DataTypeVector& field_dtypes, + const std::vector& field_ragged_ranks, const int64 partition_count, + const int64 partition_index, const bool drop_remainder) + : pimpl_(new ParquetBatchReader::Impl( + filename, batch_size, field_names, field_dtypes, field_ragged_ranks, + partition_count, partition_index, drop_remainder)) {} + +Status ParquetBatchReader::Open() { return pimpl_->Open(); } + +Status ParquetBatchReader::Read(std::vector* output_tensors) { + return pimpl_->Read(output_tensors); +} + +ParquetBatchReader::~ParquetBatchReader() {} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parquet_batch_reader.h b/tensorflow/core/kernels/data/parquet_batch_reader.h new file mode 100644 index 00000000000..15523d3faa5 --- /dev/null +++ b/tensorflow/core/kernels/data/parquet_batch_reader.h @@ -0,0 +1,50 @@ +/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARQUET_BATCH_READER_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARQUET_BATCH_READER_H_ + +#include +#include + +#include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/types.h" + +namespace tensorflow { +namespace data { + +class ParquetBatchReader { + public: + ParquetBatchReader(const string& filename, const int64 batch_size, + const std::vector& field_names, + const DataTypeVector& field_dtypes, + const std::vector& field_ragged_ranks, + const int64 partition_count, const int64 partition_index, + const bool drop_remainder); + + Status Open(); + + Status Read(std::vector* output_tensors); + + virtual ~ParquetBatchReader(); + + private: + class Impl; + std::unique_ptr pimpl_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARQUET_BATCH_READER_H_ diff --git a/tensorflow/core/kernels/data/parquet_dataset_ops.cc b/tensorflow/core/kernels/data/parquet_dataset_ops.cc new file mode 100644 index 00000000000..7b6c977495a --- /dev/null +++ b/tensorflow/core/kernels/data/parquet_dataset_ops.cc @@ -0,0 +1,217 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/core/kernels/data/parquet_dataset_ops.h" + +#include + +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/core/framework/partial_tensor_shape.h" +#include "tensorflow/core/framework/shape_inference.h" +#include "tensorflow/core/framework/tensor.h" +#include "tensorflow/core/platform/file_system.h" +#include "tensorflow/core/lib/io/buffered_inputstream.h" +#include "tensorflow/core/lib/io/inputbuffer.h" + +namespace tensorflow { +namespace data { + +#define PARSE_SCALAR tensorflow::data::ParseScalarArgument + +class ParquetTabularDatasetOp::Dataset : public DatasetBase { + public: + Dataset(OpKernelContext* ctx, const string& filename, const int64 batch_size, + const std::vector& field_names, + const DataTypeVector& field_dtypes, + const std::vector& field_ragged_ranks, + const int64 partition_count, const int64 partition_index, + const bool drop_remainder) + : DatasetBase(DatasetContext(ctx)), + filename_(std::move(filename)), + batch_size_(batch_size), + field_names_(std::move(field_names)), + field_dtypes_(std::move(field_dtypes)), + field_ragged_ranks_(std::move(field_ragged_ranks)), + partition_count_(partition_count), + partition_index_(partition_index), + drop_remainder_(drop_remainder) { + int64 num_outputs = field_names.size(); + for (int64 i = 0; i < field_names.size(); ++i) { + output_dtypes_.push_back(std::move(field_dtypes[i])); + for (int64 j = 0; j < field_ragged_ranks_[i]; ++j) { + output_dtypes_.push_back(DT_INT32); + } + num_outputs += field_ragged_ranks_[i]; + } + int64 actual_batch_size(drop_remainder ? batch_size : -1); + for (size_t i = 0; i < num_outputs; ++i) { + output_shapes_.push_back(PartialTensorShape({actual_batch_size})); + } + + reader_ = absl::make_unique( + filename_, batch_size_, field_names_, field_dtypes_, + field_ragged_ranks_, partition_count_, partition_index_, + drop_remainder_); + } + + Status Open() { + VLOG(1) << "Starting to read " << filename_ << " ..."; + return reader_->Open(); + } + + std::unique_ptr MakeIteratorInternal( + const string& prefix) const override; + + const DataTypeVector& output_dtypes() const override { + return output_dtypes_; + } + + const std::vector& output_shapes() const override { + return output_shapes_; + } + + string DebugString() const override { + return "ParquetTabularDatasetOp::Dataset"; + } + + protected: + Status AsGraphDefInternal(SerializationContext* ctx, + DatasetGraphDefBuilder* b, + Node** output) const override { + Node* filename = nullptr; + TF_RETURN_IF_ERROR(b->AddScalar(filename_, &filename)); + Node* batch_size; + TF_RETURN_IF_ERROR(b->AddScalar(batch_size_, &batch_size)); + AttrValue field_names; + b->BuildAttrValue(field_names_, &field_names); + AttrValue field_dtypes; + b->BuildAttrValue(field_dtypes_, &field_dtypes); + AttrValue field_ragged_ranks; + b->BuildAttrValue(field_ragged_ranks_, &field_ragged_ranks); + AttrValue partition_count; + b->BuildAttrValue(partition_count_, &partition_count); + AttrValue partition_index; + b->BuildAttrValue(partition_index_, &partition_index); + AttrValue drop_remainder; + b->BuildAttrValue(drop_remainder_, &drop_remainder); + TF_RETURN_IF_ERROR( + b->AddDataset(this, {{0, filename}, {1, batch_size}}, {}, + {{"field_names", field_names}, + {"field_dtypes", field_dtypes}, + {"field_ragged_ranks", field_ragged_ranks}, + {"partition_count", partition_count}, + {"partition_index", partition_index}, + {"drop_remainder", drop_remainder}}, + output)); + return Status::OK(); + } + + private: + class Iterator; + const string filename_; + const int64 batch_size_; + const std::vector field_names_; + const DataTypeVector field_dtypes_; + const std::vector field_ragged_ranks_; + const int64 partition_count_; + const int64 partition_index_; + const bool drop_remainder_; + DataTypeVector output_dtypes_; + std::vector output_shapes_; + std::unique_ptr reader_; +}; + +class ParquetTabularDatasetOp::Dataset::Iterator + : public DatasetIterator { + public: + explicit Iterator(const Params& params) + : DatasetIterator(params) {} + + Status GetNextInternal(IteratorContext* ctx, std::vector* out_tensors, + bool* end_of_sequence) override { + mutex_lock l(mu_); + Status s = dataset()->reader_->Read(out_tensors); + + if (s.ok()) { + *end_of_sequence = false; + return s; + } + if (!errors::IsOutOfRange(s)) { + return s; + } + *end_of_sequence = true; + return Status::OK(); + } + + protected: + Status SaveInternal(IteratorStateWriter* writer) override { + return errors::Unimplemented("SaveInternal is currently not supported"); + } + + Status RestoreInternal(IteratorContext* ctx, + IteratorStateReader* reader) override { + return errors::Unimplemented("RestoreInternal is currently not supported"); + } + + private: + mutex mu_; +}; + +std::unique_ptr +ParquetTabularDatasetOp::Dataset::MakeIteratorInternal( + const string& prefix) const { + return std::unique_ptr( + new Iterator({this, strings::StrCat(prefix, "::ParquetTabular")})); +} + +ParquetTabularDatasetOp::ParquetTabularDatasetOp(OpKernelConstruction* ctx) + : DatasetOpKernel(ctx), + partition_count_(1), + partition_index_(0), + drop_remainder_(false) { + OP_REQUIRES_OK(ctx, ctx->GetAttr("field_names", &field_names_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("field_dtypes", &field_dtypes_)); + OP_REQUIRES_OK(ctx, + ctx->GetAttr("field_ragged_ranks", &field_ragged_ranks_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_count", &partition_count_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("partition_index", &partition_index_)); + OP_REQUIRES_OK(ctx, ctx->GetAttr("drop_remainder", &drop_remainder_)); +} + +void ParquetTabularDatasetOp::MakeDataset(OpKernelContext* ctx, + DatasetBase** output) { + string filename; + OP_REQUIRES_OK(ctx, PARSE_SCALAR(ctx, "filename", &filename)); + + int64 batch_size = 0; + OP_REQUIRES_OK(ctx, PARSE_SCALAR(ctx, "batch_size", &batch_size)); + OP_REQUIRES(ctx, batch_size > 0, + errors::InvalidArgument("batch_size must be greater than zero.")); + + Dataset* ds = new Dataset( + ctx, filename, batch_size, field_names_, field_dtypes_, + field_ragged_ranks_, partition_count_, partition_index_, drop_remainder_); + OP_REQUIRES_OK(ctx, ds->Open()); + *output = ds; +} + +REGISTER_KERNEL_BUILDER(Name("ParquetTabularDatasetV1").Device(DEVICE_CPU), + ParquetTabularDatasetOp); + +WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS("ParquetTabularDatasetV1"); + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/kernels/data/parquet_dataset_ops.h b/tensorflow/core/kernels/data/parquet_dataset_ops.h new file mode 100644 index 00000000000..5ababf6c0f5 --- /dev/null +++ b/tensorflow/core/kernels/data/parquet_dataset_ops.h @@ -0,0 +1,43 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_CORE_KERNELS_DATA_PARQUET_DATASET_OPS_H_ +#define TENSORFLOW_CORE_KERNELS_DATA_PARQUET_DATASET_OPS_H_ + +#include "tensorflow/core/framework/dataset.h" +#include "tensorflow/core/kernels/data/parquet_batch_reader.h" + +namespace tensorflow { +namespace data { + +class ParquetTabularDatasetOp : public DatasetOpKernel { + public: + explicit ParquetTabularDatasetOp(OpKernelConstruction* ctx); + + void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override; + + private: + class Dataset; + std::vector field_names_; + DataTypeVector field_dtypes_; + std::vector field_ragged_ranks_; + int64 partition_count_; + int64 partition_index_; + bool drop_remainder_; +}; + +} // namespace data +} // namespace tensorflow + +#endif // TENSORFLOW_CORE_KERNELS_DATA_PARQUET_DATASET_OPS_H_ diff --git a/tensorflow/core/kernels/data/parquet_pybind.cc b/tensorflow/core/kernels/data/parquet_pybind.cc new file mode 100644 index 00000000000..26688d07677 --- /dev/null +++ b/tensorflow/core/kernels/data/parquet_pybind.cc @@ -0,0 +1,69 @@ +/* Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "pybind11/complex.h" +#include "pybind11/pybind11.h" +#include "pybind11/stl.h" +#include "tensorflow/core/kernels/data/arrow_util.h" + +namespace tensorflow { +namespace data { + +namespace { + +std::string make_buildinfo() { + std::string message = "deeprec buildinfo"; + return message; +} + +std::string buildinfo() { + static std::string kBuildInfo = make_buildinfo(); + return kBuildInfo; +} + +typedef std::tuple parquet_file_field_t; + +std::vector parquet_file_get_fields( + const std::string& filename) { + std::vector field_names; + std::vector field_dtypes; + std::vector field_ragged_ranks; + auto s = ArrowUtil::GetParquetDataFrameFields( + &field_names, &field_dtypes, &field_ragged_ranks, filename); + std::vector fields; + if (!s.ok()) { + std::cerr << "parquet_file_get_fields failed: " << s.message() << std::endl; + return fields; + } + for (size_t i = 0; i < field_names.size(); ++i) { + fields.emplace_back(field_names[i], field_dtypes[i], field_ragged_ranks[i]); + } + return fields; +} + +} // namespace + +PYBIND11_MODULE(_parquet_pybind, m) { + m.def("buildinfo", &buildinfo, "Get building information."); + m.def("parquet_file_get_fields", &parquet_file_get_fields, + "Get fields of a parquet file."); +} + +} // namespace data +} // namespace tensorflow diff --git a/tensorflow/core/ops/parquet_ops.cc b/tensorflow/core/ops/parquet_ops.cc new file mode 100644 index 00000000000..05d5da8f263 --- /dev/null +++ b/tensorflow/core/ops/parquet_ops.cc @@ -0,0 +1,40 @@ +// Copyright 2017 The TensorFlow Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +#include "tensorflow/core/framework/common_shape_fns.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/shape_inference.h" + +namespace tensorflow { + +REGISTER_OP("ParquetTabularDatasetV1") + .Output("handle: variant") + .Input("filename: string") + .Input("batch_size: int64") + .Attr("field_names: list(string) >= 1") + .Attr("field_dtypes: list(type) >= 1") + .Attr("field_ragged_ranks: list(int) >= 1") + .Attr("partition_count: int = 1") + .Attr("partition_index: int = 0") + .Attr("drop_remainder: bool = false") + .SetIsStateful() // NOTE: Source dataset ops must be marked stateful to + // inhibit constant folding. + .SetShapeFn([](shape_inference::InferenceContext* c) { + shape_inference::ShapeHandle unused; + // batch_size should be a scalar. + TF_RETURN_IF_ERROR(c->WithRank(c->input(1), 0, &unused)); + return shape_inference::ScalarShape(c); + }); + +} // namespace tensorflow diff --git a/tensorflow/core/platform/s3/aws_logging.h b/tensorflow/core/platform/s3/aws_logging.h index b0da8f3c835..0688a35ccb6 100644 --- a/tensorflow/core/platform/s3/aws_logging.h +++ b/tensorflow/core/platform/s3/aws_logging.h @@ -43,6 +43,11 @@ class AWSLogSystem : public Aws::Utils::Logging::LogSystemInterface { log_level_.store(log_level); } + // Writes any buffered messages to the underlying device if the logger supports buffering. + void Flush() override { + LogMessage(Aws::Utils::Logging::LogLevel::Error, "AWSLogSystem::Flush() not supported"); + } + // Does a printf style output to ProcessFormattedStatement. Don't use this, // it's unsafe. See LogStream. // Since non-static C++ methods have an implicit this argument, diff --git a/tensorflow/python/BUILD b/tensorflow/python/BUILD index c34e69cf0f9..9fabf845128 100644 --- a/tensorflow/python/BUILD +++ b/tensorflow/python/BUILD @@ -2105,6 +2105,16 @@ tf_gen_op_wrapper_private_py( ] ) +tf_gen_op_wrapper_private_py( + name = "parquet_ops_gen", + visibility = [ + "//tensorflow:__subpackages__", + ], + deps = [ + "//tensorflow/core:parquet_ops_op_lib" + ] +) + tf_gen_op_wrapper_private_py( name = "image_ops_gen", visibility = ["//learning/brain/python/ops:__pkg__"], diff --git a/tensorflow/python/data/experimental/__init__.py b/tensorflow/python/data/experimental/__init__.py index 6035e7f001b..6e7894b9e2d 100644 --- a/tensorflow/python/data/experimental/__init__.py +++ b/tensorflow/python/data/experimental/__init__.py @@ -43,6 +43,7 @@ @@TensorArrayStructure @@TensorStructure @@ThreadingOptions +@@ParquetDataset @@bucket_by_sequence_length @@bytes_produced_stats @@ -128,6 +129,7 @@ from tensorflow.python.data.experimental.ops.threading_options import ThreadingOptions from tensorflow.python.data.experimental.ops.unique import unique from tensorflow.python.data.experimental.ops.writers import TFRecordWriter +from tensorflow.python.data.experimental.ops.parquet_dataset_ops import ParquetDataset from tensorflow.python.data.ops.dataset_ops import AUTOTUNE from tensorflow.python.data.ops.dataset_ops import DatasetSpec as DatasetStructure from tensorflow.python.data.ops.dataset_ops import from_variant diff --git a/tensorflow/python/data/experimental/kernel_tests/BUILD b/tensorflow/python/data/experimental/kernel_tests/BUILD index ef71774ec0e..215f4aebbce 100644 --- a/tensorflow/python/data/experimental/kernel_tests/BUILD +++ b/tensorflow/python/data/experimental/kernel_tests/BUILD @@ -99,6 +99,20 @@ py_test( ], ) +py_test( + name = "parquet_dataset_ops_test", + size = "medium", + srcs = ["parquet_dataset_ops_test.py"], + python_version = "PY2", + srcs_version = "PY2AND3", + tags = ["no_pip"], + deps = [ + "//tensorflow/python/data/experimental/ops:parquet_dataset_ops", + "//tensorflow/python/data/kernel_tests:test_base", + "//tensorflow:tensorflow_py", + ], +) + py_test( name = "dense_to_sparse_batch_test", srcs = ["dense_to_sparse_batch_test.py"], diff --git a/tensorflow/python/data/experimental/kernel_tests/parquet_dataset_ops_test.py b/tensorflow/python/data/experimental/kernel_tests/parquet_dataset_ops_test.py new file mode 100644 index 00000000000..e99495f7ddf --- /dev/null +++ b/tensorflow/python/data/experimental/kernel_tests/parquet_dataset_ops_test.py @@ -0,0 +1,197 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); you may not +# use this file except in compliance with the License. You may obtain a copy of +# the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations under +# the License. +# ============================================================================== +"""Tests for read_parquet and ParquetDataset.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +import pandas as pd +import os +from six.moves import xrange # pylint: disable=redefined-builtin +import tempfile + +import tensorflow as tf +from tensorflow.python.data.experimental.ops import parquet_dataset_ops +from tensorflow.python.data.kernel_tests import test_base +from tensorflow.python.platform import test +from tensorflow.python.data.ops.dataset_ops import AUTOTUNE + + +class ParquetDatasetTest(test_base.DatasetTestBase): + @classmethod + def setUpClass(self): + os.environ['CUDA_VISIBLE_DEVICES'] = '' + self._workspace = tempfile.mkdtemp() + self._filename = os.path.join(self._workspace, 'test.parquet') + self._df = pd.DataFrame( + np.random.randint(0, 100, size=(200, 4), dtype=np.int64), + columns=list('ABCd')) + self._df.to_parquet(self._filename) + + def test_read(self): + batch_size = 32 + with tf.Graph().as_default() as graph: + ds = parquet_dataset_ops.ParquetDataset( + self._filename, + batch_size=batch_size, + fields=[parquet_dataset_ops.DataFrame.Field('A', tf.int64), + parquet_dataset_ops.DataFrame.Field('C', tf.int64)]) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + a = self._df['A'] + c = self._df['C'] + with tf.Session(graph=graph) as sess: + for i in xrange(3): + result = sess.run(batch) + start_row = i * batch_size + end_row = (i + 1) * batch_size + np.testing.assert_equal(result['A'], a[start_row:end_row].to_numpy()) + np.testing.assert_equal(result['C'], c[start_row:end_row].to_numpy()) + + def test_schema_auto_detection_read(self): + batch_size = 32 + with tf.Graph().as_default() as graph: + ds = parquet_dataset_ops.ParquetDataset([self._filename], batch_size=batch_size) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + c = self._df['C'] + with tf.Session(graph=graph) as sess: + for i in xrange(3): + result = sess.run(batch) + start_row = i * batch_size + end_row = (i + 1) * batch_size + np.testing.assert_equal(result['C'], c[start_row:end_row].to_numpy()) + + def test_dtype_auto_detection_read(self): + batch_size = 32 + with tf.Graph().as_default() as graph: + ds = parquet_dataset_ops.ParquetDataset( + [self._filename], + batch_size=batch_size, + fields=['B', 'C']) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + c = self._df['C'] + with tf.Session(graph=graph) as sess: + for i in xrange(3): + result = sess.run(batch) + start_row = i * batch_size + end_row = (i + 1) * batch_size + np.testing.assert_equal(result['C'], c[start_row:end_row].to_numpy()) + + def test_dtype_auto_detection_read_lower(self): + batch_size = 32 + with tf.Graph().as_default() as graph: + actual_fields = parquet_dataset_ops.ParquetDataset.read_schema( + self._filename, ['B', 'D'], lower=True) + fld = actual_fields[1].name + ds = parquet_dataset_ops.ParquetDataset( + [self._filename], + batch_size=batch_size, + fields=actual_fields) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + c = self._df[fld] + with tf.Session(graph=graph) as sess: + for i in xrange(3): + result = sess.run(batch) + start_row = i * batch_size + end_row = (i + 1) * batch_size + np.testing.assert_equal(result[fld], c[start_row:end_row].to_numpy()) + + def test_read_from_generator(self): + num_epochs = 2 + batch_size = 100 + with tf.Graph().as_default() as graph: + def gen_filenames(): + for i in xrange(num_epochs + 1): + if i == num_epochs: + return # raise StopIteration + yield self._filename + filenames = tf.data.Dataset.from_generator( + gen_filenames, tf.string, tf.TensorShape([])) + fields = [ + parquet_dataset_ops.DataFrame.Field('A', tf.int64, 0), + parquet_dataset_ops.DataFrame.Field('C', tf.int64, 0)] + ds = filenames.apply(parquet_dataset_ops.read_parquet(batch_size, fields=fields)) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + with tf.Session(graph=graph) as sess: + for _ in xrange(len(self._df) * num_epochs // batch_size): + sess.run(batch) + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batch) + + def test_read_from_generator_parallel(self): + num_epochs = 2 + batch_size = 100 + with tf.Graph().as_default() as graph: + def gen_filenames(): + for i in xrange(num_epochs + 1): + if i == num_epochs: + return # raise StopIteration + yield self._filename + filenames = tf.data.Dataset.from_generator( + gen_filenames, tf.string, tf.TensorShape([])) + fields = [ + parquet_dataset_ops.DataFrame.Field('A', tf.int64, 0), + parquet_dataset_ops.DataFrame.Field('C', tf.int64, 0)] + ds = filenames.apply( + parquet_dataset_ops.read_parquet(batch_size, fields=fields, num_parallel_reads=3)) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + with tf.Session(graph=graph) as sess: + for _ in xrange(len(self._df) * num_epochs // batch_size): + sess.run(batch) + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batch) + + def test_read_from_generator_parallel_auto(self): + num_epochs = 2 + batch_size = 100 + with tf.Graph().as_default() as graph: + def gen_filenames(): + for i in xrange(num_epochs + 1): + if i == num_epochs: + return # raise StopIteration + yield self._filename + filenames = tf.data.Dataset.from_generator( + gen_filenames, tf.string, tf.TensorShape([])) + fields = [ + parquet_dataset_ops.DataFrame.Field('A', tf.int64, 0), + parquet_dataset_ops.DataFrame.Field('C', tf.int64, 0)] + ds = filenames.apply( + parquet_dataset_ops.read_parquet( + batch_size, fields=fields, num_parallel_reads=AUTOTUNE)) + ds = ds.prefetch(4) + batch = tf.data.make_one_shot_iterator(ds).get_next() + + with tf.Session(graph=graph) as sess: + for _ in xrange(len(self._df) * num_epochs // batch_size): + sess.run(batch) + with self.assertRaises(tf.errors.OutOfRangeError): + sess.run(batch) + + +if __name__ == "__main__": + test.main() diff --git a/tensorflow/python/data/experimental/ops/BUILD b/tensorflow/python/data/experimental/ops/BUILD index a581bc9fccc..84c8b9dc9d8 100644 --- a/tensorflow/python/data/experimental/ops/BUILD +++ b/tensorflow/python/data/experimental/ops/BUILD @@ -472,6 +472,7 @@ py_library( ":matching_files", ":optimization", ":prefetching_ops", + ":parquet_dataset_ops", ":readers", ":resampling", ":scan_ops", @@ -489,3 +490,41 @@ py_library( "//tensorflow/python/data/util:nest", ], ) + + +py_library( + name = "parquet_dataset_ops", + srcs = [ + "parquet_dataset_ops.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":parquet_pybind", + ":dataframe", + "//tensorflow/python:parquet_ops_gen", + "//tensorflow/python/data/ops:dataset_ops", + "//tensorflow/python/data/ops:readers", + "//tensorflow/python/data/util:nest", + ], +) + +py_library( + name = "parquet_pybind", + srcs = ["parquet_pybind.py"], + srcs_version = "PY2AND3", + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/python:framework", + "//tensorflow/core/kernels/data:_parquet_pybind", + ], +) + +py_library( + name = "dataframe", + srcs = ["dataframe.py"], + srcs_version = "PY2AND3", + deps = [ + "//tensorflow/python:framework", + "//tensorflow/python:ops", + ], +) diff --git a/tensorflow/python/data/experimental/ops/dataframe.py b/tensorflow/python/data/experimental/ops/dataframe.py new file mode 100644 index 00000000000..003f75259f1 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/dataframe.py @@ -0,0 +1,274 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Data frame releated classes. + +See https://pandas.pydata.org/pandas-docs/stable/user_guide/dsintro.html for +more information about data frame concept. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections +import numpy as np +from six.moves import xrange # pylint: disable=redefined-builtin + +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor +from tensorflow.python.framework import tensor_shape +from tensorflow.python.framework import tensor_spec +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import gen_ragged_conversion_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import sparse_ops +from tensorflow.python.ops.ragged import ragged_tensor + + +class DataFrame(object): # pylint: disable=useless-object-inheritance + """Tabular data to train in a deep recommender.""" + + class Field(object): # pylint: disable=useless-object-inheritance + """Definition of a field in a data frame.""" + + def __init__(self, name, dtype=None, ragged_rank=None, shape=None): + self._name = name + if dtype is None: + self._dtype = dtype + else: + try: + self._dtype = dtypes.as_dtype(dtype) + except TypeError: + if dtype == np.object_: + self._dtype = dtypes.as_dtype(np.object) + else: + raise + self._ragged_rank = ragged_rank + if shape: + shape = tensor_shape.TensorShape(shape) + for d in shape: + if d.value is None: + raise ValueError( + f'Field {name} has incomplete shape: {shape}') + if ragged_rank is not None and ragged_rank > 1: + raise ValueError( + f'Field {name} is a nested list ({ragged_rank}) ' + f'with shape {shape}') + self._shape = shape + + @property + def name(self): + return self._name + + @property + def incomplete(self): + return self.dtype is None or self.ragged_rank is None + + @property + def dtype(self): + return self._dtype + + @property + def ragged_rank(self): + return self._ragged_rank + + @property + def shape(self): + return self._shape + + def __repr__(self): + if self._dtype is None: + dtypestr = 'unkown' + else: + dtypestr = self._dtype.name + if self._ragged_rank is None: + dtypestr = f'unkown<{dtypestr}>' + else: + if self._ragged_rank > 1: + dtypestr = f'list^{self._ragged_rank}<{dtypestr}>' + elif self._ragged_rank > 0: + dtypestr = f'list<{dtypestr}>' + if self._shape is None: + shapestr = 'unknown' + else: + shapestr = str(self._shape) + return f'{self._name} (dtype={dtypestr}, shape={shapestr})' + + def map(self, func, rank=None): + if rank is None: + rank = self.ragged_rank + if self.incomplete: + raise ValueError( + f'Field {self} is incomplete, please specify dtype and ragged_rank') + if rank == 0: + return func(0) + return DataFrame.Value( + func(0), + [func(i + 1) for i in xrange(rank)]) + + @property + def ragged_indices(self): + return self.map(lambda i: i) + + @property + def output_classes(self): + return self.map(lambda _: ops.Tensor) + + @property + def output_types(self): + return self.map(lambda i: self._dtype if i == 0 else dtypes.int32) + + @property + def output_shapes(self): + if self._shape is None: + return self.map(lambda _: tensor_shape.vector(None)) + return self.map( + lambda i: tensor_shape.vector(None).concatenate(self._shape) if i == 0 + else tensor_shape.vector(None)) + + @property + def output_specs(self): + shape = tensor_shape.vector(None) + if self._shape is not None: + shape = shape.concatenate(self._shape) + specs = [tensor_spec.TensorSpec(shape, dtype=self._dtype)] + specs += [ + tensor_spec.TensorSpec([None], dtype=dtypes.int32) + for _ in xrange(self._ragged_rank)] + return specs + + # pylint: disable=inherit-non-class + class Value( + collections.namedtuple( + 'DataFrameValue', + ['values', 'nested_row_splits'])): + """A structure represents a value in DataFrame.""" + + def __new__(cls, values, nested_row_splits=None): + if nested_row_splits is None: + nested_row_splits = tuple() + else: + nested_row_splits = tuple(nested_row_splits) + return super(DataFrame.Value, cls).__new__( + cls, values, nested_row_splits) + + def __repr__(self): + return f'{{{self.values}, splits={self.nested_row_splits}}}' + + def to_sparse(self, name=None): + if len(self.nested_row_splits) == 0: + return self.values + if len(self.nested_row_splits) == 1 and self.values.shape.ndims > 1: + return self.values + sparse_value = gen_ragged_conversion_ops.ragged_tensor_to_sparse( + self.nested_row_splits, self.values, name=name) + return sparse_tensor.SparseTensor( + sparse_value.sparse_indices, + sparse_value.sparse_values, + sparse_value.sparse_dense_shape) + + @classmethod + def to_sparse(cls, features): + """Convert DataFrame values to tensors or sparse tensors.""" + if isinstance(features, dict): + return {f: cls.to_sparse(features[f]) for f in features} + if isinstance(features, DataFrame.Value): + return features.to_sparse() + if isinstance(features, ragged_tensor.RaggedTensor): + if features.ragged_rank >= 1: + features = features.to_sparse() + return features + if isinstance(features, ops.Tensor): + return features + raise ValueError(f'{features} not supported') + + @classmethod + def unbatch_and_to_sparse(cls, features): + """Unbatch and convert a row of DataFrame to tensors or sparse tensors.""" + if isinstance(features, dict): + return { + f: cls.unbatch_and_to_sparse(features[f]) + for f in features} + if isinstance(features, DataFrame.Value): + if len(features.nested_row_splits) > 1: + features = features.to_sparse() + features = sparse_ops.sparse_reshape( + features, features.dense_shape[1:]) + elif len(features.nested_row_splits) == 1: + num_elems = math_ops.cast( + features.nested_row_splits[0][1], dtype=dtypes.int64) + indices = math_ops.range(num_elems) + indices = array_ops.reshape(indices, [-1, 1]) + features = sparse_tensor.SparseTensor( + indices, features.values, [-1]) + else: + features = features.values + return features + if isinstance(features, ragged_tensor.RaggedTensor): + if features.ragged_rank > 1: + features = features.to_sparse() + features = sparse_ops.sparse_reshape( + features, features.dense_shape[1:]) + elif features.ragged_rank == 1: + actual_batch_size = math_ops.cast( + features.row_splits[1], dtype=dtypes.int64) + indices = math_ops.range(actual_batch_size) + indices = array_ops.reshape(indices, [-1, 1]) + features = sparse_tensor.SparseTensor( + indices, features.values, [-1]) + return features + if isinstance(features, ops.Tensor): + return features + raise ValueError(f'{features} not supported for transformation') + + +def to_sparse(num_parallel_calls=None): + """Convert values to tensors or sparse tensors from input dataset.""" + def _apply_fn(dataset): + return dataset.map( + DataFrame.to_sparse, + num_parallel_calls=num_parallel_calls) + return _apply_fn + + +def unbatch_and_to_sparse(num_parallel_calls=None): + """Unbatch and convert a row to tensors or sparse tensors from input dataset.""" + def _apply_fn(dataset): + return dataset.map( + DataFrame.unbatch_and_to_sparse, + num_parallel_calls=num_parallel_calls) + return _apply_fn + + +def input_fields(input_dataset, fields=None): + """Fetch and validate fields from input dataset.""" + if fields is None: + ds = input_dataset + while ds: + if hasattr(ds, 'fields'): + fields = ds.fields + break + if not hasattr(ds, '_input_dataset'): + break + ds = ds._input_dataset # pylint: disable=protected-access + if not fields: + raise ValueError('`fields` must be specified') + if not isinstance(fields, (tuple, list)): + raise ValueError('`fields` must be a list of `hb.data.DataFrame.Field`.') + for f in fields: + if not isinstance(f, DataFrame.Field): + raise ValueError(f'{f} must be `hb.data.DataFrame.Field`.') + return fields diff --git a/tensorflow/python/data/experimental/ops/parquet_dataset_ops.py b/tensorflow/python/data/experimental/ops/parquet_dataset_ops.py new file mode 100644 index 00000000000..86e2b9b4ec7 --- /dev/null +++ b/tensorflow/python/data/experimental/ops/parquet_dataset_ops.py @@ -0,0 +1,281 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Dataset that reads Parquet files. This class is compatible with TensorFlow 1.15.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.data.ops import readers +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_spec +from tensorflow.python.framework import type_spec +from tensorflow.python.util import nest + +from tensorflow.python.ops import gen_parquet_ops +from tensorflow.python.data.experimental.ops.parquet_pybind import parquet_fields +from tensorflow.python.data.experimental.ops.parquet_pybind import parquet_filenames_and_fields +from tensorflow.python.data.experimental.ops.dataframe import DataFrame + + +class DataFrameValueSpec(type_spec.BatchableTypeSpec): + """A TypeSpec for reading batch of DataFrame.Value from dataset.""" + + def value_type(self): + return DataFrame.Value if self._ragged_rank > 0 else ops.Tensor + + def __init__(self, field): + """Constructs a type specification for a `tf.RaggedTensor`. + + Args: + field: The field definition. + """ + if field.incomplete: + raise ValueError( + f'Field {field} is incomplete, please specify dtype and ragged_rank') + self._field = field + + def _serialize(self): + return (self._field.dtype, self._field.ragged_rank) + + @property + def _component_specs(self): + return self._field.output_specs + + def _to_components(self, value): + if isinstance(value, DataFrame.Value): + return [value.values] + list(value.nested_row_splits) + return [value] + + def _from_components(self, tensor_list): + if len(tensor_list) < 1: + return None + if len(tensor_list) == 1: + return tensor_list[0] + return DataFrame.Value(tensor_list[0], tensor_list[1:]) + + def _batch(self, batch_size): + raise NotImplementedError('batching of a bacthed tensor not supported') + + def _unbatch(self): + raise NotImplementedError('unbatching of a bacthed tensor not supported') + + def _to_legacy_output_types(self): + return self._field.output_types + + def _to_legacy_output_shapes(self): + return self._field.output_shapes + + def _to_legacy_output_classes(self): + return self._field.output_classes + + +class _ParquetDataset(dataset_ops.DatasetSource): # pylint: disable=abstract-method + """A Parquet Dataset that reads batches from parquet files.""" + + def __init__( + self, filename, batch_size, fields, + partition_count=1, + partition_index=0, + drop_remainder=False): + """Create a `ParquetDataset`. + + Args: + filename: A 0-D `tf.string` tensor containing one filename. + batch_size: Maxium number of samples in an output batch. + fields: List of DataFrame fields. + partition_count: (Optional.) Count of row group partitions. + partition_index: (Optional.) Index of row group partitions. + drop_remainder: (Optional.) If True, only keep batches with exactly + `batch_size` samples. + """ + self._filename = ops.convert_to_tensor( + filename, dtype=dtypes.string, name='filename') + self._batch_size = ops.convert_to_tensor( + batch_size, dtype=dtypes.int64, name='batch_size') + self._fields = fields + self._output_specs = { + f.name: ( + DataFrameValueSpec(f) + if f.ragged_rank > 0 + else tensor_spec.TensorSpec(shape=[None], dtype=f.dtype)) + for f in self._fields} + self._field_names = nest.flatten({f.name: f.name for f in self._fields}) + self._field_dtypes = nest.flatten({f.name: f.dtype for f in self._fields}) + self._field_ragged_ranks = nest.flatten( + {f.name: f.ragged_rank for f in self._fields}) + self._partition_count = partition_count + self._partition_index = partition_index + self._drop_remainder = drop_remainder + + variant_tensor = gen_parquet_ops.parquet_tabular_dataset_v1( + self._filename, + self._batch_size, + field_names=self._field_names, + field_dtypes=self._field_dtypes, + field_ragged_ranks=self._field_ragged_ranks, + partition_count=self._partition_count, + partition_index=self._partition_index, + drop_remainder=self._drop_remainder) + super().__init__(variant_tensor) + + @property + def element_spec(self): + return self._output_specs + + +class ParquetDataset(dataset_ops.DatasetV2): # pylint: disable=abstract-method + """A Parquet Dataset that reads batches from parquet files.""" + + VERSION = 2002 + + @classmethod + def read_schema(cls, filename, fields=None, lower=False): + """Read schema from a parquet file. + + Args: + filename: Path of the parquet file. + fields: Existing field definitions or field names. + lower: Convert field name to lower case if not found. + + Returns: + Field definition list. + """ + return parquet_fields(filename, fields, lower=lower) + + def __init__( + self, filenames, + batch_size=1, + fields=None, + partition_count=1, + partition_index=0, + drop_remainder=False, + num_parallel_reads=None, + num_sequential_reads=1): + """Create a `ParquetDataset`. + + Args: + filenames: A 0-D or 1-D `tf.string` tensor containing one or more + filenames. + batch_size: (Optional.) Maxium number of samples in an output batch. + fields: (Optional.) List of DataFrame fields. + partition_count: (Optional.) Count of row group partitions. + partition_index: (Optional.) Index of row group partitions. + drop_remainder: (Optional.) If True, only keep batches with exactly + `batch_size` samples. + num_parallel_reads: (Optional.) A `tf.int64` scalar representing the + number of files to read in parallel. Defaults to reading files + sequentially. + num_sequential_reads: (Optional.) A `tf.int64` scalar representing the + number of batches to read in sequential. Defaults to 1. + """ + filenames, self._fields = parquet_filenames_and_fields(filenames, fields) + self._partition_count = partition_count + self._partition_index = partition_index + self._drop_remainder = drop_remainder + + def _create_dataset(f): + f = ops.convert_to_tensor(f, dtypes.string, name='filename') + return _ParquetDataset( # pylint: disable=abstract-class-instantiated + f, batch_size, + fields=self._fields, + partition_count=self._partition_count, + partition_index=self._partition_index, + drop_remainder=self._drop_remainder) + self._impl = self._build_dataset( + _create_dataset, filenames, + num_parallel_reads=num_parallel_reads, + num_sequential_reads=num_sequential_reads) + super().__init__(self._impl._variant_tensor) # pylint: disable=protected-access + + @property + def fields(self): + return self._fields + + @property + def partition_count(self): + return self._partition_count + + @property + def partition_index(self): + return self._partition_index + + @property + def drop_remainder(self): + return self._drop_remainder + + def _inputs(self): + return self._impl._inputs() # pylint: disable=protected-access + + @property + def element_spec(self): + return self._impl.element_spec # pylint: disable=protected-access + + def _build_dataset( + self, dataset_creator, filenames, + num_parallel_reads=None, + num_sequential_reads=1): + """Internal method to create a `ParquetDataset`.""" + if num_parallel_reads is None: + return filenames.flat_map(dataset_creator) + if num_parallel_reads == dataset_ops.AUTOTUNE: + return filenames.interleave( + dataset_creator, num_parallel_calls=num_parallel_reads) + return readers.ParallelInterleaveDataset( + filenames, dataset_creator, + cycle_length=num_parallel_reads, + block_length=num_sequential_reads, + sloppy=True, + buffer_output_elements=None, + prefetch_input_elements=1) + + +def read_parquet( + batch_size, + fields=None, + partition_count=1, + partition_index=0, + drop_remainder=False, + num_parallel_reads=None, + num_sequential_reads=1): + """Create a `ParquetDataset` from filenames dataset. + + Args: + batch_size: Maxium number of samples in an output batch. + fields: (Optional.) List of DataFrame fields. + partition_count: (Optional.) Count of row group partitions. + partition_index: (Optional.) Index of row group partitions. + drop_remainder: (Optional.) If True, only keep batches with exactly + `batch_size` samples. + num_parallel_reads: (Optional.) A `tf.int64` scalar representing the + number of files to read in parallel. Defaults to reading files + sequentially. + num_sequential_reads: (Optional.) A `tf.int64` scalar representing the + number of batches to read in sequential. Defaults to 1. + """ + def _apply_fn(filenames): + return ParquetDataset( + filenames, + batch_size=batch_size, + fields=fields, + partition_count=partition_count, + partition_index=partition_index, + drop_remainder=drop_remainder, + num_parallel_reads=num_parallel_reads, + num_sequential_reads=num_sequential_reads) + + return _apply_fn diff --git a/tensorflow/python/data/experimental/ops/parquet_pybind.py b/tensorflow/python/data/experimental/ops/parquet_pybind.py new file mode 100644 index 00000000000..b9e4e9c223b --- /dev/null +++ b/tensorflow/python/data/experimental/ops/parquet_pybind.py @@ -0,0 +1,168 @@ +# Copyright 2021 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================= +"""Parquet related utilities.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np +from six import string_types as string + +from tensorflow.python.data.ops import dataset_ops +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import tensor_shape +from tensorflow.python.ops import array_ops +from tensorflow.python.platform import tf_logging as logging +from tensorflow.python.data.experimental.ops.dataframe import DataFrame +from tensorflow.core.kernels.data import _parquet_pybind as _lib + + +def parquet_fields(filename, fields=None, lower=False): + """Get fields from a parquet file. + + Args: + filename: Path of the parquet file. + fields: Existing field definitions or field names. + lower: Convert field name to lower case if not found. + + Returns: + Field definitions. + """ + logging.info(f'Reading fields from {filename} ...') + all_field_tuples = _lib.parquet_file_get_fields(filename) # pylint: disable=c-extension-no-member + if not all_field_tuples: + raise ValueError( + f'No supported fields found in parquet file {filename}') + all_fields = { + f[0]: {'dtype': f[1], 'ragged_rank': f[2]} + for f in all_field_tuples} + if fields is None: + fields = all_fields.keys() + fields = tuple(fields) + new_fields = [] + for f in fields: + if isinstance(f, DataFrame.Field): + if lower and f.name not in all_fields: + f = DataFrame.Field( + f.name.lower(), + dtype=f.dtype, + shape=f.shape, + ragged_rank=f.ragged_rank) + if f.name not in all_fields: + raise ValueError( + f'Field {f.name} is not found in the parquet file {filename}') + dtype = f.dtype + actual_dtype = np.dtype(all_fields[f.name]['dtype']) + if dtype is None: + dtype = actual_dtype + elif dtype != actual_dtype: + raise ValueError( + f'Field {f.name} should has dtype {actual_dtype} not {dtype}') + ragged_rank = f.ragged_rank + actual_ragged_rank = all_fields[f.name]['ragged_rank'] + if ragged_rank is None: + ragged_rank = actual_ragged_rank + elif ragged_rank != actual_ragged_rank: + raise ValueError( + f'Field {f.name} should has ragged_rank {actual_ragged_rank} ' + f'not {ragged_rank}') + f = DataFrame.Field( + f.name, + dtype=dtype, + ragged_rank=ragged_rank, + shape=f.shape) + new_fields.append(f) + continue + if not isinstance(f, string): + raise ValueError( + f'Field {f} is not a DataFrame.Field or a string') + if lower and f not in all_fields: + f = f.lower() + if f not in all_fields: + raise ValueError( + f'Field {f} is not found in the parquet file {filename}') + new_fields.append(DataFrame.Field( + f, + dtype=np.dtype(all_fields[f]['dtype']), + ragged_rank=all_fields[f]['ragged_rank'], + shape=None)) + return tuple(new_fields) + + +def parquet_filenames_and_fields(filenames, fields, lower=False): + """Check and fetch parquet filenames and fields. + + Args: + filenames: List of Path of parquet file list. + fields: Existing field definitions or field names. + lower: Convert field name to lower case if not found. + + Returns: + Validated file names and fields. + """ + if isinstance(filenames, string): + filenames = [filenames] + fields = parquet_fields(filenames[0], fields, lower=lower) + elif isinstance(filenames, (tuple, list)): + for f in filenames: + if not isinstance(f, string): + raise ValueError(f'{f} in `filenames` must be a string') + fields = parquet_fields(filenames[0], fields, lower=lower) + elif isinstance(filenames, dataset_ops.Dataset): + if filenames.output_types != dtypes.string: + raise TypeError( + '`filenames` must be a `tf.data.Dataset` of `tf.string` elements.') + if not filenames.output_shapes.is_compatible_with( + tensor_shape.TensorShape([])): + raise ValueError( + '`filenames` must be a `tf.data.Dataset` of scalar `tf.string` ' + 'elements.') + if fields is None: + raise ValueError('`fields` must be specified.') + if not isinstance(fields, (tuple, list)): + raise ValueError('`fields` must be a list of `hb.data.DataFrame.Field`.') + for f in fields: + if not isinstance(f, DataFrame.Field): + raise ValueError(f'Field {f} must be `hb.data.DataFrame.Field`.') + if f.incomplete: + raise ValueError( + f'Field {f} is incomplete, please specify dtype and ragged_rank') + elif isinstance(filenames, ops.Tensor): + if filenames.dtype != dtypes.string: + raise TypeError( + '`filenames` must be a `tf.Tensor` of `tf.string`.') + if fields is None: + raise ValueError('`fields` must be specified.') + if not isinstance(fields, (tuple, list)): + raise ValueError('`fields` must be a list of `hb.data.DataFrame.Field`.') + for f in fields: + if not isinstance(f, DataFrame.Field): + raise ValueError(f'Field {f} must be `hb.data.DataFrame.Field`.') + if f.incomplete: + raise ValueError( + f'Field {f} is incomplete, please specify dtype and ragged_rank') + else: + raise ValueError( + f'`filenames` {filenames} must be a `tf.data.Dataset` of scalar ' + '`tf.string` elements or can be converted to a `tf.Tensor` of ' + '`tf.string`.') + + if not isinstance(filenames, dataset_ops.Dataset): + filenames = ops.convert_to_tensor(filenames, dtype=dtypes.string) + filenames = array_ops.reshape(filenames, [-1], name='filenames') + filenames = dataset_ops.Dataset.from_tensor_slices(filenames) + return filenames, fields diff --git a/tensorflow/tools/api/golden/v1/tensorflow.-storage-option.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.-storage-option.pbtxt index 8cf8fd31068..2bde270c091 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.-storage-option.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.-storage-option.pbtxt @@ -4,6 +4,6 @@ tf_class { is_instance: "" member_method { name: "__init__" - argspec: "args=[\'self\', \'storage_type\', \'storage_path\', \'storage_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'None\'], " + argspec: "args=[\'self\', \'storage_type\', \'storage_path\', \'storage_size\'], varargs=None, keywords=None, defaults=[\'None\', \'None\', \'[1073741824, 1073741824, 1073741824, 1073741824]\'], " } } diff --git a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt index dfea3c91f45..94f5c7bfe0f 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.pbtxt @@ -1164,26 +1164,6 @@ tf_module { name: "custom_gradient" argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "tensor_buffer_cancel" - argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_close" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_put" - argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " - } - member_method { - name: "tensor_buffer_size" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_take" - argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " - } member_method { name: "decode_base64" argspec: "args=[\'input\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1738,7 +1718,7 @@ tf_module { } member_method { name: "initialize_kv_variable_op" - argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "initialize_local_variables" @@ -1842,7 +1822,7 @@ tf_module { } member_method { name: "kv_resource_gather_v1" - argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'is_use_default_value_tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " } member_method { name: "kv_resource_import" @@ -1850,7 +1830,7 @@ tf_module { } member_method { name: "kv_resource_import_v2" - argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "kv_resource_incr_import" @@ -2140,6 +2120,10 @@ tf_module { name: "parallel_stack" argspec: "args=[\'values\', \'name\'], varargs=None, keywords=None, defaults=[\'parallel_stack\'], " } + member_method { + name: "parquet_tabular_dataset_v1" + argspec: "args=[\'filename\', \'batch_size\', \'field_names\', \'field_dtypes\', \'field_ragged_ranks\', \'partition_count\', \'partition_index\', \'drop_remainder\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'False\', \'None\'], " + } member_method { name: "parse_example" argspec: "args=[\'serialized\', \'features\', \'name\', \'example_names\'], varargs=None, keywords=None, defaults=[\'None\', \'None\'], " @@ -2884,6 +2868,26 @@ tf_module { name: "tensible_variable_scatter_update" argspec: "args=[\'resource\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "tensor_buffer_cancel" + argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_close" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_put" + argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " + } + member_method { + name: "tensor_buffer_size" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_take" + argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " + } member_method { name: "tensor_scatter_add" argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index aef3daa8785..8a9c63f6a08 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -936,26 +936,6 @@ tf_module { name: "CumulativeLogsumexp" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } - member_method { - name: "TensorBufferCancel" - argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferClose" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferPut" - argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " - } - member_method { - name: "TensorBufferSize" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferTake" - argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " - } member_method { name: "DataFormatDimMap" argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], " @@ -1950,7 +1930,7 @@ tf_module { } member_method { name: "InitializeKvVariableOp" - argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "InitializeTable" @@ -2078,7 +2058,7 @@ tf_module { } member_method { name: "KvResourceGatherV1" - argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'is_use_default_value_tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " } member_method { name: "KvResourceImport" @@ -2086,7 +2066,7 @@ tf_module { } member_method { name: "KvResourceImportV2" - argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "KvResourceIncrImport" @@ -2848,6 +2828,10 @@ tf_module { name: "ParameterizedTruncatedNormal" argspec: "args=[\'shape\', \'means\', \'stdevs\', \'minvals\', \'maxvals\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " } + member_method { + name: "ParquetTabularDatasetV1" + argspec: "args=[\'filename\', \'batch_size\', \'field_names\', \'field_dtypes\', \'field_ragged_ranks\', \'partition_count\', \'partition_index\', \'drop_remainder\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'False\', \'None\'], " + } member_method { name: "ParseExample" argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'dense_defaults\', \'sparse_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4928,6 +4912,26 @@ tf_module { name: "TensorArrayWriteV3" argspec: "args=[\'handle\', \'index\', \'value\', \'flow_in\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "TensorBufferCancel" + argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferClose" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferPut" + argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " + } + member_method { + name: "TensorBufferSize" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferTake" + argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " + } member_method { name: "TensorDataset" argspec: "args=[\'components\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt index 98a1271ad0c..7fb7651d951 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.pbtxt @@ -624,26 +624,6 @@ tf_module { name: "custom_gradient" argspec: "args=[\'f\'], varargs=None, keywords=None, defaults=None" } - member_method { - name: "tensor_buffer_cancel" - argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_close" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_put" - argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " - } - member_method { - name: "tensor_buffer_size" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "tensor_buffer_take" - argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " - } member_method { name: "decode_dense" argspec: "args=[\'values\', \'dtype\', \'name\'], varargs=None, keywords=None, defaults=[\"\", \'None\'], " @@ -914,7 +894,7 @@ tf_module { } member_method { name: "initialize_kv_variable_op" - argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "io_kafka_dataset" @@ -978,7 +958,7 @@ tf_module { } member_method { name: "kv_resource_gather_v1" - argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'is_use_default_value_tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " } member_method { name: "kv_resource_import" @@ -986,7 +966,7 @@ tf_module { } member_method { name: "kv_resource_import_v2" - argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "kv_resource_incr_import" @@ -1156,6 +1136,10 @@ tf_module { name: "parallel_stack" argspec: "args=[\'values\', \'name\'], varargs=None, keywords=None, defaults=[\'parallel_stack\'], " } + member_method { + name: "parquet_tabular_dataset_v1" + argspec: "args=[\'filename\', \'batch_size\', \'field_names\', \'field_dtypes\', \'field_ragged_ranks\', \'partition_count\', \'partition_index\', \'drop_remainder\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'False\', \'None\'], " + } member_method { name: "pow" argspec: "args=[\'x\', \'y\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -1528,6 +1512,26 @@ tf_module { name: "tensible_variable_scatter_update" argspec: "args=[\'resource\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "tensor_buffer_cancel" + argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_close" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_put" + argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " + } + member_method { + name: "tensor_buffer_size" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "tensor_buffer_take" + argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " + } member_method { name: "tensor_scatter_nd_add" argspec: "args=[\'tensor\', \'indices\', \'updates\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index aef3daa8785..8a9c63f6a08 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -936,26 +936,6 @@ tf_module { name: "CumulativeLogsumexp" argspec: "args=[\'x\', \'axis\', \'exclusive\', \'reverse\', \'name\'], varargs=None, keywords=None, defaults=[\'False\', \'False\', \'None\'], " } - member_method { - name: "TensorBufferCancel" - argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferClose" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferPut" - argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " - } - member_method { - name: "TensorBufferSize" - argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " - } - member_method { - name: "TensorBufferTake" - argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " - } member_method { name: "DataFormatDimMap" argspec: "args=[\'x\', \'src_format\', \'dst_format\', \'name\'], varargs=None, keywords=None, defaults=[\'NHWC\', \'NCHW\', \'None\'], " @@ -1950,7 +1930,7 @@ tf_module { } member_method { name: "InitializeKvVariableOp" - argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'resource_self\', \'resource_primary\', \'value\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'initial_num_buckets\', \'max_load_factor\', \'steps_to_live\', \'ht_type\', \'emb_index\', \'block_num\', \'slot_index\', \'ht_partition_num\', \'filter_freq\', \'max_freq\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'131072\', \'0.8\', \'0\', \'\', \'0\', \'1\', \'0\', \'1000\', \'0\', \'999999\', \'0\', \'-1\', \'-1\', \'normal\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "InitializeTable" @@ -2078,7 +2058,7 @@ tf_module { } member_method { name: "KvResourceGatherV1" - argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'None\'], " + argspec: "args=[\'resource\', \'indices\', \'default_value\', \'counts\', \'validate_indices\', \'is_use_default_value_tensor\', \'name\'], varargs=None, keywords=None, defaults=[\'True\', \'False\', \'None\'], " } member_method { name: "KvResourceImport" @@ -2086,7 +2066,7 @@ tf_module { } member_method { name: "KvResourceImportV2" - argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'0\', \'4096\', \'None\'], " + argspec: "args=[\'prefix\', \'resource_self\', \'resource_primary\', \'value\', \'tensor_names\', \'empty_key\', \'shape\', \'counter_type\', \'slot_num\', \'emb_index\', \'slot_index\', \'block_num\', \'steps_to_live\', \'partition_id\', \'partition_num\', \'ht_type\', \'filter_freq\', \'ht_partition_num\', \'max_element_size\', \'false_positive_probability\', \'l2_weight_threshold\', \'layout\', \'max_freq\', \'storage_type\', \'storage_path\', \'storage_size\', \'default_value_dim\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'0\', \'1\', \'0\', \'0\', \'1\', \'\', \'0\', \'1000\', \'0\', \'-1\', \'-1\', \'normal\', \'999999\', \'1\', \'.\', \'[]\', \'4096\', \'None\'], " } member_method { name: "KvResourceIncrImport" @@ -2848,6 +2828,10 @@ tf_module { name: "ParameterizedTruncatedNormal" argspec: "args=[\'shape\', \'means\', \'stdevs\', \'minvals\', \'maxvals\', \'seed\', \'seed2\', \'name\'], varargs=None, keywords=None, defaults=[\'0\', \'0\', \'None\'], " } + member_method { + name: "ParquetTabularDatasetV1" + argspec: "args=[\'filename\', \'batch_size\', \'field_names\', \'field_dtypes\', \'field_ragged_ranks\', \'partition_count\', \'partition_index\', \'drop_remainder\', \'name\'], varargs=None, keywords=None, defaults=[\'1\', \'0\', \'False\', \'None\'], " + } member_method { name: "ParseExample" argspec: "args=[\'serialized\', \'names\', \'sparse_keys\', \'dense_keys\', \'dense_defaults\', \'sparse_types\', \'dense_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " @@ -4928,6 +4912,26 @@ tf_module { name: "TensorArrayWriteV3" argspec: "args=[\'handle\', \'index\', \'value\', \'flow_in\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " } + member_method { + name: "TensorBufferCancel" + argspec: "args=[\'container\', \'is_cancelled\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'True\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferClose" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferPut" + argspec: "args=[\'record\', \'container\', \'shared_name\', \'shared_capacity\', \'timeout_millis\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1000\', \'None\'], " + } + member_method { + name: "TensorBufferSize" + argspec: "args=[\'container\', \'shared_name\', \'shared_capacity\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'None\'], " + } + member_method { + name: "TensorBufferTake" + argspec: "args=[\'dtypes\', \'container\', \'shared_name\', \'shared_capacity\', \'shared_threads\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'1\', \'1\', \'None\'], " + } member_method { name: "TensorDataset" argspec: "args=[\'components\', \'output_shapes\', \'name\'], varargs=None, keywords=None, defaults=[\'None\'], " diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 75a8a5c3703..0f66731e0d8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -679,15 +679,15 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ], ) - tf_http_archive( + # Note: snappy is placed earlier as tensorflow's snappy does not include snappy-c + http_archive( name = "snappy", - build_file = clean_dep("//third_party:snappy.BUILD"), - sha256 = "3dfa02e873ff51a11ee02b9ca391807f0c8ea0529a4924afa645fbf97163f9d4", - strip_prefix = "snappy-1.1.7", - system_build_file = clean_dep("//third_party/systemlibs:snappy.BUILD"), + build_file = "//third_party:snappy.BUILD", + sha256 = "16b677f07832a612b0836178db7f374e414f94657c138e6993cbfc5dcc58651f", + strip_prefix = "snappy-1.1.8", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/snappy/archive/1.1.7.tar.gz", - "https://github.com/google/snappy/archive/1.1.7.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/snappy/archive/1.1.8.tar.gz", + "https://github.com/google/snappy/archive/1.1.8.tar.gz", ], ) @@ -1077,6 +1077,142 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ], ) + http_archive( + name = "arrow", + build_file = "//third_party:arrow.BUILD", + patch_cmds = [ + # TODO: Remove the fowllowing once arrow issue is resolved. + """sed -i.bak 's/type_traits/std::max(sizeof(int16_t), type_traits/g' cpp/src/parquet/column_reader.cc""", + """sed -i.bak 's/value_byte_size/value_byte_size)/g' cpp/src/parquet/column_reader.cc""", + ], + sha256 = "a27971e2a71c412ae43d998b7b6d06201c7a3da382c804dcdc4a8126ccbabe67", + strip_prefix = "arrow-apache-arrow-4.0.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/arrow/archive/apache-arrow-4.0.0.tar.gz", + "https://github.com/apache/arrow/archive/apache-arrow-4.0.0.tar.gz", + ], + ) + + http_archive( + name = "brotli", + build_file = "//third_party:brotli.BUILD", + sha256 = "4c61bfb0faca87219ea587326c467b95acb25555b53d1a421ffa3c8a9296ee2c", + strip_prefix = "brotli-1.0.7", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/google/brotli/archive/v1.0.7.tar.gz", + "https://github.com/google/brotli/archive/v1.0.7.tar.gz", + ], + ) + + http_archive( + name = "bzip2", + build_file = "//third_party:bzip2.BUILD", + sha256 = "ab5a03176ee106d3f0fa90e381da478ddae405918153cca248e682cd0c4a2269", + strip_prefix = "bzip2-1.0.8", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz", + "https://sourceware.org/pub/bzip2/bzip2-1.0.8.tar.gz", + ], + ) + + http_archive( + name = "thrift", + build_file = "//third_party:thrift.BUILD", + sha256 = "5da60088e60984f4f0801deeea628d193c33cec621e78c8a43a5d8c4055f7ad9", + strip_prefix = "thrift-0.13.0", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/apache/thrift/archive/v0.13.0.tar.gz", + "https://github.com/apache/thrift/archive/v0.13.0.tar.gz", + ], + ) + + http_archive( + name = "xsimd", + build_file = "//third_party:xsimd.BUILD", + sha256 = "45337317c7f238fe0d64bb5d5418d264a427efc53400ddf8e6a964b6bcb31ce9", + strip_prefix = "xsimd-7.5.0", + urls = [ + "https://github.com/xtensor-stack/xsimd/archive/refs/tags/7.5.0.tar.gz", + ], + ) + + http_archive( + name = "zstd", + build_file = "//third_party:zstd.BUILD", + sha256 = "a364f5162c7d1a455cc915e8e3cf5f4bd8b75d09bc0f53965b0c9ca1383c52c8", + strip_prefix = "zstd-1.4.4", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/facebook/zstd/archive/v1.4.4.tar.gz", + "https://github.com/facebook/zstd/archive/v1.4.4.tar.gz", + ], + ) + + http_archive( + name = "rapidjson", + build_file = "//third_party:rapidjson.BUILD", + sha256 = "30bd2c428216e50400d493b38ca33a25efb1dd65f79dfc614ab0c957a3ac2c28", + strip_prefix = "rapidjson-418331e99f859f00bdc8306f69eba67e8693c55e", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/miloyip/rapidjson/archive/418331e99f859f00bdc8306f69eba67e8693c55e.tar.gz", + "https://github.com/miloyip/rapidjson/archive/418331e99f859f00bdc8306f69eba67e8693c55e.tar.gz", + ], + ) + + http_archive( + name = "aws_c_common", + build_file = "//third_party/aws_util:aws_c_common.BUILD", + sha256 = "e9462a141b5db30006704f537d19b92357a59be38d590272e6118976b0356ccd", + strip_prefix = "aws-c-common-0.7.4", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/awslabs/aws-c-common/archive/refs/tags/v0.7.4.tar.gz", + "https://github.com/awslabs/aws-c-common/archive/refs/tags/v0.7.4.tar.gz", + ], + ) + + http_archive( + name = "aws_c_io", + build_file = "//third_party/aws_util:aws_c_io.BUILD", + sha256 = "b60270d23b6e2f4a5d80e64ca6538ba114cd6044b53752964c940f87e59bf0d9", + strip_prefix = "aws-c-io-0.11.2", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/awslabs/aws-c-io/archive/refs/tags/v0.11.2.tar.gz", + "https://github.com/awslabs/aws-c-io/archive/refs/tags/v0.11.2.tar.gz", + ], + ) + + http_archive( + name = "aws_c_event_stream", + build_file = "//third_party/aws_util:aws_c_event_stream.BUILD", + sha256 = "bae0c762b6a4b779a0db0f4730512da6cb500e76681ffdcb9f7286d8e26e547a", + strip_prefix = "aws-c-event-stream-0.2.6", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/awslabs/aws-c-event-stream/archive/refs/tags/v0.2.6.tar.gz", + "https://github.com/awslabs/aws-c-event-stream/archive/refs/tags/v0.2.6.tar.gz", + ], + ) + + http_archive( + name = "aws_checksums", + build_file = "//third_party/aws_util:aws_checksums.BUILD", + sha256 = "394723034b81cc7cd528401775bc7aca2b12c7471c92350c80a0e2fb9d2909fe", + strip_prefix = "aws-checksums-0.1.12", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/awslabs/aws-checksums/archive/refs/tags/v0.1.12.tar.gz", + "https://github.com/awslabs/aws-checksums/archive/refs/tags/v0.1.12.tar.gz", + ], + ) + + http_archive( + name = "aws_c_cal", + build_file = "//third_party/aws_util:aws_c_cal.BUILD", + sha256 = "40297da04443d4ee2988d1c5fb0dc4a156d0e4cfaf80e6a1df1867452566d540", + strip_prefix = "aws-c-cal-0.5.17", + urls = [ + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/awslabs/aws-c-cal/archive/refs/tags/v0.5.17.tar.gz", + "https://github.com/awslabs/aws-c-cal/archive/refs/tags/v0.5.17.tar.gz", + ], + ) + def tf_bind(): """Bind targets for some external repositories""" ############################################################################## diff --git a/third_party/arrow.BUILD b/third_party/arrow.BUILD new file mode 100644 index 00000000000..f81e2cfda0e --- /dev/null +++ b/third_party/arrow.BUILD @@ -0,0 +1,142 @@ +# Description: +# Apache Arrow library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE.txt"]) + +load("@flatbuffers//:build_defs.bzl", "flatbuffer_cc_library") + +flatbuffer_cc_library( + name = "arrow_format", + srcs = [ + "cpp/src/arrow/ipc/feather.fbs", + "format/File.fbs", + "format/Message.fbs", + "format/Schema.fbs", + "format/SparseTensor.fbs", + "format/Tensor.fbs", + ], + flatc_args = [ + "--scoped-enums", + "--gen-object-api", + ], + out_prefix = "cpp/src/generated/", +) + +genrule( + name = "arrow_util_config", + srcs = ["cpp/src/arrow/util/config.h.cmake"], + outs = ["cpp/src/arrow/util/config.h"], + cmd = ("sed " + + "-e 's/@ARROW_VERSION_MAJOR@/3/g' " + + "-e 's/@ARROW_VERSION_MINOR@/0/g' " + + "-e 's/@ARROW_VERSION_PATCH@/0/g' " + + "-e 's/cmakedefine ARROW_USE_NATIVE_INT128/undef ARROW_USE_NATIVE_INT128/g' " + + "-e 's/cmakedefine/define/g' " + + "$< >$@"), +) + +genrule( + name = "parquet_version_h", + srcs = ["cpp/src/parquet/parquet_version.h.in"], + outs = ["cpp/src/parquet/parquet_version.h"], + cmd = ("sed " + + "-e 's/@PARQUET_VERSION_MAJOR@/1/g' " + + "-e 's/@PARQUET_VERSION_MINOR@/5/g' " + + "-e 's/@PARQUET_VERSION_PATCH@/1/g' " + + "$< >$@"), +) + +cc_library( + name = "arrow", + srcs = glob( + [ + "cpp/src/arrow/*.cc", + "cpp/src/arrow/array/*.cc", + "cpp/src/arrow/compute/*.cc", + "cpp/src/arrow/compute/**/*.h", + "cpp/src/arrow/compute/**/*.cc", + "cpp/src/arrow/csv/*.cc", + "cpp/src/arrow/io/*.cc", + "cpp/src/arrow/ipc/*.cc", + "cpp/src/arrow/json/*.cc", + "cpp/src/arrow/tensor/*.cc", + "cpp/src/arrow/util/*.cc", + "cpp/src/arrow/vendored/base64.cpp", + "cpp/src/arrow/vendored/musl/strptime.c", + "cpp/src/arrow/vendored/optional.hpp", + "cpp/src/arrow/vendored/string_view.hpp", + "cpp/src/arrow/vendored/variant.hpp", + "cpp/src/arrow/filesystem/*.cc", + "cpp/src/arrow/vendored/uriparser/*.c", + "cpp/src/arrow/**/*.h", + "cpp/src/parquet/**/*.h", + "cpp/src/parquet/**/*.cc", + "cpp/src/generated/*.h", + "cpp/src/generated/*.cpp", + ], + exclude = [ + "cpp/src/**/*_benchmark.cc", + "cpp/src/**/*_main.cc", + "cpp/src/**/*_nossl.cc", + "cpp/src/**/*_test.cc", + "cpp/src/**/test_*.cc", + "cpp/src/**/*fuzz*.cc", + "cpp/src/**/file_to_stream.cc", + "cpp/src/**/stream_to_file.cc", + "cpp/src/arrow/util/bpacking_avx2.cc", + "cpp/src/arrow/util/bpacking_avx512.cc", + ], + ), + hdrs = [ + # declare header from above genrule + "cpp/src/arrow/util/config.h", + "cpp/src/parquet/parquet_version.h", + ], + copts = [], + defines = [ + "ARROW_WITH_BROTLI", + "ARROW_WITH_SNAPPY", + "ARROW_WITH_LZ4", + "ARROW_WITH_ZLIB", + "ARROW_WITH_ZSTD", + "ARROW_WITH_BZ2", + "ARROW_STATIC", + "ARROW_HDFS=ON", + "ARROW_S3=ON", + "ARROW_EXPORT=", + "PARQUET_STATIC", + "PARQUET_EXPORT=", + "WIN32_LEAN_AND_MEAN", + "DARROW_FILESYSTEM" + ], + includes = [ + "cpp/src", + "cpp/src/arrow/vendored/xxhash", + "cpp/thirdparty/flatbuffers/include", + ], + textual_hdrs = [ + "cpp/src/arrow/vendored/xxhash/xxhash.c", + ], + deps = [ + ":arrow_format", + "@boringssl//:crypto", + "@aws", + "@brotli", + "@bzip2", + "@double_conversion//:double-conversion", + "@lz4", + "@rapidjson", + "@snappy", + "@thrift", + "@xsimd", + "@zlib_archive//:zlib", + "@zstd", + "@boost//:multiprecision", + "@org_tensorflow//third_party/hadoop:hdfs", + ], + alwayslink = 1, +) diff --git a/third_party/aws/BUILD.bazel b/third_party/aws/BUILD.bazel index 36f7ca2fd3f..d3145d595b2 100644 --- a/third_party/aws/BUILD.bazel +++ b/third_party/aws/BUILD.bazel @@ -52,10 +52,18 @@ cc_library( "aws-cpp-sdk-core/source/utils/xml/**/*.cpp", "aws-cpp-sdk-core/source/utils/crypto/*.cpp", "aws-cpp-sdk-core/source/utils/crypto/factory/**/*.cpp", + "aws-cpp-sdk-core/source/utils/event/*.cpp", + "aws-cpp-sdk-core/source/monitoring/*.cpp", + "aws-cpp-sdk-core/source/net/*.cpp", "aws-cpp-sdk-kinesis/include/**/*.h", "aws-cpp-sdk-kinesis/source/**/*.cpp", "aws-cpp-sdk-s3/include/**/*.h", "aws-cpp-sdk-s3/source/**/*.cpp", + "aws-cpp-sdk-sts/include/**/*.h", + "aws-cpp-sdk-sts/source/**/*.cpp", + "aws-cpp-sdk-identity-management/include/**/STSAssumeRoleCredentialsProvider.h", + "aws-cpp-sdk-identity-management/include/**/IdentityManagment_EXPORTS.h", + "aws-cpp-sdk-identity-management/source/**/STSAssumeRoleCredentialsProvider.cpp", ]), hdrs = [ "aws-cpp-sdk-core/include/aws/core/SDKConfig.h", @@ -92,9 +100,12 @@ cc_library( "aws-cpp-sdk-core/include/", "aws-cpp-sdk-kinesis/include/", "aws-cpp-sdk-s3/include/", + "aws-cpp-sdk-identity-management/include/", + "aws-cpp-sdk-sts/include/", ], deps = [ "@curl", + "@aws_c_event_stream", ], ) diff --git a/third_party/aws/workspace.bzl b/third_party/aws/workspace.bzl index f37699e34c5..1e917e35300 100644 --- a/third_party/aws/workspace.bzl +++ b/third_party/aws/workspace.bzl @@ -9,10 +9,10 @@ def repo(): third_party_http_archive( name = "aws", urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/github.com/aws/aws-sdk-cpp/archive/1.5.8.tar.gz", - "https://github.com/aws/aws-sdk-cpp/archive/1.5.8.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/github.com/aws/aws-sdk-cpp/archive/1.8.0.tar.gz", + "https://github.com/aws/aws-sdk-cpp/archive/1.8.0.tar.gz", ], - sha256 = "89905075fe50aa13e0337ff905c2e8c1ce9caf77a3504484a7cda39179120ffc", - strip_prefix = "aws-sdk-cpp-1.5.8", + sha256 = "2a69fb2d1a5effe2f053adafcf820535dc6d04bf37e2501cc8c0a8243b8c1f09", + strip_prefix = "aws-sdk-cpp-1.8.0", build_file = "//third_party/aws:BUILD.bazel", ) diff --git a/third_party/aws_util/BUILD b/third_party/aws_util/BUILD new file mode 100644 index 00000000000..2f5d02becb9 --- /dev/null +++ b/third_party/aws_util/BUILD @@ -0,0 +1 @@ +# Dummy BUILD file to make this directory a package. diff --git a/third_party/aws_util/aws_c_cal.BUILD b/third_party/aws_util/aws_c_cal.BUILD new file mode 100644 index 00000000000..ddb9ea6a761 --- /dev/null +++ b/third_party/aws_util/aws_c_cal.BUILD @@ -0,0 +1,23 @@ +# Description: +# AWS C CAL + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "aws_c_cal", + srcs = glob([ + "include/aws/cal/**/*.h", + "source/*.c", + "source/unix/*.c", + ]), + includes = [ + "include/", + ], + deps = [ + "@aws_c_common", + ], +) diff --git a/third_party/aws_util/aws_c_common.BUILD b/third_party/aws_util/aws_c_common.BUILD new file mode 100644 index 00000000000..dc28f3a639b --- /dev/null +++ b/third_party/aws_util/aws_c_common.BUILD @@ -0,0 +1,40 @@ +# Description: +# AWS C COMMON + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +load("@org_tensorflow//third_party:common.bzl", "template_rule") + +cc_library( + name = "aws_c_common", + srcs = glob([ + "include/aws/common/**/*.h", + "include/aws/common/**/*.inl", + "source/*.c", + "source/posix/*.c", + "source/arch/intel/cpuid.c", + "source/arch/intel/asm/cpuid.c", + ]), + hdrs = [ + "include/aws/common/config.h", + ], + includes = [ + "include/", + ], + defines = [ + "AWS_AFFINITY_METHOD=AWS_AFFINITY_METHOD_PTHREAD_ATTR", + ], +) + +template_rule( + name = "COMMONConfig_h", + src = "include/aws/common/config.h.in", + out = "include/aws/common/config.h", + substitutions = { + "cmakedefine": "define", + }, +) diff --git a/third_party/aws_util/aws_c_event_stream.BUILD b/third_party/aws_util/aws_c_event_stream.BUILD new file mode 100644 index 00000000000..49fb10619f0 --- /dev/null +++ b/third_party/aws_util/aws_c_event_stream.BUILD @@ -0,0 +1,24 @@ +# Description: +# AWS C EVENT STREAM + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "aws_c_event_stream", + srcs = glob([ + "include/aws/event-stream/*.h", + "source/event_stream.c", + ]), + includes = [ + "include/", + ], + deps = [ + "@aws_c_common", + "@aws_checksums", + "@aws_c_io", + ], +) diff --git a/third_party/aws_util/aws_c_io.BUILD b/third_party/aws_util/aws_c_io.BUILD new file mode 100644 index 00000000000..88f82ad80dd --- /dev/null +++ b/third_party/aws_util/aws_c_io.BUILD @@ -0,0 +1,31 @@ +# Description: +# AWS C IO + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "aws_c_io", + srcs = glob([ + "include/aws/io/**/*.h", + "source/*.c", + "source/pkcs11/v2.40/*.h", + "source/pkcs11_private.h", + "source/posix/*.c", + "source/linux/*.c", + ]), + includes = [ + "include/", + "source/", + ], + deps = [ + "@aws_c_common", + "@aws_c_cal" + ], + defines = [ + "BYO_CRYPTO", + ], +) diff --git a/third_party/aws_util/aws_checksums.BUILD b/third_party/aws_util/aws_checksums.BUILD new file mode 100644 index 00000000000..b332a3c5ad4 --- /dev/null +++ b/third_party/aws_util/aws_checksums.BUILD @@ -0,0 +1,23 @@ +# Description: +# AWS CHECKSUMS + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "aws_checksums", + srcs = glob([ + "include/aws/checksums/**/*.h", + "source/*.c", + "source/intel/asm/*.c", + ]), + includes = [ + "include/", + ], + deps = [ + "@aws_c_common", + ], +) diff --git a/third_party/brotli.BUILD b/third_party/brotli.BUILD new file mode 100644 index 00000000000..0e8c87100d0 --- /dev/null +++ b/third_party/brotli.BUILD @@ -0,0 +1,29 @@ +# Description: +# Brotli library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT license + +exports_files(["LICENSE"]) + +cc_library( + name = "brotli", + srcs = glob([ + "c/common/*.c", + "c/common/*.h", + "c/dec/*.c", + "c/dec/*.h", + "c/enc/*.c", + "c/enc/*.h", + "c/include/brotli/*.h", + ]), + hdrs = [], + defines = [], + includes = [ + "c/dec", + "c/include", + ], + linkopts = [], + visibility = ["//visibility:public"], +) diff --git a/third_party/bzip2.BUILD b/third_party/bzip2.BUILD new file mode 100644 index 00000000000..698378845fd --- /dev/null +++ b/third_party/bzip2.BUILD @@ -0,0 +1,26 @@ +# Description: +# Bzip2 library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # BSD-like license + +cc_library( + name = "bzip2", + srcs = [ + "blocksort.c", + "bzlib.c", + "bzlib_private.h", + "compress.c", + "crctable.c", + "decompress.c", + "huffman.c", + "randtable.c", + ], + hdrs = [ + "bzlib.h", + ], + copts = [ + ], + includes = ["."], +) diff --git a/third_party/hadoop/BUILD b/third_party/hadoop/BUILD index 563136afe3a..d75f1be8fbc 100644 --- a/third_party/hadoop/BUILD +++ b/third_party/hadoop/BUILD @@ -8,4 +8,5 @@ exports_files(["LICENSE.txt"]) cc_library( name = "hdfs", hdrs = ["hdfs.h"], + includes = ["."], ) diff --git a/third_party/rapidjson.BUILD b/third_party/rapidjson.BUILD new file mode 100644 index 00000000000..b5139169580 --- /dev/null +++ b/third_party/rapidjson.BUILD @@ -0,0 +1,17 @@ +# Description: +# Rapidjson library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # MIT/JSON license + +cc_library( + name = "rapidjson", + srcs = glob([ + "include/**/*.h", + ]), + copts = [], + includes = [ + "include", + ], +) diff --git a/third_party/snappy.BUILD b/third_party/snappy.BUILD index d93f0307690..4ea8f990344 100644 --- a/third_party/snappy.BUILD +++ b/third_party/snappy.BUILD @@ -18,6 +18,7 @@ cc_library( "snappy-stubs-public.h", ], hdrs = ["snappy.h"], + includes = ["."], copts = ["-DHAVE_CONFIG_H"] + select({ "@org_tensorflow//tensorflow:windows": [], "//conditions:default": [ diff --git a/third_party/thrift.BUILD b/third_party/thrift.BUILD new file mode 100644 index 00000000000..43005c4a081 --- /dev/null +++ b/third_party/thrift.BUILD @@ -0,0 +1,67 @@ +# Description: +# Apache Thrift library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # Apache 2.0 + +exports_files(["LICENSE"]) + +cc_library( + name = "thrift", + srcs = glob([ + "lib/cpp/src/thrift/**/*.h", + ]) + [ + "lib/cpp/src/thrift/protocol/TProtocol.cpp", + "lib/cpp/src/thrift/transport/TBufferTransports.cpp", + "lib/cpp/src/thrift/transport/TTransportException.cpp", + ], + hdrs = [ + "compiler/cpp/src/thrift/version.h", + "lib/cpp/src/thrift/config.h", + ], + includes = [ + "lib/cpp/src", + ], + textual_hdrs = [ + "lib/cpp/src/thrift/protocol/TBinaryProtocol.tcc", + "lib/cpp/src/thrift/protocol/TCompactProtocol.tcc", + ], + deps = [ + "@boost//:asio", + "@boost//:filesystem", + "@boost//:fusion", + "@boost//:lockfree", + "@boost//:program_options", + "@boost//:system", + "@boost//:thread", + "@boost//:variant", + ], +) + +genrule( + name = "version_h", + srcs = [ + "compiler/cpp/src/thrift/version.h.in", + ], + outs = [ + "compiler/cpp/src/thrift/version.h", + ], + cmd = "sed 's/@PACKAGE_VERSION@/0.12.0/g' $< > $@", +) + +genrule( + name = "config_h", + srcs = ["build/cmake/config.h.in"], + outs = ["lib/cpp/src/thrift/config.h"], + cmd = ("sed " + + "-e 's/cmakedefine/define/g' " + + "-e 's/$${PACKAGE}/thrift/g' " + + "-e 's/$${PACKAGE_BUGREPORT}//g' " + + "-e 's/$${PACKAGE_NAME}/thrift/g' " + + "-e 's/$${PACKAGE_TARNAME}/thrift/g' " + + "-e 's/$${PACKAGE_URL}//g' " + + "-e 's/$${PACKAGE_VERSION}/0.12.0/g' " + + "-e 's/$${PACKAGE_STRING}/thrift 0.12.0/g' " + + "$< >$@"), +) diff --git a/third_party/xsimd.BUILD b/third_party/xsimd.BUILD new file mode 100644 index 00000000000..7d49efd24d4 --- /dev/null +++ b/third_party/xsimd.BUILD @@ -0,0 +1,34 @@ +# Description: +# Xsimd library + +package(default_visibility = ["//visibility:public"]) + +licenses(["notice"]) # BSD 3-Clause + +exports_files(["LICENSE"]) + +cc_library( + name = "xsimd", + srcs = [], + hdrs = glob( + [ + "include/xsimd/*.hpp", + "include/xsimd/config/*.hpp", + "include/xsimd/math/*.hpp", + "include/xsimd/memory/*.hpp", + "include/xsimd/stl/*.hpp", + "include/xsimd/types/*.hpp", + ], + exclude = [ + ], + ), + copts = [], + defines = [], + includes = [ + "include", + ], + linkopts = [], + visibility = ["//visibility:public"], + deps = [ + ], +) diff --git a/third_party/zstd.BUILD b/third_party/zstd.BUILD new file mode 100644 index 00000000000..66b14b1b742 --- /dev/null +++ b/third_party/zstd.BUILD @@ -0,0 +1,40 @@ +# Description: +# Zstandard library + +licenses(["notice"]) # BSD license + +exports_files(["LICENSE"]) + +cc_library( + name = "zstd", + srcs = glob( + [ + "lib/common/*.h", + "lib/common/*.c", + "lib/compress/*.c", + "lib/compress/*.h", + "lib/decompress/*.c", + "lib/decompress/*.h", + ], + exclude = [ + "lib/common/xxhash.c", + ], + ), + hdrs = [ + "lib/zstd.h", + ], + defines = [ + "XXH_PRIVATE_API", + "ZSTDLIB_VISIBILITY=", + "ZSTDERRORLIB_VISIBILITY=", + ], + includes = [ + "lib", + "lib/common", + ], + linkopts = [], + textual_hdrs = [ + "lib/common/xxhash.c", + ], + visibility = ["//visibility:public"], +)