Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cast_string_to_float with trailing whitespaces for inf and nan string #2063

Merged
merged 14 commits into from
Jun 5, 2024
Merged
35 changes: 26 additions & 9 deletions src/main/cpp/src/cast_string_to_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,24 @@ class string_to_float {
(_warp_lane == 1 && (_c == 'A' || _c == 'a')) ||
(_warp_lane == 2 && (_c == 'N' || _c == 'n')));
if (nan_mask == 0x7) {
// if we start with 'nan', then even if we have other garbage character, this is a null row.
//
// if we're in ansi mode and this is not -precisely- nan, report that so that we can throw
// an exception later.
if (_len != 3) {
_valid = false;
_except = _len != 3;
}
return true;
// if we start with 'nan', then even if we have other garbage character(excluding
// whitespaces), this is a null row. but for e.g. : "nan " cases. spark will treat the as
// "nan", when the trailing characters are whitespaces, it is still a valid string. if we're
// in ansi mode and this is not -precisely- nan, report that so that we can throw an exception
// later.

// move forward the curren position by 3
thirtiseven marked this conversation as resolved.
Show resolved Hide resolved
_bpos += 3;
_c = __shfl_down_sync(0xffffffff, _c, 3);

// remove the trailing whitespaces, if there exits
remove_leading_whitespace();
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved

// if we're at the end
if (_bpos == _len) { return true; }
// if we reach out here, it means that we have other garbage character.
_valid = false;
_except = true;
}
return false;
}
Expand Down Expand Up @@ -295,13 +304,21 @@ class string_to_float {
(_warp_lane == 4 && (_c == 'Y' || _c == 'y')));
if (infinity_mask == 0x1f) {
_bpos += 5;
_c = __shfl_down_sync(0xffffffff, _c, 5);
// if we're at the end
if (_bpos == _len) { return true; }
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved
}

// remove the remaining whitespace if exits
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved
remove_leading_whitespace();

// if we're at the end
if (_bpos == _len) { return true; }

// if we reach here for any reason, it means we have "inf" or "infinity" at the start of the
// string but also have additional characters, making this whole thing bogus/null
_valid = false;

return true;
}
return false;
Expand Down
54 changes: 54 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,60 @@ void castToIntegerTest() {
}
}


//This is for testing the CastStrings class with the input of "inf" etc string
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved
@Test
void castToFloatsInfTest(){
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved
// The test data: Table.TestBuilder object with a column containing the string "inf"
Table.TestBuilder tb2 = new Table.TestBuilder();
tb2.column("INFINITY ", "inf", "+inf ", " -INF ");

Table.TestBuilder tb = new Table.TestBuilder();
tb.column(Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's test double type too.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some negative tests would be nice. Some strings like INFINITY AND BEYOND.

Copy link
Collaborator Author

@Feng-Jiang28 Feng-Jiang28 May 29, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

INFINITY AND BEYOND added.
with aother one test string INF.
More test cases suggestions are appreciated.


try (Table expected = tb.build()) {
List<ColumnVector> result = new ArrayList<>();
try (Table origTable = tb2.build()) {
for (int i = 0; i < origTable.getNumberOfColumns(); i++) {
ColumnVector string_col = origTable.getColumn(i);
result.add(CastStrings.toFloat(string_col, false, expected.getColumn(i).getType()));
}
try (Table result_tbl = new Table(result.toArray(new ColumnVector[result.size()]))) {
AssertUtils.assertTablesAreEqual(expected, result_tbl);
}
} finally {
result.forEach(ColumnVector::close);
}
}
}

//This is for testing the CastStrings class with the input of "nan " etc string
@Test
void castToFloatNanTest(){
// The test data: Table.TestBuilder object with a column containing the string "inf"
Table.TestBuilder tb2 = new Table.TestBuilder();
tb2.column("nan", "nan ", " nan ", "NAN", "nAn ", " NAn ");
hyperbolic2346 marked this conversation as resolved.
Show resolved Hide resolved

Table.TestBuilder tb = new Table.TestBuilder();
tb.column(Float.NaN, Float.NaN, Float.NaN, Float.NaN, Float.NaN, Float.NaN);

try (Table expected = tb.build()) {
List<ColumnVector> result = new ArrayList<>();
try (Table origTable = tb2.build()) {
for (int i = 0; i < origTable.getNumberOfColumns(); i++) {
ColumnVector string_col = origTable.getColumn(i);
result.add(CastStrings.toFloat(string_col, false, expected.getColumn(i).getType()));
}
try (Table result_tbl = new Table(result.toArray(new ColumnVector[result.size()]))) {
AssertUtils.assertTablesAreEqual(expected, result_tbl);
}
} finally {
result.forEach(ColumnVector::close);
}
}
}


@Test
void castToIntegerNoStripTest() {
Table.TestBuilder tb = new Table.TestBuilder();
Expand Down
Loading