Skip to content

Commit

Permalink
Show file tree
Hide file tree
Showing 9 changed files with 87 additions and 35 deletions.
27 changes: 16 additions & 11 deletions cpp/src/arrow/util/rle_encoding.h
Expand Up @@ -141,8 +141,8 @@ class RleDecoder {
/// Number of bits needed to encode the value. Must be between 0 and 64.
int bit_width_;
uint64_t current_value_;
uint32_t repeat_count_;
uint32_t literal_count_;
int32_t repeat_count_;
int32_t literal_count_;

private:
/// Fills literal_count_ and repeat_count_ with next values. Returns false if there
Expand Down Expand Up @@ -302,14 +302,14 @@ inline int RleDecoder::GetBatch(T* values, int batch_size) {
int remaining = batch_size - values_read;

if (repeat_count_ > 0) {
int repeat_batch = std::min(remaining, static_cast<int>(repeat_count_));
int repeat_batch = std::min(remaining, repeat_count_);
std::fill(out, out + repeat_batch, static_cast<T>(current_value_));

repeat_count_ -= repeat_batch;
values_read += repeat_batch;
out += repeat_batch;
} else if (literal_count_ > 0) {
int literal_batch = std::min(remaining, static_cast<int>(literal_count_));
int literal_batch = std::min(remaining, literal_count_);
int actual_read = bit_reader_.GetBatch(bit_width_, out, literal_batch);
if (actual_read != literal_batch) {
return values_read;
Expand Down Expand Up @@ -364,8 +364,8 @@ inline int RleDecoder::GetBatchSpaced(int batch_size, int null_count,
out += repeat_batch;
values_read += repeat_batch;
} else if (literal_count_ > 0) {
int literal_batch = std::min(batch_size - values_read - remaining_nulls,
static_cast<int>(literal_count_));
int literal_batch =
std::min(batch_size - values_read - remaining_nulls, literal_count_);

// Decode the literals
constexpr int kBufferSize = 1024;
Expand Down Expand Up @@ -427,7 +427,7 @@ inline int RleDecoder::GetBatchWithDict(const T* dictionary, int32_t dictionary_
}
T val = dictionary[idx];

int repeat_batch = std::min(remaining, static_cast<int>(repeat_count_));
int repeat_batch = std::min(remaining, repeat_count_);
std::fill(out, out + repeat_batch, val);

/* Upkeep counters */
Expand All @@ -438,7 +438,7 @@ inline int RleDecoder::GetBatchWithDict(const T* dictionary, int32_t dictionary_
constexpr int kBufferSize = 1024;
int indices[kBufferSize];

int literal_batch = std::min(remaining, static_cast<int>(literal_count_));
int literal_batch = std::min(remaining, literal_count_);
literal_batch = std::min(literal_batch, kBufferSize);

int actual_read = bit_reader_.GetBatch(bit_width_, indices, literal_batch);
Expand Down Expand Up @@ -511,8 +511,8 @@ inline int RleDecoder::GetBatchWithDictSpaced(const T* dictionary,
out += repeat_batch;
values_read += repeat_batch;
} else if (literal_count_ > 0) {
int literal_batch = std::min(batch_size - values_read - remaining_nulls,
static_cast<int>(literal_count_));
int literal_batch =
std::min(batch_size - values_read - remaining_nulls, literal_count_);

// Decode the literals
constexpr int kBufferSize = 1024;
Expand Down Expand Up @@ -572,9 +572,14 @@ bool RleDecoder::NextCounts() {
bool is_literal = indicator_value & 1;
uint32_t count = indicator_value >> 1;
if (is_literal) {
if (count > UINT32_MAX / 8) return false;
if (ARROW_PREDICT_FALSE(count > static_cast<uint32_t>(INT32_MAX) / 8)) {
return false;
}
literal_count_ = count * 8;
} else {
if (ARROW_PREDICT_FALSE(count > static_cast<uint32_t>(INT32_MAX))) {
return false;
}
repeat_count_ = count;
// XXX (ARROW-4018) this is not big-endian compatible
if (!bit_reader_.GetAligned<T>(static_cast<int>(BitUtil::CeilDiv(bit_width_, 8)),
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/parquet/arrow/reader.cc
Expand Up @@ -704,6 +704,9 @@ Status GetReader(const SchemaField& field, const std::shared_ptr<ReaderContext>&

auto type_id = field.field->type()->id();
if (field.children.size() == 0) {
if (!field.is_leaf()) {
return Status::Invalid("Parquet non-leaf node has no children");
}
std::unique_ptr<FileColumnIterator> input(
ctx->iterator_factory(field.column_index, ctx->reader));
out->reset(new LeafReader(ctx, field.field, std::move(input)));
Expand Down
3 changes: 1 addition & 2 deletions cpp/src/parquet/arrow/reader_internal.cc
Expand Up @@ -94,8 +94,7 @@ using ArrayType = typename ::arrow::TypeTraits<ArrowType>::ArrayType;
static Status MakeArrowDecimal(const LogicalType& logical_type,
std::shared_ptr<DataType>* out) {
const auto& decimal = checked_cast<const DecimalLogicalType&>(logical_type);
*out = ::arrow::decimal(decimal.precision(), decimal.scale());
return Status::OK();
return ::arrow::Decimal128Type::Make(decimal.precision(), decimal.scale(), out);
}

static Status MakeArrowInt(const LogicalType& logical_type,
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/parquet/column_reader.cc
Expand Up @@ -559,6 +559,10 @@ class ColumnReaderImplBase {
const uint8_t* buffer = page.data() + levels_byte_size;
const int64_t data_size = page.size() - levels_byte_size;

if (data_size < 0) {
throw ParquetException("Page smaller than size of encoded levels");
}

Encoding::type encoding = page.encoding();

if (IsDictionaryIndexEncoding(encoding)) {
Expand Down
61 changes: 46 additions & 15 deletions cpp/src/parquet/encoding.cc
Expand Up @@ -1149,15 +1149,19 @@ template <>
inline int DecodePlain<ByteArray>(const uint8_t* data, int64_t data_size, int num_values,
int type_length, ByteArray* out) {
int bytes_decoded = 0;
int increment;
for (int i = 0; i < num_values; ++i) {
uint32_t len = out[i].len = arrow::util::SafeLoadAs<uint32_t>(data);
increment = static_cast<int>(sizeof(uint32_t) + len);
if (data_size < increment) ParquetException::EofException();
const uint32_t len = out[i].len = arrow::util::SafeLoadAs<uint32_t>(data);
const int64_t increment = static_cast<int64_t>(sizeof(uint32_t) + len);
if (ARROW_PREDICT_FALSE(data_size < increment)) {
ParquetException::EofException();
}
out[i].ptr = data + sizeof(uint32_t);
data += increment;
data_size -= increment;
bytes_decoded += increment;
if (ARROW_PREDICT_FALSE(increment > INT_MAX - bytes_decoded)) {
throw ParquetException("BYTE_ARRAY chunk too large");
}
bytes_decoded += static_cast<int>(increment);
}
return bytes_decoded;
}
Expand Down Expand Up @@ -1468,7 +1472,9 @@ class PlainByteArrayDecoder : public PlainDecoder<ByteArrayType>,
}

auto increment = int32_s + value_len;
if (ARROW_PREDICT_FALSE(len_ < increment)) ParquetException::EofException();
if (ARROW_PREDICT_FALSE(len_ < increment)) {
ParquetException::EofException();
}
if (ARROW_PREDICT_FALSE(!helper.CanFit(value_len))) {
// This element would exceed the capacity of a chunk
RETURN_NOT_OK(helper.PushChunk());
Expand Down Expand Up @@ -1500,9 +1506,18 @@ class PlainByteArrayDecoder : public PlainDecoder<ByteArrayType>,
RETURN_NOT_OK(helper.builder->ReserveData(
std::min<int64_t>(len_, helper.chunk_space_remaining)));
for (int i = 0; i < num_values; ++i) {
int32_t value_len = static_cast<int32_t>(arrow::util::SafeLoadAs<uint32_t>(data_));
int increment = static_cast<int>(sizeof(uint32_t) + value_len);
if (ARROW_PREDICT_FALSE(len_ < increment)) ParquetException::EofException();
// For compiler warnings on unsigned/signed arithmetic.
auto int32_s = static_cast<int32_t>(sizeof(int32_t));

auto value_len = arrow::util::SafeLoadAs<int32_t>(data_);
if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - int32_s)) {
return Status::Invalid("Invalid or corrupted value_len '", value_len, "'");
}

auto increment = int32_s + value_len;
if (ARROW_PREDICT_FALSE(len_ < increment)) {
ParquetException::EofException();
}
if (ARROW_PREDICT_FALSE(!helper.CanFit(value_len))) {
// This element would exceed the capacity of a chunk
RETURN_NOT_OK(helper.PushChunk());
Expand All @@ -1529,9 +1544,16 @@ class PlainByteArrayDecoder : public PlainDecoder<ByteArrayType>,
int values_decoded = 0;
for (int i = 0; i < num_values; ++i) {
if (bit_reader.IsSet()) {
uint32_t value_len = arrow::util::SafeLoadAs<uint32_t>(data_);
int increment = static_cast<int>(sizeof(uint32_t) + value_len);
if (len_ < increment) {
// For compiler warnings on unsigned/signed arithmetic.
auto int32_s = static_cast<int32_t>(sizeof(int32_t));

auto value_len = arrow::util::SafeLoadAs<int32_t>(data_);
if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - int32_s)) {
return Status::Invalid("Invalid or corrupted value_len '", value_len, "'");
}

auto increment = int32_s + value_len;
if (ARROW_PREDICT_FALSE(len_ < increment)) {
ParquetException::EofException();
}
RETURN_NOT_OK(builder->Append(data_ + sizeof(uint32_t), value_len));
Expand All @@ -1553,9 +1575,18 @@ class PlainByteArrayDecoder : public PlainDecoder<ByteArrayType>,
num_values = std::min(num_values, num_values_);
RETURN_NOT_OK(builder->Reserve(num_values));
for (int i = 0; i < num_values; ++i) {
uint32_t value_len = arrow::util::SafeLoadAs<uint32_t>(data_);
int increment = static_cast<int>(sizeof(uint32_t) + value_len);
if (len_ < increment) ParquetException::EofException();
// For compiler warnings on unsigned/signed arithmetic.
auto int32_s = static_cast<int32_t>(sizeof(int32_t));

auto value_len = arrow::util::SafeLoadAs<int32_t>(data_);
if (ARROW_PREDICT_FALSE(value_len < 0 || value_len > INT32_MAX - int32_s)) {
return Status::Invalid("Invalid or corrupted value_len '", value_len, "'");
}

auto increment = int32_s + value_len;
if (ARROW_PREDICT_FALSE(len_ < increment)) {
ParquetException::EofException();
}
RETURN_NOT_OK(builder->Append(data_ + sizeof(uint32_t), value_len));
data_ += increment;
len_ -= increment;
Expand Down
16 changes: 10 additions & 6 deletions cpp/src/parquet/metadata.cc
Expand Up @@ -214,7 +214,7 @@ class ColumnChunkMetaData::ColumnChunkMetaDataImpl {
for (const auto& encoding : column_metadata_->encodings) {
encodings_.push_back(LoadEnumSafe(&encoding));
}
for (auto encoding_stats : column_metadata_->encoding_stats) {
for (const auto& encoding_stats : column_metadata_->encoding_stats) {
encoding_stats_.push_back({LoadEnumSafe(&encoding_stats.page_type),
LoadEnumSafe(&encoding_stats.encoding),
encoding_stats.count});
Expand Down Expand Up @@ -642,10 +642,19 @@ class FileMetaData::FileMetaDataImpl {
friend FileMetaDataBuilder;
uint32_t metadata_len_;
std::unique_ptr<format::FileMetaData> metadata_;
SchemaDescriptor schema_;
ApplicationVersion writer_version_;
std::shared_ptr<const KeyValueMetadata> key_value_metadata_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;

void InitSchema() {
if (metadata_->schema.empty()) {
throw ParquetException("Empty file schema (no root)");
}
schema_.Init(schema::Unflatten(&metadata_->schema[0],
static_cast<int>(metadata_->schema.size())));
}

void InitColumnOrders() {
// update ColumnOrder
std::vector<parquet::ColumnOrder> column_orders;
Expand All @@ -663,8 +672,6 @@ class FileMetaData::FileMetaDataImpl {

schema_.updateColumnOrders(column_orders);
}
SchemaDescriptor schema_;
ApplicationVersion writer_version_;

void InitKeyValueMetadata() {
std::shared_ptr<KeyValueMetadata> metadata = nullptr;
Expand All @@ -676,9 +683,6 @@ class FileMetaData::FileMetaDataImpl {
}
key_value_metadata_ = std::move(metadata);
}

std::shared_ptr<const KeyValueMetadata> key_value_metadata_;
std::shared_ptr<InternalFileDecryptor> file_decryptor_;
};

std::shared_ptr<FileMetaData> FileMetaData::Make(
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/parquet/schema.cc
Expand Up @@ -549,6 +549,9 @@ std::unique_ptr<Node> Unflatten(const format::SchemaElement* elements, int lengt
}

std::shared_ptr<SchemaDescriptor> FromParquet(const std::vector<SchemaElement>& schema) {
if (schema.empty()) {
throw ParquetException("Empty file schema (no root)");
}
std::unique_ptr<Node> root = Unflatten(&schema[0], static_cast<int>(schema.size()));
std::shared_ptr<SchemaDescriptor> descr = std::make_shared<SchemaDescriptor>();
descr->Init(std::shared_ptr<GroupNode>(static_cast<GroupNode*>(root.release())));
Expand Down
3 changes: 3 additions & 0 deletions cpp/src/parquet/thrift_internal.h
Expand Up @@ -361,6 +361,9 @@ inline void DeserializeThriftUnencryptedMsg(const uint8_t* buf, uint32_t* len,
shared_ptr<ThriftBuffer> tmem_transport(
new ThriftBuffer(const_cast<uint8_t*>(buf), *len));
apache::thrift::protocol::TCompactProtocolFactoryT<ThriftBuffer> tproto_factory;
// Protect against CPU and memory bombs
tproto_factory.setStringSizeLimit(10 * 1000 * 1000);
tproto_factory.setContainerSizeLimit(10 * 1000 * 1000);
shared_ptr<apache::thrift::protocol::TProtocol> tproto = //
tproto_factory.getProtocol(tmem_transport);
try {
Expand Down

0 comments on commit 6b87c6c

Please sign in to comment.