Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import org.apache.flink.table.api.dataview.MapView;
import org.apache.flink.table.functions.AggregateFunction;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/** Test aggregator functions. */
public class JavaUserDefinedAggFunctions {
Expand Down Expand Up @@ -421,4 +425,80 @@ public Tuple1<Long> createAccumulator() {
return Tuple1.of(0L);
}
}

/** User defined pojo object. */
public static class TestObject {
private final String a;

public TestObject(String a) {
this.a = a;
}

public String getA() {
return a;
}
}

/** User defined object. */
public static class UserDefinedObject {
// List with user defined pojo object.
public List<TestObject> testObjectList = new ArrayList<>();
// Map with user defined pojo object.
public Map<String, TestObject> testObjectMap = new HashMap<>();
}

/** User defined UDAF whose value and acc is user defined complex pojo object. */
public static class UserDefinedObjectUDAF
extends AggregateFunction<UserDefinedObject, UserDefinedObject> {
private static final String KEY = "key";

@Override
public UserDefinedObject getValue(UserDefinedObject accumulator) {
return accumulator;
}

@Override
public UserDefinedObject createAccumulator() {
return new UserDefinedObject();
}

public void accumulate(UserDefinedObject acc, String a) {
if (a != null) {
acc.testObjectList.add(new TestObject(a));
acc.testObjectMap.put(KEY, new TestObject(a));
}
}

public void retract(UserDefinedObject acc, UserDefinedObject a) {
// do nothing.
}
}

/** User defined UDAF whose value and acc is user defined complex pojo object. */
public static class UserDefinedObjectUDAF2
extends AggregateFunction<String, UserDefinedObject> {
private static final String KEY = "key";

@Override
public String getValue(UserDefinedObject accumulator) {
if (accumulator.testObjectMap.containsKey(KEY)) {
return accumulator.testObjectMap.get(KEY).getA();
}
return null;
}

@Override
public UserDefinedObject createAccumulator() {
return new UserDefinedObject();
}

public void accumulate(UserDefinedObject acc, UserDefinedObject a) {
acc.testObjectList = a.testObjectList;
acc.testObjectMap = a.testObjectMap;
}

public void retract(UserDefinedObject acc, UserDefinedObject a) {
// do nothing
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ import org.apache.flink.table.api.bridge.scala._
import org.apache.flink.table.api.internal.TableEnvironmentInternal
import org.apache.flink.table.planner.factories.TestValuesTableFactory
import org.apache.flink.table.planner.factories.TestValuesTableFactory.{changelogRow, registerData}
import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.VarSumAggFunction
import org.apache.flink.table.planner.plan.utils.JavaUserDefinedAggFunctions.{UserDefinedObjectUDAF, UserDefinedObjectUDAF2, VarSumAggFunction}
import org.apache.flink.table.planner.runtime.batch.sql.agg.{MyPojoAggFunction, VarArgsAggFunction}
import org.apache.flink.table.planner.runtime.utils._
import org.apache.flink.table.planner.runtime.utils.JavaUserDefinedAggFunctions.OverloadedMaxFunction
Expand Down Expand Up @@ -1359,6 +1359,34 @@ class AggregateITCase(aggMode: AggMode, miniBatch: MiniBatchMode, backend: State
assertEquals(expected.sorted, sink.getRetractResults.sorted)
}

@Test
def testUserDefinedObjectAgg(): Unit = {
tEnv.createTemporaryFunction("user_define_object", new UserDefinedObjectUDAF)
tEnv.createTemporaryFunction("user_define_object2", new UserDefinedObjectUDAF2)
val sqlQuery =
s"""
|select t1.a, user_define_object2(t1.d) from
|(SELECT a, user_define_object(b) as d
|FROM MyTable GROUP BY a) t1
|group by t1.a
|""".stripMargin
val data = new mutable.MutableList[(Int, String)]
data.+=((1, "Sam"))
data.+=((1, "Jerry"))
data.+=((2, "Ali"))
data.+=((3, "Grace"))
data.+=((3, "Lucas"))

val t = failingDataSource(data).toTable(tEnv, 'a, 'b)
tEnv.createTemporaryView("MyTable", t)

val sink = new TestingRetractSink
tEnv.sqlQuery(sqlQuery).toRetractStream[Row].addSink(sink)
env.execute()
val expected = List("1,Jerry", "2,Ali", "3,Lucas")
assertEquals(expected.sorted, sink.getRetractResults.sorted)
}

@Test
def testSTDDEV(): Unit = {
val sqlQuery = "SELECT STDDEV_SAMP(a), STDDEV_POP(a) FROM MyTable GROUP BY c"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import static org.apache.flink.table.types.logical.LogicalTypeRoot.MULTISET;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.RAW;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.ROW;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.STRUCTURED_TYPE;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITHOUT_TIME_ZONE;
import static org.apache.flink.table.types.logical.LogicalTypeRoot.TIMESTAMP_WITH_LOCAL_TIME_ZONE;
import static org.apache.flink.table.types.logical.utils.LogicalTypeChecks.isRowtimeAttribute;
Expand Down Expand Up @@ -126,8 +127,17 @@ public static boolean isRow(LogicalType type) {
return type.getTypeRoot() == ROW;
}

public static boolean isStructuredType(LogicalType type) {
return type.getTypeRoot() == STRUCTURED_TYPE;
}

public static boolean isComparable(LogicalType type) {
return !isRaw(type) && !isMap(type) && !isMultiset(type) && !isRow(type) && !isArray(type);
return !isRaw(type)
&& !isMap(type)
&& !isMultiset(type)
&& !isRow(type)
&& !isArray(type)
&& !isStructuredType(type);
}

public static boolean isMutable(LogicalType type) {
Expand Down