Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-19843] [SQL] UTF8String => (int / long) conversion expensive for invalid inputs #17184

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -850,26 +850,27 @@ public UTF8String translate(Map<Character, Character> dict) {
return fromString(sb.toString());
}

private int getDigit(byte b) {
if (b >= '0' && b <= '9') {
return b - '0';
}
throw new NumberFormatException(toString());
public static class LongWrapper {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tejasapatil can you submit a follow up small pr to add classdoc for this? would be great to explain why we have this.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

followup PR : #17205

public long value = 0;
}

/**
* Parses this UTF8String to long.
*
* Note that, in this method we accumulate the result in negative format, and convert it to
* positive format at the end, if this string is not started with '-'. This is because min value
* is bigger than max value in digits, e.g. Integer.MAX_VALUE is '2147483647' and
* Integer.MIN_VALUE is '-2147483648'.
* is bigger than max value in digits, e.g. Long.MAX_VALUE is '9223372036854775807' and
* Long.MIN_VALUE is '-9223372036854775808'.
*
* This code is mostly copied from LazyLong.parseLong in Hive.
*
* @param toLongResult If a valid `long` was parsed from this UTF8String, then its value would
* be set in `toLongResult`
* @return true if the parsing was successful else false
*/
public long toLong() {
public boolean toLong(LongWrapper toLongResult) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}

byte b = getByte(0);
Expand All @@ -878,7 +879,7 @@ public long toLong() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}

Expand All @@ -897,41 +898,52 @@ public long toLong() {
break;
}

int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}

// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Long.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop.
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Long.MIN_VALUE / radix), we
// can just use `result > 0` to check overflow. If result overflows, we should stop and throw
// exception.
// can just use `result > 0` to check overflow. If result overflows, we should stop.
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}

return result;
toLongResult.value = result;
return true;
}

public static class IntWrapper {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

and perhaps move this closer to LongWrapper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

actually why bother having an IntWrapper? Can't you use LongWrapper?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yea the calculation still use int/long respectively, the wrapper is only used for holding the result. There should be not much performance penalty to always use LongWrapper.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using IntWrapper for integers gives better perf. If we use LongWrapper for integers, there would be a conversion needed from long -> int. Its not that big of a difference but given that ints are used heavily in workloads, I dont want to leave that behind.

Here is microbenchmark result for current approach VS using LongWrapper everywhere

conversion to int:                       Best/Avg Time(ms)    Rate(M/s)   Per Row(ns)   Relative
------------------------------------------------------------------------------------------------
IntWrapper                                  20397 / 20564         26.3          38.0       1.0X
LongWrapper                                 20855 / 21530         25.7          38.8       1.0X

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK that's fair. In that case let's document this (otherwise somebody might come in and remove this in the future..)

public int value = 0;
}

/**
Expand All @@ -946,10 +958,14 @@ public long toLong() {
*
* Note that, this method is almost same as `toLong`, but we leave it duplicated for performance
* reasons, like Hive does.
*
* @param intWrapper If a valid `int` was parsed from this UTF8String, then its value would
* be set in `intWrapper`
* @return true if the parsing was successful else false
*/
public int toInt() {
public boolean toInt(IntWrapper intWrapper) {
if (numBytes == 0) {
throw new NumberFormatException("Empty string");
return false;
}

byte b = getByte(0);
Expand All @@ -958,7 +974,7 @@ public int toInt() {
if (negative || b == '+') {
offset++;
if (numBytes == 1) {
throw new NumberFormatException(toString());
return false;
}
}

Expand All @@ -977,61 +993,69 @@ public int toInt() {
break;
}

int digit = getDigit(b);
int digit;
if (b >= '0' && b <= '9') {
digit = b - '0';
} else {
return false;
}

// We are going to process the new digit and accumulate the result. However, before doing
// this, if the result is already smaller than the stopValue(Integer.MIN_VALUE / radix), then
// result * 10 will definitely be smaller than minValue, and we can stop and throw exception.
// result * 10 will definitely be smaller than minValue, and we can stop
if (result < stopValue) {
throw new NumberFormatException(toString());
return false;
}

result = result * radix - digit;
// Since the previous result is less than or equal to stopValue(Integer.MIN_VALUE / radix),
// we can just use `result > 0` to check overflow. If result overflows, we should stop and
// throw exception.
// we can just use `result > 0` to check overflow. If result overflows, we should stop
if (result > 0) {
throw new NumberFormatException(toString());
return false;
}
}

// This is the case when we've encountered a decimal separator. The fractional
// part will not change the number, but we will verify that the fractional part
// is well formed.
while (offset < numBytes) {
if (getDigit(getByte(offset)) == -1) {
throw new NumberFormatException(toString());
byte currentByte = getByte(offset);
if (currentByte < '0' || currentByte > '9') {
return false;
}
offset++;
}

if (!negative) {
result = -result;
if (result < 0) {
throw new NumberFormatException(toString());
return false;
}
}

return result;
intWrapper.value = result;
return true;
}

public short toShort() {
int intValue = toInt();
short result = (short) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toShort(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
short result = (short) intValue;
if (result == intValue) {
return true;
}
}

return result;
return false;
}

public byte toByte() {
int intValue = toInt();
byte result = (byte) intValue;
if (result != intValue) {
throw new NumberFormatException(toString());
public boolean toByte(IntWrapper intWrapper) {
if (toInt(intWrapper)) {
int intValue = intWrapper.value;
byte result = (byte) intValue;
if (result == intValue) {
return true;
}
}

return result;
return false;
}

@Override
Expand Down
Expand Up @@ -22,9 +22,7 @@
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.*;

import com.google.common.collect.ImmutableMap;
import org.apache.spark.unsafe.Platform;
Expand Down Expand Up @@ -608,4 +606,128 @@ public void writeToOutputStreamIntArray() throws IOException {
.writeTo(outputStream);
assertEquals("大千世界", outputStream.toString("UTF-8"));
}

@Test
public void testToShort() throws IOException {
Map<String, Short> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (short) 1);
inputToExpectedOutput.put("+1", (short) 1);
inputToExpectedOutput.put("-1", (short) -1);
inputToExpectedOutput.put("0", (short) 0);
inputToExpectedOutput.put("1111.12345678901234567890", (short) 1111);
inputToExpectedOutput.put(String.valueOf(Short.MAX_VALUE), Short.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Short.MIN_VALUE), Short.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
short value = (short) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper wrapper = new IntWrapper();
for (Map.Entry<String, Short> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toShort(wrapper));
assertEquals((short) entry.getValue(), wrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "3276700");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toShort(wrapper));
}
}

@Test
public void testToByte() throws IOException {
Map<String, Byte> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", (byte) 1);
inputToExpectedOutput.put("+1",(byte) 1);
inputToExpectedOutput.put("-1", (byte) -1);
inputToExpectedOutput.put("0", (byte) 0);
inputToExpectedOutput.put("111.12345678901234567890", (byte) 111);
inputToExpectedOutput.put(String.valueOf(Byte.MAX_VALUE), Byte.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Byte.MIN_VALUE), Byte.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
byte value = (byte) rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Byte> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toByte(intWrapper));
assertEquals((byte) entry.getValue(), intWrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toByte(intWrapper));
}
}

@Test
public void testToInt() throws IOException {
Map<String, Integer> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1);
inputToExpectedOutput.put("+1", 1);
inputToExpectedOutput.put("-1", -1);
inputToExpectedOutput.put("0", 0);
inputToExpectedOutput.put("11111.1234567", 11111);
inputToExpectedOutput.put(String.valueOf(Integer.MAX_VALUE), Integer.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Integer.MIN_VALUE), Integer.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
int value = rand.nextInt();
inputToExpectedOutput.put(String.valueOf(value), value);
}

IntWrapper intWrapper = new IntWrapper();
for (Map.Entry<String, Integer> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toInt(intWrapper));
assertEquals((int) entry.getValue(), intWrapper.value);
}

List<String> negativeInputs =
Arrays.asList("", " ", "null", "NULL", "\n", "~1212121", "12345678901234567890");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toInt(intWrapper));
}
}

@Test
public void testToLong() throws IOException {
Map<String, Long> inputToExpectedOutput = new HashMap<>();
inputToExpectedOutput.put("1", 1L);
inputToExpectedOutput.put("+1", 1L);
inputToExpectedOutput.put("-1", -1L);
inputToExpectedOutput.put("0", 0L);
inputToExpectedOutput.put("1076753423.12345678901234567890", 1076753423L);
inputToExpectedOutput.put(String.valueOf(Long.MAX_VALUE), Long.MAX_VALUE);
inputToExpectedOutput.put(String.valueOf(Long.MIN_VALUE), Long.MIN_VALUE);

Random rand = new Random();
for (int i = 0; i < 10; i++) {
long value = rand.nextLong();
inputToExpectedOutput.put(String.valueOf(value), value);
}

LongWrapper wrapper = new LongWrapper();
for (Map.Entry<String, Long> entry : inputToExpectedOutput.entrySet()) {
assertTrue(entry.getKey(), UTF8String.fromString(entry.getKey()).toLong(wrapper));
assertEquals((long) entry.getValue(), wrapper.value);
}

List<String> negativeInputs = Arrays.asList("", " ", "null", "NULL", "\n", "~1212121",
"1234567890123456789012345678901234");

for (String negativeInput : negativeInputs) {
assertFalse(negativeInput, UTF8String.fromString(negativeInput).toLong(wrapper));
}
}
}