Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix cast_string_to_float with trailing whitespaces for inf and nan string #2063

Merged
merged 14 commits into from
Jun 5, 2024
Merged
35 changes: 26 additions & 9 deletions src/main/cpp/src/cast_string_to_float.cu
Original file line number Diff line number Diff line change
Expand Up @@ -241,15 +241,24 @@ class string_to_float {
(_warp_lane == 1 && (_c == 'A' || _c == 'a')) ||
(_warp_lane == 2 && (_c == 'N' || _c == 'n')));
if (nan_mask == 0x7) {
// if we start with 'nan', then even if we have other garbage character, this is a null row.
//
// if we're in ansi mode and this is not -precisely- nan, report that so that we can throw
// an exception later.
if (_len != 3) {
_valid = false;
_except = _len != 3;
}
return true;
// if we start with 'nan', then even if we have other garbage character(excluding
// whitespaces), this is a null row. but for e.g. : "nan " cases. spark will treat the as
// "nan", when the trailing characters are whitespaces, it is still a valid string. if we're
// in ansi mode and this is not -precisely- nan, report that so that we can throw an exception
// later.

// move forward the current position by 3
_bpos += 3;
_c = __shfl_down_sync(0xffffffff, _c, 3);

// remove the trailing whitespaces, if there exits
remove_leading_whitespace();
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved

// if we're at the end
if (_bpos == _len) { return true; }
// if we reach out here, it means that we have other garbage character.
_valid = false;
_except = true;
}
return false;
}
Expand Down Expand Up @@ -297,11 +306,19 @@ class string_to_float {
_bpos += 5;
// if we're at the end
if (_bpos == _len) { return true; }
_c = __shfl_down_sync(0xffffffff, _c, 5);
}

// remove the remaining whitespace if exits
Feng-Jiang28 marked this conversation as resolved.
Show resolved Hide resolved
remove_leading_whitespace();

// if we're at the end
if (_bpos == _len) { return true; }

// if we reach here for any reason, it means we have "inf" or "infinity" at the start of the
// string but also have additional characters, making this whole thing bogus/null
_valid = false;

return true;
}
return false;
Expand Down
50 changes: 50 additions & 0 deletions src/test/java/com/nvidia/spark/rapids/jni/CastStringsTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,56 @@ void castToFloatsTrimTest() {
}
}

@Test
void castToFloatNanTest(){
Table.TestBuilder tb2 = new Table.TestBuilder();
tb2.column("nan", "nan ", " nan ", "NAN", "nAn ", " NAn ", "Nan 0", "nan nan");

Table.TestBuilder tb = new Table.TestBuilder();
tb.column(Float.NaN, Float.NaN, Float.NaN, Float.NaN, Float.NaN, Float.NaN, null, null);

try (Table expected = tb.build()) {
List<ColumnVector> result = new ArrayList<>();
try (Table origTable = tb2.build()) {
for (int i = 0; i < origTable.getNumberOfColumns(); i++) {
ColumnVector string_col = origTable.getColumn(i);
result.add(CastStrings.toFloat(string_col, false, expected.getColumn(i).getType()));
}
try (Table result_tbl = new Table(result.toArray(new ColumnVector[result.size()]))) {
AssertUtils.assertTablesAreEqual(expected, result_tbl);
}
} finally {
result.forEach(ColumnVector::close);
}
}
}

@Test
void castToFloatsInfTest(){
// The test data: Table.TestBuilder object with a column containing the string "inf"
Table.TestBuilder tb2 = new Table.TestBuilder();
tb2.column("INFINITY ", "inf", "+inf ", " -INF ", "INFINITY AND BEYOND", "INF");

Table.TestBuilder tb = new Table.TestBuilder();
tb.column(Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.POSITIVE_INFINITY, Float.NEGATIVE_INFINITY, null, Float.POSITIVE_INFINITY);

try (Table expected = tb.build()) {
List<ColumnVector> result = new ArrayList<>();
try (Table origTable = tb2.build()) {
for (int i = 0; i < origTable.getNumberOfColumns(); i++) {
ColumnVector string_col = origTable.getColumn(i);
result.add(CastStrings.toFloat(string_col, false, expected.getColumn(i).getType()));
}
System.out.println(result);
try (Table result_tbl = new Table(result.toArray(new ColumnVector[result.size()]))) {
AssertUtils.assertTablesAreEqual(expected, result_tbl);
}
} finally {
result.forEach(ColumnVector::close);
}
}
}

@Test
void castToDecimalTest() {
Table.TestBuilder tb = new Table.TestBuilder();
Expand Down
Loading