Skip to content

Commit

Permalink
Typed Parquet Tuple twitter#1198
Browse files Browse the repository at this point in the history
   *Improve tuple converter macro(delete unnecessary boxing)
  • Loading branch information
JiJiTang committed Apr 12, 2015
1 parent 2782288 commit f9af7de
Show file tree
Hide file tree
Showing 5 changed files with 166 additions and 76 deletions.
5 changes: 3 additions & 2 deletions project/Build.scala
Expand Up @@ -33,6 +33,7 @@ object ScaldingBuild extends Build {
val hravenVersion = "0.9.13"
val jacksonVersion = "2.4.2"
val json4SVersion = "3.2.11"
val paradiseVersion = "2.0.1"
val parquetVersion = "1.6.0rc4"
val protobufVersion = "2.4.1"
val quasiquotesVersion = "2.0.1"
Expand Down Expand Up @@ -316,7 +317,7 @@ object ScaldingBuild extends Build {
"org.scala-lang" % "scala-reflect" % scalaVersion,
"com.twitter" %% "bijection-macros" % bijectionVersion
) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % quasiquotesVersion) else Seq())
}, addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full))
}, addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full))
.dependsOn(scaldingCore, scaldingHadoopTest)

def scaldingParquetScroogeDeps(version: String) = {
Expand Down Expand Up @@ -429,7 +430,7 @@ object ScaldingBuild extends Build {
"com.twitter" %% "bijection-macros" % bijectionVersion
) ++ (if(isScala210x(scalaVersion)) Seq("org.scalamacros" %% "quasiquotes" % quasiquotesVersion) else Seq())
},
addCompilerPlugin("org.scalamacros" % "paradise" % "2.0.1" cross CrossVersion.full)
addCompilerPlugin("org.scalamacros" % "paradise" % paradiseVersion cross CrossVersion.full)
).dependsOn(scaldingCore, scaldingHadoopTest)

// This one uses a different naming convention
Expand Down
@@ -1,7 +1,7 @@
package com.twitter.scalding.parquet.tuple.macros.impl

import com.twitter.bijection.macros.impl.IsCaseClassImpl
import com.twitter.scalding.parquet.tuple.scheme.ParquetTupleConverter
import com.twitter.scalding.parquet.tuple.scheme._

import scala.reflect.macros.Context

Expand All @@ -15,57 +15,90 @@ object ParquetTupleConverterProvider {
either it is not a case class or this macro call is possibly enclosed in a class.
This will mean the macro is operating on a non-resolved type.""")

def buildGroupConverter(tpe: Type, parentTree: Tree, isOption: Boolean, idx: Int, converterBodyTree: Tree): Tree = {
q"""new _root_.com.twitter.scalding.parquet.tuple.scheme.ParquetTupleConverter($parentTree, $isOption, $idx){
override def newConverter(i: Int): _root_.parquet.io.api.Converter = {
def buildGroupConverter(tpe: Type, parentTree: Tree, isOption: Boolean, idx: Int, converterBodyTree: Tree,
valueBuilder: Tree): Tree = {
q"""new _root_.com.twitter.scalding.parquet.tuple.scheme.ParquetTupleConverter($parentTree){
override def newConverter(i: Int): _root_.com.twitter.scalding.parquet.tuple.scheme.TupleFieldConverter[Any] = {
$converterBodyTree
throw new RuntimeException("invalid index: " + i)
}

override def createValue(): Any = {
if(fieldValues.isEmpty) null
else classOf[$tpe].getConstructors()(0).newInstance(fieldValues.toSeq.map(_.asInstanceOf[AnyRef]): _*)
$valueBuilder
}
}"""
}

def matchField(idx: Int, fieldType: Type, isOption: Boolean): List[Tree] = {
def matchField(idx: Int, fieldType: Type, isOption: Boolean): (Tree, Tree) = {

def createConverter(converter: Tree): Tree = q"if($idx == i) return $converter"

def primitiveFieldValue(converterType: Type): Tree = if (isOption) {
val cachedRes = newTermName(ctx.fresh(s"fieldValue"))
q"""
{
val $cachedRes = converters($idx).asInstanceOf[$converterType]
if($cachedRes.hasValue) Some($cachedRes.currentValue) else _root_.scala.Option.empty[$fieldType]
}
"""
} else {
q"converters($idx).asInstanceOf[$converterType].currentValue"
}

def primitiveFieldConverterAndFieldValue(converterType: Type): (Tree, Tree) = {
val companion = converterType.typeSymbol.companionSymbol
(createConverter(q"$companion(this)"), primitiveFieldValue(converterType))
}

def caseClassFieldValue: Tree = if (isOption) {
val cachedRes = newTermName(ctx.fresh(s"fieldValue"))
q"""
{
val $cachedRes = converters($idx)
if($cachedRes.hasValue) Some($cachedRes.currentValue.asInstanceOf[$fieldType])
else _root_.scala.Option.empty[$fieldType]
}
"""
} else {
q"converters($idx).currentValue.asInstanceOf[$fieldType]"
}

fieldType match {
case tpe if tpe =:= typeOf[String] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.StringConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[StringConverter])
case tpe if tpe =:= typeOf[Boolean] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.BooleanConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[BooleanConverter])
case tpe if tpe =:= typeOf[Byte] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.ByteConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[ByteConverter])
case tpe if tpe =:= typeOf[Short] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.ShortConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[ShortConverter])
case tpe if tpe =:= typeOf[Int] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.IntConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[IntConverter])
case tpe if tpe =:= typeOf[Long] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.LongConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[LongConverter])
case tpe if tpe =:= typeOf[Float] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.FloatConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[FloatConverter])
case tpe if tpe =:= typeOf[Double] =>
List(createConverter(q"new _root_.com.twitter.scalding.parquet.tuple.scheme.DoubleConverter($idx, this, $isOption)"))
primitiveFieldConverterAndFieldValue(typeOf[DoubleConverter])
case tpe if tpe.erasure =:= typeOf[Option[Any]] =>
val innerType = tpe.asInstanceOf[TypeRefApi].args.head
matchField(idx, innerType, isOption = true)
case tpe if IsCaseClassImpl.isCaseClassType(ctx)(tpe) =>
val innerConverterTrees = buildConverterBody(tpe, expandMethod(tpe))
List(createConverter(buildGroupConverter(tpe, q"Option(this)", isOption, idx, innerConverterTrees)))
val (innerConverters, innerFieldValues) = expandMethod(tpe).unzip
val innerConverterTree = buildConverterBody(tpe, innerConverters)
val innerValueBuilderTree = buildTupleValue(tpe, innerFieldValues)
val innerGroupConverter = createConverter(buildGroupConverter(tpe, q"Option(this)", isOption, idx, innerConverterTree, innerValueBuilderTree))
(innerGroupConverter, caseClassFieldValue)
case _ => ctx.abort(ctx.enclosingPosition, s"Case class $T is not pure primitives or nested case classes")
}
}

def expandMethod(outerTpe: Type): List[Tree] = {
def expandMethod(outerTpe: Type): List[(Tree, Tree)] = {
outerTpe
.declarations
.collect { case m: MethodSymbol if m.isCaseAccessor => m }
.zipWithIndex
.flatMap {
.map {
case (accessorMethod, idx) =>
val fieldType = accessorMethod.returnType
matchField(idx, fieldType, isOption = false)
Expand All @@ -82,7 +115,16 @@ object ParquetTupleConverterProvider {
}
}

val groupConverter = buildGroupConverter(T.tpe, q"None", isOption = false, -1, buildConverterBody(T.tpe, expandMethod(T.tpe)))
def buildTupleValue(tpe: Type, fieldValueBuilders: List[Tree]): Tree = {
if (fieldValueBuilders.isEmpty)
ctx.abort(ctx.enclosingPosition, s"Case class $tpe has no primitive types we were able to extract")
val companion = tpe.typeSymbol.companionSymbol
q"$companion(..$fieldValueBuilders)"
}

val (converters, fieldValues) = expandMethod(T.tpe).unzip
val groupConverter = buildGroupConverter(T.tpe, q"None", isOption = false, -1, buildConverterBody(T.tpe, converters),
buildTupleValue(T.tpe, fieldValues))

ctx.Expr[ParquetTupleConverter](q"""
$groupConverter
Expand Down
@@ -1,26 +1,29 @@
package com.twitter.scalding.parquet.tuple.scheme

import parquet.io.api.{ Binary, Converter, GroupConverter, PrimitiveConverter }

import scala.collection.mutable
import scala.util.Try

trait TupleFieldConverter[+T] extends Converter {
def currentValue: T
var hasValue: Boolean = false
def reset(): Unit
}

/**
* Parquet tuple converter used to create user defined tuple value from parquet column values
* @param parent parent parquet tuple converter
* @param isOption is the field optional
* @param outerIndex field index in parent tuple schema
*/
abstract class ParquetTupleConverter(val parent: Option[ParquetTupleConverter] = None, val isOption: Boolean = false,
val outerIndex: Int = -1) extends GroupConverter {
var converters: Map[Int, Converter] = Map()
val fieldValues: mutable.ArrayBuffer[Any] = mutable.ArrayBuffer()
abstract class ParquetTupleConverter(val parent: Option[ParquetTupleConverter] = None) extends GroupConverter
with TupleFieldConverter[Any] {
var converters: Map[Int, TupleFieldConverter[Any]] = Map()

var value: Any = null

var currentValue: Any = null
override def currentValue: Any = value

def createValue(): Any

def newConverter(i: Int): Converter
def newConverter(i: Int): TupleFieldConverter[Any]

override def getConverter(i: Int) = {
val converter = converters.get(i)
Expand All @@ -33,62 +36,107 @@ abstract class ParquetTupleConverter(val parent: Option[ParquetTupleConverter] =
}

override def end(): Unit = {
currentValue = createValue()
fieldValues.remove(0, fieldValues.size)
parent.map(_.addFieldValue(outerIndex, currentValue, isOption))
if (hasValue) {
value = createValue()
parent.map(p => p.hasValue = true)
}
}

override def start(): Unit = ()

def addFieldValue(index: Int, value: Any, isOpt: Boolean) = {
val currentSize = fieldValues.size
//insert none for these optional fields that has non value written for given row
(currentSize until index).map(fieldValues.insert(_, None))
if (isOpt) fieldValues.insert(index, Option(value)) else fieldValues.insert(index, value)
override def reset(): Unit = {
value = null
converters.values.map(v => v.reset())
hasValue = false
}

override def start(): Unit = reset()
}

class PrimitiveTypeConverter(val index: Int, val parent: ParquetTupleConverter, val isOption: Boolean)
extends PrimitiveConverter {
def appendValue(value: Any) = parent.addFieldValue(index, value, isOption)
sealed trait PrimitiveTupleFieldConverter[T] extends TupleFieldConverter[T] {
val parent: ParquetTupleConverter
val defaultValue: T
var value: T = defaultValue

override def currentValue: T = value

protected def valueAdded(): Unit = {
hasValue = true
parent.hasValue = true
}

override def reset(): Unit = {
value = defaultValue
hasValue = false
}
}

class StringConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addBinary(value: Binary): Unit = appendValue(value.toStringUsingUTF8)
case class StringConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[String] {
override val defaultValue: String = null

override def addBinary(binary: Binary): Unit = {
value = binary.toStringUsingUTF8
valueAdded()
}
}

class DoubleConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addDouble(value: Double): Unit = appendValue(value)
case class DoubleConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Double] {
override val defaultValue: Double = 0D

override def addDouble(v: Double): Unit = {
value = v
valueAdded()
}
}

class FloatConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addFloat(value: Float): Unit = appendValue(value)
case class FloatConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Float] {
override val defaultValue: Float = 0F

override def addFloat(v: Float): Unit = {
value = v
valueAdded()
}
}

class LongConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addLong(value: Long) = appendValue(value)
case class LongConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Long] {
override val defaultValue: Long = 0L

override def addLong(v: Long): Unit = {
value = v
valueAdded()
}
}

class IntConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addInt(value: Int) = appendValue(value)
case class IntConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Int] {
override val defaultValue: Int = 0

override def addInt(v: Int): Unit = {
value = v
valueAdded()
}
}

class ShortConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addInt(value: Int) = appendValue(Try(value.toShort).getOrElse(null))
case class ShortConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Short] {
override val defaultValue: Short = 0

override def addInt(v: Int): Unit = {
value = Try(v.toShort).getOrElse(0)
valueAdded()
}
}

class ByteConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addInt(value: Int) = appendValue(Try(value.toByte).getOrElse(null))
case class ByteConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Byte] {
override val defaultValue: Byte = 0

override def addInt(v: Int): Unit = {
value = Try(v.toByte).getOrElse(0)
valueAdded()
}
}

class BooleanConverter(index: Int, parent: ParquetTupleConverter, isOption: Boolean = false)
extends PrimitiveTypeConverter(index, parent, isOption) {
override def addBoolean(value: Boolean) = appendValue(value)
}
case class BooleanConverter(parent: ParquetTupleConverter) extends PrimitiveConverter with PrimitiveTupleFieldConverter[Boolean] {
override val defaultValue: Boolean = false

override def addBoolean(v: Boolean): Unit = {
value = v
valueAdded()
}
}
Expand Up @@ -9,8 +9,8 @@ import com.twitter.scalding.typed.TypedPipe
import com.twitter.scalding.{ Args, Job, TypedTsv }
import org.scalatest.{ Matchers, WordSpec }
import parquet.filter2.predicate.FilterApi.binaryColumn
import parquet.filter2.predicate.{FilterApi, FilterPredicate}
import parquet.io.api.{RecordConsumer, Binary}
import parquet.filter2.predicate.{ FilterApi, FilterPredicate }
import parquet.io.api.{ RecordConsumer, Binary }
import parquet.schema.MessageType

class TypedParquetTupleTest extends WordSpec with Matchers with HadoopPlatformTest {
Expand Down Expand Up @@ -77,7 +77,7 @@ class CReadSupport extends ParquetReadSupport[SampleClassC] {

class WriteSupport extends ParquetWriteSupport[SampleClassB] {
override val rootSchema: String = SampleClassB.schema
override def writeRecord(r: SampleClassB, rc: RecordConsumer, schema: MessageType):Unit =
override def writeRecord(r: SampleClassB, rc: RecordConsumer, schema: MessageType): Unit =
Macros.caseClassWriteSupport[SampleClassB](r, rc, schema)
}

Expand Down
@@ -1,8 +1,8 @@
package com.twitter.scalding.parquet.tuple.macros

import org.scalatest.mock.MockitoSugar
import org.scalatest.{Matchers, WordSpec}
import parquet.io.api.{Binary, RecordConsumer}
import org.scalatest.{ Matchers, WordSpec }
import parquet.io.api.{ Binary, RecordConsumer }
import parquet.schema.MessageTypeParser

case class SampleClassA(x: Int, y: String)
Expand Down Expand Up @@ -182,7 +182,6 @@ class MacroUnitTests extends WordSpec with Matchers with MockitoSugar {
}
}


"Macro case class parquet write support generator" should {
"Generate write support for class with all the primitive type fields" in {
val writeSupportFn = Macros.caseClassWriteSupport[SampleClassE]
Expand Down

0 comments on commit f9af7de

Please sign in to comment.