diff --git a/core/pom.xml b/core/pom.xml index e80829b7a7f3d..317fb3bb879af 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -91,6 +91,11 @@ spark-network-shuffle_${scala.binary.version} ${project.version} + + org.apache.spark + spark-unsafe_${scala.binary.version} + ${project.version} + net.java.dev.jets3t jets3t diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala index fc02ba6c9c43e..770d9a5b28be5 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala @@ -19,6 +19,9 @@ package org.apache.spark.sql.types import java.util.Arrays +import org.apache.spark.unsafe.PlatformDependent +import org.apache.spark.unsafe.string.{UTF8StringPointer, UTF8StringMethods} + /** * A UTF-8 String, as internal representation of StringType in SparkSQL * @@ -32,12 +35,13 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { private[this] var bytes: Array[Byte] = _ + private val pointer: UTF8StringPointer = new UTF8StringPointer + /** * Update the UTF8String with String. */ def set(str: String): UTF8String = { - bytes = str.getBytes("utf-8") - this + set(str.getBytes("utf-8")) } /** @@ -45,32 +49,17 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { */ def set(bytes: Array[Byte]): UTF8String = { this.bytes = bytes + pointer.set(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, bytes.length) this } - /** - * Return the number of bytes for a code point with the first byte as `b` - * @param b The first byte of a code point - */ - @inline - private[this] def numOfBytes(b: Byte): Int = { - val offset = (b & 0xFF) - 192 - if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1 - } - /** * Return the number of code points in it. * * This is only used by Substring() when `start` is negative. */ def length(): Int = { - var len = 0 - var i: Int = 0 - while (i < bytes.length) { - i += numOfBytes(bytes(i)) - len += 1 - } - len + pointer.getLengthInCodePoints } def getBytes: Array[Byte] = { @@ -90,12 +79,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { var c = 0 var i: Int = 0 while (c < start && i < bytes.length) { - i += numOfBytes(bytes(i)) + i += UTF8StringMethods.numOfBytes(bytes(i)) c += 1 } var j = i while (c < until && j < bytes.length) { - j += numOfBytes(bytes(j)) + j += UTF8StringMethods.numOfBytes(bytes(j)) c += 1 } UTF8String(Arrays.copyOfRange(bytes, i, j)) @@ -150,14 +139,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { override def clone(): UTF8String = new UTF8String().set(this.bytes) override def compare(other: UTF8String): Int = { - var i: Int = 0 - val b = other.getBytes - while (i < bytes.length && i < b.length) { - val res = bytes(i).compareTo(b(i)) - if (res != 0) return res - i += 1 - } - bytes.length - b.length + UTF8StringMethods.compare( + pointer.getBaseObject, + pointer.getBaseOffset, + pointer.getLengthInBytes, + other.pointer.getBaseObject, + other.pointer.getBaseOffset, + other.pointer.getLengthInBytes + ) } override def compareTo(other: UTF8String): Int = { @@ -181,14 +170,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable { } object UTF8String { - // number of tailing bytes in a UTF8 sequence for a code point - // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 - private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, - 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, - 4, 4, 4, 4, 4, 4, 4, 4, - 5, 5, 5, 5, - 6, 6, 6, 6) /** * Create a UTF-8 String from String diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java index 48438e975a4e4..cbbc8713597e3 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java @@ -40,12 +40,44 @@ static long getLengthInBytes(Object baseObject, long baseOffset) { return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset); } - public static String toJavaString(Object baseObject, long baseOffset) { - final long lengthInBytes = getLengthInBytes(baseObject, baseOffset); + public static int compare( + Object leftBaseObject, + long leftBaseOffset, + int leftBaseLengthInBytes, + Object rightBaseObject, + long rightBaseOffset, + int rightBaseLengthInBytes) { + int i = 0; + while (i < leftBaseLengthInBytes && i < rightBaseLengthInBytes) { + final byte leftByte = PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i); + final byte rightByte = PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i); + final int res = leftByte - rightByte; + if (res != 0) return res; + i += 1; + } + return leftBaseLengthInBytes - rightBaseLengthInBytes; + } + + /** + * Return the number of code points in a string. + * + * This is only used by Substring() when `start` is negative. + */ + public static int getLengthInCodePoints(Object baseObject, long baseOffset, int lengthInBytes) { + int len = 0; + int i = 0; + while (i < lengthInBytes) { + i += numOfBytes(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i)); + len += 1; + } + return len; + } + + public static String toJavaString(Object baseObject, long baseOffset, int lengthInBytes) { final byte[] bytes = new byte[(int) lengthInBytes]; PlatformDependent.UNSAFE.copyMemory( baseObject, - baseOffset + 8, // skip over the length + baseOffset, bytes, PlatformDependent.BYTE_ARRAY_OFFSET, lengthInBytes @@ -67,15 +99,40 @@ public static String toJavaString(Object baseObject, long baseOffset) { public static long createFromJavaString(Object baseObject, long baseOffset, String str) { final byte[] strBytes = str.getBytes(); final long strLengthInBytes = strBytes.length; - PlatformDependent.UNSAFE.putLong(baseObject, baseOffset, strLengthInBytes); PlatformDependent.copyMemory( strBytes, PlatformDependent.BYTE_ARRAY_OFFSET, baseObject, - baseOffset + 8, + baseOffset, strLengthInBytes ); - return (8 + strLengthInBytes); + return strLengthInBytes; } + /** + * Return the number of bytes for a code point with the first byte as `b` + * @param b The first byte of a code point + */ + public static int numOfBytes(byte b) { + final int offset = (b & 0xFF) - 192; + if (offset >= 0) { + return bytesOfCodePointInUTF8[offset]; + } else { + return 1; + } + } + + /** + * number of tailing bytes in a UTF8 sequence for a code point + * see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1 + */ + private static int[] bytesOfCodePointInUTF8 = new int[] { + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, + 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, + 6, 6, 6, 6 + }; + } diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java index 4a43dc16fd613..3d22ad2fa406c 100644 --- a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java +++ b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java @@ -17,16 +17,39 @@ package org.apache.spark.unsafe.string; -import org.apache.spark.unsafe.memory.MemoryLocation; +import javax.annotation.Nullable; /** * A pointer to UTF8String data. */ -public class UTF8StringPointer extends MemoryLocation { +public class UTF8StringPointer { - public long getLengthInBytes() { return UTF8StringMethods.getLengthInBytes(obj, offset); } + @Nullable + protected Object obj; + protected long offset; + protected int lengthInBytes; - public String toJavaString() { return UTF8StringMethods.toJavaString(obj, offset); } + public UTF8StringPointer() { } + + public void set(Object obj, long offset, int lengthInBytes) { + this.obj = obj; + this.offset = offset; + this.lengthInBytes = lengthInBytes; + } + + public int getLengthInCodePoints() { + return UTF8StringMethods.getLengthInCodePoints(obj, offset, lengthInBytes); + } + + public int getLengthInBytes() { return lengthInBytes; } + + public Object getBaseObject() { return obj; } + + public long getBaseOffset() { return offset; } + + public String toJavaString() { + return UTF8StringMethods.toJavaString(obj, offset, lengthInBytes); + } @Override public String toString() { return toJavaString(); } } diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java b/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java index 1b607163b2b33..189825864ad39 100644 --- a/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java +++ b/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java @@ -32,11 +32,13 @@ public void toStringTest() { final byte[] javaStrBytes = javaStr.getBytes(); final int paddedSizeInWords = javaStrBytes.length / 8 + (javaStrBytes.length % 8 == 0 ? 0 : 1); final MemoryLocation memory = MemoryBlock.fromLongArray(new long[paddedSizeInWords]); - final long bytesWritten = - UTF8StringMethods.createFromJavaString(memory.getBaseObject(), memory.getBaseOffset(), javaStr); - Assert.assertEquals(8 + javaStrBytes.length, bytesWritten); + final long bytesWritten = UTF8StringMethods.createFromJavaString( + memory.getBaseObject(), + memory.getBaseOffset(), + javaStr); + Assert.assertEquals(javaStrBytes.length, bytesWritten); final UTF8StringPointer utf8String = new UTF8StringPointer(); - utf8String.setObjAndOffset(memory.getBaseObject(), memory.getBaseOffset()); + utf8String.set(memory.getBaseObject(), memory.getBaseOffset(), bytesWritten); Assert.assertEquals(javaStr, utf8String.toJavaString()); } }