Skip to content
Permalink
d03e0af80d
Go to file
 
 
Cannot retrieve contributors at this time
4127 lines (3681 sloc) 139 KB
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import java.util.{Comparator, TimeZone}
import scala.collection.mutable
import scala.reflect.ClassTag
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.expressions.ArraySortLike.NullOrder
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.array.ByteArrayMethods
import org.apache.spark.unsafe.array.ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH
import org.apache.spark.unsafe.types.{ByteArray, UTF8String}
import org.apache.spark.unsafe.types.CalendarInterval
import org.apache.spark.util.collection.OpenHashSet
/**
* Base trait for [[BinaryExpression]]s with two arrays of the same element type and implicit
* casting.
*/
trait BinaryArrayExpressionWithImplicitCast extends BinaryExpression
with ImplicitCastInputTypes {
@transient protected lazy val elementType: DataType =
inputTypes.head.asInstanceOf[ArrayType].elementType
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull1), ArrayType(e2, hasNull2)) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull1), ArrayType(dt, hasNull2))
case _ => Seq.empty
}
case _ => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), ArrayType(e2, _)) if e1.sameType(e2) =>
TypeCheckResult.TypeCheckSuccess
case _ => TypeCheckResult.TypeCheckFailure(s"input to function $prettyName should have " +
s"been two ${ArrayType.simpleString}s with same element type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}]")
}
}
}
/**
* Given an array or map, returns total number of elements in it.
*/
@ExpressionDescription(
usage = """
_FUNC_(expr) - Returns the size of an array or a map.
The function returns -1 if its input is null and spark.sql.legacy.sizeOfNull is set to true.
If spark.sql.legacy.sizeOfNull is set to false, the function returns null for null input.
By default, the spark.sql.legacy.sizeOfNull parameter is set to true.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', 'c', 'a'));
4
> SELECT _FUNC_(map('a', 1, 'b', 2));
2
> SELECT _FUNC_(NULL);
-1
""")
case class Size(child: Expression) extends UnaryExpression with ExpectsInputTypes {
val legacySizeOfNull = SQLConf.get.legacySizeOfNull
override def dataType: DataType = IntegerType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(ArrayType, MapType))
override def nullable: Boolean = if (legacySizeOfNull) false else super.nullable
override def eval(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
if (legacySizeOfNull) -1 else null
} else child.dataType match {
case _: ArrayType => value.asInstanceOf[ArrayData].numElements()
case _: MapType => value.asInstanceOf[MapData].numElements()
case other => throw new UnsupportedOperationException(
s"The size function doesn't support the operand type ${other.getClass.getCanonicalName}")
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (legacySizeOfNull) {
val childGen = child.genCode(ctx)
ev.copy(code = code"""
boolean ${ev.isNull} = false;
${childGen.code}
${CodeGenerator.javaType(dataType)} ${ev.value} = ${childGen.isNull} ? -1 :
(${childGen.value}).numElements();""", isNull = FalseLiteral)
} else {
defineCodeGen(ctx, ev, c => s"($c).numElements()")
}
}
}
/**
* Returns an unordered array containing the keys of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array containing the keys of the map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[1,2]
""")
case class MapKeys(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].keyType)
override def nullSafeEval(map: Any): Any = {
map.asInstanceOf[MapData].keyArray()
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).keyArray();")
}
override def prettyName: String = "map_keys"
}
@ExpressionDescription(
usage = """
_FUNC_(a1, a2, ...) - Returns a merged array of structs in which the N-th struct contains all
N-th values of input arrays.
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(2, 3, 4));
[{"0":1,"1":2},{"0":2,"1":3},{"0":3,"1":4}]
> SELECT _FUNC_(array(1, 2), array(2, 3), array(3, 4));
[{"0":1,"1":2,"2":3},{"0":2,"1":3,"2":4}]
""",
since = "2.4.0")
case class ArraysZip(children: Seq[Expression]) extends Expression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq.fill(children.length)(ArrayType)
@transient override lazy val dataType: DataType = {
val fields = children.zip(arrayElementTypes).zipWithIndex.map {
case ((expr: NamedExpression, elementType), _) =>
StructField(expr.name, elementType, nullable = true)
case ((_, elementType), idx) =>
StructField(idx.toString, elementType, nullable = true)
}
ArrayType(StructType(fields), containsNull = false)
}
override def nullable: Boolean = children.exists(_.nullable)
@transient private lazy val arrayElementTypes =
children.map(_.dataType.asInstanceOf[ArrayType].elementType)
private def genericArrayData = classOf[GenericArrayData].getName
def emptyInputGenCode(ev: ExprCode): ExprCode = {
ev.copy(code"""
|${CodeGenerator.javaType(dataType)} ${ev.value} = new $genericArrayData(new Object[0]);
|boolean ${ev.isNull} = false;
""".stripMargin)
}
def nonEmptyInputGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val genericInternalRow = classOf[GenericInternalRow].getName
val arrVals = ctx.freshName("arrVals")
val biggestCardinality = ctx.freshName("biggestCardinality")
val currentRow = ctx.freshName("currentRow")
val j = ctx.freshName("j")
val i = ctx.freshName("i")
val args = ctx.freshName("args")
val evals = children.map(_.genCode(ctx))
val getValuesAndCardinalities = evals.zipWithIndex.map { case (eval, index) =>
s"""
|if ($biggestCardinality != -1) {
| ${eval.code}
| if (!${eval.isNull}) {
| $arrVals[$index] = ${eval.value};
| $biggestCardinality = Math.max($biggestCardinality, ${eval.value}.numElements());
| } else {
| $biggestCardinality = -1;
| }
|}
""".stripMargin
}
val splittedGetValuesAndCardinalities = ctx.splitExpressionsWithCurrentInputs(
expressions = getValuesAndCardinalities,
funcName = "getValuesAndCardinalities",
returnType = "int",
makeSplitFunction = body =>
s"""
|$body
|return $biggestCardinality;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$biggestCardinality = $funcCall;").mkString("\n"),
extraArguments =
("ArrayData[]", arrVals) ::
("int", biggestCardinality) :: Nil)
val getValueForType = arrayElementTypes.zipWithIndex.map { case (eleType, idx) =>
val g = CodeGenerator.getValue(s"$arrVals[$idx]", eleType, i)
s"""
|if ($i < $arrVals[$idx].numElements() && !$arrVals[$idx].isNullAt($i)) {
| $currentRow[$idx] = $g;
|} else {
| $currentRow[$idx] = null;
|}
""".stripMargin
}
val getValueForTypeSplitted = ctx.splitExpressions(
expressions = getValueForType,
funcName = "extractValue",
arguments =
("int", i) ::
("Object[]", currentRow) ::
("ArrayData[]", arrVals) :: Nil)
val initVariables = s"""
|ArrayData[] $arrVals = new ArrayData[${children.length}];
|int $biggestCardinality = 0;
|${CodeGenerator.javaType(dataType)} ${ev.value} = null;
""".stripMargin
ev.copy(code"""
|$initVariables
|$splittedGetValuesAndCardinalities
|boolean ${ev.isNull} = $biggestCardinality == -1;
|if (!${ev.isNull}) {
| Object[] $args = new Object[$biggestCardinality];
| for (int $i = 0; $i < $biggestCardinality; $i ++) {
| Object[] $currentRow = new Object[${children.length}];
| $getValueForTypeSplitted
| $args[$i] = new $genericInternalRow($currentRow);
| }
| ${ev.value} = new $genericArrayData($args);
|}
""".stripMargin)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
if (children.length == 0) {
emptyInputGenCode(ev)
} else {
nonEmptyInputGenCode(ctx, ev)
}
}
override def eval(input: InternalRow): Any = {
val inputArrays = children.map(_.eval(input).asInstanceOf[ArrayData])
if (inputArrays.contains(null)) {
null
} else {
val biggestCardinality = if (inputArrays.isEmpty) {
0
} else {
inputArrays.map(_.numElements()).max
}
val result = new Array[InternalRow](biggestCardinality)
val zippedArrs: Seq[(ArrayData, Int)] = inputArrays.zipWithIndex
for (i <- 0 until biggestCardinality) {
val currentLayer: Seq[Object] = zippedArrs.map { case (arr, index) =>
if (i < arr.numElements() && !arr.isNullAt(i)) {
arr.get(i, arrayElementTypes(index))
} else {
null
}
}
result(i) = InternalRow.apply(currentLayer: _*)
}
new GenericArrayData(result)
}
}
override def prettyName: String = "arrays_zip"
}
/**
* Returns an unordered array containing the values of the map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array containing the values of the map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
["a","b"]
""")
case class MapValues(child: Expression)
extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
override def dataType: DataType = ArrayType(child.dataType.asInstanceOf[MapType].valueType)
override def nullSafeEval(map: Any): Any = {
map.asInstanceOf[MapData].valueArray()
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => s"${ev.value} = ($c).valueArray();")
}
override def prettyName: String = "map_values"
}
/**
* Returns an unordered array of all entries in the given map.
*/
@ExpressionDescription(
usage = "_FUNC_(map) - Returns an unordered array of all entries in the given map.",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'));
[{"key":1,"value":"a"},{"key":2,"value":"b"}]
""",
since = "2.4.0")
case class MapEntries(child: Expression) extends UnaryExpression with ExpectsInputTypes {
override def inputTypes: Seq[AbstractDataType] = Seq(MapType)
@transient private lazy val childDataType: MapType = child.dataType.asInstanceOf[MapType]
override def dataType: DataType = {
ArrayType(
StructType(
StructField("key", childDataType.keyType, false) ::
StructField("value", childDataType.valueType, childDataType.valueContainsNull) ::
Nil),
false)
}
override protected def nullSafeEval(input: Any): Any = {
val childMap = input.asInstanceOf[MapData]
val keys = childMap.keyArray()
val values = childMap.valueArray()
val length = childMap.numElements()
val resultData = new Array[AnyRef](length)
var i = 0
while (i < length) {
val key = keys.get(i, childDataType.keyType)
val value = values.get(i, childDataType.valueType)
val row = new GenericInternalRow(Array[Any](key, value))
resultData.update(i, row)
i += 1
}
new GenericArrayData(resultData)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val arrayData = ctx.freshName("arrayData")
val numElements = ctx.freshName("numElements")
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(childDataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val wordSize = UnsafeRow.WORD_SIZE
val structSize = UnsafeRow.calculateBitSetWidthInBytes(2) + wordSize * 2
val (isPrimitive, elementSize) = if (isKeyPrimitive && isValuePrimitive) {
(true, structSize + wordSize)
} else {
(false, -1)
}
val allocation =
s"""
|ArrayData $arrayData = ArrayData.allocateArrayData(
| $elementSize, $numElements, " $prettyName failed.");
""".stripMargin
val code = if (isPrimitive) {
val genCodeForPrimitive = genCodeForPrimitiveElements(
ctx, arrayData, keys, values, ev.value, numElements, structSize)
s"""
|if ($arrayData instanceof UnsafeArrayData) {
| $genCodeForPrimitive
|} else {
| ${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}
|}
""".stripMargin
} else {
s"${genCodeForAnyElements(ctx, arrayData, keys, values, ev.value, numElements)}"
}
s"""
|final int $numElements = $c.numElements();
|final ArrayData $keys = $c.keyArray();
|final ArrayData $values = $c.valueArray();
|$allocation
|$code
""".stripMargin
})
}
private def getKey(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.keyType, index)
private def getValue(varName: String, index: String) =
CodeGenerator.getValue(varName, childDataType.valueType, index)
private def genCodeForPrimitiveElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String,
structSize: Int): String = {
val unsafeArrayData = ctx.freshName("unsafeArrayData")
val baseObject = ctx.freshName("baseObject")
val unsafeRow = ctx.freshName("unsafeRow")
val structsOffset = ctx.freshName("structsOffset")
val offset = ctx.freshName("offset")
val z = ctx.freshName("z")
val calculateHeader = "UnsafeArrayData.calculateHeaderPortionInBytes"
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val wordSize = UnsafeRow.WORD_SIZE
val structSizeAsLong = s"${structSize}L"
val setKey = CodeGenerator.setColumn(unsafeRow, childDataType.keyType, 0, getKey(keys, z))
val valueAssignmentChecked = CodeGenerator.createArrayAssignment(
unsafeRow, childDataType.valueType, values, "1", z, childDataType.valueContainsNull)
s"""
|UnsafeArrayData $unsafeArrayData = (UnsafeArrayData)$arrayData;
|Object $baseObject = $unsafeArrayData.getBaseObject();
|final int $structsOffset = $calculateHeader($numElements) + $numElements * $wordSize;
|UnsafeRow $unsafeRow = new UnsafeRow(2);
|for (int $z = 0; $z < $numElements; $z++) {
| long $offset = $structsOffset + $z * $structSizeAsLong;
| $unsafeArrayData.setLong($z, ($offset << 32) + $structSizeAsLong);
| $unsafeRow.pointTo($baseObject, $baseOffset + $offset, $structSize);
| $setKey;
| $valueAssignmentChecked
|}
|$resultArrayData = $arrayData;
""".stripMargin
}
private def genCodeForAnyElements(
ctx: CodegenContext,
arrayData: String,
keys: String,
values: String,
resultArrayData: String,
numElements: String): String = {
val z = ctx.freshName("z")
val isValuePrimitive = CodeGenerator.isPrimitiveType(childDataType.valueType)
val getValueWithCheck = if (childDataType.valueContainsNull && isValuePrimitive) {
s"$values.isNullAt($z) ? null : (Object)${getValue(values, z)}"
} else {
getValue(values, z)
}
val rowClass = classOf[GenericInternalRow].getName
val genericArrayDataClass = classOf[GenericArrayData].getName
val genericArrayData = ctx.freshName("genericArrayData")
val rowObject = s"new $rowClass(new Object[]{${getKey(keys, z)}, $getValueWithCheck})"
s"""
|$genericArrayDataClass $genericArrayData = ($genericArrayDataClass)$arrayData;
|for (int $z = 0; $z < $numElements; $z++) {
| $genericArrayData.update($z, $rowObject);
|}
|$resultArrayData = $arrayData;
""".stripMargin
}
override def prettyName: String = "map_entries"
}
/**
* Returns the union of all the given maps.
*/
@ExpressionDescription(
usage = "_FUNC_(map, ...) - Returns the union of all the given maps",
examples = """
Examples:
> SELECT _FUNC_(map(1, 'a', 2, 'b'), map(2, 'c', 3, 'd'));
{1:"a",2:"c",3:"d"}
""", since = "2.4.0")
case class MapConcat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
override def checkInputDataTypes(): TypeCheckResult = {
var funcName = s"function $prettyName"
if (children.exists(!_.dataType.isInstanceOf[MapType])) {
TypeCheckResult.TypeCheckFailure(
s"input to $funcName should all be of type map, but it's " +
children.map(_.dataType.catalogString).mkString("[", ", ", "]"))
} else {
TypeUtils.checkForSameTypeInputExpr(children.map(_.dataType), funcName)
}
}
@transient override lazy val dataType: MapType = {
if (children.isEmpty) {
MapType(StringType, StringType)
} else {
super.dataType.asInstanceOf[MapType]
}
}
override def nullable: Boolean = children.exists(_.nullable)
override def eval(input: InternalRow): Any = {
val maps = children.map(_.eval(input))
if (maps.contains(null)) {
return null
}
val keyArrayDatas = maps.map(_.asInstanceOf[MapData].keyArray())
val valueArrayDatas = maps.map(_.asInstanceOf[MapData].valueArray())
val numElements = keyArrayDatas.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful attempt to concat maps with $numElements " +
s"elements due to exceeding the map size limit " +
s"${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.")
}
val finalKeyArray = new Array[AnyRef](numElements.toInt)
val finalValueArray = new Array[AnyRef](numElements.toInt)
var position = 0
for (i <- keyArrayDatas.indices) {
val keyArray = keyArrayDatas(i).toObjectArray(dataType.keyType)
val valueArray = valueArrayDatas(i).toObjectArray(dataType.valueType)
Array.copy(keyArray, 0, finalKeyArray, position, keyArray.length)
Array.copy(valueArray, 0, finalValueArray, position, valueArray.length)
position += keyArray.length
}
new ArrayBasedMapData(new GenericArrayData(finalKeyArray),
new GenericArrayData(finalValueArray))
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val mapCodes = children.map(_.genCode(ctx))
val keyType = dataType.keyType
val valueType = dataType.valueType
val argsName = ctx.freshName("args")
val hasNullName = ctx.freshName("hasNull")
val mapDataClass = classOf[MapData].getName
val arrayBasedMapDataClass = classOf[ArrayBasedMapData].getName
val arrayDataClass = classOf[ArrayData].getName
val init =
s"""
|$mapDataClass[] $argsName = new $mapDataClass[${mapCodes.size}];
|boolean ${ev.isNull}, $hasNullName = false;
|$mapDataClass ${ev.value} = null;
""".stripMargin
val assignments = mapCodes.zip(children.map(_.nullable)).zipWithIndex.map {
case ((m, true), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| if (!${m.isNull}) {
| $argsName[$i] = ${m.value};
| } else {
| $hasNullName = true;
| }
|}
""".stripMargin
case ((m, false), i) =>
s"""
|if (!$hasNullName) {
| ${m.code}
| $argsName[$i] = ${m.value};
|}
""".stripMargin
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = assignments,
funcName = "getMapConcatInputs",
extraArguments = (s"$mapDataClass[]", argsName) :: ("boolean", hasNullName) :: Nil,
returnType = "boolean",
makeSplitFunction = body =>
s"""
|$body
|return $hasNullName;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$hasNullName = $funcCall;").mkString("\n")
)
val idxName = ctx.freshName("idx")
val numElementsName = ctx.freshName("numElems")
val finKeysName = ctx.freshName("finalKeys")
val finValsName = ctx.freshName("finalValues")
val keyConcat = genCodeForArrays(ctx, keyType, false)
val valueConcat =
if (valueType.sameType(keyType) &&
!(CodeGenerator.isPrimitiveType(valueType) && dataType.valueContainsNull)) {
keyConcat
} else {
genCodeForArrays(ctx, valueType, dataType.valueContainsNull)
}
val keyArgsName = ctx.freshName("keyArgs")
val valArgsName = ctx.freshName("valArgs")
val mapMerge =
s"""
|${ev.isNull} = $hasNullName;
|if (!${ev.isNull}) {
| $arrayDataClass[] $keyArgsName = new $arrayDataClass[${mapCodes.size}];
| $arrayDataClass[] $valArgsName = new $arrayDataClass[${mapCodes.size}];
| long $numElementsName = 0;
| for (int $idxName = 0; $idxName < $argsName.length; $idxName++) {
| $keyArgsName[$idxName] = $argsName[$idxName].keyArray();
| $valArgsName[$idxName] = $argsName[$idxName].valueArray();
| $numElementsName += $argsName[$idxName].numElements();
| }
| if ($numElementsName > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Unsuccessful attempt to concat maps with " +
| $numElementsName + " elements due to exceeding the map size limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
| }
| $arrayDataClass $finKeysName = $keyConcat($keyArgsName,
| (int) $numElementsName);
| $arrayDataClass $finValsName = $valueConcat($valArgsName,
| (int) $numElementsName);
| ${ev.value} = new $arrayBasedMapDataClass($finKeysName, $finValsName);
|}
""".stripMargin
ev.copy(
code = code"""
|$init
|$codes
|$mapMerge
""".stripMargin)
}
private def genCodeForArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val argsName = ctx.freshName("args")
val numElemName = ctx.freshName("numElements")
val y = ctx.freshName("y")
val z = ctx.freshName("z")
val allocation = CodeGenerator.createArrayData(
arrayData, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
arrayData, elementType, s"$argsName[$y]", counter, z, checkForNull)
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] $argsName, int $numElemName) {
| $allocation
| int $counter = 0;
| for (int $y = 0; $y < ${children.length}; $y++) {
| for (int $z = 0; $z < $argsName[$y].numElements(); $z++) {
| $assignment
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
override def prettyName: String = "map_concat"
}
/**
* Returns a map created from the given array of entries.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfEntries) - Returns a map created from the given array of entries.",
examples = """
Examples:
> SELECT _FUNC_(array(struct(1, 'a'), struct(2, 'b')));
{1:"a",2:"b"}
""",
since = "2.4.0")
case class MapFromEntries(child: Expression) extends UnaryExpression {
@transient
private lazy val dataTypeDetails: Option[(MapType, Boolean, Boolean)] = child.dataType match {
case ArrayType(
StructType(Array(
StructField(_, keyType, keyNullable, _),
StructField(_, valueType, valueNullable, _))),
containsNull) => Some((MapType(keyType, valueType, valueNullable), keyNullable, containsNull))
case _ => None
}
@transient private lazy val nullEntries: Boolean = dataTypeDetails.get._3
override def nullable: Boolean = child.nullable || nullEntries
@transient override lazy val dataType: MapType = dataTypeDetails.get._1
override def checkInputDataTypes(): TypeCheckResult = dataTypeDetails match {
case Some(_) => TypeCheckResult.TypeCheckSuccess
case None => TypeCheckResult.TypeCheckFailure(s"'${child.sql}' is of " +
s"${child.dataType.catalogString} type. $prettyName accepts only arrays of pair structs.")
}
override protected def nullSafeEval(input: Any): Any = {
val arrayData = input.asInstanceOf[ArrayData]
val numEntries = arrayData.numElements()
var i = 0
if(nullEntries) {
while (i < numEntries) {
if (arrayData.isNullAt(i)) return null
i += 1
}
}
val keyArray = new Array[AnyRef](numEntries)
val valueArray = new Array[AnyRef](numEntries)
i = 0
while (i < numEntries) {
val entry = arrayData.getStruct(i, 2)
val key = entry.get(0, dataType.keyType)
if (key == null) {
throw new RuntimeException("The first field from a struct (key) can't be null.")
}
keyArray.update(i, key)
val value = entry.get(1, dataType.valueType)
valueArray.update(i, value)
i += 1
}
ArrayBasedMapData(keyArray, valueArray)
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val numEntries = ctx.freshName("numEntries")
val isKeyPrimitive = CodeGenerator.isPrimitiveType(dataType.keyType)
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val code = if (isKeyPrimitive && isValuePrimitive) {
genCodeForPrimitiveElements(ctx, c, ev.value, numEntries)
} else {
genCodeForAnyElements(ctx, c, ev.value, numEntries)
}
ctx.nullArrayElementsSaveExec(nullEntries, ev.isNull, c) {
s"""
|final int $numEntries = $c.numElements();
|$code
""".stripMargin
}
})
}
private def genCodeForAssignmentLoop(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String,
keyAssignment: (String, String) => String,
valueAssignment: (String, String) => String): String = {
val entry = ctx.freshName("entry")
val i = ctx.freshName("idx")
val nullKeyCheck = if (dataTypeDetails.get._2) {
s"""
|if ($entry.isNullAt(0)) {
| throw new RuntimeException("The first field from a struct (key) can't be null.");
|}
""".stripMargin
} else {
""
}
s"""
|for (int $i = 0; $i < $numEntries; $i++) {
| InternalRow $entry = $childVariable.getStruct($i, 2);
| $nullKeyCheck
| ${keyAssignment(CodeGenerator.getValue(entry, dataType.keyType, "0"), i)}
| ${valueAssignment(entry, i)}
|}
""".stripMargin
}
private def genCodeForPrimitiveElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String): String = {
val byteArraySize = ctx.freshName("byteArraySize")
val keySectionSize = ctx.freshName("keySectionSize")
val valueSectionSize = ctx.freshName("valueSectionSize")
val data = ctx.freshName("byteArray")
val unsafeMapData = ctx.freshName("unsafeMapData")
val keyArrayData = ctx.freshName("keyArrayData")
val valueArrayData = ctx.freshName("valueArrayData")
val baseOffset = Platform.BYTE_ARRAY_OFFSET
val keySize = dataType.keyType.defaultSize
val valueSize = dataType.valueType.defaultSize
val kByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $keySize)"
val vByteSize = s"UnsafeArrayData.calculateSizeOfUnderlyingByteArray($numEntries, $valueSize)"
val keyAssignment = (key: String, idx: String) =>
CodeGenerator.setArrayElement(keyArrayData, dataType.keyType, idx, key)
val valueAssignment = (entry: String, idx: String) =>
CodeGenerator.createArrayAssignment(
valueArrayData, dataType.valueType, entry, idx, "1", dataType.valueContainsNull)
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
mapData,
numEntries,
keyAssignment,
valueAssignment
)
s"""
|final long $keySectionSize = $kByteSize;
|final long $valueSectionSize = $vByteSize;
|final long $byteArraySize = 8 + $keySectionSize + $valueSectionSize;
|if ($byteArraySize > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| ${genCodeForAnyElements(ctx, childVariable, mapData, numEntries)}
|} else {
| final byte[] $data = new byte[(int)$byteArraySize];
| UnsafeMapData $unsafeMapData = new UnsafeMapData();
| Platform.putLong($data, $baseOffset, $keySectionSize);
| Platform.putLong($data, ${baseOffset + 8}, $numEntries);
| Platform.putLong($data, ${baseOffset + 8} + $keySectionSize, $numEntries);
| $unsafeMapData.pointTo($data, $baseOffset, (int)$byteArraySize);
| ArrayData $keyArrayData = $unsafeMapData.keyArray();
| ArrayData $valueArrayData = $unsafeMapData.valueArray();
| $assignmentLoop
| $mapData = $unsafeMapData;
|}
""".stripMargin
}
private def genCodeForAnyElements(
ctx: CodegenContext,
childVariable: String,
mapData: String,
numEntries: String): String = {
val keys = ctx.freshName("keys")
val values = ctx.freshName("values")
val mapDataClass = classOf[ArrayBasedMapData].getName()
val isValuePrimitive = CodeGenerator.isPrimitiveType(dataType.valueType)
val valueAssignment = (entry: String, idx: String) => {
val value = CodeGenerator.getValue(entry, dataType.valueType, "1")
if (dataType.valueContainsNull && isValuePrimitive) {
s"$values[$idx] = $entry.isNullAt(1) ? null : (Object)$value;"
} else {
s"$values[$idx] = $value;"
}
}
val keyAssignment = (key: String, idx: String) => s"$keys[$idx] = $key;"
val assignmentLoop = genCodeForAssignmentLoop(
ctx,
childVariable,
mapData,
numEntries,
keyAssignment,
valueAssignment)
s"""
|final Object[] $keys = new Object[$numEntries];
|final Object[] $values = new Object[$numEntries];
|$assignmentLoop
|$mapData = $mapDataClass.apply($keys, $values);
""".stripMargin
}
override def prettyName: String = "map_from_entries"
}
/**
* Common base class for [[SortArray]] and [[ArraySort]].
*/
trait ArraySortLike extends ExpectsInputTypes {
protected def arrayExpression: Expression
protected def nullOrder: NullOrder
@transient private lazy val lt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
nullOrder
} else if (o2 == null) {
-nullOrder
} else {
ordering.compare(o1, o2)
}
}
}
}
@transient private lazy val gt: Comparator[Any] = {
val ordering = arrayExpression.dataType match {
case _ @ ArrayType(n: AtomicType, _) => n.ordering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(a: ArrayType, _) => a.interpretedOrdering.asInstanceOf[Ordering[Any]]
case _ @ ArrayType(s: StructType, _) => s.interpretedOrdering.asInstanceOf[Ordering[Any]]
}
new Comparator[Any]() {
override def compare(o1: Any, o2: Any): Int = {
if (o1 == null && o2 == null) {
0
} else if (o1 == null) {
-nullOrder
} else if (o2 == null) {
nullOrder
} else {
ordering.compare(o2, o1)
}
}
}
}
@transient lazy val elementType: DataType =
arrayExpression.dataType.asInstanceOf[ArrayType].elementType
def containsNull: Boolean = arrayExpression.dataType.asInstanceOf[ArrayType].containsNull
def sortEval(array: Any, ascending: Boolean): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementType != NullType) {
java.util.Arrays.sort(data, if (ascending) lt else gt)
}
new GenericArrayData(data.asInstanceOf[Array[Any]])
}
def sortCodegen(ctx: CodegenContext, ev: ExprCode, base: String, order: String): String = {
val arrayData = classOf[ArrayData].getName
val genericArrayData = classOf[GenericArrayData].getName
val unsafeArrayData = classOf[UnsafeArrayData].getName
val array = ctx.freshName("array")
val c = ctx.freshName("c")
if (elementType == NullType) {
s"${ev.value} = $base.copy();"
} else {
val elementTypeTerm = ctx.addReferenceObj("elementTypeTerm", elementType)
val sortOrder = ctx.freshName("sortOrder")
val o1 = ctx.freshName("o1")
val o2 = ctx.freshName("o2")
val jt = CodeGenerator.javaType(elementType)
val comp = if (CodeGenerator.isPrimitiveType(elementType)) {
val bt = CodeGenerator.boxedType(elementType)
val v1 = ctx.freshName("v1")
val v2 = ctx.freshName("v2")
s"""
|$jt $v1 = (($bt) $o1).${jt}Value();
|$jt $v2 = (($bt) $o2).${jt}Value();
|int $c = ${ctx.genComp(elementType, v1, v2)};
""".stripMargin
} else {
s"int $c = ${ctx.genComp(elementType, s"(($jt) $o1)", s"(($jt) $o2)")};"
}
val canPerformFastSort =
CodeGenerator.isPrimitiveType(elementType) && elementType != BooleanType && !containsNull
val nonNullPrimitiveAscendingSort = if (canPerformFastSort) {
val javaType = CodeGenerator.javaType(elementType)
val primitiveTypeName = CodeGenerator.primitiveTypeName(elementType)
s"""
|if ($order) {
| $javaType[] $array = $base.to${primitiveTypeName}Array();
| java.util.Arrays.sort($array);
| ${ev.value} = $unsafeArrayData.fromPrimitiveArray($array);
|} else
""".stripMargin
} else {
""
}
s"""
|$nonNullPrimitiveAscendingSort
|{
| Object[] $array = $base.toObjectArray($elementTypeTerm);
| final int $sortOrder = $order ? 1 : -1;
| java.util.Arrays.sort($array, new java.util.Comparator() {
| @Override public int compare(Object $o1, Object $o2) {
| if ($o1 == null && $o2 == null) {
| return 0;
| } else if ($o1 == null) {
| return $sortOrder * $nullOrder;
| } else if ($o2 == null) {
| return -$sortOrder * $nullOrder;
| }
| $comp
| return $sortOrder * $c;
| }
| });
| ${ev.value} = new $genericArrayData($array);
|}
""".stripMargin
}
}
}
object ArraySortLike {
type NullOrder = Int
// Least: place null element at the first of the array for ascending order
// Greatest: place null element at the end of the array for ascending order
object NullOrder {
val Least: NullOrder = -1
val Greatest: NullOrder = 1
}
}
/**
* Sorts the input array in ascending / descending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array[, ascendingOrder]) - Sorts the input array in ascending or descending order
according to the natural ordering of the array elements. Null elements will be placed
at the beginning of the returned array in ascending order or at the end of the returned
array in descending order.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'), true);
[null,"a","b","c","d"]
""")
// scalastyle:on line.size.limit
case class SortArray(base: Expression, ascendingOrder: Expression)
extends BinaryExpression with ArraySortLike {
def this(e: Expression) = this(e, Literal(true))
override def left: Expression = base
override def right: Expression = ascendingOrder
override def dataType: DataType = base.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, BooleanType)
override def arrayExpression: Expression = base
override def nullOrder: NullOrder = NullOrder.Least
override def checkInputDataTypes(): TypeCheckResult = base.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
ascendingOrder match {
case Literal(_: Boolean, BooleanType) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
"Sort order in second argument requires a boolean literal.")
}
case ArrayType(dt, _) =>
val dtSimple = dt.catalogString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
override def nullSafeEval(array: Any, ascending: Any): Any = {
sortEval(array, ascending.asInstanceOf[Boolean])
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (b, order) => sortCodegen(ctx, ev, b, order))
}
override def prettyName: String = "sort_array"
}
/**
* Sorts the input array in ascending order according to the natural ordering of
* the array elements and returns it.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(array) - Sorts the input array in ascending order. The elements of the input array must
be orderable. Null elements will be placed at the end of the returned array.
""",
examples = """
Examples:
> SELECT _FUNC_(array('b', 'd', null, 'c', 'a'));
["a","b","c","d",null]
""",
since = "2.4.0")
// scalastyle:on line.size.limit
case class ArraySort(child: Expression) extends UnaryExpression with ArraySortLike {
override def dataType: DataType = child.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
override def arrayExpression: Expression = child
override def nullOrder: NullOrder = NullOrder.Greatest
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(dt, _) if RowOrdering.isOrderable(dt) =>
TypeCheckResult.TypeCheckSuccess
case ArrayType(dt, _) =>
val dtSimple = dt.catalogString
TypeCheckResult.TypeCheckFailure(
s"$prettyName does not support sorting array of type $dtSimple which is not orderable")
case _ =>
TypeCheckResult.TypeCheckFailure(s"$prettyName only supports array input.")
}
override def nullSafeEval(array: Any): Any = {
sortEval(array, true)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => sortCodegen(ctx, ev, c, "true"))
}
override def prettyName: String = "array_sort"
}
/**
* Returns a random permutation of the given array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns a random permutation of the given array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, 3, 5));
[3,1,5,20]
> SELECT _FUNC_(array(1, 20, null, 3));
[20,null,3,1]
""",
note = "The function is non-deterministic.",
since = "2.4.0")
case class Shuffle(child: Expression, randomSeed: Option[Long] = None)
extends UnaryExpression with ExpectsInputTypes with Stateful with ExpressionWithRandomSeed {
def this(child: Expression) = this(child, None)
override def withNewSeed(seed: Long): Shuffle = copy(randomSeed = Some(seed))
override lazy val resolved: Boolean =
childrenResolved && checkInputDataTypes().isSuccess && randomSeed.isDefined
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
override def dataType: DataType = child.dataType
@transient lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
@transient private[this] var random: RandomIndicesGenerator = _
override protected def initializeInternal(partitionIndex: Int): Unit = {
random = RandomIndicesGenerator(randomSeed.get + partitionIndex)
}
override protected def evalInternal(input: InternalRow): Any = {
val value = child.eval(input)
if (value == null) {
null
} else {
val source = value.asInstanceOf[ArrayData]
val numElements = source.numElements()
val indices = random.getNextIndices(numElements)
new GenericArrayData(indices.map(source.get(_, elementType)))
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => shuffleArrayCodeGen(ctx, ev, c))
}
private def shuffleArrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val randomClass = classOf[RandomIndicesGenerator].getName
val rand = ctx.addMutableState(randomClass, "rand", forceInline = true)
ctx.addPartitionInitializationStatement(
s"$rand = new $randomClass(${randomSeed.get}L + partitionIndex);")
val numElements = ctx.freshName("numElements")
val arrayData = ctx.freshName("arrayData")
val indices = ctx.freshName("indices")
val i = ctx.freshName("i")
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElements, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(arrayData, elementType, childName,
i, s"$indices[$i]", dataType.asInstanceOf[ArrayType].containsNull)
s"""
|int $numElements = $childName.numElements();
|int[] $indices = $rand.getNextIndices($numElements);
|$initialization
|for (int $i = 0; $i < $numElements; $i++) {
| $assignment
|}
|${ev.value} = $arrayData;
""".stripMargin
}
override def freshCopy(): Shuffle = Shuffle(child, randomSeed)
}
/**
* Returns a reversed string or an array with reverse order of elements.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns a reversed string or an array with reverse order of elements.",
examples = """
Examples:
> SELECT _FUNC_('Spark SQL');
LQS krapS
> SELECT _FUNC_(array(2, 1, 4, 3));
[3,4,1,2]
""",
since = "1.5.0",
note = "Reverse logic for arrays is available since 2.4.0."
)
case class Reverse(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
// Input types are utilized by type coercion in ImplicitTypeCasts.
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(StringType, ArrayType))
override def dataType: DataType = child.dataType
override def nullSafeEval(input: Any): Any = doReverse(input)
@transient private lazy val doReverse: Any => Any = dataType match {
case ArrayType(elementType, _) =>
input => {
val arrayData = input.asInstanceOf[ArrayData]
new GenericArrayData(arrayData.toObjectArray(elementType).reverse)
}
case StringType => _.asInstanceOf[UTF8String].reverse()
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => dataType match {
case _: StringType => stringCodeGen(ev, c)
case _: ArrayType => arrayCodeGen(ctx, ev, c)
})
}
private def stringCodeGen(ev: ExprCode, childName: String): String = {
s"${ev.value} = ($childName).reverse();"
}
private def arrayCodeGen(ctx: CodegenContext, ev: ExprCode, childName: String): String = {
val numElements = ctx.freshName("numElements")
val arrayData = ctx.freshName("arrayData")
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val elementType = dataType.asInstanceOf[ArrayType].elementType
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElements, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
arrayData, elementType, childName, i, j, dataType.asInstanceOf[ArrayType].containsNull)
s"""
|final int $numElements = $childName.numElements();
|$initialization
|for (int $i = 0; $i < $numElements; $i++) {
| int $j = $numElements - $i - 1;
| $assignment
|}
|${ev.value} = $arrayData;
""".stripMargin
}
override def prettyName: String = "reverse"
}
/**
* Checks if the array (left) has the element (right)
*/
@ExpressionDescription(
usage = "_FUNC_(array, value) - Returns true if the array contains the value.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), 2);
true
""")
case class ArrayContains(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = BooleanType
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (_, NullType) => Seq.empty
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
case _ => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (_, NullType) =>
TypeCheckResult.TypeCheckFailure("Null typed values cannot be used as arguments")
case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}
override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull
}
override def nullSafeEval(arr: Any, value: Any): Any = {
var hasNull = false
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null) {
hasNull = true
} else if (ordering.equiv(v, value)) {
return true
}
)
if (hasNull) {
null
} else {
false
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
val loopBodyCode = if (nullable) {
s"""
|if ($arr.isNullAt($i)) {
| ${ev.isNull} = true;
|} else if (${ctx.genEqual(right.dataType, value, getValue)}) {
| ${ev.isNull} = false;
| ${ev.value} = true;
| break;
|}
""".stripMargin
} else {
s"""
|if (${ctx.genEqual(right.dataType, value, getValue)}) {
| ${ev.value} = true;
| break;
|}
""".stripMargin
}
s"""
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| $loopBodyCode
|}
""".stripMargin
})
}
override def prettyName: String = "array_contains"
}
/**
* Checks if the two arrays contain at least one common element.
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(a1, a2) - Returns true if a1 contains at least a non-null element present also in a2. If the arrays have no common element and they are both non-empty and either of them contains a null element null is returned, false otherwise.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), array(3, 4, 5));
true
""", since = "2.4.0")
// scalastyle:off line.size.limit
case class ArraysOverlap(left: Expression, right: Expression)
extends BinaryArrayExpressionWithImplicitCast {
override def checkInputDataTypes(): TypeCheckResult = super.checkInputDataTypes() match {
case TypeCheckResult.TypeCheckSuccess =>
TypeUtils.checkForOrderingExpr(elementType, s"function $prettyName")
case failure => failure
}
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(elementType)
@transient private lazy val doEvaluation = if (TypeUtils.typeWithProperEquals(elementType)) {
fastEval _
} else {
bruteForceEval _
}
override def dataType: DataType = BooleanType
override def nullable: Boolean = {
left.nullable || right.nullable || left.dataType.asInstanceOf[ArrayType].containsNull ||
right.dataType.asInstanceOf[ArrayType].containsNull
}
override def nullSafeEval(a1: Any, a2: Any): Any = {
doEvaluation(a1.asInstanceOf[ArrayData], a2.asInstanceOf[ArrayData])
}
/**
* A fast implementation which puts all the elements from the smaller array in a set
* and then performs a lookup on it for each element of the bigger one.
* This eval mode works only for data types which implements properly the equals method.
*/
private def fastEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
val (bigger, smaller) = if (arr1.numElements() > arr2.numElements()) {
(arr1, arr2)
} else {
(arr2, arr1)
}
if (smaller.numElements() > 0) {
val smallestSet = new mutable.HashSet[Any]
smaller.foreach(elementType, (_, v) =>
if (v == null) {
hasNull = true
} else {
smallestSet += v
})
bigger.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else if (smallestSet.contains(v1)) {
return true
}
)
}
if (hasNull) {
null
} else {
false
}
}
/**
* A slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceEval(arr1: ArrayData, arr2: ArrayData): Any = {
var hasNull = false
if (arr1.numElements() > 0 && arr2.numElements() > 0) {
arr1.foreach(elementType, (_, v1) =>
if (v1 == null) {
hasNull = true
} else {
arr2.foreach(elementType, (_, v2) =>
if (v2 == null) {
hasNull = true
} else if (ordering.equiv(v1, v2)) {
return true
}
)
})
}
if (hasNull) {
null
} else {
false
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (a1, a2) => {
val smaller = ctx.freshName("smallerArray")
val bigger = ctx.freshName("biggerArray")
val comparisonCode = if (TypeUtils.typeWithProperEquals(elementType)) {
fastCodegen(ctx, ev, smaller, bigger)
} else {
bruteForceCodegen(ctx, ev, smaller, bigger)
}
s"""
|ArrayData $smaller;
|ArrayData $bigger;
|if ($a1.numElements() > $a2.numElements()) {
| $bigger = $a1;
| $smaller = $a2;
|} else {
| $smaller = $a1;
| $bigger = $a2;
|}
|if ($smaller.numElements() > 0) {
| $comparisonCode
|}
""".stripMargin
})
}
/**
* Code generation for a fast implementation which puts all the elements from the smaller array
* in a set and then performs a lookup on it for each element of the bigger one.
* It works only for data types which implements properly the equals method.
*/
private def fastCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, i)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val javaElementClass = CodeGenerator.boxedType(elementType)
val javaSet = classOf[java.util.HashSet[_]].getName
val set = ctx.freshName("set")
val addToSetFromSmallerCode = nullSafeElementCodegen(
smaller, i, s"$set.add($getFromSmaller);", s"${ev.isNull} = true;")
val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else ""
val elementIsInSetCode = nullSafeElementCodegen(
bigger,
i,
s"""
|if ($set.contains($getFromBigger)) {
| $setIsNullCode
| ${ev.value} = true;
| break;
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|$javaSet<$javaElementClass> $set = new $javaSet<$javaElementClass>();
|for (int $i = 0; $i < $smaller.numElements(); $i ++) {
| $addToSetFromSmallerCode
|}
|for (int $i = 0; $i < $bigger.numElements(); $i ++) {
| $elementIsInSetCode
|}
""".stripMargin
}
/**
* Code generation for a slower evaluation which performs a nested loop and supports all the data types.
*/
private def bruteForceCodegen(ctx: CodegenContext, ev: ExprCode, smaller: String, bigger: String): String = {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val getFromSmaller = CodeGenerator.getValue(smaller, elementType, j)
val getFromBigger = CodeGenerator.getValue(bigger, elementType, i)
val setIsNullCode = if (nullable) s"${ev.isNull} = false;" else ""
val compareValues = nullSafeElementCodegen(
smaller,
j,
s"""
|if (${ctx.genEqual(elementType, getFromSmaller, getFromBigger)}) {
| $setIsNullCode
| ${ev.value} = true;
|}
""".stripMargin,
s"${ev.isNull} = true;")
val isInSmaller = nullSafeElementCodegen(
bigger,
i,
s"""
|for (int $j = 0; $j < $smaller.numElements() && !${ev.value}; $j ++) {
| $compareValues
|}
""".stripMargin,
s"${ev.isNull} = true;")
s"""
|for (int $i = 0; $i < $bigger.numElements() && !${ev.value}; $i ++) {
| $isInSmaller
|}
""".stripMargin
}
def nullSafeElementCodegen(
arrayVar: String,
index: String,
code: String,
isNullCode: String): String = {
if (inputTypes.exists(_.asInstanceOf[ArrayType].containsNull)) {
s"""
|if ($arrayVar.isNullAt($index)) {
| $isNullCode
|} else {
| $code
|}
""".stripMargin
} else {
code
}
}
override def prettyName: String = "arrays_overlap"
}
/**
* Slices an array according to the requested start index and length
*/
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_(x, start, length) - Subsets array x starting from index start (or starting from the end if start is negative) with the specified length.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, 4), 2, 2);
[2,3]
> SELECT _FUNC_(array(1, 2, 3, 4), -2, 2);
[3,4]
""", since = "2.4.0")
// scalastyle:on line.size.limit
case class Slice(x: Expression, start: Expression, length: Expression)
extends TernaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = x.dataType
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType, IntegerType, IntegerType)
@transient override lazy val children: Seq[Expression] = Seq(x, start, length) // called from eval
@transient private lazy val elementType: DataType = x.dataType.asInstanceOf[ArrayType].elementType
override def nullSafeEval(xVal: Any, startVal: Any, lengthVal: Any): Any = {
val startInt = startVal.asInstanceOf[Int]
val lengthInt = lengthVal.asInstanceOf[Int]
val arr = xVal.asInstanceOf[ArrayData]
val startIndex = if (startInt == 0) {
throw new RuntimeException(
s"Unexpected value for start in function $prettyName: SQL array indices start at 1.")
} else if (startInt < 0) {
startInt + arr.numElements()
} else {
startInt - 1
}
if (lengthInt < 0) {
throw new RuntimeException(s"Unexpected value for length in function $prettyName: " +
"length must be greater than or equal to 0.")
}
// startIndex can be negative if start is negative and its absolute value is greater than the
// number of elements in the array
if (startIndex < 0 || startIndex >= arr.numElements()) {
return new GenericArrayData(Array.empty[AnyRef])
}
val data = arr.toSeq[AnyRef](elementType)
new GenericArrayData(data.slice(startIndex, startIndex + lengthInt))
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (x, start, length) => {
val startIdx = ctx.freshName("startIdx")
val resLength = ctx.freshName("resLength")
val defaultIntValue = CodeGenerator.defaultValue(CodeGenerator.JAVA_INT, false)
s"""
|${CodeGenerator.JAVA_INT} $startIdx = $defaultIntValue;
|${CodeGenerator.JAVA_INT} $resLength = $defaultIntValue;
|if ($start == 0) {
| throw new RuntimeException("Unexpected value for start in function $prettyName: "
| + "SQL array indices start at 1.");
|} else if ($start < 0) {
| $startIdx = $start + $x.numElements();
|} else {
| // arrays in SQL are 1-based instead of 0-based
| $startIdx = $start - 1;
|}
|if ($length < 0) {
| throw new RuntimeException("Unexpected value for length in function $prettyName: "
| + "length must be greater than or equal to 0.");
|} else if ($length > $x.numElements() - $startIdx) {
| $resLength = $x.numElements() - $startIdx;
|} else {
| $resLength = $length;
|}
|${genCodeForResult(ctx, ev, x, startIdx, resLength)}
""".stripMargin
})
}
def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
startIdx: String,
resLength: String): String = {
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val genericArrayData = classOf[GenericArrayData].getName
val allocation = CodeGenerator.createArrayData(
values, elementType, resLength, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(values, elementType, inputArray,
i, s"$i + $startIdx", dataType.asInstanceOf[ArrayType].containsNull)
s"""
|if ($startIdx < 0 || $startIdx >= $inputArray.numElements()) {
| ${ev.value} = new $genericArrayData(new Object[0]);
|} else {
| $allocation
| for (int $i = 0; $i < $resLength; $i ++) {
| $assignment
| }
| ${ev.value} = $values;
|}
""".stripMargin
}
}
/**
* Creates a String containing all the elements of the input array separated by the delimiter.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, delimiter[, nullReplacement]) - Concatenates the elements of the given array
using the delimiter and an optional string to replace nulls. If no value is set for
nullReplacement, any null value is filtered.""",
examples = """
Examples:
> SELECT _FUNC_(array('hello', 'world'), ' ');
hello world
> SELECT _FUNC_(array('hello', null ,'world'), ' ');
hello world
> SELECT _FUNC_(array('hello', null ,'world'), ' ', ',');
hello , world
""", since = "2.4.0")
case class ArrayJoin(
array: Expression,
delimiter: Expression,
nullReplacement: Option[Expression]) extends Expression with ExpectsInputTypes {
def this(array: Expression, delimiter: Expression) = this(array, delimiter, None)
def this(array: Expression, delimiter: Expression, nullReplacement: Expression) =
this(array, delimiter, Some(nullReplacement))
override def inputTypes: Seq[AbstractDataType] = if (nullReplacement.isDefined) {
Seq(ArrayType(StringType), StringType, StringType)
} else {
Seq(ArrayType(StringType), StringType)
}
override def children: Seq[Expression] = if (nullReplacement.isDefined) {
Seq(array, delimiter, nullReplacement.get)
} else {
Seq(array, delimiter)
}
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = {
val arrayEval = array.eval(input)
if (arrayEval == null) return null
val delimiterEval = delimiter.eval(input)
if (delimiterEval == null) return null
val nullReplacementEval = nullReplacement.map(_.eval(input))
if (nullReplacementEval.contains(null)) return null
val buffer = new UTF8StringBuilder()
var firstItem = true
val nullHandling = nullReplacementEval match {
case Some(rep) => (prependDelimiter: Boolean) => {
if (!prependDelimiter) {
buffer.append(delimiterEval.asInstanceOf[UTF8String])
}
buffer.append(rep.asInstanceOf[UTF8String])
true
}
case None => (_: Boolean) => false
}
arrayEval.asInstanceOf[ArrayData].foreach(StringType, (_, item) => {
if (item == null) {
if (nullHandling(firstItem)) {
firstItem = false
}
} else {
if (!firstItem) {
buffer.append(delimiterEval.asInstanceOf[UTF8String])
}
buffer.append(item.asInstanceOf[UTF8String])
firstItem = false
}
})
buffer.build()
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val code = nullReplacement match {
case Some(replacement) =>
val replacementGen = replacement.genCode(ctx)
val nullHandling = (buffer: String, delimiter: String, firstItem: String) => {
s"""
|if (!$firstItem) {
| $buffer.append($delimiter);
|}
|$buffer.append(${replacementGen.value});
|$firstItem = false;
""".stripMargin
}
val execCode = if (replacement.nullable) {
ctx.nullSafeExec(replacement.nullable, replacementGen.isNull) {
genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
}
} else {
genCodeForArrayAndDelimiter(ctx, ev, nullHandling)
}
s"""
|${replacementGen.code}
|$execCode
""".stripMargin
case None => genCodeForArrayAndDelimiter(ctx, ev,
(_: String, _: String, _: String) => "// nulls are ignored")
}
if (nullable) {
ev.copy(
code"""
|boolean ${ev.isNull} = true;
|UTF8String ${ev.value} = null;
|$code
""".stripMargin)
} else {
ev.copy(
code"""
|UTF8String ${ev.value} = null;
|$code
""".stripMargin, FalseLiteral)
}
}
private def genCodeForArrayAndDelimiter(
ctx: CodegenContext,
ev: ExprCode,
nullEval: (String, String, String) => String): String = {
val arrayGen = array.genCode(ctx)
val delimiterGen = delimiter.genCode(ctx)
val buffer = ctx.freshName("buffer")
val bufferClass = classOf[UTF8StringBuilder].getName
val i = ctx.freshName("i")
val firstItem = ctx.freshName("firstItem")
val resultCode =
s"""
|$bufferClass $buffer = new $bufferClass();
|boolean $firstItem = true;
|for (int $i = 0; $i < ${arrayGen.value}.numElements(); $i ++) {
| if (${arrayGen.value}.isNullAt($i)) {
| ${nullEval(buffer, delimiterGen.value, firstItem)}
| } else {
| if (!$firstItem) {
| $buffer.append(${delimiterGen.value});
| }
| $buffer.append(${CodeGenerator.getValue(arrayGen.value, StringType, i)});
| $firstItem = false;
| }
|}
|${ev.value} = $buffer.build();""".stripMargin
if (array.nullable || delimiter.nullable) {
arrayGen.code + ctx.nullSafeExec(array.nullable, arrayGen.isNull) {
delimiterGen.code + ctx.nullSafeExec(delimiter.nullable, delimiterGen.isNull) {
s"""
|${ev.isNull} = false;
|$resultCode""".stripMargin
}
}
} else {
s"""
|${arrayGen.code}
|${delimiterGen.code}
|$resultCode""".stripMargin
}
}
override def dataType: DataType = StringType
override def prettyName: String = "array_join"
}
/**
* Returns the minimum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the minimum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
1
""", since = "2.4.0")
case class ArrayMin(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfSmaller(dataType, ev, item)}
| }
|}
""".stripMargin)
}
override protected def nullSafeEval(input: Any): Any = {
var min: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (min == null || ordering.lt(item, min))) {
min = item
}
)
min
}
@transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
override def prettyName: String = "array_min"
}
/**
* Returns the maximum value in the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Returns the maximum value in the array. NULL elements are skipped.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 20, null, 3));
20
""", since = "2.4.0")
case class ArrayMax(child: Expression) extends UnaryExpression with ImplicitCastInputTypes {
override def nullable: Boolean = true
override def inputTypes: Seq[AbstractDataType] = Seq(ArrayType)
@transient private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)
override def checkInputDataTypes(): TypeCheckResult = {
val typeCheckResult = super.checkInputDataTypes()
if (typeCheckResult.isSuccess) {
TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
} else {
typeCheckResult
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val childGen = child.genCode(ctx)
val javaType = CodeGenerator.javaType(dataType)
val i = ctx.freshName("i")
val item = ExprCode(EmptyBlock,
isNull = JavaCode.isNullExpression(s"${childGen.value}.isNullAt($i)"),
value = JavaCode.expression(CodeGenerator.getValue(childGen.value, dataType, i), dataType))
ev.copy(code =
code"""
|${childGen.code}
|boolean ${ev.isNull} = true;
|$javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
|if (!${childGen.isNull}) {
| for (int $i = 0; $i < ${childGen.value}.numElements(); $i ++) {
| ${ctx.reassignIfGreater(dataType, ev, item)}
| }
|}
""".stripMargin)
}
override protected def nullSafeEval(input: Any): Any = {
var max: Any = null
input.asInstanceOf[ArrayData].foreach(dataType, (_, item) =>
if (item != null && (max == null || ordering.gt(item, max))) {
max = item
}
)
max
}
@transient override lazy val dataType: DataType = child.dataType match {
case ArrayType(dt, _) => dt
case _ => throw new IllegalStateException(s"$prettyName accepts only arrays.")
}
override def prettyName: String = "array_max"
}
/**
* Returns the position of the first occurrence of element in the given array as long.
* Returns 0 if the given value could not be found in the array. Returns null if either of
* the arguments are null
*
* NOTE: that this is not zero based, but 1-based index. The first element in the array has
* index 1.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, element) - Returns the (1-based) index of the first element of the array as long.
""",
examples = """
Examples:
> SELECT _FUNC_(array(3, 2, 1), 1);
3
""",
since = "2.4.0")
case class ArrayPosition(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
override def dataType: DataType = LongType
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
case _ => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}
override def nullSafeEval(arr: Any, value: Any): Any = {
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v != null && ordering.equiv(v, value)) {
return (i + 1).toLong
}
)
0L
}
override def prettyName: String = "array_position"
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val pos = ctx.freshName("arrayPosition")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(arr, right.dataType, i)
s"""
|int $pos = 0;
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| if (!$arr.isNullAt($i) && ${ctx.genEqual(right.dataType, value, getValue)}) {
| $pos = $i + 1;
| break;
| }
|}
|${ev.value} = (long) $pos;
""".stripMargin
})
}
}
/**
* Returns the value of index `right` in Array `left` or the value for key `right` in Map `left`.
*/
@ExpressionDescription(
usage = """
_FUNC_(array, index) - Returns element of array at given (1-based) index. If index < 0,
accesses elements from the last to the first. Returns NULL if the index exceeds the length
of the array.
_FUNC_(map, key) - Returns value for given key, or NULL if the key is not contained in the map
""",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3), 2);
2
> SELECT _FUNC_(map(1, 'a', 2, 'b'), 2);
b
""",
since = "2.4.0")
case class ElementAt(left: Expression, right: Expression) extends GetMapValueUtil {
@transient private lazy val mapKeyType = left.dataType.asInstanceOf[MapType].keyType
@transient private lazy val arrayContainsNull = left.dataType.asInstanceOf[ArrayType].containsNull
@transient private lazy val ordering: Ordering[Any] = TypeUtils.getInterpretedOrdering(mapKeyType)
@transient override lazy val dataType: DataType = left.dataType match {
case ArrayType(elementType, _) => elementType
case MapType(_, valueType, _) => valueType
}
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (arr: ArrayType, e2: IntegralType) if (e2 != LongType) =>
Seq(arr, IntegerType)
case (MapType(keyType, valueType, hasNull), e2) =>
TypeCoercion.findTightestCommonType(keyType, e2) match {
case Some(dt) => Seq(MapType(dt, valueType, hasNull), dt)
case _ => Seq.empty
}
case (l, r) => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (_: ArrayType, e2) if e2 != IntegerType =>
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a ${IntegerType.simpleString}, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
case (MapType(e1, _, _), e2) if (!e2.sameType(e1)) =>
TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${MapType.simpleString} followed by a value of same key type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
case (e1, _) if (!e1.isInstanceOf[MapType] && !e1.isInstanceOf[ArrayType]) =>
TypeCheckResult.TypeCheckFailure(s"The first argument to function $prettyName should " +
s"have been ${ArrayType.simpleString} or ${MapType.simpleString} type, but its " +
s"${left.dataType.catalogString} type.")
case _ => TypeCheckResult.TypeCheckSuccess
}
}
override def nullable: Boolean = true
override def nullSafeEval(value: Any, ordinal: Any): Any = doElementAt(value, ordinal)
@transient private lazy val doElementAt: (Any, Any) => Any = left.dataType match {
case _: ArrayType =>
(value, ordinal) => {
val array = value.asInstanceOf[ArrayData]
val index = ordinal.asInstanceOf[Int]
if (array.numElements() < math.abs(index)) {
null
} else {
val idx = if (index == 0) {
throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1")
} else if (index > 0) {
index - 1
} else {
array.numElements() + index
}
if (arrayContainsNull && array.isNullAt(idx)) {
null
} else {
array.get(idx, dataType)
}
}
}
case _: MapType =>
(value, ordinal) => getValueEval(value, ordinal, mapKeyType, ordering)
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
left.dataType match {
case _: ArrayType =>
nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
val index = ctx.freshName("elementAtIndex")
val nullCheck = if (arrayContainsNull) {
s"""
|if ($eval1.isNullAt($index)) {
| ${ev.isNull} = true;
|} else
""".stripMargin
} else {
""
}
s"""
|int $index = (int) $eval2;
|if ($eval1.numElements() < Math.abs($index)) {
| ${ev.isNull} = true;
|} else {
| if ($index == 0) {
| throw new ArrayIndexOutOfBoundsException("SQL array indices start at 1");
| } else if ($index > 0) {
| $index--;
| } else {
| $index += $eval1.numElements();
| }
| $nullCheck
| {
| ${ev.value} = ${CodeGenerator.getValue(eval1, dataType, index)};
| }
|}
""".stripMargin
})
case _: MapType =>
doGetValueGenCode(ctx, ev, left.dataType.asInstanceOf[MapType])
}
}
override def prettyName: String = "element_at"
}
/**
* Concatenates multiple input columns together into a single column.
* The function works with strings, binary and compatible array columns.
*/
@ExpressionDescription(
usage = "_FUNC_(col1, col2, ..., colN) - Returns the concatenation of col1, col2, ..., colN.",
examples = """
Examples:
> SELECT _FUNC_('Spark', 'SQL');
SparkSQL
> SELECT _FUNC_(array(1, 2, 3), array(4, 5), array(6));
[1,2,3,4,5,6]
""",
note = "Concat logic for arrays is available since 2.4.0.")
case class Concat(children: Seq[Expression]) extends ComplexTypeMergingExpression {
private def allowedTypes: Seq[AbstractDataType] = Seq(StringType, BinaryType, ArrayType)
override def checkInputDataTypes(): TypeCheckResult = {
if (children.isEmpty) {
TypeCheckResult.TypeCheckSuccess
} else {
val childTypes = children.map(_.dataType)
if (childTypes.exists(tpe => !allowedTypes.exists(_.acceptsType(tpe)))) {
return TypeCheckResult.TypeCheckFailure(
s"input to function $prettyName should have been ${StringType.simpleString}," +
s" ${BinaryType.simpleString} or ${ArrayType.simpleString}, but it's " +
childTypes.map(_.catalogString).mkString("[", ", ", "]"))
}
TypeUtils.checkForSameTypeInputExpr(childTypes, s"function $prettyName")
}
}
@transient override lazy val dataType: DataType = {
if (children.isEmpty) {
StringType
} else {
super.dataType
}
}
private def javaType: String = CodeGenerator.javaType(dataType)
override def nullable: Boolean = children.exists(_.nullable)
override def foldable: Boolean = children.forall(_.foldable)
override def eval(input: InternalRow): Any = doConcat(input)
@transient private lazy val doConcat: InternalRow => Any = dataType match {
case BinaryType =>
input => {
val inputs = children.map(_.eval(input).asInstanceOf[Array[Byte]])
ByteArray.concat(inputs: _*)
}
case StringType =>
input => {
val inputs = children.map(_.eval(input).asInstanceOf[UTF8String])
UTF8String.concat(inputs: _*)
}
case ArrayType(elementType, _) =>
input => {
val inputs = children.toStream.map(_.eval(input))
if (inputs.contains(null)) {
null
} else {
val arrayData = inputs.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, ad) => sum + ad.numElements())
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to concat arrays with $numberOfElements" +
" elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val finalData = new Array[AnyRef](numberOfElements.toInt)
var position = 0
for (ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, finalData, position, arr.length)
position += arr.length
}
new GenericArrayData(finalData)
}
}
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val evals = children.map(_.genCode(ctx))
val args = ctx.freshName("args")
val hasNull = ctx.freshName("hasNull")
val inputs = evals.zip(children.map(_.nullable)).zipWithIndex.map {
case ((eval, true), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| if (!${eval.isNull}) {
| $args[$index] = ${eval.value};
| } else {
| $hasNull = true;
| }
|}
""".stripMargin
case ((eval, false), index) =>
s"""
|if (!$hasNull) {
| ${eval.code}
| $args[$index] = ${eval.value};
|}
""".stripMargin
}
val codes = ctx.splitExpressionsWithCurrentInputs(
expressions = inputs,
funcName = "valueConcat",
extraArguments = (s"$javaType[]", args) :: ("boolean", hasNull) :: Nil,
returnType = "boolean",
makeSplitFunction = body =>
s"""
|$body
|return $hasNull;
""".stripMargin,
foldFunctions = _.map(funcCall => s"$hasNull = $funcCall;").mkString("\n")
)
val (concat, initCode) = dataType match {
case BinaryType =>
(s"${classOf[ByteArray].getName}.concat", s"byte[][] $args = new byte[${evals.length}][];")
case StringType =>
("UTF8String.concat", s"UTF8String[] $args = new UTF8String[${evals.length}];")
case ArrayType(elementType, containsNull) =>
val concat = genCodeForArrays(ctx, elementType, containsNull)
(concat, s"ArrayData[] $args = new ArrayData[${evals.length}];")
}
ev.copy(code =
code"""
|boolean $hasNull = false;
|$initCode
|$codes
|$javaType ${ev.value} = null;
|if (!$hasNull) {
| ${ev.value} = $concat($args);
|}
|boolean ${ev.isNull} = ${ev.value} == null;
""".stripMargin)
}
private def genCodeForNumberOfElements(ctx: CodegenContext) : (String, String) = {
val numElements = ctx.freshName("numElements")
val z = ctx.freshName("z")
val code = s"""
|long $numElements = 0L;
|for (int $z = 0; $z < ${children.length}; $z++) {
| $numElements += args[$z].numElements();
|}
""".stripMargin
(code, numElements)
}
private def genCodeForArrays(
ctx: CodegenContext,
elementType: DataType,
checkForNull: Boolean): String = {
val counter = ctx.freshName("counter")
val arrayData = ctx.freshName("arrayData")
val y = ctx.freshName("y")
val z = ctx.freshName("z")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx)
val initialization = CodeGenerator.createArrayData(
arrayData, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
arrayData, elementType, s"args[$y]", counter, z,
dataType.asInstanceOf[ArrayType].containsNull)
val concat = ctx.freshName("concat")
val concatDef =
s"""
|private ArrayData $concat(ArrayData[] args) {
| $numElemCode
| $initialization
| int $counter = 0;
| for (int $y = 0; $y < ${children.length}; $y++) {
| for (int $z = 0; $z < args[$y].numElements(); $z++) {
| $assignment
| $counter++;
| }
| }
| return $arrayData;
|}
""".stripMargin
ctx.addNewFunction(concat, concatDef)
}
override def toString: String = s"concat(${children.mkString(", ")})"
override def sql: String = s"concat(${children.map(_.sql).mkString(", ")})"
}
/**
* Transforms an array of arrays into a single array.
*/
@ExpressionDescription(
usage = "_FUNC_(arrayOfArrays) - Transforms an array of arrays into a single array.",
examples = """
Examples:
> SELECT _FUNC_(array(array(1, 2), array(3, 4)));
[1,2,3,4]
""",
since = "2.4.0")
case class Flatten(child: Expression) extends UnaryExpression {
private def childDataType: ArrayType = child.dataType.asInstanceOf[ArrayType]
override def nullable: Boolean = child.nullable || childDataType.containsNull
@transient override lazy val dataType: DataType = childDataType.elementType
@transient private lazy val elementType: DataType = dataType.asInstanceOf[ArrayType].elementType
override def checkInputDataTypes(): TypeCheckResult = child.dataType match {
case ArrayType(_: ArrayType, _) =>
TypeCheckResult.TypeCheckSuccess
case _ =>
TypeCheckResult.TypeCheckFailure(
s"The argument should be an array of arrays, " +
s"but '${child.sql}' is of ${child.dataType.catalogString} type."
)
}
override def nullSafeEval(child: Any): Any = {
val elements = child.asInstanceOf[ArrayData].toObjectArray(dataType)
if (elements.contains(null)) {
null
} else {
val arrayData = elements.map(_.asInstanceOf[ArrayData])
val numberOfElements = arrayData.foldLeft(0L)((sum, e) => sum + e.numElements())
if (numberOfElements > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException("Unsuccessful try to flatten an array of arrays with " +
s"$numberOfElements elements due to exceeding the array size limit " +
ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH + ".")
}
val flattenedData = new Array(numberOfElements.toInt)
var position = 0
for (ad <- arrayData) {
val arr = ad.toObjectArray(elementType)
Array.copy(arr, 0, flattenedData, position, arr.length)
position += arr.length
}
new GenericArrayData(flattenedData)
}
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, c => {
val code = genCodeForFlatten(ctx, c, ev.value)
ctx.nullArrayElementsSaveExec(childDataType.containsNull, ev.isNull, c)(code)
})
}
private def genCodeForNumberOfElements(
ctx: CodegenContext,
childVariableName: String) : (String, String) = {
val variableName = ctx.freshName("numElements")
val code = s"""
|long $variableName = 0;
|for (int z = 0; z < $childVariableName.numElements(); z++) {
| $variableName += $childVariableName.getArray(z).numElements();
|}
""".stripMargin
(code, variableName)
}
private def genCodeForFlatten(
ctx: CodegenContext,
childVariableName: String,
arrayDataName: String): String = {
val counter = ctx.freshName("counter")
val tempArrayDataName = ctx.freshName("tempArrayData")
val k = ctx.freshName("k")
val l = ctx.freshName("l")
val arr = ctx.freshName("arr")
val (numElemCode, numElemName) = genCodeForNumberOfElements(ctx, childVariableName)
val allocation = CodeGenerator.createArrayData(
tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
tempArrayDataName, elementType, arr, counter, l,
dataType.asInstanceOf[ArrayType].containsNull)
s"""
|$numElemCode
|$allocation
|int $counter = 0;
|for (int $k = 0; $k < $childVariableName.numElements(); $k++) {
| ArrayData $arr = $childVariableName.getArray($k);
| for (int $l = 0; $l < $arr.numElements(); $l++) {
| $assignment
| $counter++;
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
}
override def prettyName: String = "flatten"
}
@ExpressionDescription(
usage = """
_FUNC_(start, stop, step) - Generates an array of elements from start to stop (inclusive),
incrementing by step. The type of the returned elements is the same as the type of argument
expressions.
Supported types are: byte, short, integer, long, date, timestamp.
The start and stop expressions must resolve to the same type.
If start and stop expressions resolve to the 'date' or 'timestamp' type
then the step expression must resolve to the 'interval' type, otherwise to the same type
as the start and stop expressions.
""",
arguments = """
Arguments:
* start - an expression. The start of the range.
* stop - an expression. The end the range (inclusive).
* step - an optional expression. The step of the range.
By default step is 1 if start is less than or equal to stop, otherwise -1.
For the temporal sequences it's 1 day and -1 day respectively.
If start is greater than stop then the step must be negative, and vice versa.
""",
examples = """
Examples:
> SELECT _FUNC_(1, 5);
[1,2,3,4,5]
> SELECT _FUNC_(5, 1);
[5,4,3,2,1]
> SELECT _FUNC_(to_date('2018-01-01'), to_date('2018-03-01'), interval 1 month);
[2018-01-01,2018-02-01,2018-03-01]
""",
since = "2.4.0"
)
case class Sequence(
start: Expression,
stop: Expression,
stepOpt: Option[Expression],
timeZoneId: Option[String] = None)
extends Expression
with TimeZoneAwareExpression {
import Sequence._
def this(start: Expression, stop: Expression) =
this(start, stop, None, None)
def this(start: Expression, stop: Expression, step: Expression) =
this(start, stop, Some(step), None)
override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
copy(timeZoneId = Some(timeZoneId))
override def children: Seq[Expression] = Seq(start, stop) ++ stepOpt
override def foldable: Boolean = children.forall(_.foldable)
override def nullable: Boolean = children.exists(_.nullable)
override def dataType: ArrayType = ArrayType(start.dataType, containsNull = false)
override def checkInputDataTypes(): TypeCheckResult = {
val startType = start.dataType
def stepType = stepOpt.get.dataType
val typesCorrect =
startType.sameType(stop.dataType) &&
(startType match {
case TimestampType | DateType =>
stepOpt.isEmpty || CalendarIntervalType.acceptsType(stepType)
case _: IntegralType =>
stepOpt.isEmpty || stepType.sameType(startType)
case _ => false
})
if (typesCorrect) {
TypeCheckResult.TypeCheckSuccess
} else {
TypeCheckResult.TypeCheckFailure(
s"$prettyName only supports integral, timestamp or date types")
}
}
def coercibleChildren: Seq[Expression] = children.filter(_.dataType != CalendarIntervalType)
def castChildrenTo(widerType: DataType): Expression = Sequence(
Cast(start, widerType),
Cast(stop, widerType),
stepOpt.map(step => if (step.dataType != CalendarIntervalType) Cast(step, widerType) else step),
timeZoneId)
@transient private lazy val impl: SequenceImpl = dataType.elementType match {
case iType: IntegralType =>
type T = iType.InternalType
val ct = ClassTag[T](iType.tag.mirror.runtimeClass(iType.tag.tpe))
new IntegralSequenceImpl(iType)(ct, iType.integral)
case TimestampType =>
new TemporalSequenceImpl[Long](LongType, 1, identity, timeZone)
case DateType =>
new TemporalSequenceImpl[Int](IntegerType, MICROS_PER_DAY, _.toInt, timeZone)
}
override def eval(input: InternalRow): Any = {
val startVal = start.eval(input)
if (startVal == null) return null
val stopVal = stop.eval(input)
if (stopVal == null) return null
val stepVal = stepOpt.map(_.eval(input)).getOrElse(impl.defaultStep(startVal, stopVal))
if (stepVal == null) return null
ArrayData.toArrayData(impl.eval(startVal, stopVal, stepVal))
}
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val startGen = start.genCode(ctx)
val stopGen = stop.genCode(ctx)
val stepGen = stepOpt.map(_.genCode(ctx)).getOrElse(
impl.defaultStep.genCode(ctx, startGen, stopGen))
val resultType = CodeGenerator.javaType(dataType)
val resultCode = {
val arr = ctx.freshName("arr")
val arrElemType = CodeGenerator.javaType(dataType.elementType)
s"""
|final $arrElemType[] $arr = null;
|${impl.genCode(ctx, startGen.value, stopGen.value, stepGen.value, arr, arrElemType)}
|${ev.value} = UnsafeArrayData.fromPrimitiveArray($arr);
""".stripMargin
}
if (nullable) {
val nullSafeEval =
startGen.code + ctx.nullSafeExec(start.nullable, startGen.isNull) {
stopGen.code + ctx.nullSafeExec(stop.nullable, stopGen.isNull) {
stepGen.code + ctx.nullSafeExec(stepOpt.exists(_.nullable), stepGen.isNull) {
s"""
|${ev.isNull} = false;
|$resultCode
""".stripMargin
}
}
}
ev.copy(code =
code"""
|boolean ${ev.isNull} = true;
|$resultType ${ev.value} = null;
|$nullSafeEval
""".stripMargin)
} else {
ev.copy(code =
code"""
|${startGen.code}
|${stopGen.code}
|${stepGen.code}
|$resultType ${ev.value} = null;
|$resultCode
""".stripMargin,
isNull = FalseLiteral)
}
}
}
object Sequence {
private type LessThanOrEqualFn = (Any, Any) => Boolean
private class DefaultStep(lteq: LessThanOrEqualFn, stepType: DataType, one: Any) {
private val negativeOne = UnaryMinus(Literal(one)).eval()
def apply(start: Any, stop: Any): Any = {
if (lteq(start, stop)) one else negativeOne
}
def genCode(ctx: CodegenContext, startGen: ExprCode, stopGen: ExprCode): ExprCode = {
val Seq(oneVal, negativeOneVal) = Seq(one, negativeOne).map(Literal(_).genCode(ctx).value)
ExprCode.forNonNullValue(JavaCode.expression(
s"${startGen.value} <= ${stopGen.value} ? $oneVal : $negativeOneVal",
stepType))
}
}
private trait SequenceImpl {
def eval(start: Any, stop: Any, step: Any): Any
def genCode(
ctx: CodegenContext,
start: String,
stop: String,
step: String,
arr: String,
elemType: String): String
val defaultStep: DefaultStep
}
private class IntegralSequenceImpl[T: ClassTag]
(elemType: IntegralType)(implicit num: Integral[T]) extends SequenceImpl {
override val defaultStep: DefaultStep = new DefaultStep(
(elemType.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
elemType,
num.one)
override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
import num._
val start = input1.asInstanceOf[T]
val stop = input2.asInstanceOf[T]
val step = input3.asInstanceOf[T]
var i: Int = getSequenceLength(start, stop, step)
val arr = new Array[T](i)
while (i > 0) {
i -= 1
arr(i) = start + step * num.fromInt(i)
}
arr
}
override def genCode(
ctx: CodegenContext,
start: String,
stop: String,
step: String,
arr: String,
elemType: String): String = {
val i = ctx.freshName("i")
s"""
|${genSequenceLengthCode(ctx, start, stop, step, i)}
|$arr = new $elemType[$i];
|while ($i > 0) {
| $i--;
| $arr[$i] = ($elemType) ($start + $step * $i);
|}
""".stripMargin
}
}
private class TemporalSequenceImpl[T: ClassTag]
(dt: IntegralType, scale: Long, fromLong: Long => T, timeZone: TimeZone)
(implicit num: Integral[T]) extends SequenceImpl {
override val defaultStep: DefaultStep = new DefaultStep(
(dt.ordering.lteq _).asInstanceOf[LessThanOrEqualFn],
CalendarIntervalType,
new CalendarInterval(0, MICROS_PER_DAY))
private val backedSequenceImpl = new IntegralSequenceImpl[T](dt)
private val microsPerMonth = 28 * CalendarInterval.MICROS_PER_DAY
override def eval(input1: Any, input2: Any, input3: Any): Array[T] = {
val start = input1.asInstanceOf[T]
val stop = input2.asInstanceOf[T]
val step = input3.asInstanceOf[CalendarInterval]
val stepMonths = step.months
val stepMicros = step.microseconds
if (stepMonths == 0) {
backedSequenceImpl.eval(start, stop, fromLong(stepMicros / scale))
} else {
// To estimate the resulted array length we need to make assumptions
// about a month length in microseconds
val intervalStepInMicros = stepMicros + stepMonths * microsPerMonth
val startMicros: Long = num.toLong(start) * scale
val stopMicros: Long = num.toLong(stop) * scale
val maxEstimatedArrayLength =
getSequenceLength(startMicros, stopMicros, intervalStepInMicros)
val stepSign = if (stopMicros > startMicros) +1 else -1
val exclusiveItem = stopMicros + stepSign
val arr = new Array[T](maxEstimatedArrayLength)
var t = startMicros
var i = 0
while (t < exclusiveItem ^ stepSign < 0) {
arr(i) = fromLong(t / scale)
t = timestampAddInterval(t, stepMonths, stepMicros, timeZone)
i += 1
}
// truncate array to the correct length
if (arr.length == i) arr else arr.slice(0, i)
}
}
override def genCode(
ctx: CodegenContext,
start: String,
stop: String,
step: String,
arr: String,
elemType: String): String = {
val stepMonths = ctx.freshName("stepMonths")
val stepMicros = ctx.freshName("stepMicros")
val stepScaled = ctx.freshName("stepScaled")
val intervalInMicros = ctx.freshName("intervalInMicros")
val startMicros = ctx.freshName("startMicros")
val stopMicros = ctx.freshName("stopMicros")
val arrLength = ctx.freshName("arrLength")
val stepSign = ctx.freshName("stepSign")
val exclusiveItem = ctx.freshName("exclusiveItem")
val t = ctx.freshName("t")
val i = ctx.freshName("i")
val genTimeZone = ctx.addReferenceObj("timeZone", timeZone, classOf[TimeZone].getName)
val sequenceLengthCode =
s"""
|final long $intervalInMicros = $stepMicros + $stepMonths * ${microsPerMonth}L;
|${genSequenceLengthCode(ctx, startMicros, stopMicros, intervalInMicros, arrLength)}
""".stripMargin
val timestampAddIntervalCode =
s"""
|$t = org.apache.spark.sql.catalyst.util.DateTimeUtils.timestampAddInterval(
| $t, $stepMonths, $stepMicros, $genTimeZone);
""".stripMargin
s"""
|final int $stepMonths = $step.months;
|final long $stepMicros = $step.microseconds;
|
|if ($stepMonths == 0) {
| final $elemType $stepScaled = ($elemType) ($stepMicros / ${scale}L);
| ${backedSequenceImpl.genCode(ctx, start, stop, stepScaled, arr, elemType)};
|
|} else {
| final long $startMicros = $start * ${scale}L;
| final long $stopMicros = $stop * ${scale}L;
|
| $sequenceLengthCode
|
| final int $stepSign = $stopMicros > $startMicros ? +1 : -1;
| final long $exclusiveItem = $stopMicros + $stepSign;
|
| $arr = new $elemType[$arrLength];
| long $t = $startMicros;
| int $i = 0;
|
| while ($t < $exclusiveItem ^ $stepSign < 0) {
| $arr[$i] = ($elemType) ($t / ${scale}L);
| $timestampAddIntervalCode
| $i += 1;
| }
|
| if ($arr.length > $i) {
| $arr = java.util.Arrays.copyOf($arr, $i);
| }
|}
""".stripMargin
}
}
private def getSequenceLength[U](start: U, stop: U, step: U)(implicit num: Integral[U]): Int = {
import num._
require(
(step > num.zero && start <= stop)
|| (step < num.zero && start >= stop)
|| (step == num.zero && start == stop),
s"Illegal sequence boundaries: $start to $stop by $step")
val len = if (start == stop) 1L else 1L + (stop.toLong - start.toLong) / step.toLong
require(
len <= MAX_ROUNDED_ARRAY_LENGTH,
s"Too long sequence: $len. Should be <= $MAX_ROUNDED_ARRAY_LENGTH")
len.toInt
}
private def genSequenceLengthCode(
ctx: CodegenContext,
start: String,
stop: String,
step: String,
len: String): String = {
val longLen = ctx.freshName("longLen")
s"""
|if (!(($step > 0 && $start <= $stop) ||
| ($step < 0 && $start >= $stop) ||
| ($step == 0 && $start == $stop))) {
| throw new IllegalArgumentException(
| "Illegal sequence boundaries: " + $start + " to " + $stop + " by " + $step);
|}
|long $longLen = $stop == $start ? 1L : 1L + ((long) $stop - $start) / $step;
|if ($longLen > $MAX_ROUNDED_ARRAY_LENGTH) {
| throw new IllegalArgumentException(
| "Too long sequence: " + $longLen + ". Should be <= $MAX_ROUNDED_ARRAY_LENGTH");
|}
|int $len = (int) $longLen;
""".stripMargin
}
}
/**
* Returns the array containing the given input value (left) count (right) times.
*/
@ExpressionDescription(
usage = "_FUNC_(element, count) - Returns the array containing element count times.",
examples = """
Examples:
> SELECT _FUNC_('123', 2);
["123","123"]
""",
since = "2.4.0")
case class ArrayRepeat(left: Expression, right: Expression)
extends BinaryExpression with ExpectsInputTypes {
override def dataType: ArrayType = ArrayType(left.dataType, left.nullable)
override def inputTypes: Seq[AbstractDataType] = Seq(AnyDataType, IntegerType)
override def nullable: Boolean = right.nullable
override def eval(input: InternalRow): Any = {
val count = right.eval(input)
if (count == null) {
null
} else {
if (count.asInstanceOf[Int] > ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH) {
throw new RuntimeException(s"Unsuccessful try to create array with $count elements " +
s"due to exceeding the array size limit ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}.");
}
val element = left.eval(input)
new GenericArrayData(Array.fill(count.asInstanceOf[Int])(element))
}
}
override def prettyName: String = "array_repeat"
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val leftGen = left.genCode(ctx)
val rightGen = right.genCode(ctx)
val element = leftGen.value
val count = rightGen.value
val et = dataType.elementType
val coreLogic = genCodeForElement(ctx, et, element, count, leftGen.isNull, ev.value)
val resultCode = nullElementsProtection(ev, rightGen.isNull, coreLogic)
ev.copy(code =
code"""
|boolean ${ev.isNull} = false;
|${leftGen.code}
|${rightGen.code}
|${CodeGenerator.javaType(dataType)} ${ev.value} =
| ${CodeGenerator.defaultValue(dataType)};
|$resultCode
""".stripMargin)
}
private def nullElementsProtection(
ev: ExprCode,
rightIsNull: String,
coreLogic: String): String = {
if (nullable) {
s"""
|if ($rightIsNull) {
| ${ev.isNull} = true;
|} else {
| ${coreLogic}
|}
""".stripMargin
} else {
coreLogic
}
}
private def genCodeForNumberOfElements(ctx: CodegenContext, count: String): (String, String) = {
val numElements = ctx.freshName("numElements")
val numElementsCode =
s"""
|int $numElements = 0;
|if ($count > 0) {
| $numElements = $count;
|}
""".stripMargin
(numElements, numElementsCode)
}
private def genCodeForElement(
ctx: CodegenContext,
elementType: DataType,
element: String,
count: String,
leftIsNull: String,
arrayDataName: String): String = {
val tempArrayDataName = ctx.freshName("tempArrayData")
val k = ctx.freshName("k")
val (numElemName, numElemCode) = genCodeForNumberOfElements(ctx, count)
val allocation = CodeGenerator.createArrayData(
tempArrayDataName, elementType, numElemName, s" $prettyName failed.")
val assignment =
CodeGenerator.setArrayElement(tempArrayDataName, elementType, k, element)
s"""
|$numElemCode
|$allocation
|if (!$leftIsNull) {
| for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) {
| $assignment
| }
|} else {
| for (int $k = 0; $k < $tempArrayDataName.numElements(); $k++) {
| $tempArrayDataName.setNullAt($k);
| }
|}
|$arrayDataName = $tempArrayDataName;
""".stripMargin
}
}
/**
* Remove all elements that equal to element from the given array
*/
@ExpressionDescription(
usage = "_FUNC_(array, element) - Remove all elements that equal to element from array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, null, 3), 3);
[1,2,null]
""", since = "2.4.0")
case class ArrayRemove(left: Expression, right: Expression)
extends BinaryExpression with ImplicitCastInputTypes {
override def dataType: DataType = left.dataType
override def inputTypes: Seq[AbstractDataType] = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, hasNull), e2) =>
TypeCoercion.findTightestCommonType(e1, e2) match {
case Some(dt) => Seq(ArrayType(dt, hasNull), dt)
case _ => Seq.empty
}
case _ => Seq.empty
}
}
override def checkInputDataTypes(): TypeCheckResult = {
(left.dataType, right.dataType) match {
case (ArrayType(e1, _), e2) if e1.sameType(e2) =>
TypeUtils.checkForOrderingExpr(e2, s"function $prettyName")
case _ => TypeCheckResult.TypeCheckFailure(s"Input to function $prettyName should have " +
s"been ${ArrayType.simpleString} followed by a value with same element type, but it's " +
s"[${left.dataType.catalogString}, ${right.dataType.catalogString}].")
}
}
private def elementType: DataType = left.dataType.asInstanceOf[ArrayType].elementType
@transient private lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(right.dataType)
override def nullSafeEval(arr: Any, value: Any): Any = {
val newArray = new Array[Any](arr.asInstanceOf[ArrayData].numElements())
var pos = 0
arr.asInstanceOf[ArrayData].foreach(right.dataType, (i, v) =>
if (v == null || !ordering.equiv(v, value)) {
newArray(pos) = v
pos += 1
}
)
new GenericArrayData(newArray.slice(0, pos))
}
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (arr, value) => {
val numsToRemove = ctx.freshName("numsToRemove")
val newArraySize = ctx.freshName("newArraySize")
val i = ctx.freshName("i")
val getValue = CodeGenerator.getValue(arr, elementType, i)
val isEqual = ctx.genEqual(elementType, value, getValue)
s"""
|int $numsToRemove = 0;
|for (int $i = 0; $i < $arr.numElements(); $i ++) {
| if (!$arr.isNullAt($i) && $isEqual) {
| $numsToRemove = $numsToRemove + 1;
| }
|}
|int $newArraySize = $arr.numElements() - $numsToRemove;
|${genCodeForResult(ctx, ev, arr, value, newArraySize)}
""".stripMargin
})
}
def genCodeForResult(
ctx: CodegenContext,
ev: ExprCode,
inputArray: String,
value: String,
newArraySize: String): String = {
val values = ctx.freshName("values")
val i = ctx.freshName("i")
val pos = ctx.freshName("pos")
val getValue = CodeGenerator.getValue(inputArray, elementType, i)
val isEqual = ctx.genEqual(elementType, value, getValue)
val allocation = CodeGenerator.createArrayData(
values, elementType, newArraySize, s" $prettyName failed.")
val assignment = CodeGenerator.createArrayAssignment(
values, elementType, inputArray, pos, i, false)
s"""
|$allocation
|int $pos = 0;
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $values.setNullAt($pos);
| $pos = $pos + 1;
| }
| else {
| if (!($isEqual)) {
| $assignment
| $pos = $pos + 1;
| }
| }
|}
|${ev.value} = $values;
""".stripMargin
}
override def prettyName: String = "array_remove"
}
/**
* Will become common base class for [[ArrayDistinct]], [[ArrayUnion]], [[ArrayIntersect]],
* and [[ArrayExcept]].
*/
trait ArraySetLike {
protected def dt: DataType
protected def et: DataType
@transient protected lazy val canUseSpecializedHashSet = et match {
case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
case _ => false
}
@transient protected lazy val ordering: Ordering[Any] =
TypeUtils.getInterpretedOrdering(et)
protected def genGetValue(array: String, i: String): String =
CodeGenerator.getValue(array, et, i)
@transient protected lazy val (hsPostFix, hsTypeName) = {
val ptName = CodeGenerator.primitiveTypeName(et)
et match {
// we cast byte/short to int when writing to the hash set.
case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
case LongType => ("$mcJ$sp", ptName)
case FloatType => ("$mcF$sp", ptName)
case DoubleType => ("$mcD$sp", ptName)
}
}
// we cast byte/short to int when writing to the hash set.
@transient protected lazy val hsValueCast = et match {
case ByteType | ShortType => "(int) "
case _ => ""
}
// When hitting a null value, put a null holder in the ArrayBuilder. Finally we will
// convert ArrayBuilder to ArrayData and setNull on the slot with null holder.
@transient protected lazy val nullValueHolder = et match {
case ByteType => "(byte) 0"
case ShortType => "(short) 0"
case _ => "0"
}
protected def withResultArrayNullCheck(
body: String,
value: String,
nullElementIndex: String): String = {
if (dt.asInstanceOf[ArrayType].containsNull) {
s"""
|$body
|if ($nullElementIndex >= 0) {
| // result has null element
| $value.setNullAt($nullElementIndex);
|}
""".stripMargin
} else {
body
}
}
def buildResultArray(
builder: String,
value : String,
size : String,
nullElementIndex : String): String = withResultArrayNullCheck(
s"""
|if ($size > ${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH}) {
| throw new RuntimeException("Cannot create array with " + $size +
| " elements of data due to exceeding the limit " +
| "${ByteArrayMethods.MAX_ROUNDED_ARRAY_LENGTH} elements for ArrayData.");
|}
|
|if (!UnsafeArrayData.shouldUseGenericArrayData(${et.defaultSize}, $size)) {
| $value = UnsafeArrayData.fromPrimitiveArray($builder.result());
|} else {
| $value = new ${classOf[GenericArrayData].getName}($builder.result());
|}
""".stripMargin, value, nullElementIndex)
}
/**
* Removes duplicate values from the array.
*/
@ExpressionDescription(
usage = "_FUNC_(array) - Removes duplicate values from the array.",
examples = """
Examples:
> SELECT _FUNC_(array(1, 2, 3, null, 3));
[1,2,3,null]
""", since = "2.4.0")
case class ArrayDistinct(child: Expression)
extends UnaryExpression with ArraySetLike with ExpectsInputTypes {