Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,14 @@ public enum BuiltinAggregationFunctionEnum {
AVG("avg"),
SUM("sum"),
MAX_BY("max_by"),
MIN_BY("min_by");
MIN_BY("min_by"),
CORR("corr"),
COVAR_POP("covar_pop"),
COVAR_SAMP("covar_samp"),
REGR_SLOPE("regr_slope"),
REGR_INTERCEPT("regr_intercept"),
SKEWNESS("skewness"),
KURTOSIS("kurtosis");

private final String functionName;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,11 @@ public static boolean isMultiInputAggregation(TAggregationType aggregationType)
switch (aggregationType) {
case MAX_BY:
case MIN_BY:
case CORR:
case COVAR_POP:
case COVAR_SAMP:
case REGR_SLOPE:
case REGR_INTERCEPT:
return true;
default:
return false;
Expand All @@ -84,6 +89,31 @@ public static Accumulator createBuiltinMultiInputAccumulator(
case MIN_BY:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new MinByAccumulator(inputDataTypes.get(0), inputDataTypes.get(1));
case CORR:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new CorrelationAccumulator(
new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)},
CorrelationAccumulator.CorrelationType.CORR);
case COVAR_POP:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new CorrelationAccumulator(
new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)},
CorrelationAccumulator.CorrelationType.COVAR_POP);
case COVAR_SAMP:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new CorrelationAccumulator(
new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)},
CorrelationAccumulator.CorrelationType.COVAR_SAMP);
case REGR_SLOPE:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new RegressionAccumulator(
new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)},
RegressionAccumulator.RegressionType.REGR_SLOPE);
case REGR_INTERCEPT:
checkState(inputDataTypes.size() == 2, "Wrong inputDataTypes size.");
return new RegressionAccumulator(
new TSDataType[] {inputDataTypes.get(0), inputDataTypes.get(1)},
RegressionAccumulator.RegressionType.REGR_INTERCEPT);
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
}
Expand Down Expand Up @@ -140,6 +170,12 @@ private static Accumulator createBuiltinSingleInputAccumulator(
return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_SAMP);
case VAR_POP:
return new VarianceAccumulator(tsDataType, VarianceAccumulator.VarianceType.VAR_POP);
case SKEWNESS:
return new CentralMomentAccumulator(
tsDataType, CentralMomentAccumulator.MomentType.SKEWNESS);
case KURTOSIS:
return new CentralMomentAccumulator(
tsDataType, CentralMomentAccumulator.MomentType.KURTOSIS);
default:
throw new IllegalArgumentException("Invalid Aggregation function: " + aggregationType);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,238 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.iotdb.db.queryengine.execution.aggregation;

import org.apache.tsfile.block.column.Column;
import org.apache.tsfile.block.column.ColumnBuilder;
import org.apache.tsfile.enums.TSDataType;
import org.apache.tsfile.file.metadata.statistics.Statistics;
import org.apache.tsfile.utils.Binary;
import org.apache.tsfile.utils.BitMap;

import java.nio.ByteBuffer;

import static com.google.common.base.Preconditions.checkArgument;

public class CentralMomentAccumulator implements Accumulator {

public enum MomentType {
SKEWNESS,
KURTOSIS
}

private final TSDataType seriesDataType;
private final MomentType momentType;

private long count;
private double mean;
private double m2;
private double m3;
private double m4;

public CentralMomentAccumulator(TSDataType seriesDataType, MomentType momentType) {
this.seriesDataType = seriesDataType;
this.momentType = momentType;
}

@Override
public void addInput(Column[] columns, BitMap bitMap) {

int size = columns[1].getPositionCount();
for (int i = 0; i < size; i++) {
if (bitMap != null && !bitMap.isMarked(i)) {
continue;
}
if (columns[1].isNull(i)) {
continue;
}
update(getDoubleValue(columns[1], i));
}
}

private double getDoubleValue(Column column, int position) {
switch (seriesDataType) {
case INT32:
case DATE:
return column.getInt(position);
case INT64:
case TIMESTAMP:
return column.getLong(position);
case FLOAT:
return column.getFloat(position);
case DOUBLE:
return column.getDouble(position);
default:
throw new UnsupportedOperationException(
"Unsupported data type in CentralMoment Aggregation: " + seriesDataType);
}
}

private void update(double value) {
long n1 = count;
count++;

double delta = value - mean;
double delta_n = delta / count;
double delta_n2 = delta_n * delta_n;
double term1 = delta * delta_n * n1;

mean += delta_n;

m4 += term1 * delta_n2 * (count * count - 3 * count + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3;

m3 += term1 * delta_n * (count - 2) - 3 * delta_n * m2;

m2 += term1;
}

@Override
public void addIntermediate(Column[] partialResult) {
checkArgument(partialResult.length == 1, "partialResult of CentralMoment should be 1");
if (partialResult[0].isNull(0)) {
return;
}
byte[] bytes = partialResult[0].getBinary(0).getValues();
ByteBuffer buffer = ByteBuffer.wrap(bytes);

long otherCount = buffer.getLong();
double otherMean = buffer.getDouble();
double otherM2 = buffer.getDouble();
double otherM3 = buffer.getDouble();
double otherM4 = buffer.getDouble();

merge(otherCount, otherMean, otherM2, otherM3, otherM4);
}

private void merge(long nB, double meanB, double m2B, double m3B, double m4B) {
if (nB == 0) return;
if (count == 0) {
count = nB;
mean = meanB;
m2 = m2B;
m3 = m3B;
m4 = m4B;
} else {
long nA = count;
long nTotal = nA + nB;
double delta = meanB - mean;
double delta2 = delta * delta;
double delta3 = delta * delta2;
double delta4 = delta2 * delta2;

m4 +=
m4B
+ delta4 * nA * nB * (nA * nA - nA * nB + nB * nB) / (nTotal * nTotal * nTotal)
+ 6.0 * delta2 * (nA * nA * m2B + nB * nB * m2) / (nTotal * nTotal)
+ 4.0 * delta * (nA * m3B - nB * m3) / nTotal;

m3 +=
m3B
+ delta3 * nA * nB * (nA - nB) / (nTotal * nTotal)
+ 3.0 * delta * (nA * m2B - nB * m2) / nTotal;

m2 += m2B + delta2 * nA * nB / nTotal;

mean += delta * nB / nTotal;
count = nTotal;
}
}

@Override
public void outputIntermediate(ColumnBuilder[] columnBuilders) {
checkArgument(columnBuilders.length == 1, "partialResult should be 1");
if (count == 0) {
columnBuilders[0].appendNull();
} else {

byte[] bytes = new byte[40];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.putLong(count);
buffer.putDouble(mean);
buffer.putDouble(m2);
buffer.putDouble(m3);
buffer.putDouble(m4);
columnBuilders[0].writeBinary(new Binary(bytes));
}
}

@Override
public void outputFinal(ColumnBuilder columnBuilder) {
if (count == 0 || m2 == 0) {
columnBuilder.appendNull();
return;
}

if (momentType == MomentType.SKEWNESS) {
if (count < 3) {
columnBuilder.appendNull();
} else {

double variance = m2 / (count - 1);
double stdev = Math.sqrt(variance);
double result = (count * m3) / ((count - 1) * (count - 2) * stdev * stdev * stdev);
columnBuilder.writeDouble(result);
}
} else {
if (count < 4) {
columnBuilder.appendNull();
} else {

double variance = m2 / (count - 1);
double term1 =
(count * (count + 1) * m4)
/ ((count - 1) * (count - 2) * (count - 3) * variance * variance);
double term2 = (3 * Math.pow(count - 1, 2)) / ((count - 2) * (count - 3));
columnBuilder.writeDouble(term1 - term2);
}
}
}

@Override
public void removeIntermediate(Column[] input) {
throw new UnsupportedOperationException();
}

@Override
public void addStatistics(Statistics statistics) {
throw new UnsupportedOperationException();
}

@Override
public void setFinal(Column finalResult) {}

@Override
public void reset() {
count = 0;
mean = 0;
m2 = 0;
m3 = 0;
m4 = 0;
}

@Override
public boolean hasFinalResult() {
return false;
}

@Override
public TSDataType[] getIntermediateType() {
return new TSDataType[] {TSDataType.TEXT};
}

@Override
public TSDataType getFinalType() {
return TSDataType.DOUBLE;
}
}
Loading