Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -4352,6 +4352,18 @@ public void approxPercentileTest() {
"2024-09-24T06:15:55.000Z,shanghai,55,null,",
},
DATABASE_NAME);

tableResultSetEqualTest(
"select approx_percentile(s1,null,0.5) from table1",
new String[] {"_col0"},
new String[] {"null,"},
DATABASE_NAME);
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The new NULL-weight test only covers the non-grouped execution path. Since group-by uses a different (grouped) accumulator, it would be good to add an IT that exercises approx_percentile(..., null, ...) (and an invalid weight like -1) under GROUP BY to ensure the grouped implementation matches the intended semantics and doesn’t regress.

Suggested change
DATABASE_NAME);
DATABASE_NAME);
tableResultSetEqualTest(
"select 1 as g, approx_percentile(s1,null,0.5) from table1 group by 1",
new String[] {"g", "_col1"},
new String[] {"1,null,"},
DATABASE_NAME);

Copilot uses AI. Check for mistakes.

tableResultSetEqualTest(
"select 1 as g, approx_percentile(s1,null,0.5) from table1 group by 1",
new String[] {"g", "_col1"},
new String[] {"1,null,"},
DATABASE_NAME);
}

@Test
Expand Down Expand Up @@ -4432,6 +4444,18 @@ public void exceptionTest() {
"select approx_percentile(s5,0.5) from table1",
"701: Aggregation functions [approx_percentile] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]",
DATABASE_NAME);
tableAssertTestFail(
"select approx_percentile(s1,-1,0.5) from table1",
"701: weight must be >= 1, was -1",
DATABASE_NAME);
tableAssertTestFail(
"select approx_percentile(s1,s2,0.5) from table1",
"701: Aggregation functions [approx_percentile] do not support weight as INT64 type",
DATABASE_NAME);
tableAssertTestFail(
"select 1 as g, approx_percentile(s1,s2,0.5) from table1 group by 1",
"701: Aggregation functions [approx_percentile] do not support weight as INT64 type",
DATABASE_NAME);
}

// ==================================================================
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@

package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;

import org.apache.iotdb.db.exception.sql.SemanticException;

import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.enums.TSDataType;

Expand All @@ -32,6 +34,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) {

if (mask.isSelectAll()) {
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
if (weightColumn.isNull(i)) {
continue;
}
Comment on lines +37 to +39
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Null weights are currently silently skipped (continue), which contradicts the PR description that says weights are validated to be non-null. If the intended behavior is to reject NULL weights (at least when the value is non-null), this should throw a SemanticException instead of skipping; otherwise the PR description/tests should be aligned to explicitly define the semantics for NULL weights.

Copilot uses AI. Check for mistakes.
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
if (!valueColumn.isNull(i)) {
tDigest.add(valueColumn.getInt(i), weightColumn.getInt(i));
}
Comment on lines +37 to 45
Copy link

Copilot AI Mar 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Weight validation happens before checking whether the corresponding value is null. This means rows with NULL values can still throw on an invalid/negative weight even though that row would otherwise be ignored by the aggregation. Consider only reading/validating the weight inside the !valueColumn.isNull(...) branch (and then validating null/<=0 weight there).

Copilot uses AI. Check for mistakes.
Expand All @@ -41,6 +49,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) {
int position;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
if (!valueColumn.isNull(position)) {
tDigest.add(valueColumn.getInt(position), weightColumn.getInt(position));
}
Expand All @@ -57,6 +71,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) {

if (mask.isSelectAll()) {
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
if (!valueColumn.isNull(i)) {
tDigest.add(toDoubleExact(valueColumn.getLong(i)), weightColumn.getInt(i));
}
Expand All @@ -66,6 +86,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) {
int position;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
if (!valueColumn.isNull(position)) {
tDigest.add(toDoubleExact(valueColumn.getLong(position)), weightColumn.getInt(position));
}
Expand All @@ -82,6 +108,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) {

if (mask.isSelectAll()) {
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
if (!valueColumn.isNull(i)) {
tDigest.add(valueColumn.getFloat(i), weightColumn.getInt(i));
}
Expand All @@ -91,6 +123,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) {
int position;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
if (!valueColumn.isNull(position)) {
tDigest.add(valueColumn.getFloat(position), weightColumn.getInt(position));
}
Expand All @@ -107,6 +145,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) {

if (mask.isSelectAll()) {
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
if (!valueColumn.isNull(i)) {
tDigest.add(valueColumn.getDouble(i), weightColumn.getInt(i));
}
Expand All @@ -116,6 +160,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) {
int position;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
if (!valueColumn.isNull(position)) {
tDigest.add(valueColumn.getDouble(position), weightColumn.getInt(position));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;

import org.apache.iotdb.db.exception.sql.SemanticException;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.approximate.TDigest;

Expand All @@ -36,6 +37,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask

if (mask.isSelectAll()) {
for (int i = 0; i < positionCount; i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
int groupId = groupIds[i];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(i)) {
Expand All @@ -48,6 +55,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask
int groupId;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
groupId = groupIds[position];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(position)) {
Expand All @@ -66,6 +79,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas

if (mask.isSelectAll()) {
for (int i = 0; i < positionCount; i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
int groupId = groupIds[i];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(i)) {
Expand All @@ -78,6 +97,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas
int groupId;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
groupId = groupIds[position];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(position)) {
Expand All @@ -96,6 +121,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma

if (mask.isSelectAll()) {
for (int i = 0; i < positionCount; i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
int groupId = groupIds[i];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(i)) {
Expand All @@ -108,6 +139,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma
int groupId;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
groupId = groupIds[position];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(position)) {
Expand All @@ -126,6 +163,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m

if (mask.isSelectAll()) {
for (int i = 0; i < positionCount; i++) {
if (weightColumn.isNull(i)) {
continue;
}
if (weightColumn.getInt(i) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
}
int groupId = groupIds[i];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(i)) {
Expand All @@ -138,6 +181,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m
int groupId;
for (int i = 0; i < positionCount; i++) {
position = selectedPositions[i];
if (weightColumn.isNull(position)) {
continue;
}
if (weightColumn.getInt(position) < 1) {
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
}
groupId = groupIds[position];
TDigest tDigest = array.get(groupId);
if (!valueColumn.isNull(position)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1194,19 +1194,22 @@ && isIntegerNumber(argumentTypes.get(2)))) {
functionName));
}

// Validate percentage and weight parameters
boolean hasInvalidTypes =
(argumentSize == 2 && !isDecimalType(argumentTypes.get(1)))
|| (argumentSize == 3
&& (!isIntegerNumber(argumentTypes.get(1))
|| !isDecimalType(argumentTypes.get(2))));

if (hasInvalidTypes) {
Type percentageType = argumentTypes.get(argumentSize - 1);
if (!isDecimalType(percentageType)) {
throw new SemanticException(
String.format(
"Aggregation functions [%s] should have weight as integer type and percentage as decimal type",
"Aggregation functions [%s] should have percentage as decimal type",
functionName));
}
if (argumentSize == 3) {
Type weightType = argumentTypes.get(1);
if (!INT32.equals(weightType) && !isUnknownType(weightType)) {
throw new SemanticException(
String.format(
"Aggregation functions [%s] do not support weight as %s type",
functionName, weightType.getDisplayName()));
}
}

break;
case SqlConstant.COUNT:
Expand Down
Loading