Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove deprecated selection in skill, score, etc. #442

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
180 changes: 18 additions & 162 deletions modelskill/comparison/_collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
from ..skill import SkillTable
from ..skill_grid import SkillGrid

from ..utils import _get_idx, _get_name
from ..utils import _get_name
from ._comparison import Comparer, Scoreable
from ..metrics import _parse_metric
from ._utils import (
Expand All @@ -40,31 +40,6 @@
IdxOrNameTypes,
TimeTypes,
)
from ._comparison import _get_deprecated_args # TODO remove in v 1.1


def _get_deprecated_obs_var_args(kwargs): # type: ignore
observation, variable = None, None

# Don't bother refactoring this, it will be removed in v1.1
if "observation" in kwargs:
observation = kwargs.pop("observation")
if observation is not None:
warnings.warn(
f"The 'observation' argument is deprecated, use 'sel(observation='{observation}') instead",
FutureWarning,
)

if "variable" in kwargs:
variable = kwargs.pop("variable")

if variable is not None:
warnings.warn(
f"The 'variable' argument is deprecated, use 'sel(quantity='{variable}') instead",
FutureWarning,
)

return observation, variable


class ComparerCollection(Mapping, Scoreable):
Expand Down Expand Up @@ -446,7 +421,6 @@ def skill(
by: str | Iterable[str] | None = None,
metrics: Iterable[str] | Iterable[Callable] | str | Callable | None = None,
observed: bool = False,
**kwargs: Any,
) -> SkillTable:
"""Aggregated skill assessment of model(s)

Expand Down Expand Up @@ -505,36 +479,18 @@ def skill(
2017-10-29 163 -0.21 0.52 0.47 0.42 0.79 0.11 0.99
"""

# TODO remove in v1.1 ----------
model, start, end, area = _get_deprecated_args(kwargs) # type: ignore
observation, variable = _get_deprecated_obs_var_args(kwargs) # type: ignore
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

cc = self.sel(
model=model,
observation=observation,
quantity=variable,
start=start,
end=end,
area=area,
)
if cc.n_points == 0:
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

pmetrics = _parse_metric(metrics)

agg_cols = _parse_groupby(by, n_mod=cc.n_models, n_qnt=cc.n_quantities)
agg_cols = _parse_groupby(by, n_mod=self.n_models, n_qnt=self.n_quantities)
agg_cols, attrs_keys = self._attrs_keys_in_by(agg_cols)

df = cc._to_long_dataframe(attrs_keys=attrs_keys, observed=observed)
df = self._to_long_dataframe(attrs_keys=attrs_keys, observed=observed)

res = _groupby_df(df, by=agg_cols, metrics=pmetrics)
mtr_cols = [m.__name__ for m in pmetrics] # type: ignore
res = res.dropna(subset=mtr_cols, how="all") # TODO: ok to remove empty?
res = self._append_xy_to_res(res, cc)
res = cc._add_as_col_if_not_in_index(df, skilldf=res) # type: ignore
res = self._append_xy_to_res(res, self)
res = self._add_as_col_if_not_in_index(df, skilldf=res) # type: ignore
return SkillTable(res)

def _to_long_dataframe(
Expand Down Expand Up @@ -674,30 +630,12 @@ def gridded_skill(
* y (y) float64 51.5 52.5 53.5 54.5 55.5 56.5
"""

model, start, end, area = _get_deprecated_args(kwargs) # type: ignore
observation, variable = _get_deprecated_obs_var_args(kwargs) # type: ignore
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

cmp = self.sel(
model=model,
observation=observation,
quantity=variable,
start=start,
end=end,
area=area,
)

if cmp.n_points == 0:
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

metrics = _parse_metric(metrics)

df = cmp._to_long_dataframe()
df = self._to_long_dataframe()
df = _add_spatial_grid_to_df(df=df, bins=bins, binsize=binsize)

agg_cols = _parse_groupby(by, n_mod=cmp.n_models, n_qnt=cmp.n_quantities)
agg_cols = _parse_groupby(by, n_mod=self.n_models, n_qnt=self.n_quantities)
if "x" not in agg_cols:
agg_cols.insert(0, "x")
if "y" not in agg_cols:
Expand Down Expand Up @@ -764,39 +702,21 @@ def mean_skill(
>>> sk = cc.mean_skill(weights={"EPL": 2.0}) # more weight on EPL, others=1.0
"""

# TODO remove in v1.1
model, start, end, area = _get_deprecated_args(kwargs) # type: ignore
observation, variable = _get_deprecated_obs_var_args(kwargs) # type: ignore
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

# filter data
cc = self.sel(
model=model, # deprecated
observation=observation, # deprecated
quantity=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)
if cc.n_points == 0:
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

df = cc._to_long_dataframe() # TODO: remove
mod_names = cc.mod_names
df = self._to_long_dataframe() # TODO: remove
mod_names = self.mod_names
# obs_names = cmp.obs_names # df.observation.unique()
qnt_names = cc.quantity_names
qnt_names = self.quantity_names

# skill assessment
pmetrics = _parse_metric(metrics)
sk = cc.skill(metrics=pmetrics)
sk = self.skill(metrics=pmetrics)
if sk is None:
# TODO don't return None
return None
skilldf = sk.to_dataframe()

# weights
weights = cc._parse_weights(weights, sk.obs_names)
weights = self._parse_weights(weights, sk.obs_names)
skilldf["weights"] = (
skilldf.n if weights is None else np.tile(weights, len(mod_names)) # type: ignore
)
Expand All @@ -805,7 +725,7 @@ def weighted_mean(x: Any) -> Any:
return np.average(x, weights=skilldf.loc[x.index, "weights"])

# group by
by = cc._mean_skill_by(skilldf, mod_names, qnt_names) # type: ignore
by = self._mean_skill_by(skilldf, mod_names, qnt_names) # type: ignore
agg = {"n": "sum"}
for metric in pmetrics: # type: ignore
agg[metric.__name__] = weighted_mean # type: ignore
Expand All @@ -815,7 +735,7 @@ def weighted_mean(x: Any) -> Any:
res.index.name = "model"

# output
res = cc._add_as_col_if_not_in_index(df, res, fields=["model", "quantity"]) # type: ignore
res = self._add_as_col_if_not_in_index(df, res, fields=["model", "quantity"]) # type: ignore
return SkillTable(res.astype({"n": int}))

# def mean_skill_points(
Expand Down Expand Up @@ -1004,32 +924,7 @@ def score(
if not (callable(metric) or isinstance(metric, str)):
raise ValueError("metric must be a string or a function")

model, start, end, area = _get_deprecated_args(kwargs) # type: ignore
observation, variable = _get_deprecated_obs_var_args(kwargs) # type: ignore
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

if model is None:
models = self.mod_names
else:
# TODO: these two lines looks familiar, extract to function
models = [model] if np.isscalar(model) else model # type: ignore
models = [_get_name(m, self.mod_names) for m in models] # type: ignore

cmp = self.sel(
model=models, # deprecated
observation=observation, # deprecated
quantity=variable, # deprecated
start=start, # deprecated
end=end, # deprecated
area=area, # deprecated
)

if cmp.n_points == 0:
raise ValueError("Dataset is empty, no data to compare.")

## ---- end of deprecated code ----

sk = cmp.mean_skill(weights=weights, metrics=[metric])
sk = self.mean_skill(weights=weights, metrics=[metric])
df = sk.to_dataframe()

metric_name = metric if isinstance(metric, str) else metric.__name__
Expand Down Expand Up @@ -1153,29 +1048,7 @@ def scatter(
):
warnings.warn("scatter is deprecated, use plot.scatter instead", FutureWarning)

# TODO remove in v1.1
model, start, end, area = _get_deprecated_args(kwargs)
observation, variable = _get_deprecated_obs_var_args(kwargs)

# select model
mod_idx = _get_idx(model, self.mod_names)
mod_name = self.mod_names[mod_idx]

# select variable
qnt_idx = _get_idx(variable, self.quantity_names)
qnt_name = self.quantity_names[qnt_idx]

# filter data
cmp = self.sel(
model=mod_name,
observation=observation,
quantity=qnt_name,
start=start,
end=end,
area=area,
)

return cmp.plot.scatter(
return self.plot.scatter(
bins=bins,
quantiles=quantiles,
fit_to_quantiles=fit_to_quantiles,
Expand Down Expand Up @@ -1206,30 +1079,13 @@ def taylor(
):
warnings.warn("taylor is deprecated, use plot.taylor instead", FutureWarning)

model, start, end, area = _get_deprecated_args(kwargs)
observation, variable = _get_deprecated_obs_var_args(kwargs)
assert kwargs == {}, f"Unknown keyword arguments: {kwargs}"

cmp = self.sel(
model=model,
observation=observation,
quantity=variable,
start=start,
end=end,
area=area,
)

if cmp.n_points == 0:
warnings.warn("No data!")
return

if (not aggregate_observations) and (not normalize_std):
raise ValueError(
"aggregate_observations=False is only possible if normalize_std=True!"
)

metrics = [mtr._std_obs, mtr._std_mod, mtr.cc]
skill_func = cmp.mean_skill if aggregate_observations else cmp.skill
skill_func = self.mean_skill if aggregate_observations else self.skill
sk = skill_func(metrics=metrics)

df = sk.to_dataframe()
Expand Down
Loading
Loading