Skip to content

Commit

Permalink
[SPARK] Improve array handling in SQL
Browse files Browse the repository at this point in the history
Refactor logic to use absolute field names as the relative name are
prone to overwrites
Improve field filtering to allow partial or strict matching

relates #482 #484
  • Loading branch information
costin committed Oct 18, 2015
1 parent 53e10a3 commit 6222363
Show file tree
Hide file tree
Showing 12 changed files with 240 additions and 68 deletions.
Expand Up @@ -61,7 +61,7 @@
import org.junit.runners.Parameterized;
import org.junit.runners.Parameterized.Parameters;

import static org.junit.Assume.assumeFalse;
import static org.junit.Assume.*;

@FixMethodOrder(MethodSorters.NAME_ASCENDING)
@RunWith(Parameterized.class)
Expand Down
Expand Up @@ -107,6 +107,8 @@ enum NumberType {

void skipChildren();

String absoluteName();

String currentName();

Object currentValue();
Expand All @@ -133,6 +135,7 @@ enum NumberType {

byte[] binaryValue();

@Override
void close();

// Fairly experimental methods
Expand Down
Expand Up @@ -30,6 +30,7 @@
import org.apache.commons.logging.Log;
import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException;
import org.elasticsearch.hadoop.cfg.FieldPresenceValidation;
import org.elasticsearch.hadoop.serialization.FieldType;
import org.elasticsearch.hadoop.serialization.field.FieldFilter;
import org.elasticsearch.hadoop.util.StringUtils;

Expand Down Expand Up @@ -126,16 +127,28 @@ private static String removeDoubleBrackets(List col) {
}
return col.toString();
}

public static Field filter(Field field, Collection<String> includes, Collection<String> excludes) {
List<Field> filtered = new ArrayList<Field>();

for (Field fl : field.properties()) {
if (FieldFilter.filter(fl.name(), includes, excludes)) {
filtered.add(fl);
}
}

return new Field(field.name(), field.type(), filtered);
List<Field> filtered = new ArrayList<Field>();

for (Field fl : field.skipHeaders().properties()) {
processField(fl, null, filtered, includes, excludes);
}

return new Field(field.name(), field.type(), filtered);
}

private static void processField(Field field, String parentName, List<Field> filtered, Collection<String> includes, Collection<String> excludes) {
String fieldName = (parentName != null ? parentName + "." + field.name() : field.name());

if (FieldFilter.filter(fieldName, includes, excludes)) {
filtered.add(field);
}

if (FieldType.OBJECT == field.type()) {
for (Field nestedField : field.properties()) {
processField(nestedField, fieldName, filtered, includes, excludes);
}
}
}
}
Expand Up @@ -33,7 +33,7 @@ public abstract class FieldFilter {
* @param excludes
* @return
*/
public static boolean filter(String path, Collection<String> includes, Collection<String> excludes) {
public static boolean filter(String path, Collection<String> includes, Collection<String> excludes, boolean allowPartialMatches) {
includes = (includes == null ? Collections.<String> emptyList() : includes);
excludes = (excludes == null ? Collections.<String> emptyList() : excludes);

Expand Down Expand Up @@ -81,11 +81,15 @@ else if (include.length() > path.length() && include.charAt(path.length()) == '.
}
}

if (pathIsPrefixOfAnInclude || exactIncludeMatch) {
// if match or part of the path
// if match or part of the path (based on the passed param)
if (exactIncludeMatch || (allowPartialMatches && pathIsPrefixOfAnInclude)) {
return true;
}

return false;
}

public static boolean filter(String path, Collection<String> includes, Collection<String> excludes) {
return filter(path, includes, excludes, true);
}
}
Expand Up @@ -20,9 +20,12 @@

import java.io.IOException;
import java.io.InputStream;
import java.util.ArrayList;
import java.util.List;

import org.codehaus.jackson.JsonFactory;
import org.codehaus.jackson.JsonParser;
import org.codehaus.jackson.JsonStreamContext;
import org.codehaus.jackson.JsonToken;
import org.codehaus.jackson.impl.JsonParserBase;
import org.elasticsearch.hadoop.serialization.EsHadoopSerializationException;
Expand Down Expand Up @@ -135,6 +138,25 @@ public void skipChildren() {
}
}

@Override
public String absoluteName() {
List<String> tree = new ArrayList<String>();
for (JsonStreamContext ctx = parser.getParsingContext(); ctx != null; ctx = ctx.getParent()) {
if (ctx.inObject()) {
tree.add(ctx.getCurrentName());
}
}
StringBuilder sb = new StringBuilder();
for (int index = tree.size(); index > 0; index--) {
sb.append(tree.get(index - 1));
sb.append(".");
}

// remove the last .
sb.setLength(sb.length() - 1);
return sb.toString();
}

@Override
public String currentName() {
try {
Expand Down
Expand Up @@ -22,13 +22,11 @@ import java.{lang => jl}
import java.sql.Timestamp
import java.{util => ju}
import java.util.concurrent.TimeUnit
import javax.xml.bind.DatatypeConverter

import scala.collection.JavaConversions.propertiesAsScalaMap
import scala.collection.JavaConverters.asScalaBufferConverter
import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkConf
import org.apache.spark.SparkContext
import org.apache.spark.SparkException
Expand All @@ -42,15 +40,14 @@ import org.apache.spark.sql.types.StructField
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.types.TimestampType
import org.apache.spark.storage.StorageLevel._

import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException
import org.elasticsearch.hadoop.EsHadoopIllegalStateException
import org.elasticsearch.hadoop.cfg.ConfigurationOptions._
import org.elasticsearch.hadoop.cfg.PropertiesSettings
import org.elasticsearch.hadoop.mr.RestUtils
import org.elasticsearch.hadoop.util.StringUtils
import org.elasticsearch.hadoop.util.TestSettings
import org.elasticsearch.hadoop.util.TestUtils
import org.elasticsearch.spark._
import org.elasticsearch.spark.cfg._
import org.elasticsearch.spark.sql._
import org.elasticsearch.spark.sql.api.java.JavaEsSparkSQL
import org.elasticsearch.spark.sql.sqlContextFunctions
Expand All @@ -64,11 +61,13 @@ import org.junit.BeforeClass
import org.junit.FixMethodOrder
import org.junit.Test
import org.junit.runner.RunWith
import org.junit.runners.MethodSorters
import org.junit.runners.Parameterized
import org.junit.runners.Parameterized.Parameters
import org.junit.runners.MethodSorters
import com.esotericsoftware.kryo.io.{Input => KryoInput}
import com.esotericsoftware.kryo.io.{Output => KryoOutput}
import javax.xml.bind.DatatypeConverter
import org.elasticsearch.hadoop.EsHadoopIllegalArgumentException

object AbstractScalaEsScalaSparkSQL {
@transient val conf = new SparkConf().set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
Expand Down Expand Up @@ -170,8 +169,8 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
}

@Test
def testArrayMapping() {
val mapping = """{ "array-mapping": {
def testArrayMappingFirstLevel() {
val mapping = """{ "array-mapping-top-level": {
| "properties" : {
| "arr" : {
| "properties" : {
Expand All @@ -185,7 +184,7 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
}""".stripMargin

val index = wrapIndex("sparksql-test")
val indexAndType = s"$index/array-mapping"
val indexAndType = s"$index/array-mapping-top-level"
RestUtils.touch(index)
RestUtils.putMapping(indexAndType, mapping.getBytes(StringUtils.UTF_8))

Expand All @@ -203,6 +202,87 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
assertEquals(1, df.count())
}

@Test
def testMultiFieldsWithSameName {
val index = wrapIndex("sparksql-test")
val indexAndType = s"$index/array-mapping-nested"
RestUtils.touch(index)

// add some data
val jsonDoc = """{
| "bar" : {
| "bar" : {
| "bar" : [{
| "bar" : 1
| }, {
| "bar" : 2
| }
| ],
| "level" : 2,
| "level3" : true
| },
| "foo" : 10,
| "level" : 1,
| "level2" : 2
| },
| "foo" : "text",
| "level" : 0,
| "level1" : "string"
|}
""".stripMargin
RestUtils.postData(indexAndType, jsonDoc.getBytes(StringUtils.UTF_8))
RestUtils.refresh(index)

val newCfg = collection.mutable.Map(cfg.toSeq: _*) += ("es.field.read.as.array.include" -> "bar.bar.bar", "es.resource" -> indexAndType)
val cfgSettings = new SparkSettingsManager().load(sc.getConf).copy().merge(newCfg.asJava)
val schema = SchemaUtilsTestable.discoverMapping(cfgSettings)
val mapping = SchemaUtilsTestable.rowInfo(cfgSettings)

val df = sqc.read.options(newCfg).format("org.elasticsearch.spark.sql").load(indexAndType)
df.printSchema()
df.take(1).foreach(println)
assertEquals(1, df.count())
}

@Test
def testNestedFieldArray {
val index = wrapIndex("sparksql-test")
val indexAndType = s"$index/nested-same-name-fields"
RestUtils.touch(index)

// add some data
val jsonDoc = """{"foo" : 5, "nested": { "bar" : [{"date":"2015-01-01", "age":20},{"date":"2015-01-01", "age":20}], "what": "now" } }"""
sc.makeRDD(Seq(jsonDoc)).saveJsonToEs(indexAndType)
RestUtils.refresh(index)

val newCfg = collection.mutable.Map(cfg.toSeq: _*) += ("es.field.read.as.array.include" -> "nested.bar")

val df = sqc.read.options(newCfg).format("org.elasticsearch.spark.sql").load(indexAndType)
df.printSchema()
df.take(1).foreach(println)
assertEquals(1, df.count())
}

@Test
def testArrayValue {
val index = wrapIndex("sparksql-test")
val indexAndType = s"$index/array-value"
RestUtils.touch(index)

// add some data
val jsonDoc = """{"array" : [1, 2, 4, 5] }"""
sc.makeRDD(Seq(jsonDoc)).saveJsonToEs(indexAndType)
RestUtils.refresh(index)

val newCfg = collection.mutable.Map(cfg.toSeq: _*) += ("es.field.read.as.array.include" -> "array")

val df = sqc.read.options(newCfg).format("org.elasticsearch.spark.sql").load(indexAndType)
df.printSchema()
df.take(1).foreach(println)
assertEquals(1, df.count())
}


@Test
def testBasicRead() {
val dataFrame = artistsAsDataFrame
Expand All @@ -224,6 +304,14 @@ class AbstractScalaEsScalaSparkSQL(prefix: String, readMetadata: jl.Boolean, pus
assertThat(RestUtils.get(target + "/_search?"), containsString("345"))
}

@Test
def testEsDataFrame1WriteCount() {
val target = wrapIndex("sparksql-test/scala-basic-write")

val dataFrame = sqc.esDF(target, cfg)
assertEquals(345, dataFrame.count())
}

@Test
def testEsDataFrame1WriteWithMapping() {
val dataFrame = artistsAsDataFrame
Expand Down
@@ -0,0 +1,14 @@
package org.elasticsearch.spark.sql

import org.elasticsearch.hadoop.cfg.Settings

object SchemaUtilsTestable {

def discoverMapping(cfg: Settings) = SchemaUtils.discoverMapping(cfg)

def rowInfo(cfg: Settings) = {
val schema = SchemaUtils.discoverMapping(cfg)
SchemaUtils.setRowInfo(cfg, schema.struct)
SchemaUtils.getRowInfo(cfg)
}
}
Expand Up @@ -4,7 +4,9 @@ import java.util.Locale
import scala.None
import scala.Null
import scala.collection.JavaConverters.mapAsJavaMapConverter
import scala.collection.mutable.ArrayOps
import scala.collection.mutable.LinkedHashMap
import scala.collection.mutable.LinkedHashSet
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Row
Expand Down Expand Up @@ -46,9 +48,7 @@ import org.elasticsearch.hadoop.util.IOUtils
import org.elasticsearch.hadoop.util.StringUtils
import org.elasticsearch.spark.cfg.SparkSettingsManager
import org.elasticsearch.spark.serialization.ScalaValueWriter
import scala.collection.mutable.ListBuffer
import scala.collection.mutable.LinkedHashSet
import scala.collection.mutable.ArrayOps
import org.apache.commons.logging.LogFactory

private[sql] class DefaultSource extends RelationProvider with SchemaRelationProvider with CreatableRelationProvider {

Expand Down Expand Up @@ -122,7 +122,7 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @
}

if (filters != null && filters.size > 0 && Utils.isPushDown(cfg)) {
val log = Utils.logger("org.elasticsearch.spark.sql.DataSource")
val log = logger
if (log.isDebugEnabled()) {
log.debug(s"Pushing down filters ${filters.mkString("[", ",", "]")}")
}
Expand Down Expand Up @@ -297,7 +297,7 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @

def insert(data: DataFrame, overwrite: Boolean) {
if (overwrite) {
Utils.logger("org.elasticsearch.spark.sql.DataSource").info(s"Overwriting data for ${cfg.getResourceWrite}")
logger.info(s"Overwriting data for ${cfg.getResourceWrite}")

// perform a scan-scroll delete
val cfgCopy = cfg.copy()
Expand All @@ -318,4 +318,6 @@ private[sql] case class ElasticsearchRelation(parameters: Map[String, String], @
rr.close()
empty
}

private def logger = LogFactory.getLog("org.elasticsearch.spark.sql.DataSource")
}

0 comments on commit 6222363

Please sign in to comment.