Skip to content

Commit

Permalink
Merge 29eaa7e into 105b22e
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahclaude authored Apr 26, 2024
2 parents 105b22e + 29eaa7e commit 260369c
Show file tree
Hide file tree
Showing 6 changed files with 550 additions and 1 deletion.
1 change: 1 addition & 0 deletions environment-dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dependencies:
- scikit-image
- xarray
- xclim >=0.47
- hvplot
# To make the package and notebooks usable
- dask
- h5py
Expand Down
1 change: 1 addition & 0 deletions figanos/hvplot/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Figanos hvplot plotting module."""
293 changes: 293 additions & 0 deletions figanos/hvplot/plot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,293 @@
"""Hvplot figanos plotting functions."""

import warnings
from pathlib import Path
from typing import Any

import holoviews as hv
import hvplot.xarray # noqa: F401
import xarray as xr
from utils import defaults_curves, get_all_values, get_glyph_param

from figanos.matplotlib.utils import (
check_timeindex,
convert_scen_name,
empty_dict,
fill_between_label,
get_array_categ,
get_scen_color,
process_keys,
sort_lines,
)


def _plot_ens_reals(
name: str,
array_categ: dict(str, str),
arr: xr.DataArray,
non_dict_data: bool,
cplot_kw: dict[str, Any],
copts_kw: dict[str, Any],
) -> dict:
"""Plot realizations ensembles"""
hv_fig = {}
if array_categ[name] == "ENS_REALS_DS":
if len(arr.data_vars) >= 2:
raise TypeError(
"To plot multiple ensembles containing realizations, use DataArrays outside a Dataset"
)
else:
arr = arr[list(arr.data_vars)[0]]

if non_dict_data:
cplot_kw[name] = {"by": "realization", "x": "time"} | cplot_kw[name]
hv_fig = arr.hvplot.line(cplot_kw[name]).opts(**copts_kw[name])
else:
cplot_kw[name].setdefault("label", name)
for r in arr.realization:
hv_fig[f"realization_{r.values.item()}"] = (
arr.sel(realization=r)
.hvplot.line(**cplot_kw[name])
.opts(**copts_kw[name])
)
return hv_fig


def _plot_ens_pct_stats(
name: str,
arr: xr.DataArray,
array_categ: dict[str, str],
array_data: dict[str, xr.DataArray],
cplot_kw: dict[str, Any],
copts_kw: dict[str, Any],
legend: str,
) -> dict:
"""Plot ensembles with percentiles and statistics (min/moy/max)"""
hv_fig = {}

# create a dictionary labeling the middle, upper and lower line
sorted_lines = sort_lines(array_data)
# plot
hv_fig["line"] = (
array_data[sorted_lines["middle"]]
.hvplot.line(label=name, **cplot_kw[name])
.opts(**copts_kw[name])
)
c = get_glyph_param(hv_fig["line"], "line_color")
lab_area = fill_between_label(sorted_lines, name, array_categ, legend)
if "ENS_PCT_DIM" in array_categ[name]:
arr = arr.to_dataset(dim="percentiles")
arr = arr.rename({k: str(k) for k in arr.keys()})
hv_fig["area"] = arr.hvplot.area(
y=sorted_lines["lower"],
y2=sorted_lines["upper"],
label=lab_area,
color=c,
linewidth=0.0,
alpha=0.2,
**cplot_kw[name],
).opts(**copts_kw[name])
return hv_fig


def _plot_timeseries(
plot_kw: dict[str, Any],
name: str,
arr: xr.DataArray | xr.Dataset,
array_categ: dict[str, str],
cplot_kw: dict[str, Any],
copts_kw: dict[str, Any],
non_dict_data: bool,
legend: str,
) -> dict | hv.element.chart.Curve | hv.core.overlay.Overlay:
"""Plot time series from 1D Xarray Datasets or DataArrays as line plots."""
hv_fig = {}

if (
array_categ[name] == "ENS_REALS_DA" or array_categ[name] == "ENS_REALS_DS"
): # ensemble with 'realization' dim, as DataArray or Dataset
return _plot_ens_reals(
name, array_categ, arr, non_dict_data, cplot_kw, copts_kw
)
elif (
array_categ[name] == "ENS_PCT_DIM_DS"
): # ensemble percentiles stored as dimension coordinates, DataSet
for k, sub_arr in arr.data_vars.items():
sub_name = (
sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
)
hv_fig[sub_name] = {}
# extract each percentile array from the dims
array_data = {}
for pct in sub_arr.percentiles.values:
array_data[str(pct)] = sub_arr.sel(percentiles=pct)

hv_fig[sub_name] = _plot_ens_pct_stats(
name, arr, array_categ, array_data, cplot_kw, copts_kw, legend
)
elif array_categ[name] in [
"ENS_PCT_VAR_DS", # ensemble statistics (min, mean, max) stored as variables
"ENS_STATS_VAR_DS", # ensemble percentiles stored as variables
"ENS_PCT_DIM_DA", # ensemble percentiles stored as dimension coordinates, DataArray
]:
# extract each array from the datasets
array_data = {}
if array_categ[name] == "ENS_PCT_DIM_DA":
for pct in arr.percentiles:
array_data[str(int(pct))] = arr.sel(percentiles=int(pct))
else:
for k, v in arr.data_vars.items():
array_data[k] = v

return _plot_ens_pct_stats(
name, arr, array_categ, array_data, cplot_kw, copts_kw, legend
)
# non-ensemble Datasets
elif array_categ[name] == "DS":
ignore_label = False
for k, sub_arr in arr.data_vars.items():
sub_name = (
sub_arr.name if non_dict_data is True else (name + "_" + sub_arr.name)
)
# if kwargs are specified by user, all lines are the same and we want one legend entry
if plot_kw[name]:
label = name if not ignore_label else ""
ignore_label = True
else:
label = sub_name

hv_fig[sub_name] = sub_arr.hvplot.line(
x="time", label=label, **plot_kw[name]
).opts(**copts_kw[name])

# non-ensemble DataArrays
elif array_categ[name] in ["DA"]:
return arr.hvplot.line(label=name, **plot_kw[name]).opts(**copts_kw[name])
else:
raise ValueError(
"Data structure not supported"
) # can probably be removed along with elif logic above,
# given that get_array_categ() also does this check
if hv_fig:
return hv_fig


def timeseries(
data: dict[str, Any] | xr.DataArray | xr.Dataset,
use_attrs: dict[str, Any] | None = None,
plot_kw: dict[str, Any] | None = None,
opts_kw: dict[str, Any] | None = None,
legend: str = "lines",
show_lat_lon: bool | str | int | tuple[float, float] = True,
) -> hv.element.chart.Curve | hv.core.overlay.Overlay:
"""Plot time series from 1D Xarray Datasets or DataArrays as line plots.
Parameters
----------
data : dict or Dataset/DataArray
Input data to plot. It can be a DataArray, Dataset or a dictionary of DataArrays and/or Datasets.
use_attrs : dict, optional
A dict linking a plot element (key, e.g. 'title') to a DataArray attribute (value, e.g. 'Description').
Default value is {'title': 'description', 'ylabel': 'long_name', 'yunits': 'units'}.
Only the keys found in the default dict can be used.
plot_kw : dict, optional
Arguments to pass to the `hvplot.line()` or hvplot.area() function. Changes how the line looks.
If 'data' is a dictionary, must be a nested dictionary with the same keys as 'data'.
legend : str (default 'lines') or dict
'full' (lines and shading), 'lines' (lines only), 'in_plot' (end of lines),
'edge' (out of plot), 'none' (no legend).
show_lat_lon : bool, tuple, str or int
If True, show latitude and longitude at the bottom right of the figure.
Can be a tuple of axis coordinates (from 0 to 1, as a fraction of the axis length) representing
the location of the text. If a string or an int, the same values as those of the 'loc' parameter
of matplotlib's legends are accepted.
Returns
-------
hvplot.Overlay
"""
# create empty dicts if None
use_attrs = empty_dict(use_attrs)
copts_kw = empty_dict(opts_kw)
cplot_kw = empty_dict(plot_kw)

# convert SSP, RCP, CMIP formats in keys
if isinstance(data, dict):
data = process_keys(data, convert_scen_name)
if isinstance(plot_kw, dict):
cplot_kw = process_keys(cplot_kw, convert_scen_name)

# add ouranos default cycler colors
defaults_curves()

# if only one data input, insert in dict.
non_dict_data = False
if not isinstance(data, dict):
non_dict_data = True
data = {"_no_label": data} # mpl excludes labels starting with "_" from legend
cplot_kw = {"_no_label": cplot_kw}

# assign keys to plot_kw if not there
if non_dict_data is False:
for name in data:
if name not in cplot_kw:
cplot_kw[name] = {}
warnings.warn(
f"Key {name} not found in plot_kw. Using empty dict instead."
)
for key in plot_kw:
if key not in data:
raise KeyError(
'plot_kw must be a nested dictionary with keys corresponding to the keys in "data"'
)

# check: type
for name, arr in data.items():
if not isinstance(arr, (xr.Dataset, xr.DataArray)):
raise TypeError(
'"data" must be a xr.Dataset, a xr.DataArray or a dictionary of such objects.'
)

# check: 'time' dimension and calendar format
data = check_timeindex(data)

# add use attributes defaults ToDo: Adapt use_attrs to hvplot (already an option in hvplot with xarray)
use_attrs = {
"title": "description",
"ylabel": "long_name",
"yunits": "units",
} | use_attrs

# dict of array 'categories'
array_categ = {name: get_array_categ(array) for name, array in data.items()}

# dictionary of hvplots plots
figs = {}

# get data and plot
for name, arr in data.items():
# add defaults to plot_kw if not present (GReY BACKGORUDN LINES + USER ATTRS
# ToDo: Add user_attrs here and grey backgrounds lines
# ToDo: if legend = 'edge' add hook in opts_kw

# remove 'label' to avoid error due to double 'label' args
if "label" in plot_kw[name]:
del plot_kw[name]["label"]
warnings.warn(f'"label" entry in plot_kw[{name}] will be ignored.')

# SSP, RCP, CMIP model colors
cat_colors = (
Path(__file__).parents[1] / "data/ipcc_colors/categorical_colors.json"
)
if get_scen_color(name, cat_colors):
cplot_kw[name].setdefault("color", get_scen_color(name, cat_colors))

figs[name] = _plot_timeseries(
plot_kw, name, arr, array_categ, cplot_kw, copts_kw, non_dict_data, legend
)

if not legend:
return hv.Overlay(list(get_all_values(figs))).opts(show_legend=False)
else:
return hv.Overlay(list(get_all_values(figs)))
93 changes: 93 additions & 0 deletions figanos/hvplot/style/bokeh_theme.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
#Ouranos bokeh theme
#caliber matplotlib version: https://github.com/bokeh/bokeh/blob/8a5a35eb078e386c14b5580a51e8a6673d57d197/src/bokeh/themes/_caliber.py
#refer to IPCC style guide

attrs:
Plot:
background_fill_color: white
border_fill_color: white #color around plot - could be null to match behind plot
outline_line_color: !!null

Axis:
major_tick_in: 4
major_tick_out: 0
major_tick_line_alpha: 1
major_tick_line_color: black

minor_tick_line_alpha: 0
minor_tick_line_color: !!null

axis_line_alpha: 1
axis_line_color: black

major_label_text_color: black
major_label_text_font: DejaVu Sans
major_label_text_font_size: 10pt
major_label_text_font_style: normal

axis_label_standoff: 10
axis_label_text_color: black
axis_label_text_font: DejaVu Sans
axis_label_text_font_size: 11pt
axis_label_text_font_style: normal

Legend:
spacing: 8
glyph_width: 15

label_standoff: 8
label_text_color: black
label_text_font: DejaVu Sans
label_text_font_size: 10pt
label_text_font_style: normal

title_text_font: DejaVu Sans
title_text_font_style: normal
title_text_font_size: 11pt
title_text_color: black

border_line_alpha: 0
background_fill_alpha: 1
background_fill_color: white

BaseColorBar:
title_text_color: black
title_text_font: DejaVu Sans
title_text_font_size: 11pt
title_text_font_style: normal

major_label_text_color: black
major_label_text_font: DejaVu Sans
major_label_text_font_size: 10pt
major_label_text_font_style: normal

major_tick_line_alpha: 0
bar_line_alpha: 0

Grid:
grid_line_width: 0
grid_line_color: black
grid_line_alpha: 0.4

Title:
text_color: black
text_font: DejaVu Sans
text_font_style: normal
text_font_size: 14pt

Toolbar:
logo: !!null
autohide: True

figure:
toolbar_location: below

CategoricalColorMapper:
palette:
- '#052946'
- '#ce153d'
- '#18bbbb'
- '#fdc414'
- '#6850af'
- '#196a5e'
- '#7a5315'
Loading

0 comments on commit 260369c

Please sign in to comment.