Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions be/src/exprs/aggregate/aggregate_function_ema.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

#include "exprs/aggregate/aggregate_function_ema.h"

#include "exprs/aggregate/aggregate_function_simple_factory.h"
#include "exprs/aggregate/helpers.h"

namespace doris {

void register_aggregate_function_ema(AggregateFunctionSimpleFactory& factory) {
factory.register_function_both(
"exponential_moving_average",
creator_without_type::creator<AggregateFunctionExponentialMovingAverage>);
}

} // namespace doris
173 changes: 173 additions & 0 deletions be/src/exprs/aggregate/aggregate_function_ema.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// This file is adapted from
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionExponentialMovingAverage.cpp

#pragma once

#include <cmath>
#include <memory>

#include "core/assert_cast.h"
#include "core/column/column_vector.h"
#include "core/data_type/data_type_number.h"
#include "core/types.h"
#include "exprs/aggregate/aggregate_function.h"

namespace doris {
class Arena;
class BufferReadable;
class BufferWritable;
class IColumn;

/**
* Exponentially smoothed moving average over time.
*
* Each value corresponds to a timeunit index. The half_decay parameter is the
* time lag at which exponential weights decay by one-half.
*
* State is a (value, time) pair representing the exponentially accumulated sum
* at a reference time. To get the average, divide by sumWeights(half_decay).
*
* Formula:
* scale(dt, x) = 2^(-dt/x)
* sumWeights(x) = 1 / (1 - 2^(-1/x))
* add(v, t): merge current state with point (v, t)
* merge(a, b): move both to the later time, then sum values
* get(): value / sumWeights(half_decay)
*
* Usage: exponential_moving_average(half_decay, value, timeunit)
* - half_decay: constant double, the half-life period in timeunit units
* - value: numeric column to average
* - timeunit: numeric time index (not raw timestamp; use intDiv if needed)
* Returns DOUBLE.
*/
struct ExponentialMovingAverageData {
double value = 0.0;
double time = 0.0;
double half_decay = 0.0;

static double scale(double time_passed, double hd) { return std::exp2(-time_passed / hd); }

static double sum_weights(double hd) { return 1.0 / (1.0 - std::exp2(-1.0 / hd)); }

void add(double new_value, double current_time, double hd) {
half_decay = hd;
ExponentialMovingAverageData other;
other.value = new_value;
other.time = current_time;
merge_point(other, hd);
}

void merge_point(const ExponentialMovingAverageData& other, double hd) {
if (time > other.time) {
value = value + other.value * scale(time - other.time, hd);
} else if (time < other.time) {
value = other.value + value * scale(other.time - time, hd);
time = other.time;
} else {
value = value + other.value;
}
}

void merge(const ExponentialMovingAverageData& rhs) {
double hd = half_decay != 0.0 ? half_decay : rhs.half_decay;
if (hd == 0.0) {
return;
}
half_decay = hd;
merge_point(rhs, hd);
}

double get() const {
if (half_decay == 0.0) {
return 0.0;
}
return value / sum_weights(half_decay);
}

void write(BufferWritable& buf) const {
buf.write_binary(value);
buf.write_binary(time);
buf.write_binary(half_decay);
}

void read(BufferReadable& buf) {
buf.read_binary(value);
buf.read_binary(time);
buf.read_binary(half_decay);
}

void reset() {
value = 0.0;
time = 0.0;
half_decay = 0.0;
}
};

class AggregateFunctionExponentialMovingAverage final
: public IAggregateFunctionDataHelper<ExponentialMovingAverageData,
AggregateFunctionExponentialMovingAverage>,
MultiExpression,
NullableAggregateFunction {
public:
AggregateFunctionExponentialMovingAverage(const DataTypes& argument_types_)
: IAggregateFunctionDataHelper<ExponentialMovingAverageData,
AggregateFunctionExponentialMovingAverage>(
argument_types_) {}

String get_name() const override { return "exponential_moving_average"; }

DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }

void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }

void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
Arena&) const override {
const double half_decay =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0])
.get_data()[row_num];
const double new_value =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1])
.get_data()[row_num];
const double current_time =
assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2])
.get_data()[row_num];
this->data(place).add(new_value, current_time, half_decay);
}

void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
Arena&) const override {
this->data(place).merge(this->data(rhs));
}

void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
this->data(place).write(buf);
}

void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
Arena&) const override {
this->data(place).read(buf);
}

void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
assert_cast<ColumnFloat64&>(to).get_data().push_back(this->data(place).get());
}
};

} // namespace doris
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,7 @@ void register_aggregate_function_percentile_reservoir(AggregateFunctionSimpleFac
void register_aggregate_function_ai_agg(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_bool_union(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_sem(AggregateFunctionSimpleFactory& factory);
void register_aggregate_function_ema(AggregateFunctionSimpleFactory& factory);

AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
static std::once_flag oc;
Expand Down Expand Up @@ -137,6 +138,7 @@ AggregateFunctionSimpleFactory& AggregateFunctionSimpleFactory::instance() {
register_aggregate_function_ai_agg(instance);
register_aggregate_function_bool_union(instance);
register_aggregate_function_sem(instance);
register_aggregate_function_ema(instance);
// Register foreach and foreachv2 functions
register_aggregate_function_combinator_foreach(instance);
register_aggregate_function_combinator_foreachv2(instance);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.ExponentialMovingAverage;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
Expand Down Expand Up @@ -132,6 +133,7 @@ private BuiltinAggregateFunctions() {
agg(CollectSet.class, "collect_set", "group_uniq_array"),
agg(Corr.class, "corr"),
agg(CorrWelford.class, "corr_welford"),
agg(ExponentialMovingAverage.class, "exponential_moving_average"),
agg(Count.class, "count"),
agg(CountByEnum.class, "count_by_enum"),
agg(Covar.class, "covar", "covar_pop"),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.trees.expressions.functions.agg;

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.DoubleType;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;

import java.util.List;

/**
* Exponential Moving Average aggregate function.
*
* <p>Computes the exponentially smoothed moving average over time-indexed values.
* The half_decay parameter controls the half-life period: the time after which the
* exponential weight of a past value decays by a factor of 1/2.
*
* <p>Signature: {@code exponential_moving_average(half_decay DOUBLE, value DOUBLE,
* timeunit DOUBLE) -> DOUBLE}
*
* <p>The timeunit argument is a numeric time index, not a raw timestamp. For
* timestamp columns use {@code intDiv(toUnixTimestamp(ts), interval_seconds)}.
*/
public class ExponentialMovingAverage extends NullableAggregateFunction
implements ExplicitlyCastableSignature {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(DoubleType.INSTANCE)
.args(DoubleType.INSTANCE, DoubleType.INSTANCE, DoubleType.INSTANCE)
);

/**
* Constructor with 3 arguments: (half_decay, value, timeunit).
*/
public ExponentialMovingAverage(Expression halfDecay, Expression value, Expression timeunit) {
this(false, halfDecay, value, timeunit);
}

/**
* Constructor with distinct flag and 3 arguments.
*/
public ExponentialMovingAverage(boolean distinct, Expression halfDecay,
Expression value, Expression timeunit) {
this(distinct, false, halfDecay, value, timeunit);
}

/**
* Full constructor.
*/
public ExponentialMovingAverage(boolean distinct, boolean alwaysNullable,
Expression halfDecay, Expression value, Expression timeunit) {
super("exponential_moving_average", distinct, alwaysNullable, halfDecay, value, timeunit);
}

/** Constructor for withChildren and reuse signature. */
private ExponentialMovingAverage(NullableAggregateFunctionParams functionParams) {
super(functionParams);
}

@Override
public void checkLegalityBeforeTypeCoercion() {
if (!getArgument(0).isConstant()) {
throw new AnalysisException("The half_decay argument of "
+ getName() + " must be a constant");
}
if (!getArgumentType(0).isNumericType()) {
throw new AnalysisException("The half_decay argument of "
+ getName() + " must be numeric");
}
if (!getArgumentType(1).isNumericType()) {
throw new AnalysisException("The value argument of "
+ getName() + " must be numeric");
}
if (!getArgumentType(2).isNumericType()) {
throw new AnalysisException("The timeunit argument of "
+ getName() + " must be numeric");
}
}

@Override
public ExponentialMovingAverage withDistinctAndChildren(boolean distinct,
List<Expression> children) {
Preconditions.checkArgument(children.size() == 3);
return new ExponentialMovingAverage(getFunctionParams(distinct, children));
}

@Override
public ExponentialMovingAverage withAlwaysNullable(boolean alwaysNullable) {
return new ExponentialMovingAverage(getAlwaysNullableFunctionParams(alwaysNullable));
}

@Override
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitExponentialMovingAverage(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
import org.apache.doris.nereids.trees.expressions.functions.agg.CountByEnum;
import org.apache.doris.nereids.trees.expressions.functions.agg.Covar;
import org.apache.doris.nereids.trees.expressions.functions.agg.CovarSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.ExponentialMovingAverage;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayIntersect;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupArrayUnion;
import org.apache.doris.nereids.trees.expressions.functions.agg.GroupBitAnd;
Expand Down Expand Up @@ -125,6 +126,10 @@ default R visitBitmapAgg(BitmapAgg bitmapAgg, C context) {
return visitAggregateFunction(bitmapAgg, context);
}

default R visitExponentialMovingAverage(ExponentialMovingAverage ema, C context) {
return visitNullableAggregateFunction(ema, context);
}

default R visitBitmapIntersect(BitmapIntersect bitmapIntersect, C context) {
return visitAggregateFunction(bitmapIntersect, context);
}
Expand Down
Loading
Loading