Skip to content
This repository has been archived by the owner on Nov 20, 2019. It is now read-only.

Commit

Permalink
Merge branch 'master' into feature/postgresql-testAT
Browse files Browse the repository at this point in the history
  • Loading branch information
pmadrigal authored Feb 6, 2017
2 parents af2493b + 6225af2 commit b1bb1d6
Show file tree
Hide file tree
Showing 6 changed files with 275 additions and 191 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,12 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
* the types test table.
*/

val arrayFlattenTestColumn: String = "arraystructarraystruct" /* Column used to test flattening of arrays */

//Template: This is the template implementation and shouldn't be modified in any specific test

def doTypesTest(datasourceName: String): Unit = {
for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
datasourceName should s"provide the right types for $executionType execution" in {
assumeEnvironmentIsUpAndRunning
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
for(
(tpe, i) <- typesSet zipWithIndex;
typeCheck <- tpe.typeCheck
) typeCheck(dframe.collect(executionType).head(i))
}

//Multi-level column flat test
protected def multilevelFlattenTests(datasourceName: String): Unit = {
//Multi-level column flatten test

it should "provide flattened column names through the `annotatedCollect` method" in {
val dataFrame = sql("SELECT structofstruct.struct1.structField1 FROM typesCheckTable")
Expand All @@ -81,8 +72,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
rows.length shouldBe 1
}

it should "be able to flatten whole rows" in {
val dataFrame = sql("SELECT * FROM typesCheckTable")
val rows = dataFrame.flattenedCollect()
val hasComposedTypes = rows.head.schema.fields exists { field =>
field.dataType match {
case _: StructType | _: ArrayType => true
case _ => false
}
}
hasComposedTypes shouldBe false
}
}

protected def arrayFlattenTests(datasourceName: String): Unit = {
//Multi-level column, with nested arrays, flatten test

it should "be able to vertically flatten results for array columns" in {
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable")
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable")
val res = dataFrame.flattenedCollect()

// No array columns should be found in the result schema
Expand All @@ -100,7 +107,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
}

it should "correctly apply user limits to a vertically flattened array column" in {
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable LIMIT 1")
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable LIMIT 1")
val res = dataFrame.flattenedCollect()
res.length shouldBe 1
}
Expand All @@ -113,6 +120,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {

}

def doTypesTest(datasourceName: String): Unit = {

for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
datasourceName should s"provide the right types for $executionType execution" in {
assumeEnvironmentIsUpAndRunning
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
for(
(tpe, i) <- typesSet zipWithIndex;
typeCheck <- tpe.typeCheck
) typeCheck(dframe.collect(executionType).head(i))
}

multilevelFlattenTests(datasourceName)

arrayFlattenTests(datasourceName)

}

abstract override def saveTestData: Unit = {
super.saveTestData
require(saveTypesData > 0, emptyTypesSetError)
Expand Down Expand Up @@ -167,6 +192,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
}

object SharedXDContextTypesTest {

val dataTypesTableName = "typesCheckTable"
case class SparkSQLColDef(colname: String, sqlType: String, typeCheck: Option[Any => Unit] = None)
object SparkSQLColDef {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan}
import org.apache.spark.sql.{Row, sources}
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, FilterReport, ProjectReport, SimpleLogicalPlan, CrossdataExecutionPlan}
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, CrossdataExecutionPlan, FilterReport, ProjectReport, SimpleLogicalPlan}
import org.apache.spark.sql.sources.{CatalystToCrossdataAdapter, Filter => SourceFilter}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{StructField, StructType, ArrayType}
import org.elasticsearch.action.search.SearchResponse

import scala.util.{Failure, Try}
Expand Down Expand Up @@ -101,8 +101,6 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
case sources.StringStartsWith(attribute, value) => prefixQuery(attribute, value.toLowerCase)
}

import scala.collection.JavaConversions._

val searchFilters = sFilters.collect {
case sources.EqualTo(attribute, value) => termQuery(attribute, value)
case sources.GreaterThan(attribute, value) => rangeQuery(attribute).from(value).includeLower(false)
Expand Down Expand Up @@ -131,6 +129,7 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
val subDocuments = schemaProvided.toSeq flatMap {
_.fields collect {
case StructField(name, _: StructType, _, _) => name
case StructField(name, ArrayType(_: StructType, _), _, _) => name
}
}
val stringFields: Seq[String] = fields.view map (_.name) filterNot (subDocuments contains _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ object ElasticSearchRowConverter {
// TODO: Note that if a nested subdocument is targeted, it won't work and this algorithm should be made recursive.
(hitFields.get(name) orElse subDocuments.get(name)).flatMap(Option(_)) map {
((value: Any) => enforceCorrectType(value, schemaMap(name))) compose {
case hitField: SearchHitField => hitField.getValue
case hitField: SearchHitField =>
if(hitField.getValues.size()>1) hitField.getValues
else hitField.getValue
case other => other
}
} orNull
Expand All @@ -85,6 +87,7 @@ object ElasticSearchRowConverter {
case DateType => toDate(value)
case BinaryType => toBinary(value)
case schema: StructType => toRow(value, schema)
case ArrayType(elementType: DataType, _) => toArray(value, elementType)
case _ =>
sys.error(s"Unsupported datatype conversion [${value.getClass}},$desiredType]")
value
Expand Down Expand Up @@ -174,4 +177,9 @@ object ElasticSearchRowConverter {
case _ => sys.error(s"Unsupported datatype conversion [${value.getClass}},Row")
}

def toArray(value: Any, elementType: DataType): Seq[Any] = value match {
case arr: util.ArrayList[Any] =>
arr.toArray.map(enforceCorrectType(_, elementType))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright (C) 2015 Stratio (http://stratio.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.stratio.crossdata.connector.elasticsearch

import java.util.{GregorianCalendar, UUID}

import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.mappings.FieldType._
import com.sksamuel.elastic4s.mappings.{MappingDefinition, TypedFieldDefinition}
import com.stratio.common.utils.components.logger.impl.SparkLoggerComponent
import com.typesafe.config.ConfigFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest.SparkSQLColDef
import org.joda.time.DateTime

trait ElasticDataTypes extends ElasticWithSharedContext
with SharedXDContextTypesTest
with ElasticSearchDataTypesDefaultConstants
with SparkLoggerComponent {

override val dataTypesSparkOptions = Map(
"resource" -> s"$Index/$Type",
"es.nodes" -> s"$ElasticHost",
"es.port" -> s"$ElasticRestPort",
"es.nativePort" -> s"$ElasticNativePort",
"es.cluster" -> s"$ElasticClusterName",
"es.nodes.wan.only" -> "true",
"es.read.field.as.array.include" -> Seq(
"arrayint"
).mkString(",")
)

protected case class ESColumnData(elasticType: Option[TypedFieldDefinition], data: () => Any)
protected object ESColumnData {
def apply(data: () => Any): ESColumnData = ESColumnData(None, data)
def apply(elasticType: TypedFieldDefinition, data: () => Any): ESColumnData = ESColumnData(Some(elasticType), data)
}


override val arrayFlattenTestColumn: String = "arraystruct"

protected val dataTest: Seq[(SparkSQLColDef, ESColumnData)] = Seq(
(SparkSQLColDef("id", "INT", _ shouldBe a[java.lang.Integer]), ESColumnData("id" typed IntegerType, () => 1)),
(SparkSQLColDef("age", "LONG", _ shouldBe a[java.lang.Long]), ESColumnData("age" typed LongType, () => 1)),
(
SparkSQLColDef("description", "STRING", _ shouldBe a[java.lang.String]),
ESColumnData("description" typed StringType, () => "1")
),
(
SparkSQLColDef("name", "STRING", _ shouldBe a[java.lang.String]),
ESColumnData( "name" typed StringType index NotAnalyzed, () => "1")
),
(
SparkSQLColDef("enrolled", "BOOLEAN", _ shouldBe a[java.lang.Boolean]),
ESColumnData("enrolled" typed BooleanType, () => false)
),
(
SparkSQLColDef("birthday", "DATE", _ shouldBe a [java.sql.Date]),
ESColumnData("birthday" typed DateType, () => DateTime.parse(1980 + "-01-01T10:00:00-00:00").toDate)
),
(
SparkSQLColDef("salary", "DOUBLE", _ shouldBe a[java.lang.Double]),
ESColumnData("salary" typed DoubleType, () => 0.15)
),
(
SparkSQLColDef("timecol", "TIMESTAMP", _ shouldBe a[java.sql.Timestamp]),
ESColumnData(
"timecol" typed DateType,
() => new java.sql.Timestamp(new GregorianCalendar(1970, 0, 1, 0, 0, 0).getTimeInMillis)
)
),
(
SparkSQLColDef("float", "FLOAT", _ shouldBe a[java.lang.Float]),
ESColumnData("float" typed FloatType, () => 0.15)
),
(
SparkSQLColDef("binary", "BINARY", x => x.isInstanceOf[Array[Byte]] shouldBe true),
ESColumnData("binary" typed BinaryType, () => "YWE=".getBytes)
),
(
SparkSQLColDef("tinyint", "TINYINT", _ shouldBe a[java.lang.Byte]),
ESColumnData("tinyint" typed ByteType, () => Byte.MinValue)
),
(
SparkSQLColDef("smallint", "SMALLINT", _ shouldBe a[java.lang.Short]),
ESColumnData("smallint" typed ShortType, () => Short.MaxValue)
),
(
SparkSQLColDef("subdocument", "STRUCT<field1: INT>", _ shouldBe a [Row]),
ESColumnData("subdocument" inner ("field1" typed IntegerType), () => Map( "field1" -> 15))
),
(
SparkSQLColDef(
"structofstruct",
"STRUCT<field1: INT, struct1: STRUCT<structField1: INT>>",
{ res =>
res shouldBe a[GenericRowWithSchema]
res.asInstanceOf[GenericRowWithSchema].get(1) shouldBe a[GenericRowWithSchema]
}
),
ESColumnData(
"structofstruct" inner ("field1" typed IntegerType, "struct1" inner("structField1" typed IntegerType)),
() => Map("field1" -> 15, "struct1" -> Map("structField1" -> 42))
)
),
(
SparkSQLColDef("arrayint", "ARRAY<INT>", _ shouldBe a[Seq[_]]),
ESColumnData(() => Seq(1,2,3,4))
),
(
SparkSQLColDef("arraystruct", "ARRAY<STRUCT<field1: LONG, field2: LONG>>", _ shouldBe a[Seq[_]]),
ESColumnData(
"arraystruct" nested(
"field1" typed LongType,
"field2" typed LongType
),
() =>
Array(
Map(
"field1" -> 11,
"field2" -> 12
),
Map(
"field1" -> 21,
"field2" -> 22
),
Map(
"field1" -> 31,
"field2" -> 32
)
)
)
)/*,
(
SparkSQLColDef(
"arraystructarraystruct",
"ARRAY<STRUCT<stringfield: STRING, arrayfield: ARRAY<STRUCT<field1: INT, field2: INT>>>>",
{ res =>
res shouldBe a[Seq[_]]
res.asInstanceOf[Seq[_]].head shouldBe a[Row]
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1) shouldBe a[Seq[_]]
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1).asInstanceOf[Seq[_]].head shouldBe a[Row]
}
),
ESColumnData(
"arraystructarraystruct" nested (
"stringfield" typed StringType,
"arrayfield" nested (
"field1" typed IntegerType,
"field2" typed IntegerType
)
),
() => Array(
Map(
"stringfield" -> "hello",
"arrayfield" -> Array(
Map(
"field1" -> 10,
"field2" -> 20
)
)
)
)
)
)*/
)


override protected def typesSet: Seq[SparkSQLColDef] = dataTest.map(_._1)


abstract override def saveTestData: Unit = {
require(saveTypesData > 0, emptyTypesSetError)
}

override def saveTypesData: Int = {
client.get.execute {
val fieldsData = dataTest map {
case (SparkSQLColDef(fieldName, _, _), ESColumnData(_, data)) => (fieldName, data())
}
index into Index / Type fields (fieldsData: _*)
}.await
client.get.execute {
flush index Index
}.await
1
}

override def typeMapping(): MappingDefinition =
Type fields (
dataTest collect {
case (_, ESColumnData(Some(mapping), _)) => mapping
}: _*
)

override val emptyTypesSetError: String = "Couldn't insert Elasticsearch types test data"

}


trait ElasticSearchDataTypesDefaultConstants extends ElasticSearchDefaultConstants{
private lazy val config = ConfigFactory.load()
override val Index = s"idxname${UUID.randomUUID.toString.replaceAll("-", "")}"
override val Type = s"typename${UUID.randomUUID.toString.replaceAll("-", "")}"

}
Loading

0 comments on commit b1bb1d6

Please sign in to comment.