Skip to content

Commit

Permalink
Implement explicit API for categorize() (#837)
Browse files Browse the repository at this point in the history
* prototype implementation for fast categorize with linting

* relint

* update release notes

* apply ruff

* Sort imports

* Mark top-level method as deprecated

* Fix a typo in test names

* Fix a typo

* Extend `categorize()` signature similar to `validate()`

* Refactor the tests to check old and new signature

* Update release notes

---------

Co-authored-by: Matthew Gidden <matthew.gidden@gmail.com>
  • Loading branch information
danielhuppmann and gidden committed Mar 15, 2024
1 parent 56467f3 commit 90f7073
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 38 deletions.
8 changes: 5 additions & 3 deletions RELEASE_NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ Bumped minimum version of pandas and numpy to fit **ixmp4**'s requirement.

## Individual updates

- [#837](https://github.com/IAMconsortium/pyam/pull/837) Support filters as direct keyword arguments for `categorize()`
similar to `validate()` signature (see [#804](https://github.com/IAMconsortium/pyam/pull/804))
- [#832](https://github.com/IAMconsortium/pyam/pull/832) Improve the test-suite for the ixmp4 integration
- [#827](https://github.com/IAMconsortium/pyam/pull/827) Migrate to poetry for project management
- [#830](https://github.com/IAMconsortium/pyam/pull/830) Implement more consistent logging behavior with **ixmp4**
- [#829](https://github.com/IAMconsortium/pyam/pull/829) Add a `pyam.iiasa.platforms()` function for a list of available platforms
- [#829](https://github.com/IAMconsortium/pyam/pull/829) Add `pyam.iiasa.platforms()` for a list of available platforms
- [#826](https://github.com/IAMconsortium/pyam/pull/826) Add `read_ixmp4()` function and extend integration test
- [#825](https://github.com/IAMconsortium/pyam/pull/825) Add support for Python 3.12
- [#824](https://github.com/IAMconsortium/pyam/pull/824) Update ixmp4 requirement to >=0.7.1
Expand All @@ -28,9 +30,9 @@ Bumped minimum version of pandas and numpy to fit **ixmp4**'s requirement.

## Individual updates

- [#804](https://github.com/IAMconsortium/pyam/pull/804) Support filters as direct keyword arguments for `validate()` method
- [#804](https://github.com/IAMconsortium/pyam/pull/804) Support filters as direct keyword arguments for `validate()`
- [#801](https://github.com/IAMconsortium/pyam/pull/801) Support initializing with `meta` dataframe in long format
- [#796](https://github.com/IAMconsortium/pyam/pull/796) Raise explicit error message if no connection to IIASA manager service
- [#796](https://github.com/IAMconsortium/pyam/pull/796) Raise explicit error if no connection to IIASA manager service
- [#794](https://github.com/IAMconsortium/pyam/pull/794) Fix wrong color codes for AR6 Illustrative Pathways
- [#792](https://github.com/IAMconsortium/pyam/pull/792) Support region-aggregation with weights-index >> data-index

Expand Down
85 changes: 57 additions & 28 deletions pyam/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@
import pandas as pd
from pandas.api.types import is_integer

from pyam.filter import filter_by_dt_arg, filter_by_time_domain, filter_by_year
from pyam.ixmp4 import write_to_ixmp4
from pyam.slice import IamSlice

try:
from datapackage import Package

Expand All @@ -35,6 +31,9 @@
from pyam.compute import IamComputeAccessor
from pyam.filter import (
datetime_match,
filter_by_dt_arg,
filter_by_time_domain,
filter_by_year,
)
from pyam.index import (
append_index_col,
Expand All @@ -44,9 +43,11 @@
replace_index_values,
verify_index_integrity,
)
from pyam.ixmp4 import write_to_ixmp4
from pyam.logging import deprecation_warning, format_log_message, raise_data_error
from pyam.plotting import PlotAccessor
from pyam.run_control import run_control
from pyam.slice import IamSlice
from pyam.str import find_depth, is_str
from pyam.time import swap_time_for_year, swap_year_for_time
from pyam.units import convert_unit
Expand All @@ -68,7 +69,7 @@
to_list,
write_sheet,
)
from pyam.validation import _apply_criteria, _exclude_on_fail, _validate
from pyam.validation import _exclude_on_fail, _validate

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -919,39 +920,69 @@ def set_meta_from_data(self, name, method=None, column="value", **kwargs):
self.set_meta(meta, name)

def categorize(
self, name, value, criteria, color=None, marker=None, linestyle=None
self,
name,
value,
criteria: dict = None,
*,
upper_bound: float = None,
lower_bound: float = None,
color=None,
marker=None,
linestyle=None,
**kwargs,
):
"""Assign scenarios to a category according to specific criteria
"""Assign meta indicator to all scenarios that meet given validation criteria
Parameters
----------
name : str
column name of the 'meta' table
Name of the meta indicator
value : str
category identifier
criteria : dict
dictionary with variables mapped to applicable checks
('up' and 'lo' for respective bounds, 'year' for years - optional)
Value of the meta indicator
criteria : dict, optional, deprecated
This option is deprecated; dictionary with variable keys and validation
mappings ('up' and 'lo' for respective bounds, 'year' for years).
upper_bound, lower_bound : float, optional
Upper and lower bounds for validation criteria of timeseries :attr:`data`.
color : str, optional
assign a color to this category for plotting
Assign a color to this category for plotting
marker : str, optional
assign a marker to this category for plotting
Assign a marker to this category for plotting
linestyle : str, optional
assign a linestyle to this category for plotting
Assign a linestyle to this category for plotting
**kwargs
Passed to :meth:`slice` to downselect datapoints for validation.
See Also
--------
validate
"""
# add plotting run control

for kind, arg in [
("color", color),
("marker", marker),
("linestyle", linestyle),
]:
if arg:
run_control().update({kind: {name: {value: arg}}})
# find all data that matches categorization
rows = _apply_criteria(self._data, criteria, in_range=True, return_test="all")
idx = make_index(rows, cols=self.index.names)

if len(idx) == 0:
# find all data that satisfies the validation criteria
# TODO: if validate returned an empty index, this check would be easier
not_valid = self.validate(
criteria=criteria,
upper_bound=upper_bound,
lower_bound=lower_bound,
**kwargs,
)
if not_valid is None:
idx = self.index
elif len(not_valid) < len(self.index):
idx = self.index.difference(
not_valid.set_index(["model", "scenario"]).index.unique()
)
else:
logger.info("No scenarios satisfy the criteria")
return

Expand Down Expand Up @@ -1074,6 +1105,10 @@ def validate(
-------
:class:`pandas.DataFrame` or None
All data points that do not satisfy the criteria.
See Also
--------
categorize
"""
return _validate(
self,
Expand Down Expand Up @@ -2573,15 +2608,9 @@ def require_variable(*args, **kwargs):
def categorize(
df, name, value, criteria, color=None, marker=None, linestyle=None, **kwargs
):
"""Assign scenarios to a category according to specific criteria.
Parameters
----------
df : IamDataFrame
args : passed to :meth:`IamDataFrame.categorize`
kwargs : used for downselecting IamDataFrame
passed to :meth:`IamDataFrame.filter`
"""
"""This method is deprecated, use `df.validate()` instead."""
# TODO: method is deprecated, remove for release >= 3.0
deprecation_warning("Use `IamDataFrame.categorize()` instead.")
fdf = df.filter(**kwargs)
fdf.categorize(
name=name,
Expand Down
2 changes: 1 addition & 1 deletion pyam/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def _validate(df, criteria, upper_bound, lower_bound, exclude_on_fail, **kwargs)
raise NotImplementedError(
"Using `criteria` and other arguments simultaneously is not supported."
)
# translate legcy `criteria` argument to explicit kwargs
# translate legacy `criteria` argument to explicit kwargs
if len(criteria) == 1:
key, value = list(criteria.items())[0]
kwargs = dict(variable=key)
Expand Down
28 changes: 22 additions & 6 deletions tests/test_feature_validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def test_validate_both(test_df, args):
dict(criteria={"Primary Energy": {"up": 6, "year": 2005}}),
),
)
def test_validate_year_2010(test_df, args):
def test_validate_year_2005(test_df, args):
# checking that the year filter works as expected
obs = test_df.validate(**args)
assert obs is None
Expand All @@ -197,7 +197,7 @@ def test_validate_year_2010(test_df, args):
dict(criteria={"Primary Energy": {"up": 6, "year": 2010}}),
),
)
def test_validate_year_201ß(test_df, args):
def test_validate_year_2010(test_df, args):
# checking that the return-type is correct
obs = test_df.validate(**args)
pdt.assert_frame_equal(obs, test_df.data[5:6].reset_index(drop=True))
Expand Down Expand Up @@ -238,20 +238,36 @@ def test_validate_top_level(test_df):
assert list(test_df.exclude) == [False, True]


def test_category_none(test_df):
test_df.categorize("category", "Testing", {"Primary Energy": {"up": 0.8}})
# include args for deprecated legacy signature
@pytest.mark.parametrize(
"args",
(
dict(variable="Primary Energy", upper_bound=0),
dict(criteria={"Primary Energy": {"up": 0}}),
),
)
def test_category_no_match(test_df, args):
test_df.categorize("category", "foo", **args)
assert "category" not in test_df.meta.columns


def test_category_pass(test_df):
# include args for deprecated legacy signature
@pytest.mark.parametrize(
"args",
(
dict(variable="Primary Energy", upper_bound=6),
dict(criteria={"Primary Energy": {"up": 6}}),
),
)
def test_category_match(test_df, args):
dct = {
"model": ["model_a", "model_a"],
"scenario": ["scen_a", "scen_b"],
"category": ["foo", None],
}
exp = pd.DataFrame(dct).set_index(["model", "scenario"])["category"]

test_df.categorize("category", "foo", {"Primary Energy": {"up": 6, "year": 2010}})
test_df.categorize("category", "foo", **args)
obs = test_df["category"]
pd.testing.assert_series_equal(obs, exp)

Expand Down

0 comments on commit 90f7073

Please sign in to comment.