Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -3410,6 +3410,28 @@ case class ArrayDistinct(child: Expression)
case _ => false
}

@transient protected lazy val canUseSpecializedHashSet = elementType match {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extract those common methods?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can do so. To minimize the changes due to remaining time for cutting, I would like to do this in another PR #21912.

case ByteType | ShortType | IntegerType | LongType | FloatType | DoubleType => true
case _ => false
}

@transient protected lazy val (hsPostFix, hsTypeName) = {
val ptName = CodeGenerator.primitiveTypeName(elementType)
elementType match {
// we cast byte/short to int when writing to the hash set.
case ByteType | ShortType | IntegerType => ("$mcI$sp", "Int")
case LongType => ("$mcJ$sp", ptName)
case FloatType => ("$mcF$sp", ptName)
case DoubleType => ("$mcD$sp", ptName)
}
}

// we cast byte/short to int when writing to the hash set.
@transient protected lazy val hsValueCast = elementType match {
case ByteType | ShortType => "(int) "
case _ => ""
}

override def nullSafeEval(array: Any): Any = {
val data = array.asInstanceOf[ArrayData].toArray[AnyRef](elementType)
if (elementTypeSupportEquals) {
Expand Down Expand Up @@ -3442,17 +3464,15 @@ case class ArrayDistinct(child: Expression)
}

override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
nullSafeCodeGen(ctx, ev, (array) => {
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
val getValue1 = CodeGenerator.getValue(array, elementType, i)
val getValue2 = CodeGenerator.getValue(array, elementType, j)
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val hs = ctx.freshName("hs")
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
if (elementTypeSupportEquals) {
if (canUseSpecializedHashSet) {
nullSafeCodeGen(ctx, ev, (array) => {
val i = ctx.freshName("i")
val sizeOfDistinctArray = ctx.freshName("sizeOfDistinctArray")
val foundNullElement = ctx.freshName("foundNullElement")
val openHashSet = classOf[OpenHashSet[_]].getName
val hs = ctx.freshName("hs")
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"
val getValue = CodeGenerator.getValue(array, elementType, i)
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
Expand All @@ -3461,53 +3481,26 @@ case class ArrayDistinct(child: Expression)
| if ($array.isNullAt($i)) {
| $foundNullElement = true;
| } else {
| $hs.add($getValue1);
| $hs.add$hsPostFix($hsValueCast$getValue);
| }
|}
|$sizeOfDistinctArray = $hs.size() + ($foundNullElement ? 1 : 0);
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
} else {
s"""
|int $sizeOfDistinctArray = 0;
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $array.numElements(); $i ++) {
| if ($array.isNullAt($i)) {
| if (!($foundNullElement)) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| $foundNullElement = true;
| }
| } else {
| int $j;
| for ($j = 0; $j < $i; $j ++) {
| if (!$array.isNullAt($j) && ${ctx.genEqual(elementType, getValue1, getValue2)}) {
| break;
| }
| }
| if ($i == $j) {
| $sizeOfDistinctArray = $sizeOfDistinctArray + 1;
| }
| }
|}
|
|${genCodeForResult(ctx, ev, array, sizeOfDistinctArray)}
""".stripMargin
}
})
})
} else {
nullSafeCodeGen(ctx, ev, (array) => {
val expr = ctx.addReferenceObj("arrayDistinctExpr", this)
s"${ev.value} = (ArrayData)$expr.nullSafeEval($array);"
})
}
}

private def setNull(
isPrimitive: Boolean,
foundNullElement: String,
distinctArray: String,
pos: String): String = {
val setNullValue =
if (!isPrimitive) {
s"$distinctArray[$pos] = null";
} else {
s"$distinctArray.setNullAt($pos)";
}

val setNullValue = s"$distinctArray.setNullAt($pos)"
s"""
|if (!($foundNullElement)) {
| $setNullValue;
Expand All @@ -3517,57 +3510,16 @@ case class ArrayDistinct(child: Expression)
""".stripMargin
}

private def setNotNullValue(isPrimitive: Boolean,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
if (!isPrimitive) {
s"$distinctArray[$pos] = $getValue1";
} else {
s"$distinctArray.set$primitiveValueTypeName($pos, $getValue1)";
}
}

private def setValueForFastEval(
isPrimitive: Boolean,
private def setValue(
hs: String,
distinctArray: String,
pos: String,
getValue1: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|if (!($hs.contains($getValue1))) {
| $hs.add($getValue1);
| $setValue;
| $pos = $pos + 1;
|}
""".stripMargin
}

private def setValueForBruteForceEval(
isPrimitive: Boolean,
i: String,
j: String,
inputArray: String,
distinctArray: String,
pos: String,
getValue1: String,
isEqual: String,
primitiveValueTypeName: String): String = {
val setValue = setNotNullValue(isPrimitive,
distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|int $j;
|for ($j = 0; $j < $i; $j ++) {
| if (!$inputArray.isNullAt($j) && $isEqual) {
| break;
| }
|}
|if ($i == $j) {
| $setValue;
|if (!($hs.contains$hsPostFix($hsValueCast$getValue1))) {
| $hs.add$hsPostFix($hsValueCast$getValue1);
| $distinctArray.set$primitiveValueTypeName($pos, $getValue1);
| $pos = $pos + 1;
|}
""".stripMargin
Expand All @@ -3580,73 +3532,28 @@ case class ArrayDistinct(child: Expression)
size: String): String = {
val distinctArray = ctx.freshName("distinctArray")
val i = ctx.freshName("i")
val j = ctx.freshName("j")
val pos = ctx.freshName("pos")
val getValue1 = CodeGenerator.getValue(inputArray, elementType, i)
val getValue2 = CodeGenerator.getValue(inputArray, elementType, j)
val isEqual = ctx.genEqual(elementType, getValue1, getValue2)
val foundNullElement = ctx.freshName("foundNullElement")
val hs = ctx.freshName("hs")
val openHashSet = classOf[OpenHashSet[_]].getName
if (!CodeGenerator.isPrimitiveType(elementType)) {
val arrayClass = classOf[GenericArrayData].getName
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.Object()"
val setNullForNonPrimitive =
setNull(false, foundNullElement, distinctArray, pos)
if (elementTypeSupportEquals) {
val setValueForFast = setValueForFastEval(false, hs, distinctArray, pos, getValue1, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
} else {
val setValueForBruteForce = setValueForBruteForceEval(
false, i, j, inputArray, distinctArray, pos, getValue1, isEqual, "")
s"""
|int $pos = 0;
|Object[] $distinctArray = new Object[$size];
|boolean $foundNullElement = false;
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForNonPrimitive;
| } else {
| $setValueForBruteForce;
| }
|}
|${ev.value} = new $arrayClass($distinctArray);
""".stripMargin
}
} else {
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val setNullForPrimitive = setNull(true, foundNullElement, distinctArray, pos)
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$primitiveValueTypeName()"
val setValueForFast =
setValueForFastEval(true, hs, distinctArray, pos, getValue1, primitiveValueTypeName)
s"""
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
|int $pos = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| $setNullForPrimitive;
| } else {
| $setValueForFast;
| }
|}
|${ev.value} = $distinctArray;
""".stripMargin
}
val primitiveValueTypeName = CodeGenerator.primitiveTypeName(elementType)
val classTag = s"scala.reflect.ClassTag$$.MODULE$$.$hsTypeName()"

s"""
|${ctx.createUnsafeArray(distinctArray, size, elementType, s" $prettyName failed.")}
|int $pos = 0;
|boolean $foundNullElement = false;
|$openHashSet $hs = new $openHashSet($classTag);
|for (int $i = 0; $i < $inputArray.numElements(); $i ++) {
| if ($inputArray.isNullAt($i)) {
| ${setNull(foundNullElement, distinctArray, pos)}
| } else {
| ${setValue(hs, distinctArray, pos, getValue1, primitiveValueTypeName)}
| }
|}
|${ev.value} = $distinctArray;
""".stripMargin
}

override def prettyName: String = "array_distinct"
Expand Down