Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-11149] [SQL] Improve cache performance for primitive types #9145

Closed
wants to merge 7 commits into from
Closed
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ private class CodeFormatter {
private var indentLevel = 0
private val indentSize = 2
private var indentString = ""
private var currentLine = 1

private def addLine(line: String): Unit = {
val indentChange =
Expand All @@ -44,11 +45,13 @@ private class CodeFormatter {
} else {
indentString
}
code.append(f"/* ${currentLine}%03d */ ")
code.append(thisLineIndent)
code.append(line)
code.append("\n")
indentLevel = newIndentLevel
indentString = " " * (indentSize * newIndentLevel)
currentLine += 1
}

private def addLines(code: String): CodeFormatter = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -391,26 +391,24 @@ abstract class CodeGenerator[InType <: AnyRef, OutType <: AnyRef] extends Loggin
classOf[ArrayData].getName,
classOf[UnsafeArrayData].getName,
classOf[MapData].getName,
classOf[UnsafeMapData].getName
classOf[UnsafeMapData].getName,
classOf[MutableRow].getName
))
evaluator.setExtendedClass(classOf[GeneratedClass])

def formatted = CodeFormatter.format(code)
def withLineNums = formatted.split("\n").zipWithIndex.map {
case (l, n) => f"${n + 1}%03d $l"
}.mkString("\n")

logDebug({
// Only add extra debugging info to byte code when we are going to print the source code.
evaluator.setDebuggingInformation(true, true, false)
withLineNums
formatted
})

try {
evaluator.cook("generated.java", code)
} catch {
case e: Exception =>
val msg = s"failed to compile: $e\n$withLineNums"
val msg = s"failed to compile: $e\n$formatted"
logError(msg, e)
throw new Exception(msg, e)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,78 +29,68 @@ class CodeFormatterSuite extends SparkFunSuite {
}

testCase("basic example") {
"""
|class A {
"""class A {
|blahblah;
|}
""".stripMargin
|}""".stripMargin
}{
"""
|class A {
| blahblah;
|}
|/* 001 */ class A {
|/* 002 */ blahblah;
|/* 003 */ }
""".stripMargin
}

testCase("nested example") {
"""
|class A {
"""class A {
| if (c) {
|duh;
|}
|}
""".stripMargin
|}""".stripMargin
} {
"""
|class A {
| if (c) {
| duh;
| }
|}
|/* 001 */ class A {
|/* 002 */ if (c) {
|/* 003 */ duh;
|/* 004 */ }
|/* 005 */ }
""".stripMargin
}

testCase("single line") {
"""
|class A {
"""class A {
| if (c) {duh;}
|}
""".stripMargin
|}""".stripMargin
}{
"""
|class A {
| if (c) {duh;}
|}
|/* 001 */ class A {
|/* 002 */ if (c) {duh;}
|/* 003 */ }
""".stripMargin
}

testCase("if else on the same line") {
"""
|class A {
"""class A {
| if (c) {duh;} else {boo;}
|}
""".stripMargin
|}""".stripMargin
}{
"""
|class A {
| if (c) {duh;} else {boo;}
|}
|/* 001 */ class A {
|/* 002 */ if (c) {duh;} else {boo;}
|/* 003 */ }
""".stripMargin
}

testCase("function calls") {
"""
|foo(
"""foo(
|a,
|b,
|c)
""".stripMargin
|c)""".stripMargin
}{
"""
|foo(
| a,
| b,
| c)
|/* 001 */ foo(
|/* 002 */ a,
|/* 003 */ b,
|/* 004 */ c)
""".stripMargin
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,36 @@ import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.types.UTF8String


/**
* A help class for fast reading Int/Long/Float/Double from ByteBuffer in native order.
*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

put a big warning here that this only works with HeapByteBuffer.

object ByteBufferHelper {
def getInt(buffer: ByteBuffer): Int = {
val pos = buffer.position()
buffer.position(pos + 4)
Platform.getInt(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
}

def getLong(buffer: ByteBuffer): Long = {
val pos = buffer.position()
buffer.position(pos + 8)
Platform.getLong(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
}

def getFloat(buffer: ByteBuffer): Float = {
val pos = buffer.position()
buffer.position(pos + 4)
Platform.getFloat(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
}

def getDouble(buffer: ByteBuffer): Double = {
val pos = buffer.position()
buffer.position(pos + 8)
Platform.getDouble(buffer.array(), Platform.BYTE_ARRAY_OFFSET + pos)
}
}

/**
* An abstract class that represents type of a column. Used to append/extract Java objects into/from
* the underlying [[ByteBuffer]] of a column.
Expand Down Expand Up @@ -134,11 +164,11 @@ private[sql] object INT extends NativeColumnType(IntegerType, 4) {
}

override def extract(buffer: ByteBuffer): Int = {
buffer.getInt()
ByteBufferHelper.getInt(buffer)
}

override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
row.setInt(ordinal, buffer.getInt())
row.setInt(ordinal, ByteBufferHelper.getInt(buffer))
}

override def setField(row: MutableRow, ordinal: Int, value: Int): Unit = {
Expand All @@ -163,11 +193,11 @@ private[sql] object LONG extends NativeColumnType(LongType, 8) {
}

override def extract(buffer: ByteBuffer): Long = {
buffer.getLong()
ByteBufferHelper.getLong(buffer)
}

override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
row.setLong(ordinal, buffer.getLong())
row.setLong(ordinal, ByteBufferHelper.getLong(buffer))
}

override def setField(row: MutableRow, ordinal: Int, value: Long): Unit = {
Expand All @@ -191,11 +221,11 @@ private[sql] object FLOAT extends NativeColumnType(FloatType, 4) {
}

override def extract(buffer: ByteBuffer): Float = {
buffer.getFloat()
ByteBufferHelper.getFloat(buffer)
}

override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
row.setFloat(ordinal, buffer.getFloat())
row.setFloat(ordinal, ByteBufferHelper.getFloat(buffer))
}

override def setField(row: MutableRow, ordinal: Int, value: Float): Unit = {
Expand All @@ -219,11 +249,11 @@ private[sql] object DOUBLE extends NativeColumnType(DoubleType, 8) {
}

override def extract(buffer: ByteBuffer): Double = {
buffer.getDouble()
ByteBufferHelper.getDouble(buffer)
}

override def extract(buffer: ByteBuffer, row: MutableRow, ordinal: Int): Unit = {
row.setDouble(ordinal, buffer.getDouble())
row.setDouble(ordinal, ByteBufferHelper.getDouble(buffer))
}

override def setField(row: MutableRow, ordinal: Int, value: Double): Unit = {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Around line 332, there is call to buffer.getShort()

Is it worth adding corresponding method to ByteBufferHelper ?

If so, I can send a PR.

Thanks

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it does not worth it.

Expand Down Expand Up @@ -330,7 +360,7 @@ private[sql] object STRING extends NativeColumnType(StringType, 8) {
}

override def extract(buffer: ByteBuffer): UTF8String = {
val length = buffer.getInt()
val length = ByteBufferHelper.getInt(buffer)
assert(buffer.hasArray)
val base = buffer.array()
val offset = buffer.arrayOffset()
Expand Down Expand Up @@ -358,7 +388,7 @@ private[sql] case class COMPACT_DECIMAL(precision: Int, scale: Int)
extends NativeColumnType(DecimalType(precision, scale), 8) {

override def extract(buffer: ByteBuffer): Decimal = {
Decimal(buffer.getLong(), precision, scale)
Decimal(ByteBufferHelper.getLong(buffer), precision, scale)
}

override def append(v: Decimal, buffer: ByteBuffer): Unit = {
Expand Down Expand Up @@ -396,7 +426,7 @@ private[sql] sealed abstract class ByteArrayColumnType[JvmType](val defaultSize:
}

override def extract(buffer: ByteBuffer): JvmType = {
val length = buffer.getInt()
val length = ByteBufferHelper.getInt(buffer)
val bytes = new Array[Byte](length)
buffer.get(bytes, 0, length)
deserialize(bytes)
Expand Down Expand Up @@ -480,7 +510,7 @@ private[sql] case class STRUCT(dataType: StructType) extends ColumnType[UnsafeRo
}

override def extract(buffer: ByteBuffer): UnsafeRow = {
val sizeInBytes = buffer.getInt()
val sizeInBytes = ByteBufferHelper.getInt(buffer)
assert(buffer.hasArray)
val base = buffer.array()
val offset = buffer.arrayOffset()
Expand Down
Loading