Skip to content

Commit

Permalink
Begin merging the UTF8String implementations.
Browse files Browse the repository at this point in the history
  • Loading branch information
JoshRosen committed Apr 22, 2015
1 parent 480a74a commit ab68e08
Show file tree
Hide file tree
Showing 5 changed files with 119 additions and 51 deletions.
5 changes: 5 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@
<artifactId>spark-network-shuffle_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-unsafe_${scala.binary.version}</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>net.java.dev.jets3t</groupId>
<artifactId>jets3t</artifactId>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ package org.apache.spark.sql.types

import java.util.Arrays

import org.apache.spark.unsafe.PlatformDependent
import org.apache.spark.unsafe.string.{UTF8StringPointer, UTF8StringMethods}

/**
* A UTF-8 String, as internal representation of StringType in SparkSQL
*
Expand All @@ -32,45 +35,31 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {

private[this] var bytes: Array[Byte] = _

private val pointer: UTF8StringPointer = new UTF8StringPointer

/**
* Update the UTF8String with String.
*/
def set(str: String): UTF8String = {
bytes = str.getBytes("utf-8")
this
set(str.getBytes("utf-8"))
}

/**
* Update the UTF8String with Array[Byte], which should be encoded in UTF-8
*/
def set(bytes: Array[Byte]): UTF8String = {
this.bytes = bytes
pointer.set(bytes, PlatformDependent.BYTE_ARRAY_OFFSET, bytes.length)
this
}

/**
* Return the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/
@inline
private[this] def numOfBytes(b: Byte): Int = {
val offset = (b & 0xFF) - 192
if (offset >= 0) UTF8String.bytesOfCodePointInUTF8(offset) else 1
}

/**
* Return the number of code points in it.
*
* This is only used by Substring() when `start` is negative.
*/
def length(): Int = {
var len = 0
var i: Int = 0
while (i < bytes.length) {
i += numOfBytes(bytes(i))
len += 1
}
len
pointer.getLengthInCodePoints
}

def getBytes: Array[Byte] = {
Expand All @@ -90,12 +79,12 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
var c = 0
var i: Int = 0
while (c < start && i < bytes.length) {
i += numOfBytes(bytes(i))
i += UTF8StringMethods.numOfBytes(bytes(i))
c += 1
}
var j = i
while (c < until && j < bytes.length) {
j += numOfBytes(bytes(j))
j += UTF8StringMethods.numOfBytes(bytes(j))
c += 1
}
UTF8String(Arrays.copyOfRange(bytes, i, j))
Expand Down Expand Up @@ -150,14 +139,14 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
override def clone(): UTF8String = new UTF8String().set(this.bytes)

override def compare(other: UTF8String): Int = {
var i: Int = 0
val b = other.getBytes
while (i < bytes.length && i < b.length) {
val res = bytes(i).compareTo(b(i))
if (res != 0) return res
i += 1
}
bytes.length - b.length
UTF8StringMethods.compare(
pointer.getBaseObject,
pointer.getBaseOffset,
pointer.getLengthInBytes,
other.pointer.getBaseObject,
other.pointer.getBaseOffset,
other.pointer.getLengthInBytes
)
}

override def compareTo(other: UTF8String): Int = {
Expand All @@ -181,14 +170,6 @@ final class UTF8String extends Ordered[UTF8String] with Serializable {
}

object UTF8String {
// number of tailing bytes in a UTF8 sequence for a code point
// see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
private[types] val bytesOfCodePointInUTF8: Array[Int] = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4,
5, 5, 5, 5,
6, 6, 6, 6)

/**
* Create a UTF-8 String from String
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,44 @@ static long getLengthInBytes(Object baseObject, long baseOffset) {
return PlatformDependent.UNSAFE.getLong(baseObject, baseOffset);
}

public static String toJavaString(Object baseObject, long baseOffset) {
final long lengthInBytes = getLengthInBytes(baseObject, baseOffset);
public static int compare(
Object leftBaseObject,
long leftBaseOffset,
int leftBaseLengthInBytes,
Object rightBaseObject,
long rightBaseOffset,
int rightBaseLengthInBytes) {
int i = 0;
while (i < leftBaseLengthInBytes && i < rightBaseLengthInBytes) {
final byte leftByte = PlatformDependent.UNSAFE.getByte(leftBaseObject, leftBaseOffset + i);
final byte rightByte = PlatformDependent.UNSAFE.getByte(rightBaseObject, rightBaseOffset + i);
final int res = leftByte - rightByte;
if (res != 0) return res;
i += 1;
}
return leftBaseLengthInBytes - rightBaseLengthInBytes;
}

/**
* Return the number of code points in a string.
*
* This is only used by Substring() when `start` is negative.
*/
public static int getLengthInCodePoints(Object baseObject, long baseOffset, int lengthInBytes) {
int len = 0;
int i = 0;
while (i < lengthInBytes) {
i += numOfBytes(PlatformDependent.UNSAFE.getByte(baseObject, baseOffset + i));
len += 1;
}
return len;
}

public static String toJavaString(Object baseObject, long baseOffset, int lengthInBytes) {
final byte[] bytes = new byte[(int) lengthInBytes];
PlatformDependent.UNSAFE.copyMemory(
baseObject,
baseOffset + 8, // skip over the length
baseOffset,
bytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
lengthInBytes
Expand All @@ -67,15 +99,40 @@ public static String toJavaString(Object baseObject, long baseOffset) {
public static long createFromJavaString(Object baseObject, long baseOffset, String str) {
final byte[] strBytes = str.getBytes();
final long strLengthInBytes = strBytes.length;
PlatformDependent.UNSAFE.putLong(baseObject, baseOffset, strLengthInBytes);
PlatformDependent.copyMemory(
strBytes,
PlatformDependent.BYTE_ARRAY_OFFSET,
baseObject,
baseOffset + 8,
baseOffset,
strLengthInBytes
);
return (8 + strLengthInBytes);
return strLengthInBytes;
}

/**
* Return the number of bytes for a code point with the first byte as `b`
* @param b The first byte of a code point
*/
public static int numOfBytes(byte b) {
final int offset = (b & 0xFF) - 192;
if (offset >= 0) {
return bytesOfCodePointInUTF8[offset];
} else {
return 1;
}
}

/**
* number of tailing bytes in a UTF8 sequence for a code point
* see http://en.wikipedia.org/wiki/UTF-8, 192-256 of Byte 1
*/
private static int[] bytesOfCodePointInUTF8 = new int[] {
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3,
4, 4, 4, 4, 4, 4, 4, 4,
5, 5, 5, 5,
6, 6, 6, 6
};

}
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,39 @@

package org.apache.spark.unsafe.string;

import org.apache.spark.unsafe.memory.MemoryLocation;
import javax.annotation.Nullable;

/**
* A pointer to UTF8String data.
*/
public class UTF8StringPointer extends MemoryLocation {
public class UTF8StringPointer {

public long getLengthInBytes() { return UTF8StringMethods.getLengthInBytes(obj, offset); }
@Nullable
protected Object obj;
protected long offset;
protected int lengthInBytes;

public String toJavaString() { return UTF8StringMethods.toJavaString(obj, offset); }
public UTF8StringPointer() { }

public void set(Object obj, long offset, int lengthInBytes) {
this.obj = obj;
this.offset = offset;
this.lengthInBytes = lengthInBytes;
}

public int getLengthInCodePoints() {
return UTF8StringMethods.getLengthInCodePoints(obj, offset, lengthInBytes);
}

public int getLengthInBytes() { return lengthInBytes; }

public Object getBaseObject() { return obj; }

public long getBaseOffset() { return offset; }

public String toJavaString() {
return UTF8StringMethods.toJavaString(obj, offset, lengthInBytes);
}

@Override public String toString() { return toJavaString(); }
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,13 @@ public void toStringTest() {
final byte[] javaStrBytes = javaStr.getBytes();
final int paddedSizeInWords = javaStrBytes.length / 8 + (javaStrBytes.length % 8 == 0 ? 0 : 1);
final MemoryLocation memory = MemoryBlock.fromLongArray(new long[paddedSizeInWords]);
final long bytesWritten =
UTF8StringMethods.createFromJavaString(memory.getBaseObject(), memory.getBaseOffset(), javaStr);
Assert.assertEquals(8 + javaStrBytes.length, bytesWritten);
final long bytesWritten = UTF8StringMethods.createFromJavaString(
memory.getBaseObject(),
memory.getBaseOffset(),
javaStr);
Assert.assertEquals(javaStrBytes.length, bytesWritten);
final UTF8StringPointer utf8String = new UTF8StringPointer();
utf8String.setObjAndOffset(memory.getBaseObject(), memory.getBaseOffset());
utf8String.set(memory.getBaseObject(), memory.getBaseOffset(), bytesWritten);
Assert.assertEquals(javaStr, utf8String.toJavaString());
}
}

0 comments on commit ab68e08

Please sign in to comment.