Skip to content

Commit

Permalink
[SPARK] Add field extraction for case classes and JavaBeans
Browse files Browse the repository at this point in the history
relates #384
  • Loading branch information
costin committed Mar 2, 2015
1 parent 77e348a commit ed7f175
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 92 deletions.
Expand Up @@ -207,14 +207,13 @@ public BitSet tryFlush() {
// double check data - it might be a false flush (called on clean-up)
if (data.length() > 0) {
bulkResult = client.bulk(resourceW, data);
executedBulkWrite = true;
}
} catch (EsHadoopException ex) {
hadWriteErrors = true;
throw ex;
}

executedBulkWrite = true;

// discard the data buffer, only if it was properly sent/processed
//if (bulkResult.isEmpty()) {
// always discard data since there's no code path that uses the in flight data
Expand Down
3 changes: 3 additions & 0 deletions spark/src/itest/resources/simple.json1
@@ -0,0 +1,3 @@


{ "firstName": "John", "isAlive": true, "age": 25, "children": ["Alex", "Joe"], "address": { "streetAddress": "21 2nd Street" } }
Expand Up @@ -46,9 +46,10 @@ import org.junit.BeforeClass
import java.awt.Polygon
import org.elasticsearch.spark.rdd.EsSpark
import org.junit.Test
import org.elasticsearch.spark.Bean
import org.elasticsearch.hadoop.serialization.EsHadoopSerializationException
import org.apache.spark.SparkException
import org.elasticsearch.spark.serialization.ReflectionUtils
import org.elasticsearch.spark.serialization.Bean

object AbstractScalaEsScalaSpark {
@transient val conf = new SparkConf().setAll(TestSettings.TESTING_PROPS).setMaster("local").setAppName("estest");
Expand All @@ -69,7 +70,7 @@ object AbstractScalaEsScalaSpark {
}
}

case class ModuleCaseClass(departure: String, var arrival: String) {
case class ModuleCaseClass(id: Integer, departure: String, var arrival: String) {
var l = math.Pi
}
}
Expand Down Expand Up @@ -110,10 +111,13 @@ class AbstractScalaEsScalaSpark extends Serializable {
def testEsRDDWriteCaseClass() {
val javaBean = new Bean("bar", 1, true)
val caseClass1 = Trip("OTP", "SFO")
val caseClass2 = AbstractScalaEsScalaSpark.ModuleCaseClass("OTP", "MUC")
val caseClass2 = AbstractScalaEsScalaSpark.ModuleCaseClass(1, "OTP", "MUC")

val vals = ReflectionUtils.caseClassValues(caseClass2)

sc.makeRDD(Seq(javaBean, caseClass1)).saveToEs("spark-test/scala-basic-write-objects")
sc.makeRDD(Seq(javaBean, caseClass2)).saveToEs("spark-test/scala-basic-write-objects")
sc.makeRDD(Seq(javaBean, caseClass2)).saveToEs("spark-test/scala-basic-write-objects", Map("es.mapping.id"->"id"))

assertTrue(RestUtils.exists("spark-test/scala-basic-write-objects"))
assertThat(RestUtils.get("spark-test/scala-basic-write-objects/_search?"), containsString(""))
}
Expand Down
Expand Up @@ -5,38 +5,33 @@ import java.lang.reflect.Method
import scala.reflect.runtime.{ universe => ru }
import scala.reflect.runtime.universe._
import scala.reflect.ClassTag
import scala.collection.mutable.HashMap
import org.apache.commons.logging.LogFactory

private[spark] object ReflectionUtils {

val caseClassCache = new HashMap[Class[_], (Boolean, Iterable[String])]
val javaBeanCache = new HashMap[Class[_], Array[(String, Method)]]

//SI-6240
protected[spark] object ReflectionLock

def javaBeansInfo(clazz: Class[_]) = {
Introspector.getBeanInfo(clazz).getPropertyDescriptors().collect {
case pd if (pd.getName != "class" && pd.getReadMethod() != null) => (pd.getName, pd.getReadMethod)
}.sortBy(_._1)
}

def javaBeansValues(target: AnyRef, info: Array[(String, Method)]) = {
info.map(in => (in._1, in._2.invoke(target))).toMap
}

def isCaseClass(clazz: Class[_]): Boolean = {
private def checkCaseClass(clazz: Class[_]): Boolean = {
ReflectionLock.synchronized {
// reliable case class identifier only happens through class symbols...
runtimeMirror(clazz.getClassLoader()).classSymbol(clazz).isCaseClass
}
}

def caseClassInfo(clazz: Class[_]): Iterable[String] = {
private def doGetCaseClassInfo(clazz: Class[_]): Iterable[String] = {
ReflectionLock.synchronized {
runtimeMirror(clazz.getClassLoader()).classSymbol(clazz).toType.declarations.collect {
case m: MethodSymbol if m.isCaseAccessor => m.name.toString()
}
}
}

def isCaseClassInsideACompanionModule(clazz: Class[_], arity: Int): Boolean = {
private def isCaseClassInsideACompanionModule(clazz: Class[_], arity: Int): Boolean = {
if (!classOf[Serializable].isAssignableFrom(clazz)) {
false
}
Expand All @@ -50,17 +45,68 @@ private[spark] object ReflectionUtils {
}

// TODO: this is a hack since we expect the field declaration order to be according to the source but there's no guarantee
def caseClassInfoInsideACompanionModule(clazz: Class[_], arity: Int): Iterable[String] = {
private def caseClassInfoInsideACompanionModule(clazz: Class[_], arity: Int): Iterable[String] = {
// fields are private so use the 'declared' variant
var counter: Int = 0
clazz.getDeclaredFields.collect {
case field if (counter < arity) => counter += 1; field.getName
}
}

def caseClassValues(target: AnyRef, props: Iterable[String]) = {
private def doGetCaseClassValues(target: AnyRef, props: Iterable[String]) = {
val product = target.asInstanceOf[Product].productIterator
val tuples = for (y <- props) yield (y, product.next)
tuples.toMap
}

private def checkCaseClassCache(p: Product) = {
caseClassCache.getOrElseUpdate(p.getClass, {
var isCaseClazz = checkCaseClass(p.getClass)
var info = if (isCaseClazz) doGetCaseClassInfo(p.getClass) else null
if (!isCaseClazz) {
isCaseClazz = isCaseClassInsideACompanionModule(p.getClass, p.productArity)
if (isCaseClazz) {
LogFactory.getLog(classOf[ScalaValueWriter]).warn(
String.format("[%s] is detected as a case class in Java but not in Scala and thus " +
"its properties might be detected incorrectly - make sure the @ScalaSignature is available within the class bytecode " +
"and/or consider moving the case class from its companion object/module", p.getClass))
}
info = if (isCaseClazz) caseClassInfoInsideACompanionModule(p.getClass(), p.productArity) else null
}

(isCaseClazz, info)
})
}

def isCaseClass(p: Product) = {
checkCaseClassCache(p)._1
}

def caseClassValues(p: Product) = {
doGetCaseClassValues(p.asInstanceOf[AnyRef], checkCaseClassCache(p)._2)
}

private def checkJavaBeansCache(o: AnyRef) = {
javaBeanCache.getOrElseUpdate(o.getClass, {
javaBeansInfo(o.getClass)
})
}

def isJavaBean(value: AnyRef) = {
!checkJavaBeansCache(value).isEmpty
}

def javaBeanAsMap(value: AnyRef) = {
javaBeansValues(value, checkJavaBeansCache(value))
}

private def javaBeansInfo(clazz: Class[_]) = {
Introspector.getBeanInfo(clazz).getPropertyDescriptors().collect {
case pd if (pd.getName != "class" && pd.getReadMethod() != null) => (pd.getName, pd.getReadMethod)
}.sortBy(_._1)
}

private def javaBeansValues(target: AnyRef, info: Array[(String, Method)]) = {
info.map(in => (in._1, in._2.invoke(target))).toMap
}
}
Expand Up @@ -4,15 +4,24 @@ import org.elasticsearch.hadoop.serialization.field.ConstantFieldExtractor
import org.elasticsearch.hadoop.serialization.MapFieldExtractor
import scala.collection.GenMapLike
import scala.collection.Map
import org.elasticsearch.hadoop.serialization.field.FieldExtractor
import org.elasticsearch.hadoop.serialization.field.FieldExtractor._
import org.elasticsearch.spark.serialization.{ ReflectionUtils => RU }

class ScalaMapFieldExtractor extends MapFieldExtractor {

override protected def extractField(target: AnyRef): AnyRef = {
target match {
case m: Map[AnyRef, AnyRef] => m.getOrElse(getFieldName(), FieldExtractor.NOT_FOUND)
case _ => super.extractField(target)
case m: Map[AnyRef, AnyRef] => m.getOrElse(getFieldName, NOT_FOUND)
case p: Product if RU.isCaseClass(p) => RU.caseClassValues(p).getOrElse(getFieldName, NOT_FOUND).asInstanceOf[AnyRef]
case _ => {
val result = super.extractField(target)

if (result == NOT_FOUND && RU.isJavaBean(target)) {
return RU.javaBeanAsMap(target).getOrElse(getFieldName, NOT_FOUND)
}

result
}
}
}

}
}
Expand Up @@ -13,9 +13,6 @@ import org.apache.commons.logging.LogFactory

class ScalaValueWriter(writeUnknownTypes: Boolean = false) extends JdkValueWriter(writeUnknownTypes) {

val caseClassCache = new HashMap[Class[_], (Boolean, Iterable[String])]
val javaBeanCache = new HashMap[Class[_], Array[(String, Method)]]

def this() {
this(false)
}
Expand Down Expand Up @@ -58,8 +55,8 @@ class ScalaValueWriter(writeUnknownTypes: Boolean = false) extends JdkValueWrite

case p: Product => {
// handle case class
if (isCaseClass(p)) {
val result = doWrite(caseClassValues(p), generator, false)
if (RU.isCaseClass(p)) {
val result = doWrite(RU.caseClassValues(p), generator, false)
if (!result.isSuccesful()) {
return result
}
Expand All @@ -80,8 +77,8 @@ class ScalaValueWriter(writeUnknownTypes: Boolean = false) extends JdkValueWrite
// normal JDK types failed, try the JavaBean last
val result = super.write(value, generator)
if (!result.isSuccesful()) {
if (acceptsJavaBeans && isJavaBean(value)) {
return doWrite(javaBeanAsMap(value), generator, false)
if (acceptsJavaBeans && RU.isJavaBean(value)) {
return doWrite(RU.javaBeanAsMap(value), generator, false)
} else
return result
}
Expand All @@ -90,37 +87,4 @@ class ScalaValueWriter(writeUnknownTypes: Boolean = false) extends JdkValueWrite

Result.SUCCESFUL()
}

def isCaseClass(p: Product) = {
caseClassCache.getOrElseUpdate(p.getClass, {
var isCaseClazz = RU.isCaseClass(p.getClass)
var info = if (isCaseClazz) RU.caseClassInfo(p.getClass) else null
if (!isCaseClazz) {
isCaseClazz = RU.isCaseClassInsideACompanionModule(p.getClass, p.productArity)
if (isCaseClazz) {
LogFactory.getLog(classOf[ScalaValueWriter]).warn(
String.format("[%s] is detected as a case class in Java but not in Scala and thus " +
"its properties might be detected incorrectly - make sure the @ScalaSignature is available within the class bytecode " +
"and/or consider moving the case class from its companion object/module", p.getClass))
}
info = if (isCaseClazz) RU.caseClassInfoInsideACompanionModule(p.getClass(), p.productArity) else null
}

(isCaseClazz, info)
})._1
}

def caseClassValues(p: Product) = {
RU.caseClassValues(p.asInstanceOf[AnyRef], caseClassCache.get(p.getClass).get._2)
}

def isJavaBean(value: AnyRef) = {
!javaBeanCache.getOrElseUpdate(value.getClass, {
RU.javaBeansInfo(value.getClass)
}).isEmpty
}

def javaBeanAsMap(value: AnyRef) = {
RU.javaBeansValues(value, javaBeanCache.get(value.getClass()).get)
}
}
@@ -1,18 +1,18 @@
package org.elasticsearch.spark;
package org.elasticsearch.spark.serialization;

import java.io.Serializable;

public class Bean implements Serializable {

private String foo;
private Number bar;
private Number id;
private boolean bool;

public Bean() {}

public Bean(String foo, Number bar, boolean bool) {
this.foo = foo;
this.bar = bar;
this.id = bar;
this.bool = bool;
}
public String getFoo() {
Expand All @@ -21,11 +21,11 @@ public String getFoo() {
public void setFoo(String foo) {
this.foo = foo;
}
public Number getBar() {
return bar;
public Number getId() {
return id;
}
public void setBar(Number bar) {
this.bar = bar;
this.id = bar;
}
public boolean isBool() {
return bool;
Expand Down
@@ -1,7 +1,6 @@
package org.elasticsearch.spark
package org.elasticsearch.spark.serialization

import org.elasticsearch.spark.serialization.ReflectionUtils._
import org.junit.BeforeClass
import org.junit.Test
import org.junit.Assert._
import org.hamcrest.Matchers._
Expand All @@ -10,34 +9,24 @@ class ScalaReflectionUtilsTest {

@Test
def testJavaBean() {
val info = javaBeansInfo(classOf[Bean])
val values = javaBeansValues(new Bean("1", Integer.valueOf(1), true), info)
val values = javaBeanAsMap(new Bean("1", Integer.valueOf(1), true))
assertEquals(Map("bar" -> 1, "bool" -> true, "foo" -> "1"), values)
}

@Test
def testCaseClassIdentify() {
assertFalse(isCaseClass(classOf[Bean]))
assertTrue(isCaseClass(classOf[SimpleCaseClass]))
assertTrue(isCaseClass(classOf[CaseClassWithValue]))
}

@Test
def testCaseClassValues() {
val cc = SimpleCaseClass(1, "simpleClass")
val info = caseClassInfo(cc.getClass())
assertEquals(Seq("i", "s"), info)
val values = caseClassValues(cc, info)
assertTrue(isCaseClass(cc))
val values = caseClassValues(cc)

println(values)
//println(values)
assertEquals(Map("i" -> 1, "s" -> "simpleClass"), values)

val ccv = CaseClassWithValue(2, "caseClassWithVal")
val infoccv = caseClassInfo(ccv.getClass())
assertEquals(Seq("first", "second"), infoccv)
val valuesccv = caseClassValues(ccv, infoccv)
assertTrue(isCaseClass(ccv))
val valuesccv = caseClassValues(ccv)

println(valuesccv)
//println(valuesccv)
assertEquals(Map("first" -> 2, "second" -> "caseClassWithVal"), valuesccv)
}
}
Expand Down

0 comments on commit ed7f175

Please sign in to comment.