Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions common/utils/src/main/resources/error/error-conditions.json
Original file line number Diff line number Diff line change
Expand Up @@ -5605,6 +5605,12 @@
],
"sqlState" : "428EK"
},
"THETA_INVALID_FAMILY" : {
"message" : [
"Invalid call to <function>; the `family` parameter must be one of: <validFamilies>. Got: <value>."
],
"sqlState" : "22546"
},
"THETA_INVALID_INPUT_SKETCH_BUFFER" : {
"message" : [
"Invalid call to <function>; only valid Theta sketch buffers are supported as inputs (such as those produced by the `theta_sketch_agg` function)."
Expand Down
9 changes: 7 additions & 2 deletions python/pyspark/sql/connect/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4337,12 +4337,17 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
) -> Column:
fn = "theta_sketch_agg"
if lgNomEntries is None:
if lgNomEntries is None and family is None:
return _invoke_function_over_columns(fn, col)
else:
elif family is None:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
else:
if lgNomEntries is None:
lgNomEntries = 12 # default value
return _invoke_function_over_columns(fn, col, lit(lgNomEntries), lit(family))


theta_sketch_agg.__doc__ = pysparkfuncs.theta_sketch_agg.__doc__
Expand Down
51 changes: 33 additions & 18 deletions python/pyspark/sql/functions/builtin.py
Original file line number Diff line number Diff line change
Expand Up @@ -25941,10 +25941,12 @@ def hll_union(
def theta_sketch_agg(
col: "ColumnOrName",
lgNomEntries: Optional[Union[int, Column]] = None,
family: Optional[str] = None,
) -> Column:
"""
Aggregate function: returns the compact binary representation of the Datasketches
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries.
ThetaSketch with the values in the input column configured with lgNomEntries nominal entries
and the specified sketch family.

.. versionadded:: 4.1.0

Expand All @@ -25954,6 +25956,8 @@ def theta_sketch_agg(
lgNomEntries : :class:`~pyspark.sql.Column` or int, optional
The log-base-2 of nominal entries, where nominal entries is the size of the sketch
(must be between 4 and 26, defaults to 12)
family : str, optional
The sketch family: 'QUICKSELECT' or 'ALPHA' (defaults to 'QUICKSELECT').

Returns
-------
Expand All @@ -25974,24 +25978,35 @@ def theta_sketch_agg(
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+

>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value", 15, "ALPHA"))).show()
+---------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 15, ALPHA))|
+---------------------------------------------------------+
| 3|
+---------------------------------------------------------+
"""
fn = "theta_sketch_agg"
if lgNomEntries is None:
if lgNomEntries is None and family is None:
return _invoke_function_over_columns(fn, col)
else:
elif family is None:
return _invoke_function_over_columns(fn, col, lit(lgNomEntries))
else:
if lgNomEntries is None:
lgNomEntries = 12 # default value
return _invoke_function_over_columns(fn, col, lit(lgNomEntries), lit(family))


@_try_remote_functions
Expand Down Expand Up @@ -26118,11 +26133,11 @@ def theta_sketch_estimate(col: "ColumnOrName") -> Column:
>>> from pyspark.sql import functions as sf
>>> df = spark.createDataFrame([1,2,2,3], "INT")
>>> df.agg(sf.theta_sketch_estimate(sf.theta_sketch_agg("value"))).show()
+--------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12))|
+--------------------------------------------------+
| 3|
+--------------------------------------------------+
+---------------------------------------------------------------+
|theta_sketch_estimate(theta_sketch_agg(value, 12, QUICKSELECT))|
+---------------------------------------------------------------+
| 3|
+---------------------------------------------------------------+
"""

fn = "theta_sketch_estimate"
Expand Down
52 changes: 52 additions & 0 deletions sql/api/src/main/scala/org/apache/spark/sql/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -1198,6 +1198,17 @@ object functions {
def theta_sketch_agg(e: Column, lgNomEntries: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
* entries and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Column, family: Column): Column =
Column.fn("theta_sketch_agg", e, lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column and configured with the `lgNomEntries` nominal
Expand Down Expand Up @@ -1242,6 +1253,47 @@ object functions {
def theta_sketch_agg(columnName: String): Column =
theta_sketch_agg(Column(columnName))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, lgNomEntries: Int, family: String): Column =
Column.fn("theta_sketch_agg", e, lit(lgNomEntries), lit(family))

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with `lgNomEntries` and `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(columnName: String, lgNomEntries: Int, family: String): Column =
theta_sketch_agg(Column(columnName), lgNomEntries, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with the specified `family` and default
* lgNomEntries.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(e: Column, family: String): Column =
theta_sketch_agg(e, 12, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches ThetaSketch
* built with the values in the input column, configured with specified `family`.
*
* @group agg_funcs
* @since 4.1.0
*/
def theta_sketch_agg(columnName: String, family: String): Column =
theta_sketch_agg(columnName, 12, family)

/**
* Aggregate function: returns the compact binary representation of the Datasketches
* ThetaSketch, generated by the union of Datasketches ThetaSketch instances in the input column
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,13 @@

package org.apache.spark.sql.catalyst.expressions.aggregate

import org.apache.datasketches.common.Family
import org.apache.datasketches.memory.Memory
import org.apache.datasketches.theta.{CompactSketch, Intersection, SetOperation, Sketch, Union, UpdateSketch, UpdateSketchBuilder}

import org.apache.spark.SparkUnsupportedOperationException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression, ExpressionDescription, Literal}
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.catalyst.trees.{BinaryLike, UnaryLike}
import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, ThetaSketchUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
Expand Down Expand Up @@ -59,10 +59,12 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
*
* See [[https://datasketches.apache.org/docs/Theta/ThetaSketches.html]] for more information.
*
* @param left
* @param child
* child expression against which unique counting will occur
* @param right
* @param lgNomEntriesExpr
* the log-base-2 of nomEntries decides the number of buckets for the sketch
* @param familyExpr
* the family of the sketch (QUICKSELECT or ALPHA)
* @param mutableAggBufferOffset
* offset for mutable aggregation buffer
* @param inputAggBufferOffset
Expand All @@ -71,46 +73,64 @@ case class FinalizedSketch(sketch: CompactSketch) extends ThetaSketchState {
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """
_FUNC_(expr, lgNomEntries) - Returns the ThetaSketch compact binary representation.
_FUNC_(expr, lgNomEntries, family) - Returns the ThetaSketch compact binary representation.
`lgNomEntries` (optional) is the log-base-2 of nominal entries, with nominal entries deciding
the number buckets or slots for the ThetaSketch. """,
the number buckets or slots for the ThetaSketch.
`family` (optional) is the sketch family, either 'QUICKSELECT' or 'ALPHA' (defaults to
'QUICKSELECT').""",
examples = """
Examples:
> SELECT theta_sketch_estimate(_FUNC_(col)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 12)) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
> SELECT theta_sketch_estimate(_FUNC_(col, 15, 'ALPHA')) FROM VALUES (1), (1), (2), (2), (3) tab(col);
3
""",
group = "agg_funcs",
since = "4.1.0")
// scalastyle:on line.size.limit
case class ThetaSketchAgg(
left: Expression,
right: Expression,
child: Expression,
lgNomEntriesExpr: Expression,
familyExpr: Expression,
override val mutableAggBufferOffset: Int,
override val inputAggBufferOffset: Int)
extends TypedImperativeAggregate[ThetaSketchState]
with BinaryLike[Expression]
with ExpectsInputTypes {

// ThetaSketch config - mark as lazy so that they're not evaluated during tree transformation.

lazy val lgNomEntries: Int = {
val lgNomEntriesInput = right.eval().asInstanceOf[Int]
private lazy val lgNomEntries: Int = {
val lgNomEntriesInput = lgNomEntriesExpr.eval().asInstanceOf[Int]
ThetaSketchUtils.checkLgNomLongs(lgNomEntriesInput, prettyName)
lgNomEntriesInput
}

// Constructors
private lazy val family: Family =
ThetaSketchUtils.parseFamily(familyExpr.eval().asInstanceOf[UTF8String].toString, prettyName)

// Constructors
def this(child: Expression) = {
this(child, Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS), 0, 0)
this(child,
Literal(ThetaSketchUtils.DEFAULT_LG_NOM_LONGS),
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression) = {
this(child, lgNomEntries, 0, 0)
this(child,
lgNomEntries,
Literal(UTF8String.fromString(ThetaSketchUtils.DEFAULT_FAMILY)),
0, 0)
}

def this(child: Expression, lgNomEntries: Expression, family: Expression) = {
this(child, lgNomEntries, family, 0, 0)
}

def this(child: Expression, lgNomEntries: Int) = {
this(child, Literal(lgNomEntries), 0, 0)
this(child, Literal(lgNomEntries))
}

// Copy constructors required by ImperativeAggregate
Expand All @@ -122,15 +142,16 @@ case class ThetaSketchAgg(
copy(inputAggBufferOffset = newInputAggBufferOffset)

override protected def withNewChildrenInternal(
newLeft: Expression,
newRight: Expression): ThetaSketchAgg =
copy(left = newLeft, right = newRight)
newChildren: IndexedSeq[Expression]): ThetaSketchAgg =
copy(child = newChildren(0), lgNomEntriesExpr = newChildren(1), familyExpr = newChildren(2))

override def children: Seq[Expression] = Seq(child, lgNomEntriesExpr, familyExpr)

// Overrides for TypedImperativeAggregate

override def prettyName: String = "theta_sketch_agg"

override def inputTypes: Seq[AbstractDataType] =
override def inputTypes: Seq[AbstractDataType] = {
Seq(
TypeCollection(
ArrayType(IntegerType),
Expand All @@ -141,21 +162,24 @@ case class ThetaSketchAgg(
IntegerType,
LongType,
StringTypeWithCollation(supportsTrimCollation = true)),
IntegerType)
IntegerType,
StringType)
}

override def dataType: DataType = BinaryType

override def nullable: Boolean = false

/**
* Instantiate an UpdateSketch instance using the lgNomEntries param.
* Instantiate an UpdateSketch instance using the lgNomEntries and family params.
*
* @return
* an UpdateSketch instance wrapped with UpdatableSketchBuffer
*/
override def createAggregationBuffer(): ThetaSketchState = {
val builder = new UpdateSketchBuilder
builder.setLogNominalEntries(lgNomEntries)
builder.setFamily(family)
UpdatableSketchBuffer(builder.build)
}

Expand All @@ -176,7 +200,7 @@ case class ThetaSketchAgg(
*/
override def update(updateBuffer: ThetaSketchState, input: InternalRow): ThetaSketchState = {
// Return early for null values.
val v = left.eval(input)
val v = child.eval(input)
if (v == null) return updateBuffer

// Initialized buffer should be UpdatableSketchBuffer, else error out.
Expand All @@ -186,7 +210,7 @@ case class ThetaSketchAgg(
}

// Handle the different data types for sketch updates.
left.dataType match {
child.dataType match {
case ArrayType(IntegerType, _) =>
val arr = v.asInstanceOf[ArrayData].toIntArray()
sketch.update(arr)
Expand All @@ -213,7 +237,7 @@ case class ThetaSketchAgg(
case _ =>
throw new SparkUnsupportedOperationException(
errorClass = "_LEGACY_ERROR_TEMP_3121",
messageParameters = Map("dataType" -> left.dataType.toString))
messageParameters = Map("dataType" -> child.dataType.toString))
}

UpdatableSketchBuffer(sketch)
Expand Down
Loading