Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
[SPARK-30423][SQL] Deprecate UserDefinedAggregateFunction
### What changes were proposed in this pull request?
* Annotate UserDefinedAggregateFunction as deprecated by SPARK-27296
* Update user doc examples to reflect new ability to register typed Aggregator[IN, BUF, OUT] as an untyped aggregating UDF
### Why are the changes needed?
UserDefinedAggregateFunction is being deprecated

### Does this PR introduce any user-facing change?
Changes are to user documentation, and deprecation annotations.

### How was this patch tested?
Testing was via package build to verify doc generation, deprecation warnings, and successful example compilation.

Closes #27193 from erikerlandson/spark-30423.

Authored-by: Erik Erlandson <eerlands@redhat.com>
Signed-off-by: Wenchen Fan <wenchen@databricks.com>
  • Loading branch information
erikerlandson authored and cloud-fan committed Jan 14, 2020
1 parent a2aa966 commit 176b696
Show file tree
Hide file tree
Showing 5 changed files with 105 additions and 104 deletions.
23 changes: 11 additions & 12 deletions docs/sql-getting-started.md
Expand Up @@ -358,30 +358,29 @@ While those functions are designed for DataFrames, Spark SQL also has type-safe
[Java](api/java/org/apache/spark/sql/expressions/javalang/typed.html) to work with strongly typed Datasets.
Moreover, users are not limited to the predefined aggregate functions and can create their own.

### Untyped User-Defined Aggregate Functions
Users have to extend the [UserDefinedAggregateFunction](api/scala/index.html#org.apache.spark.sql.expressions.UserDefinedAggregateFunction)
abstract class to implement a custom untyped aggregate function. For example, a user-defined average
can look like:
### Type-Safe User-Defined Aggregate Functions

User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class.
For example, a type-safe user-defined average can look like:

<div class="codetabs">
<div data-lang="scala" markdown="1">
{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%}
{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%}
</div>
<div data-lang="java" markdown="1">
{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%}
</div>
</div>

### Type-Safe User-Defined Aggregate Functions

User-defined aggregations for strongly typed Datasets revolve around the [Aggregator](api/scala/index.html#org.apache.spark.sql.expressions.Aggregator) abstract class.
For example, a type-safe user-defined average can look like:
### Untyped User-Defined Aggregate Functions
Typed aggregations, as described above, may also be registered as untyped aggregating UDFs for use with DataFrames.
For example, a user-defined average for untyped DataFrames can look like:

<div class="codetabs">
<div data-lang="scala" markdown="1">
{% include_example typed_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedTypedAggregation.scala%}
{% include_example untyped_custom_aggregation scala/org/apache/spark/examples/sql/UserDefinedUntypedAggregation.scala%}
</div>
<div data-lang="java" markdown="1">
{% include_example typed_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedTypedAggregation.java%}
{% include_example untyped_custom_aggregation java/org/apache/spark/examples/sql/JavaUserDefinedUntypedAggregation.java%}
</div>
</div>
Expand Up @@ -17,81 +17,85 @@
package org.apache.spark.examples.sql;

// $example on:untyped_custom_aggregation$
import java.util.ArrayList;
import java.util.List;
import java.io.Serializable;

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.functions;
// $example off:untyped_custom_aggregation$

public class JavaUserDefinedUntypedAggregation {

// $example on:untyped_custom_aggregation$
public static class MyAverage extends UserDefinedAggregateFunction {
public static class Average implements Serializable {
private long sum;
private long count;

private StructType inputSchema;
private StructType bufferSchema;
// Constructors, getters, setters...
// $example off:typed_custom_aggregation$
public Average() {
}

public Average(long sum, long count) {
this.sum = sum;
this.count = count;
}

public MyAverage() {
List<StructField> inputFields = new ArrayList<>();
inputFields.add(DataTypes.createStructField("inputColumn", DataTypes.LongType, true));
inputSchema = DataTypes.createStructType(inputFields);
public long getSum() {
return sum;
}

List<StructField> bufferFields = new ArrayList<>();
bufferFields.add(DataTypes.createStructField("sum", DataTypes.LongType, true));
bufferFields.add(DataTypes.createStructField("count", DataTypes.LongType, true));
bufferSchema = DataTypes.createStructType(bufferFields);
public void setSum(long sum) {
this.sum = sum;
}
// Data types of input arguments of this aggregate function
public StructType inputSchema() {
return inputSchema;

public long getCount() {
return count;
}
// Data types of values in the aggregation buffer
public StructType bufferSchema() {
return bufferSchema;

public void setCount(long count) {
this.count = count;
}
// The data type of the returned value
public DataType dataType() {
return DataTypes.DoubleType;
// $example on:typed_custom_aggregation$
}

public static class MyAverage extends Aggregator<Long, Average, Double> {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
public Average zero() {
return new Average(0L, 0L);
}
// Whether this function always returns the same output on the identical input
public boolean deterministic() {
return true;
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
public Average reduce(Average buffer, Long data) {
long newSum = buffer.getSum() + data;
long newCount = buffer.getCount() + 1;
buffer.setSum(newSum);
buffer.setCount(newCount);
return buffer;
}
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
public void initialize(MutableAggregationBuffer buffer) {
buffer.update(0, 0L);
buffer.update(1, 0L);
// Merge two intermediate values
public Average merge(Average b1, Average b2) {
long mergedSum = b1.getSum() + b2.getSum();
long mergedCount = b1.getCount() + b2.getCount();
b1.setSum(mergedSum);
b1.setCount(mergedCount);
return b1;
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
public void update(MutableAggregationBuffer buffer, Row input) {
if (!input.isNullAt(0)) {
long updatedSum = buffer.getLong(0) + input.getLong(0);
long updatedCount = buffer.getLong(1) + 1;
buffer.update(0, updatedSum);
buffer.update(1, updatedCount);
}
// Transform the output of the reduction
public Double finish(Average reduction) {
return ((double) reduction.getSum()) / reduction.getCount();
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
long mergedSum = buffer1.getLong(0) + buffer2.getLong(0);
long mergedCount = buffer1.getLong(1) + buffer2.getLong(1);
buffer1.update(0, mergedSum);
buffer1.update(1, mergedCount);
// Specifies the Encoder for the intermediate value type
public Encoder<Average> bufferEncoder() {
return Encoders.bean(Average.class);
}
// Calculates the final result
public Double evaluate(Row buffer) {
return ((double) buffer.getLong(0)) / buffer.getLong(1);
// Specifies the Encoder for the final output value type
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
// $example off:untyped_custom_aggregation$
Expand All @@ -104,7 +108,7 @@ public static void main(String[] args) {

// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf().register("myAverage", new MyAverage());
spark.udf().register("myAverage", functions.udaf(new MyAverage(), Encoders.LONG()));

Dataset<Row> df = spark.read().json("examples/src/main/resources/employees.json");
df.createOrReplaceTempView("employees");
Expand Down
Expand Up @@ -17,48 +17,38 @@
package org.apache.spark.examples.sql

// $example on:untyped_custom_aggregation$
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Encoder, Encoders, SparkSession}
import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.functions
// $example off:untyped_custom_aggregation$

object UserDefinedUntypedAggregation {

// $example on:untyped_custom_aggregation$
object MyAverage extends UserDefinedAggregateFunction {
// Data types of input arguments of this aggregate function
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// Data types of values in the aggregation buffer
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// The data type of the returned value
def dataType: DataType = DoubleType
// Whether this function always returns the same output on the identical input
def deterministic: Boolean = true
// Initializes the given aggregation buffer. The buffer itself is a `Row` that in addition to
// standard methods like retrieving a value at an index (e.g., get(), getBoolean()), provides
// the opportunity to update its values. Note that arrays and maps inside the buffer are still
// immutable.
def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
// Updates the given aggregation buffer `buffer` with new input data from `input`
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
case class Average(var sum: Long, var count: Long)

object MyAverage extends Aggregator[Long, Average, Double] {
// A zero value for this aggregation. Should satisfy the property that any b + zero = b
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, data: Long): Average = {
buffer.sum += data
buffer.count += 1
buffer
}
// Merges two aggregation buffers and stores the updated buffer values back to `buffer1`
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
// Merge two intermediate values
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// Calculates the final result
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
// Transform the output of the reduction
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// Specifies the Encoder for the intermediate value type
def bufferEncoder: Encoder[Average] = Encoders.product
// Specifies the Encoder for the final output value type
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
// $example off:untyped_custom_aggregation$

Expand All @@ -70,7 +60,7 @@ object UserDefinedUntypedAggregation {

// $example on:untyped_custom_aggregation$
// Register the function to access it
spark.udf.register("myAverage", MyAverage)
spark.udf.register("myAverage", functions.udaf(MyAverage))

val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
Expand Down
Expand Up @@ -73,7 +73,11 @@ class UDFRegistration private[sql] (functionRegistry: FunctionRegistry) extends
* @return the registered UDAF.
*
* @since 1.5.0
* @deprecated this method and the use of UserDefinedAggregateFunction are deprecated.
* Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
*/
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
" via the functions.udaf(agg) method.", "3.0.0")
def register(name: String, udaf: UserDefinedAggregateFunction): UserDefinedAggregateFunction = {
def builder(children: Seq[Expression]) = ScalaUDAF(children, udaf)
functionRegistry.createOrReplaceTempFunction(name, builder)
Expand Down
Expand Up @@ -27,8 +27,12 @@ import org.apache.spark.sql.types._
* The base class for implementing user-defined aggregate functions (UDAF).
*
* @since 1.5.0
* @deprecated UserDefinedAggregateFunction is deprecated.
* Aggregator[IN, BUF, OUT] should now be registered as a UDF via the functions.udaf(agg) method.
*/
@Stable
@deprecated("Aggregator[IN, BUF, OUT] should now be registered as a UDF" +
" via the functions.udaf(agg) method.", "3.0.0")
abstract class UserDefinedAggregateFunction extends Serializable {

/**
Expand Down

0 comments on commit 176b696

Please sign in to comment.