-
Notifications
You must be signed in to change notification settings - Fork 30
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* fix: get fresh sensor instance to avoid sqlalchemy.exc.InvalidRequestError Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * fx: handle the case of not finding the sensor in the datbase Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: add AggregatorReporter Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: add fixture Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * fix: typo in file name Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * Set author in PandasReporter and AggregatorReporter Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * style: remove unnecesay class property Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * test: add description of test Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * fix: vectorized bdf creation Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * style: improve Exception message Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * fix: lowercase enum value Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: make method and weights optional Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: set SUM as the feault aggregation method Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * typo Signed-off-by: F.N. Claessen <felix@seita.nl> * refactor: simplify column value assignments Signed-off-by: F.N. Claessen <felix@seita.nl> * typo Signed-off-by: F.N. Claessen <felix@seita.nl> * fix: use aggregate function instead of getattr Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: allow users to pass any string Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * test: add more test cases Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * feat: weights and method as class attribute Signed-off-by: Victor Garcia Reolid <victor@seita.nl> * simplify float to int Signed-off-by: F.N. Claessen <felix@seita.nl> * changelog entry Signed-off-by: F.N. Claessen <felix@seita.nl> * black Signed-off-by: F.N. Claessen <felix@seita.nl> --------- Signed-off-by: Victor Garcia Reolid <victor@seita.nl> Signed-off-by: F.N. Claessen <felix@seita.nl> Co-authored-by: F.N. Claessen <felix@seita.nl>
- Loading branch information
1 parent
cfb662d
commit 583fc87
Showing
6 changed files
with
211 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,77 @@ | ||
from __future__ import annotations | ||
|
||
from datetime import datetime, timedelta | ||
|
||
import timely_beliefs as tb | ||
import pandas as pd | ||
|
||
from flexmeasures.data.models.reporting import Reporter | ||
from flexmeasures.data.schemas.reporting.aggregation import AggregatorSchema | ||
|
||
from flexmeasures.utils.time_utils import server_now | ||
|
||
|
||
class AggregatorReporter(Reporter): | ||
"""This reporter applies an aggregation function to multiple sensors""" | ||
|
||
__version__ = "1" | ||
__author__ = "Seita" | ||
schema = AggregatorSchema() | ||
weights: dict | ||
method: str | ||
|
||
def deserialize_config(self): | ||
# call Reporter deserialize_config | ||
super().deserialize_config() | ||
|
||
# extract AggregatorReporter specific fields | ||
self.method = self.reporter_config.get("method") | ||
self.weights = self.reporter_config.get("weights", dict()) | ||
|
||
def _compute( | ||
self, | ||
start: datetime, | ||
end: datetime, | ||
input_resolution: timedelta | None = None, | ||
belief_time: datetime | None = None, | ||
) -> tb.BeliefsDataFrame: | ||
""" | ||
This method merges all the BeliefDataFrames into a single one, dropping | ||
all indexes but event_start, and applies an aggregation function over the | ||
columns. | ||
""" | ||
|
||
dataframes = [] | ||
|
||
if belief_time is None: | ||
belief_time = server_now() | ||
|
||
for belief_search_config in self.beliefs_search_configs: | ||
# if alias is not in belief_search_config, using the Sensor id instead | ||
column_name = belief_search_config.get( | ||
"alias", f"sensor_{belief_search_config['sensor'].id}" | ||
) | ||
data = self.data[column_name].droplevel([1, 2, 3]) | ||
|
||
# apply weight | ||
if column_name in self.weights: | ||
data *= self.weights[column_name] | ||
|
||
dataframes.append(data) | ||
|
||
output_df = pd.concat(dataframes, axis=1) | ||
|
||
# apply aggregation method | ||
output_df = output_df.aggregate(self.method, axis=1) | ||
|
||
# convert BeliefsSeries into a BeliefsDataFrame | ||
output_df = output_df.to_frame("event_value") | ||
output_df["belief_time"] = belief_time | ||
output_df["cumulative_probability"] = 0.5 | ||
output_df["source"] = self.data_source | ||
|
||
output_df = output_df.set_index( | ||
["belief_time", "source", "cumulative_probability"], append=True | ||
) | ||
|
||
return output_df |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
60 changes: 60 additions & 0 deletions
60
flexmeasures/data/models/reporting/tests/test_aggregator.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
import pytest | ||
|
||
from flexmeasures.data.models.reporting.aggregator import AggregatorReporter | ||
|
||
from datetime import datetime | ||
from pytz import utc | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"aggregation_method, expected_value", | ||
[ | ||
("sum", 0), | ||
("mean", 0), | ||
("var", 2), | ||
("std", 2**0.5), | ||
("max", 1), | ||
("min", -1), | ||
("prod", -1), | ||
("median", 0), | ||
], | ||
) | ||
def test_aggregator(setup_dummy_data, aggregation_method, expected_value): | ||
""" | ||
This test computes the aggregation of two sensors containing 24 entries | ||
with value 1 and -1, respectively, for sensors 1 and 2. | ||
Test cases: | ||
1) sum: 0 = 1 + (-1) | ||
2) mean: 0 = ((1) + (-1))/2 | ||
3) var: 2 = (1)^2 + (-1)^2 | ||
4) std: sqrt(2) = sqrt((1)^2 + (-1)^2) | ||
5) max: 1 = max(1, -1) | ||
6) min: -1 = min(1, -1) | ||
7) prod: -1 = (1) * (-1) | ||
8) median: even number of elements, mean of the most central elements, 0 = ((1) + (-1))/2 | ||
""" | ||
s1, s2, reporter_sensor = setup_dummy_data | ||
|
||
reporter_config_raw = dict( | ||
beliefs_search_configs=[ | ||
dict(sensor=s1.id, source=1), | ||
dict(sensor=s2.id, source=2), | ||
], | ||
method=aggregation_method, | ||
) | ||
|
||
agg_reporter = AggregatorReporter( | ||
reporter_sensor, reporter_config_raw=reporter_config_raw | ||
) | ||
|
||
result = agg_reporter.compute( | ||
start=datetime(2023, 5, 10, tzinfo=utc), | ||
end=datetime(2023, 5, 11, tzinfo=utc), | ||
) | ||
|
||
# check that we got a result for 24 hours | ||
assert len(result) == 24 | ||
|
||
# check that the value is equal to expected_value | ||
assert (result == expected_value).all().event_value |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
from marshmallow import fields, ValidationError, validates_schema | ||
|
||
from flexmeasures.data.schemas.reporting import ReporterConfigSchema | ||
|
||
|
||
class AggregatorSchema(ReporterConfigSchema): | ||
"""Schema for the reporter_config of the AggregatorReporter | ||
Example: | ||
.. code-block:: json | ||
{ | ||
"beliefs_search_configs": [ | ||
{ | ||
"sensor": 1, | ||
"source" : 1, | ||
"alias" : "pv" | ||
}, | ||
{ | ||
"sensor": 1, | ||
"source" : 2, | ||
"alias" : "consumption" | ||
} | ||
], | ||
"method" : "sum", | ||
"weights" : { | ||
"pv" : 1.0, | ||
"consumption" : -1.0 | ||
} | ||
} | ||
""" | ||
|
||
method = fields.Str(required=False, dump_default="sum") | ||
weights = fields.Dict(fields.Str(), fields.Float(), required=False) | ||
|
||
@validates_schema | ||
def validate_source(self, data, **kwargs): | ||
|
||
for beliefs_search_config in data["beliefs_search_configs"]: | ||
if "source" not in beliefs_search_config: | ||
raise ValidationError("`source` is a required field.") | ||
|
||
@validates_schema | ||
def validate_weights(self, data, **kwargs): | ||
if "weights" not in data: | ||
return | ||
|
||
# get aliases | ||
aliases = [] | ||
for beliefs_search_config in data["beliefs_search_configs"]: | ||
if "alias" in beliefs_search_config: | ||
aliases.append(beliefs_search_config.get("alias")) | ||
|
||
# check that the aliases in weights are defined | ||
for alias in data.get("weights").keys(): | ||
if alias not in aliases: | ||
raise ValidationError( | ||
f"alias `{alias}` in `weights` is not defined in `beliefs_search_config`" | ||
) |