Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-40270: [C++] Use LargeStringArray for casting when writing tables to CSV #40271

Merged
merged 3 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 78 additions & 18 deletions cpp/src/arrow/csv/writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,10 +128,20 @@ class ColumnPopulator {
// Populators are intented to be applied to reasonably small data. In most cases
// threading overhead would not be justified.
ctx.set_use_threads(false);
ASSIGN_OR_RAISE(
std::shared_ptr<Array> casted,
compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx));
casted_array_ = checked_pointer_cast<StringArray>(casted);
if (data.type() && is_large_binary_like(data.type()->id())) {
ASSIGN_OR_RAISE(array_, compute::Cast(data, /*to_type=*/large_utf8(),
compute::CastOptions(), &ctx));
} else {
auto casted = compute::Cast(data, /*to_type=*/utf8(), compute::CastOptions(), &ctx);
pitrou marked this conversation as resolved.
Show resolved Hide resolved
if (casted.ok()) {
array_ = std::move(casted).ValueOrDie();
} else if (casted.status().IsCapacityError()) {
ASSIGN_OR_RAISE(array_, compute::Cast(data, /*to_type=*/large_utf8(),
compute::CastOptions(), &ctx));
} else {
return casted.status();
}
}
return UpdateRowLengths(row_lengths);
}

Expand All @@ -146,7 +156,8 @@ class ColumnPopulator {

protected:
virtual Status UpdateRowLengths(int64_t* row_lengths) = 0;
std::shared_ptr<StringArray> casted_array_;
// It must be a `StringArray` or `LargeStringArray`.
std::shared_ptr<Array> array_;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you just a comment for array_ that it would be a StringArray or LargeStringArray?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed.

const std::string end_chars_;
std::shared_ptr<Buffer> null_string_;

Expand Down Expand Up @@ -181,15 +192,28 @@ class UnquotedColumnPopulator : public ColumnPopulator {
reject_values_with_quotes_(reject_values_with_quotes) {}

Status UpdateRowLengths(int64_t* row_lengths) override {
if (ARROW_PREDICT_TRUE(array_->type_id() == Type::STRING)) {
return UpdateRowLengths<StringArray>(row_lengths);
} else if (ARROW_PREDICT_TRUE(array_->type_id() == Type::LARGE_STRING)) {
return UpdateRowLengths<LargeStringArray>(row_lengths);
} else {
return Status::TypeError("The array must be StringArray or LargeStringArray.");
}
}

template <typename StringArrayType>
Status UpdateRowLengths(int64_t* row_lengths) {
auto casted_array = checked_pointer_cast<StringArrayType>(array_);
if (reject_values_with_quotes_) {
// When working on values that, after casting, could produce quotes,
// we need to return an error in accord with RFC4180.
RETURN_NOT_OK(CheckStringArrayHasNoStructuralChars(*casted_array_, delimiter_));
RETURN_NOT_OK(CheckStringArrayHasNoStructuralChars<StringArrayType>(*casted_array,
delimiter_));
}

int64_t row_number = 0;
VisitArraySpanInline<StringType>(
*casted_array_->data(),
VisitArraySpanInline<typename StringArrayType::TypeClass>(
*casted_array->data(),
[&](std::string_view s) {
row_lengths[row_number] += static_cast<int64_t>(s.length());
row_number++;
Expand All @@ -202,6 +226,17 @@ class UnquotedColumnPopulator : public ColumnPopulator {
}

Status PopulateRows(char* output, int64_t* offsets) const override {
if (ARROW_PREDICT_TRUE(array_->type_id() == Type::STRING)) {
return PopulateRows<StringArray>(output, offsets);
} else if (ARROW_PREDICT_TRUE(array_->type_id() == Type::LARGE_STRING)) {
return PopulateRows<LargeStringArray>(output, offsets);
} else {
return Status::TypeError("The array must be StringArray or LargeStringArray.");
}
}

template <typename StringArrayType>
Status PopulateRows(char* output, int64_t* offsets) const {
// Function applied to valid values cast to string.
auto valid_function = [&](std::string_view s) {
memcpy(output + *offsets, s.data(), s.length());
Expand All @@ -222,13 +257,14 @@ class UnquotedColumnPopulator : public ColumnPopulator {
return Status::OK();
};

return VisitArraySpanInline<StringType>(*casted_array_->data(), valid_function,
null_function);
return VisitArraySpanInline<typename StringArrayType::TypeClass>(
*array_->data(), valid_function, null_function);
}

private:
// Returns an error status if string array has any structural characters.
static Status CheckStringArrayHasNoStructuralChars(const StringArray& array,
template <typename ArrayType>
static Status CheckStringArrayHasNoStructuralChars(const ArrayType& array,
const char delimiter) {
// scan the underlying string array buffer as a single big string
const uint8_t* const data = array.raw_data() + array.value_offset(0);
Expand Down Expand Up @@ -282,14 +318,26 @@ class QuotedColumnPopulator : public ColumnPopulator {
: ColumnPopulator(pool, std::move(end_chars), std::move(null_string)) {}

Status UpdateRowLengths(int64_t* row_lengths) override {
const StringArray& input = *casted_array_;
if (ARROW_PREDICT_TRUE(array_->type_id() == Type::STRING)) {
return UpdateRowLengths<StringArray>(row_lengths);
} else if (ARROW_PREDICT_TRUE(array_->type_id() == Type::LARGE_STRING)) {
return UpdateRowLengths<LargeStringArray>(row_lengths);
} else {
return Status::TypeError("The array must be StringArray or LargeStringArray.");
}
}

template <typename StringArrayType>
Status UpdateRowLengths(int64_t* row_lengths) {
auto casted_array = checked_pointer_cast<StringArrayType>(array_);
const StringArrayType& input = *casted_array;

row_needs_escaping_.resize(casted_array_->length(), false);
row_needs_escaping_.resize(casted_array->length(), false);

if (NoQuoteInArray(input)) {
// fast path if no quote
int row_number = 0;
VisitArraySpanInline<StringType>(
VisitArraySpanInline<typename StringArrayType::TypeClass>(
*input.data(),
[&](std::string_view s) {
row_lengths[row_number] += static_cast<int64_t>(s.length()) + kQuoteCount;
Expand All @@ -301,7 +349,7 @@ class QuotedColumnPopulator : public ColumnPopulator {
});
} else {
int row_number = 0;
VisitArraySpanInline<StringType>(
VisitArraySpanInline<typename StringArrayType::TypeClass>(
*input.data(),
[&](std::string_view s) {
// Each quote in the value string needs to be escaped.
Expand All @@ -320,9 +368,20 @@ class QuotedColumnPopulator : public ColumnPopulator {
}

Status PopulateRows(char* output, int64_t* offsets) const override {
if (ARROW_PREDICT_TRUE(array_->type_id() == Type::STRING)) {
return PopulateRows<StringArray>(output, offsets);
} else if (ARROW_PREDICT_TRUE(array_->type_id() == Type::LARGE_STRING)) {
return PopulateRows<LargeStringArray>(output, offsets);
} else {
return Status::TypeError("The array must be StringArray or LargeStringArray.");
}
}

template <typename StringArrayType>
Status PopulateRows(char* output, int64_t* offsets) const {
auto needs_escaping = row_needs_escaping_.begin();
VisitArraySpanInline<StringType>(
*(casted_array_->data()),
VisitArraySpanInline<typename StringArrayType::TypeClass>(
*array_->data(),
[&](std::string_view s) {
// still needs string content length to be added
char* row = output + *offsets;
Expand Down Expand Up @@ -355,7 +414,8 @@ class QuotedColumnPopulator : public ColumnPopulator {

private:
// Returns true if there's no quote in the string array
static bool NoQuoteInArray(const StringArray& array) {
template <typename StringArrayType>
static bool NoQuoteInArray(const StringArrayType& array) {
const uint8_t* data = array.raw_data() + array.value_offset(0);
const int64_t buffer_size = array.total_values_length();
return std::memchr(data, '"', buffer_size) == nullptr;
Expand Down
3 changes: 2 additions & 1 deletion cpp/src/arrow/csv/writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,8 @@ namespace arrow {
namespace csv {

// Functionality for converting Arrow data to Comma separated value text.
// This library supports all primitive types that can be cast to a StringArrays.
// This library supports all primitive types that can be cast to a StringArray or
// a LargeStringArray.
// It applies to following formatting rules:
// - For non-binary types no quotes surround values. Nulls are represented as the empty
// string.
Expand Down
78 changes: 40 additions & 38 deletions cpp/src/arrow/csv/writer_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,17 @@ WriteOptions DefaultTestOptions(bool include_header = false,
}

std::string UtilGetExpectedWithEOL(const std::string& eol) {
return std::string("1,,-1,,,,") + eol + // line 1
R"(1,"abc""efg",2324,,,,)" + eol + // line 2
R"(,"abcd",5467,,,,)" + eol + // line 3
R"(,,,,,,)" + eol + // line 4
R"(546,"",517,,,,)" + eol + // line 5
R"(124,"a""""b""",,,,,)" + eol + // line 6
R"(,,,1970-01-01,,,)" + eol + // line 7
R"(,,,,1970-01-02,,)" + eol + // line 8
R"(,,,,,2004-02-29 01:02:03,)" + eol + // line 9
R"(,,,,,,3600)" + eol + // line 10
R"(,"NA",,,,,)" + eol; // line 11
return std::string("1,,-1,,,,,") + eol + // line 1
R"(1,"abc""efg",2324,,,,,)" + eol + // line 2
R"(,"abcd",5467,,,,,"efghi")" + eol + // line 3
R"(,,,,,,,)" + eol + // line 4
R"(546,"",517,,,,,)" + eol + // line 5
R"(124,"a""""b""",,,,,,)" + eol + // line 6
R"(,,,1970-01-01,,,,"jklm")" + eol + // line 7
R"(,,,,1970-01-02,,,)" + eol + // line 8
R"(,,,,,2004-02-29 01:02:03,,)" + eol + // line 9
R"(,,,,,,3600,)" + eol + // line 10
R"(,"NA",,,,,,)" + eol; // line 11
}

std::vector<WriterTestParams> GenerateTestCases() {
Expand All @@ -100,20 +100,22 @@ std::vector<WriterTestParams> GenerateTestCases() {
field("e", date64()),
field("f", timestamp(TimeUnit::SECOND)),
field("g", duration(TimeUnit::SECOND)),
field("h", large_utf8()),
});
auto populated_batch = R"([{"a": 1, "c ": -1},
{ "a": 1, "b\"": "abc\"efg", "c ": 2324},
{ "b\"": "abcd", "c ": 5467},
{ "b\"": "abcd", "c ": 5467, "h": "efghi"},
{ },
{ "a": 546, "b\"": "", "c ": 517 },
{ "a": 124, "b\"": "a\"\"b\"" },
{ "d": 0 },
{ "d": 0, "h": "jklm" },
{ "e": 86400000 },
{ "f": 1078016523 },
{ "g": 3600 },
{ "b\"": "NA" }])";

std::string expected_header = std::string(R"("a","b""","c ","d","e","f","g")") + "\n";
std::string expected_header =
std::string(R"("a","b""","c ","d","e","f","g","h")") + "\n";

// Expected output without header when using default QuotingStyle::Needed.
std::string expected_without_header = UtilGetExpectedWithEOL("\n");
Expand All @@ -122,42 +124,42 @@ std::vector<WriterTestParams> GenerateTestCases() {

// Expected output without header when using QuotingStyle::AllValid.
std::string expected_quoting_style_all_valid =
std::string(R"("1",,"-1",,,,)") + "\n" + // line 1
R"("1","abc""efg","2324",,,,)" + "\n" + // line 2
R"(,"abcd","5467",,,,)" + "\n" + // line 3
R"(,,,,,,)" + "\n" + // line 4
R"("546","","517",,,,)" + "\n" + // line 5
R"("124","a""""b""",,,,,)" + "\n" + // line 6
R"(,,,"1970-01-01",,,)" + "\n" + // line 7
R"(,,,,"1970-01-02",,)" + "\n" + // line 8
R"(,,,,,"2004-02-29 01:02:03",)" + "\n" + // line 9
R"(,,,,,,"3600")" + "\n" + // line 10
R"(,"NA",,,,,)" + "\n"; // line 11
std::string(R"("1",,"-1",,,,,)") + "\n" + // line 1
R"("1","abc""efg","2324",,,,,)" + "\n" + // line 2
R"(,"abcd","5467",,,,,"efghi")" + "\n" + // line 3
R"(,,,,,,,)" + "\n" + // line 4
R"("546","","517",,,,,)" + "\n" + // line 5
R"("124","a""""b""",,,,,,)" + "\n" + // line 6
R"(,,,"1970-01-01",,,,"jklm")" + "\n" + // line 7
R"(,,,,"1970-01-02",,,)" + "\n" + // line 8
R"(,,,,,"2004-02-29 01:02:03",,)" + "\n" + // line 9
R"(,,,,,,"3600",)" + "\n" + // line 10
R"(,"NA",,,,,,)" + "\n"; // line 11

// Batch when testing QuotingStyle::None. The values may not contain any quotes for this
// style according to RFC4180.
auto populated_batch_quoting_style_none = R"([{"a": 1, "c ": -1},
{ "a": 1, "b\"": "abcefg", "c ": 2324},
{ "b\"": "abcd", "c ": 5467},
{ "b\"": "abcd", "c ": 5467, "h": "efghi"},
{ },
{ "a": 546, "b\"": "", "c ": 517 },
{ "a": 124, "b\"": "ab" },
{ "d": 0 },
{ "d": 0, "h": "jklm" },
{ "e": 86400000 },
{ "f": 1078016523 },
{ "g": 3600 }])";
// Expected output for QuotingStyle::None.
std::string expected_quoting_style_none = std::string("1,,-1,,,,") + "\n" + // line 1
R"(1,abcefg,2324,,,,)" + "\n" + // line 2
R"(,abcd,5467,,,,)" + "\n" + // line 3
R"(,,,,,,)" + "\n" + // line 4
R"(546,,517,,,,)" + "\n" + // line 5
R"(124,ab,,,,,)" + "\n" + // line 6
R"(,,,1970-01-01,,,)" + "\n" + // line 7
R"(,,,,1970-01-02,,)" + "\n" + // line 8
R"(,,,,,2004-02-29 01:02:03,)" +
"\n" + // line 9
R"(,,,,,,3600)" + "\n"; // line 10
std::string expected_quoting_style_none = std::string("1,,-1,,,,,") + "\n" + // line 1
R"(1,abcefg,2324,,,,,)" + "\n" + // line 2
R"(,abcd,5467,,,,,efghi)" + "\n" + // line 3
R"(,,,,,,,)" + "\n" + // line 4
R"(546,,517,,,,,)" + "\n" + // line 5
R"(124,ab,,,,,,)" + "\n" + // line 6
R"(,,,1970-01-01,,,,jklm)" + "\n" + // line 7
R"(,,,,1970-01-02,,,)" + "\n" + // line 8
R"(,,,,,2004-02-29 01:02:03,,)" +
"\n" + // line 9
R"(,,,,,,3600,)" + "\n"; // line 10

// Schema and data to test custom null value string.
auto schema_custom_na = schema({field("g", uint64()), field("h", utf8())});
Expand Down
Loading