diff --git a/core/pom.xml b/core/pom.xml
index e80829b7a7f3d..317fb3bb879af 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -91,6 +91,11 @@
spark-network-shuffle_${scala.binary.version}
${project.version}
+
+ org.apache.spark
+ spark-unsafe_${scala.binary.version}
+ ${project.version}
+
net.java.dev.jets3t
jets3t
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
index fc02ba6c9c43e..770d9a5b28be5 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/types/UTF8String.scala
@@ -19,6 +19,9 @@ package org.apache.spark.sql.types
import java.util.Arrays
+import org.apache.spark.unsafe.PlatformDependent
+import org.apache.spark.unsafe.string.{UTF8StringPointer, UTF8StringMethods}
+
/**
* A UTF-8 String, as internal representation of StringType in SparkSQL
*
@@ -32,12 +35,13 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
private[this] var bytes: Array[Byte] = _
+ private val pointer: UTF8StringPointer = new UTF8StringPointer
+
/**
* Update the UTF8String with String.
*/
def set(str: String): UTF8String = {
- bytes = str.getBytes("utf-8")
- this
+ set(str.getBytes("utf-8"))
}
/**
@@ -45,32 +49,17 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
*/
def set(bytes: Array[Byte]): UTF8String = {
this.bytes = bytes
+ pointer.set(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, bytes.length)
this
}
- /**
- * Return the number of bytes for a code point with the first byte as `b`
- * @param b The first byte of a code point
- */
- @inline
- private[this] def numOfBytes(b: Byte): Int = {
- val offset = (b & 0xFF) - 192
- if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
- }
-
/**
* Return the number of code points in it.
*
* This is only used by Substring() when `start` is negative.
*/
def length(): Int = {
- var len = 0
- var i: Int = 0
- while (i < bytes.length) {
- i += numOfBytes(bytes(i))
- len += 1
- }
- len
+ pointer.getLengthInCodePoints
}
def getBytes: Array[Byte] = {
@@ -90,12 +79,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
var c = 0
var i: Int = 0
while (c < start && i < bytes.length) {
- i += numOfBytes(bytes(i))
+ i += UTF8StringMethods.numOfBytes(bytes(i))
c += 1
}
var j = i
while (c < until && j < bytes.length) {
- j += numOfBytes(bytes(j))
+ j += UTF8StringMethods.numOfBytes(bytes(j))
c += 1
}
UTF8String(Arrays.copyOfRange(bytes, i, j))
@@ -150,14 +139,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
override def clone(): UTF8String = new UTF8String().set(this.bytes)
override def compare(other: UTF8String): Int = {
- var i: Int = 0
- val b = other.getBytes
- while (i < bytes.length && i < b.length) {
- val res = bytes(i).compareTo(b(i))
- if (res != 0) return res
- i += 1
- }
- bytes.length - b.length
+ UTF8StringMethods.compare(
+ pointer.getBaseObject,
+ pointer.getBaseOffset,
+ pointer.getLengthInBytes,
+ other.pointer.getBaseObject,
+ other.pointer.getBaseOffset,
+ other.pointer.getLengthInBytes
+ )
}
override def compareTo(other: UTF8String): Int = {
@@ -181,14 +170,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
}
object UTF8String {
- // number of tailing bytes in a UTF8 sequence for a code point
- // see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
- private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
- 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
- 4, 4, 4, 4, 4, 4, 4, 4,
- 5, 5, 5, 5,
- 6, 6, 6, 6)
/**
* Create a UTF-8 String from String
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java
index 48438e975a4e4..cbbc8713597e3 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringMethods.java
@@ -40,12 +40,44 @@ static long getLengthInBytes(Object baseObject, long baseOffset) {
return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
}
- public static String toJavaString(Object baseObject, long baseOffset) {
- final long lengthInBytes = getLengthInBytes(baseObject, baseOffset);
+ public static int compare(
+ Object leftBaseObject,
+ long leftBaseOffset,
+ int leftBaseLengthInBytes,
+ Object rightBaseObject,
+ long rightBaseOffset,
+ int rightBaseLengthInBytes) {
+ int i = 0;
+ while (i < leftBaseLengthInBytes && i < rightBaseLengthInBytes) {
+ final byte leftByte = PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
+ final byte rightByte = PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
+ final int res = leftByte - rightByte;
+ if (res != 0) return res;
+ i += 1;
+ }
+ return leftBaseLengthInBytes - rightBaseLengthInBytes;
+ }
+
+ /**
+ * Return the number of code points in a string.
+ *
+ * This is only used by Substring() when `start` is negative.
+ */
+ public static int getLengthInCodePoints(Object baseObject, long baseOffset, int lengthInBytes) {
+ int len = 0;
+ int i = 0;
+ while (i < lengthInBytes) {
+ i += numOfBytes(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i));
+ len += 1;
+ }
+ return len;
+ }
+
+ public static String toJavaString(Object baseObject, long baseOffset, int lengthInBytes) {
final byte[] bytes = new byte[(int) lengthInBytes];
PlatformDependent.UNSAFE.copyMemory(
baseObject,
- baseOffset + 8, // skip over the length
+ baseOffset,
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
lengthInBytes
@@ -67,15 +99,40 @@ public static String toJavaString(Object baseObject, long baseOffset) {
public static long createFromJavaString(Object baseObject, long baseOffset, String str) {
final byte[] strBytes = str.getBytes();
final long strLengthInBytes = strBytes.length;
- PlatformDependent.UNSAFE.putLong(baseObject, baseOffset, strLengthInBytes);
PlatformDependent.copyMemory(
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
- baseOffset + 8,
+ baseOffset,
strLengthInBytes
);
- return (8 + strLengthInBytes);
+ return strLengthInBytes;
}
+ /**
+ * Return the number of bytes for a code point with the first byte as `b`
+ * @param b The first byte of a code point
+ */
+ public static int numOfBytes(byte b) {
+ final int offset = (b & 0xFF) - 192;
+ if (offset >= 0) {
+ return bytesOfCodePointInUTF8[offset];
+ } else {
+ return 1;
+ }
+ }
+
+ /**
+ * number of tailing bytes in a UTF8 sequence for a code point
+ * see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
+ */
+ private static int[] bytesOfCodePointInUTF8 = new int[] {
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
+ 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
+ 4, 4, 4, 4, 4, 4, 4, 4,
+ 5, 5, 5, 5,
+ 6, 6, 6, 6
+ };
+
}
diff --git a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java
index 4a43dc16fd613..3d22ad2fa406c 100644
--- a/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java
+++ b/unsafe/src/main/java/org/apache/spark/unsafe/string/UTF8StringPointer.java
@@ -17,16 +17,39 @@
package org.apache.spark.unsafe.string;
-import org.apache.spark.unsafe.memory.MemoryLocation;
+import javax.annotation.Nullable;
/**
* A pointer to UTF8String data.
*/
-public class UTF8StringPointer extends MemoryLocation {
+public class UTF8StringPointer {
- public long getLengthInBytes() { return UTF8StringMethods.getLengthInBytes(obj, offset); }
+ @Nullable
+ protected Object obj;
+ protected long offset;
+ protected int lengthInBytes;
- public String toJavaString() { return UTF8StringMethods.toJavaString(obj, offset); }
+ public UTF8StringPointer() { }
+
+ public void set(Object obj, long offset, int lengthInBytes) {
+ this.obj = obj;
+ this.offset = offset;
+ this.lengthInBytes = lengthInBytes;
+ }
+
+ public int getLengthInCodePoints() {
+ return UTF8StringMethods.getLengthInCodePoints(obj, offset, lengthInBytes);
+ }
+
+ public int getLengthInBytes() { return lengthInBytes; }
+
+ public Object getBaseObject() { return obj; }
+
+ public long getBaseOffset() { return offset; }
+
+ public String toJavaString() {
+ return UTF8StringMethods.toJavaString(obj, offset, lengthInBytes);
+ }
@Override public String toString() { return toJavaString(); }
}
diff --git a/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java b/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java
index 1b607163b2b33..189825864ad39 100644
--- a/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java
+++ b/unsafe/src/test/java/org/apache/spark/unsafe/string/TestUTF8String.java
@@ -32,11 +32,13 @@ public void toStringTest() {
final byte[] javaStrBytes = javaStr.getBytes();
final int paddedSizeInWords = javaStrBytes.length / 8 + (javaStrBytes.length % 8 == 0 ? 0 : 1);
final MemoryLocation memory = MemoryBlock.fromLongArray(new long[paddedSizeInWords]);
- final long bytesWritten =
- UTF8StringMethods.createFromJavaString(memory.getBaseObject(), memory.getBaseOffset(), javaStr);
- Assert.assertEquals(8 + javaStrBytes.length, bytesWritten);
+ final long bytesWritten = UTF8StringMethods.createFromJavaString(
+ memory.getBaseObject(),
+ memory.getBaseOffset(),
+ javaStr);
+ Assert.assertEquals(javaStrBytes.length, bytesWritten);
final UTF8StringPointer utf8String = new UTF8StringPointer();
- utf8String.setObjAndOffset(memory.getBaseObject(), memory.getBaseOffset());
+ utf8String.set(memory.getBaseObject(), memory.getBaseOffset(), bytesWritten);
Assert.assertEquals(javaStr, utf8String.toJavaString());
}
}