Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ARROW-12995: [C++] Add validation to CSV options #10505

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions cpp/src/arrow/csv/options.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,19 @@ namespace csv {

ParseOptions ParseOptions::Defaults() { return ParseOptions(); }

Status ParseOptions::Validate() const {
if (ARROW_PREDICT_FALSE(delimiter == '\n' || delimiter == '\r')) {
return Status::Invalid("ParseOptions: delimiter cannot be \\r or \\n");
}
if (ARROW_PREDICT_FALSE(quoting && (quote_char == '\n' || quote_char == '\r'))) {
return Status::Invalid("ParseOptions: quote_char cannot be \\r or \\n");
}
if (ARROW_PREDICT_FALSE(escaping && (escape_char == '\n' || escape_char == '\r'))) {
return Status::Invalid("ParseOptions: escape_char cannot be \\r or \\n");
}
return Status::OK();
}

ConvertOptions ConvertOptions::Defaults() {
auto options = ConvertOptions();
// Same default null / true / false spellings as in Pandas.
Expand All @@ -33,8 +46,38 @@ ConvertOptions ConvertOptions::Defaults() {
return options;
}

Status ConvertOptions::Validate() const { return Status::OK(); }

ReadOptions ReadOptions::Defaults() { return ReadOptions(); }

Status ReadOptions::Validate() const {
if (ARROW_PREDICT_FALSE(block_size < 1)) {
// Min is 1 because some tests use really small block sizes
return Status::Invalid("ReadOptions: block_size must be at least 1: ", block_size);
}
if (ARROW_PREDICT_FALSE(skip_rows < 0)) {
return Status::Invalid("ReadOptions: skip_rows cannot be negative: ", skip_rows);
}
if (ARROW_PREDICT_FALSE(skip_rows_after_names < 0)) {
return Status::Invalid("ReadOptions: skip_rows_after_names cannot be negative: ",
skip_rows_after_names);
}
if (ARROW_PREDICT_FALSE(autogenerate_column_names && !column_names.empty())) {
return Status::Invalid(
"ReadOptions: autogenerate_column_names cannot be true when column_names are "
"provided");
}
return Status::OK();
}

WriteOptions WriteOptions::Defaults() { return WriteOptions(); }

Status WriteOptions::Validate() const {
if (ARROW_PREDICT_FALSE(batch_size < 1)) {
return Status::Invalid("WriteOptions: batch_size must be at least 1: ", batch_size);
}
return Status::OK();
}

} // namespace csv
} // namespace arrow
14 changes: 14 additions & 0 deletions cpp/src/arrow/csv/options.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <vector>

#include "arrow/csv/type_fwd.h"
#include "arrow/status.h"
#include "arrow/util/visibility.h"

namespace arrow {
Expand Down Expand Up @@ -59,6 +60,9 @@ struct ARROW_EXPORT ParseOptions {

/// Create parsing options with default values
static ParseOptions Defaults();

/// \brief Test that all set options are valid
Status Validate() const;
};

struct ARROW_EXPORT ConvertOptions {
Expand Down Expand Up @@ -112,6 +116,9 @@ struct ARROW_EXPORT ConvertOptions {
/// Create conversion options with default values, including conventional
/// values for `null_values`, `true_values` and `false_values`
static ConvertOptions Defaults();

/// \brief Test that all set options are valid
Status Validate() const;
};

struct ARROW_EXPORT ReadOptions {
Expand All @@ -124,6 +131,7 @@ struct ARROW_EXPORT ReadOptions {
///
/// This will determine multi-threading granularity as well as
/// the size of individual record batches.
/// Minimum valid value for block size is 1
int32_t block_size = 1 << 20; // 1 MB

/// Number of header rows to skip (not including the row of column names, if any)
Expand All @@ -143,6 +151,9 @@ struct ARROW_EXPORT ReadOptions {

/// Create read options with default values
static ReadOptions Defaults();

/// \brief Test that all set options are valid
Status Validate() const;
};

/// Experimental
Expand All @@ -158,6 +169,9 @@ struct ARROW_EXPORT WriteOptions {

/// Create write options with default values
static WriteOptions Defaults();

/// \brief Test that all set options are valid
Status Validate() const;
};

} // namespace csv
Expand Down
8 changes: 8 additions & 0 deletions cpp/src/arrow/csv/reader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1033,6 +1033,9 @@ Result<std::shared_ptr<TableReader>> MakeTableReader(
MemoryPool* pool, io::IOContext io_context, std::shared_ptr<io::InputStream> input,
const ReadOptions& read_options, const ParseOptions& parse_options,
const ConvertOptions& convert_options) {
RETURN_NOT_OK(parse_options.Validate());
RETURN_NOT_OK(read_options.Validate());
RETURN_NOT_OK(convert_options.Validate());
std::shared_ptr<BaseTableReader> reader;
if (read_options.use_threads) {
auto cpu_executor = internal::GetCpuThreadPool();
Expand All @@ -1051,6 +1054,9 @@ Future<std::shared_ptr<StreamingReader>> MakeStreamingReader(
io::IOContext io_context, std::shared_ptr<io::InputStream> input,
internal::Executor* cpu_executor, const ReadOptions& read_options,
const ParseOptions& parse_options, const ConvertOptions& convert_options) {
RETURN_NOT_OK(parse_options.Validate());
RETURN_NOT_OK(read_options.Validate());
RETURN_NOT_OK(convert_options.Validate());
std::shared_ptr<BaseStreamingReader> reader;
reader = std::make_shared<SerialStreamingReader>(
io_context, cpu_executor, input, read_options, parse_options, convert_options,
Expand Down Expand Up @@ -1182,6 +1188,8 @@ Future<int64_t> CountRowsAsync(io::IOContext io_context,
internal::Executor* cpu_executor,
const ReadOptions& read_options,
const ParseOptions& parse_options) {
RETURN_NOT_OK(parse_options.Validate());
RETURN_NOT_OK(read_options.Validate());
auto counter = std::make_shared<CSVRowCounter>(
io_context, cpu_executor, std::move(input), read_options, parse_options);
return counter->Count();
Expand Down
2 changes: 2 additions & 0 deletions cpp/src/arrow/csv/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class CSVConverter {

Status WriteCSV(const Table& table, const WriteOptions& options, MemoryPool* pool,
arrow::io::OutputStream* output) {
RETURN_NOT_OK(options.Validate());
if (pool == nullptr) {
pool = default_memory_pool();
}
Expand All @@ -424,6 +425,7 @@ Status WriteCSV(const Table& table, const WriteOptions& options, MemoryPool* poo

Status WriteCSV(const RecordBatch& batch, const WriteOptions& options, MemoryPool* pool,
arrow::io::OutputStream* output) {
RETURN_NOT_OK(options.Validate());
if (pool == nullptr) {
pool = default_memory_pool();
}
Expand Down
13 changes: 13 additions & 0 deletions python/pyarrow/_csv.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ cdef class ReadOptions(_Weakrefable):
How much bytes to process at a time from the input stream.
This will determine multi-threading granularity as well as
the size of individual record batches or table chunks.
Minimum valid value for block size is 1
skip_rows: int, optional (default 0)
The number of rows to skip before the column names (if any)
and the CSV data.
Expand Down Expand Up @@ -189,6 +190,9 @@ cdef class ReadOptions(_Weakrefable):
def skip_rows_after_names(self, value):
deref(self.options).skip_rows_after_names = value

def validate(self):
check_status(deref(self.options).Validate())

def equals(self, ReadOptions other):
return (
self.use_threads == other.use_threads and
Expand Down Expand Up @@ -359,6 +363,9 @@ cdef class ParseOptions(_Weakrefable):
def ignore_empty_lines(self, value):
deref(self.options).ignore_empty_lines = value

def validate(self):
check_status(deref(self.options).Validate())

def equals(self, ParseOptions other):
return (
self.delimiter == other.delimiter and
Expand Down Expand Up @@ -680,6 +687,9 @@ cdef class ConvertOptions(_Weakrefable):
out.options.reset(new CCSVConvertOptions(move(options)))
return out

def validate(self):
check_status(deref(self.options).Validate())

def equals(self, ConvertOptions other):
return (
self.check_utf8 == other.check_utf8 and
Expand Down Expand Up @@ -941,6 +951,9 @@ cdef class WriteOptions(_Weakrefable):
def batch_size(self, value):
self.options.batch_size = value

def validate(self):
check_status(self.options.Validate())


cdef _get_write_options(WriteOptions write_options, CCSVWriteOptions* out):
if write_options is None:
Expand Down
8 changes: 8 additions & 0 deletions python/pyarrow/includes/libarrow.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -1592,6 +1592,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
@staticmethod
CCSVParseOptions Defaults()

CStatus Validate()

cdef cppclass CCSVConvertOptions" arrow::csv::ConvertOptions":
c_bool check_utf8
unordered_map[c_string, shared_ptr[CDataType]] column_types
Expand All @@ -1613,6 +1615,8 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
@staticmethod
CCSVConvertOptions Defaults()

CStatus Validate()

cdef cppclass CCSVReadOptions" arrow::csv::ReadOptions":
c_bool use_threads
int32_t block_size
Expand All @@ -1627,13 +1631,17 @@ cdef extern from "arrow/csv/api.h" namespace "arrow::csv" nogil:
@staticmethod
CCSVReadOptions Defaults()

CStatus Validate()

cdef cppclass CCSVWriteOptions" arrow::csv::WriteOptions":
c_bool include_header
int32_t batch_size

@staticmethod
CCSVWriteOptions Defaults()

CStatus Validate()

cdef cppclass CCSVReader" arrow::csv::TableReader":
@staticmethod
CResult[shared_ptr[CCSVReader]] Make(
Expand Down
74 changes: 74 additions & 0 deletions python/pyarrow/tests/test_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,34 @@ def test_read_options():
opts = cls(block_size=1234)
assert opts.block_size == 1234

opts.validate()

match = "ReadOptions: block_size must be at least 1: 0"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.block_size = 0
opts.validate()

match = "ReadOptions: skip_rows cannot be negative: -1"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.skip_rows = -1
opts.validate()

match = "ReadOptions: skip_rows_after_names cannot be negative: -1"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.skip_rows_after_names = -1
opts.validate()

match = "ReadOptions: autogenerate_column_names cannot be true when" \
" column_names are provided"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.autogenerate_column_names = True
opts.column_names = ('a', 'b')
opts.validate()


def test_parse_options():
cls = ParseOptions
Expand All @@ -150,6 +178,44 @@ def test_parse_options():
newlines_in_values=True,
ignore_empty_lines=False)

cls().validate()
opts = cls()
opts.delimiter = "\t"
opts.validate()

match = "ParseOptions: delimiter cannot be \\\\r or \\\\n"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.delimiter = "\n"
opts.validate()

with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.delimiter = "\r"
opts.validate()

match = "ParseOptions: quote_char cannot be \\\\r or \\\\n"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.quote_char = "\n"
opts.validate()

with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.quote_char = "\r"
opts.validate()

match = "ParseOptions: escape_char cannot be \\\\r or \\\\n"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.escape_char = "\n"
opts.validate()

with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.escape_char = "\r"
opts.validate()


def test_convert_options():
cls = ConvertOptions
Expand Down Expand Up @@ -238,6 +304,14 @@ def test_write_options():
opts = cls(batch_size=9876)
assert opts.batch_size == 9876

opts.validate()

match = "WriteOptions: batch_size must be at least 1: 0"
with pytest.raises(pa.ArrowInvalid, match=match):
opts = cls()
opts.batch_size = 0
opts.validate()


class BaseTestCSVRead:

Expand Down