Skip to content
This repository was archived by the owner on Aug 22, 2025. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,12 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
* the types test table.
*/

val arrayFlattenTestColumn: String = "arraystructarraystruct" /* Column used to test flattening of arrays */

//Template: This is the template implementation and shouldn't be modified in any specific test

def doTypesTest(datasourceName: String): Unit = {
for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
datasourceName should s"provide the right types for $executionType execution" in {
assumeEnvironmentIsUpAndRunning
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
for(
(tpe, i) <- typesSet zipWithIndex;
typeCheck <- tpe.typeCheck
) typeCheck(dframe.collect(executionType).head(i))
}

//Multi-level column flat test
protected def multilevelFlattenTests(datasourceName: String): Unit = {
//Multi-level column flatten test

it should "provide flattened column names through the `annotatedCollect` method" in {
val dataFrame = sql("SELECT structofstruct.struct1.structField1 FROM typesCheckTable")
Expand All @@ -81,8 +72,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
rows.length shouldBe 1
}

it should "be able to flatten whole rows" in {
val dataFrame = sql("SELECT * FROM typesCheckTable")
val rows = dataFrame.flattenedCollect()
val hasComposedTypes = rows.head.schema.fields exists { field =>
field.dataType match {
case _: StructType | _: ArrayType => true
case _ => false
}
}
hasComposedTypes shouldBe false
}
}

protected def arrayFlattenTests(datasourceName: String): Unit = {
//Multi-level column, with nested arrays, flatten test

it should "be able to vertically flatten results for array columns" in {
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable")
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable")
val res = dataFrame.flattenedCollect()

// No array columns should be found in the result schema
Expand All @@ -100,7 +107,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
}

it should "correctly apply user limits to a vertically flattened array column" in {
val dataFrame = sql(s"SELECT arraystructarraystruct FROM typesCheckTable LIMIT 1")
val dataFrame = sql(s"SELECT $arrayFlattenTestColumn FROM typesCheckTable LIMIT 1")
val res = dataFrame.flattenedCollect()
res.length shouldBe 1
}
Expand All @@ -113,6 +120,24 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {

}

def doTypesTest(datasourceName: String): Unit = {

for(executionType <- ExecutionType.Spark::ExecutionType.Native::Nil)
datasourceName should s"provide the right types for $executionType execution" in {
assumeEnvironmentIsUpAndRunning
val dframe = sql("SELECT " + typesSet.map(_.colname).mkString(", ") + s" FROM $dataTypesTableName")
for(
(tpe, i) <- typesSet zipWithIndex;
typeCheck <- tpe.typeCheck
) typeCheck(dframe.collect(executionType).head(i))
}

multilevelFlattenTests(datasourceName)

arrayFlattenTests(datasourceName)

}

abstract override def saveTestData: Unit = {
super.saveTestData
require(saveTypesData > 0, emptyTypesSetError)
Expand Down Expand Up @@ -167,6 +192,7 @@ trait SharedXDContextTypesTest extends SharedXDContextWithDataTest {
}

object SharedXDContextTypesTest {

val dataTypesTableName = "typesCheckTable"
case class SparkSQLColDef(colname: String, sqlType: String, typeCheck: Option[Any => Unit] = None)
object SparkSQLColDef {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, Literal}
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.catalyst.plans.logical.{Limit, LogicalPlan}
import org.apache.spark.sql.{Row, sources}
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, FilterReport, ProjectReport, SimpleLogicalPlan, CrossdataExecutionPlan}
import org.apache.spark.sql.sources.CatalystToCrossdataAdapter.{BaseLogicalPlan, CrossdataExecutionPlan, FilterReport, ProjectReport, SimpleLogicalPlan}
import org.apache.spark.sql.sources.{CatalystToCrossdataAdapter, Filter => SourceFilter}
import org.apache.spark.sql.types.{StructField, StructType}
import org.apache.spark.sql.types.{StructField, StructType, ArrayType}
import org.elasticsearch.action.search.SearchResponse

import scala.util.{Failure, Try}
Expand Down Expand Up @@ -101,8 +101,6 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
case sources.StringStartsWith(attribute, value) => prefixQuery(attribute, value.toLowerCase)
}

import scala.collection.JavaConversions._

val searchFilters = sFilters.collect {
case sources.EqualTo(attribute, value) => termQuery(attribute, value)
case sources.GreaterThan(attribute, value) => rangeQuery(attribute).from(value).includeLower(false)
Expand Down Expand Up @@ -131,6 +129,7 @@ class ElasticSearchQueryProcessor(val logicalPlan: LogicalPlan, val parameters:
val subDocuments = schemaProvided.toSeq flatMap {
_.fields collect {
case StructField(name, _: StructType, _, _) => name
case StructField(name, ArrayType(_: StructType, _), _, _) => name
}
}
val stringFields: Seq[String] = fields.view map (_.name) filterNot (subDocuments contains _)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ object ElasticSearchRowConverter {
// TODO: Note that if a nested subdocument is targeted, it won't work and this algorithm should be made recursive.
(hitFields.get(name) orElse subDocuments.get(name)).flatMap(Option(_)) map {
((value: Any) => enforceCorrectType(value, schemaMap(name))) compose {
case hitField: SearchHitField => hitField.getValue
case hitField: SearchHitField =>
if(hitField.getValues.size()>1) hitField.getValues
else hitField.getValue
case other => other
}
} orNull
Expand All @@ -85,6 +87,7 @@ object ElasticSearchRowConverter {
case DateType => toDate(value)
case BinaryType => toBinary(value)
case schema: StructType => toRow(value, schema)
case ArrayType(elementType: DataType, _) => toArray(value, elementType)
case _ =>
sys.error(s"Unsupported datatype conversion [${value.getClass}},$desiredType]")
value
Expand Down Expand Up @@ -174,4 +177,9 @@ object ElasticSearchRowConverter {
case _ => sys.error(s"Unsupported datatype conversion [${value.getClass}},Row")
}

def toArray(value: Any, elementType: DataType): Seq[Any] = value match {
case arr: util.ArrayList[Any] =>
arr.toArray.map(enforceCorrectType(_, elementType))
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,221 @@
/*
* Copyright (C) 2015 Stratio (http://stratio.com)
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.stratio.crossdata.connector.elasticsearch

import java.util.{GregorianCalendar, UUID}

import com.sksamuel.elastic4s.ElasticDsl._
import com.sksamuel.elastic4s.mappings.FieldType._
import com.sksamuel.elastic4s.mappings.{MappingDefinition, TypedFieldDefinition}
import com.stratio.common.utils.components.logger.impl.SparkLoggerComponent
import com.typesafe.config.ConfigFactory
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest
import org.apache.spark.sql.crossdata.test.SharedXDContextTypesTest.SparkSQLColDef
import org.joda.time.DateTime

trait ElasticDataTypes extends ElasticWithSharedContext
with SharedXDContextTypesTest
with ElasticSearchDataTypesDefaultConstants
with SparkLoggerComponent {

override val dataTypesSparkOptions = Map(
"resource" -> s"$Index/$Type",
"es.nodes" -> s"$ElasticHost",
"es.port" -> s"$ElasticRestPort",
"es.nativePort" -> s"$ElasticNativePort",
"es.cluster" -> s"$ElasticClusterName",
"es.nodes.wan.only" -> "true",
"es.read.field.as.array.include" -> Seq(
"arrayint"
).mkString(",")
)

protected case class ESColumnData(elasticType: Option[TypedFieldDefinition], data: () => Any)
protected object ESColumnData {
def apply(data: () => Any): ESColumnData = ESColumnData(None, data)
def apply(elasticType: TypedFieldDefinition, data: () => Any): ESColumnData = ESColumnData(Some(elasticType), data)
}


override val arrayFlattenTestColumn: String = "arraystruct"

protected val dataTest: Seq[(SparkSQLColDef, ESColumnData)] = Seq(
(SparkSQLColDef("id", "INT", _ shouldBe a[java.lang.Integer]), ESColumnData("id" typed IntegerType, () => 1)),
(SparkSQLColDef("age", "LONG", _ shouldBe a[java.lang.Long]), ESColumnData("age" typed LongType, () => 1)),
(
SparkSQLColDef("description", "STRING", _ shouldBe a[java.lang.String]),
ESColumnData("description" typed StringType, () => "1")
),
(
SparkSQLColDef("name", "STRING", _ shouldBe a[java.lang.String]),
ESColumnData( "name" typed StringType index NotAnalyzed, () => "1")
),
(
SparkSQLColDef("enrolled", "BOOLEAN", _ shouldBe a[java.lang.Boolean]),
ESColumnData("enrolled" typed BooleanType, () => false)
),
(
SparkSQLColDef("birthday", "DATE", _ shouldBe a [java.sql.Date]),
ESColumnData("birthday" typed DateType, () => DateTime.parse(1980 + "-01-01T10:00:00-00:00").toDate)
),
(
SparkSQLColDef("salary", "DOUBLE", _ shouldBe a[java.lang.Double]),
ESColumnData("salary" typed DoubleType, () => 0.15)
),
(
SparkSQLColDef("timecol", "TIMESTAMP", _ shouldBe a[java.sql.Timestamp]),
ESColumnData(
"timecol" typed DateType,
() => new java.sql.Timestamp(new GregorianCalendar(1970, 0, 1, 0, 0, 0).getTimeInMillis)
)
),
(
SparkSQLColDef("float", "FLOAT", _ shouldBe a[java.lang.Float]),
ESColumnData("float" typed FloatType, () => 0.15)
),
(
SparkSQLColDef("binary", "BINARY", x => x.isInstanceOf[Array[Byte]] shouldBe true),
ESColumnData("binary" typed BinaryType, () => "YWE=".getBytes)
),
(
SparkSQLColDef("tinyint", "TINYINT", _ shouldBe a[java.lang.Byte]),
ESColumnData("tinyint" typed ByteType, () => Byte.MinValue)
),
(
SparkSQLColDef("smallint", "SMALLINT", _ shouldBe a[java.lang.Short]),
ESColumnData("smallint" typed ShortType, () => Short.MaxValue)
),
(
SparkSQLColDef("subdocument", "STRUCT<field1: INT>", _ shouldBe a [Row]),
ESColumnData("subdocument" inner ("field1" typed IntegerType), () => Map( "field1" -> 15))
),
(
SparkSQLColDef(
"structofstruct",
"STRUCT<field1: INT, struct1: STRUCT<structField1: INT>>",
{ res =>
res shouldBe a[GenericRowWithSchema]
res.asInstanceOf[GenericRowWithSchema].get(1) shouldBe a[GenericRowWithSchema]
}
),
ESColumnData(
"structofstruct" inner ("field1" typed IntegerType, "struct1" inner("structField1" typed IntegerType)),
() => Map("field1" -> 15, "struct1" -> Map("structField1" -> 42))
)
),
(
SparkSQLColDef("arrayint", "ARRAY<INT>", _ shouldBe a[Seq[_]]),
ESColumnData(() => Seq(1,2,3,4))
),
(
SparkSQLColDef("arraystruct", "ARRAY<STRUCT<field1: LONG, field2: LONG>>", _ shouldBe a[Seq[_]]),
ESColumnData(
"arraystruct" nested(
"field1" typed LongType,
"field2" typed LongType
),
() =>
Array(
Map(
"field1" -> 11,
"field2" -> 12
),
Map(
"field1" -> 21,
"field2" -> 22
),
Map(
"field1" -> 31,
"field2" -> 32
)
)
)
)/*,
(
SparkSQLColDef(
"arraystructarraystruct",
"ARRAY<STRUCT<stringfield: STRING, arrayfield: ARRAY<STRUCT<field1: INT, field2: INT>>>>",
{ res =>
res shouldBe a[Seq[_]]
res.asInstanceOf[Seq[_]].head shouldBe a[Row]
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1) shouldBe a[Seq[_]]
res.asInstanceOf[Seq[_]].head.asInstanceOf[Row].get(1).asInstanceOf[Seq[_]].head shouldBe a[Row]
}
),
ESColumnData(
"arraystructarraystruct" nested (
"stringfield" typed StringType,
"arrayfield" nested (
"field1" typed IntegerType,
"field2" typed IntegerType
)
),
() => Array(
Map(
"stringfield" -> "hello",
"arrayfield" -> Array(
Map(
"field1" -> 10,
"field2" -> 20
)
)
)
)
)
)*/
)


override protected def typesSet: Seq[SparkSQLColDef] = dataTest.map(_._1)


abstract override def saveTestData: Unit = {
require(saveTypesData > 0, emptyTypesSetError)
}

override def saveTypesData: Int = {
client.get.execute {
val fieldsData = dataTest map {
case (SparkSQLColDef(fieldName, _, _), ESColumnData(_, data)) => (fieldName, data())
}
index into Index / Type fields (fieldsData: _*)
}.await
client.get.execute {
flush index Index
}.await
1
}

override def typeMapping(): MappingDefinition =
Type fields (
dataTest collect {
case (_, ESColumnData(Some(mapping), _)) => mapping
}: _*
)

override val emptyTypesSetError: String = "Couldn't insert Elasticsearch types test data"

}


trait ElasticSearchDataTypesDefaultConstants extends ElasticSearchDefaultConstants{
private lazy val config = ConfigFactory.load()
override val Index = s"idxname${UUID.randomUUID.toString.replaceAll("-", "")}"
override val Type = s"typename${UUID.randomUUID.toString.replaceAll("-", "")}"

}
Loading