Skip to content

Commit

Permalink
Merge a40fd9e into e2c9ef6
Browse files Browse the repository at this point in the history
  • Loading branch information
juliettelavoie committed Dec 11, 2023
2 parents e2c9ef6 + a40fd9e commit 6b53d52
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 0 deletions.
1 change: 1 addition & 0 deletions figanos/matplotlib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
gridmap,
hatchmap,
heatmap,
partition,
scattermap,
stripes,
taylordiagram,
Expand Down
120 changes: 120 additions & 0 deletions figanos/matplotlib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2037,3 +2037,123 @@ def hatchmap(
set_plot_attrs(use_attrs, dattrs, ax, wrap_kw={"max_line_len": 60})

return ax


def _add_lead_time_coord(da, ref):
"""Add a lead time coordinate to the data. Modifies da in-place."""
lead_time = da.time.dt.year - int(ref)
da["Lead time"] = lead_time
da["Lead time"].attrs["units"] = f"years from {ref}"
return lead_time


def partition(
data: xr.DataArray | xr.Dataset,
ax: matplotlib.axes.Axes | None = None,
start_year: str | None = None,
show_num: bool = True,
fill_kw: dict[str, Any] | None = None,
line_kw: dict[str, Any] | None = None,
fig_kw: dict[str, Any] | None = None,
legend_kw: dict[str, Any] | None = None,
) -> matplotlib.axes.Axes:
"""Figure of the partition of total uncertainty by components.
See Hawkins and Sutton (2009) and Lafferty and Sriver (2023) for example.
Parameters
----------
data: xr.DataArray or xr.Dataset
Variance over time of the different components of uncertainty.
Output of a `xclim.ensembles._partitioning` function.
ax : matplotlib axis, optional
Matplotlib axis on which to plot
start_year: str
If None, the x-axis will be the time in year.
If str, the x-axis will show the number of year since start_year.
show_num: bool
If True, show the number of components in parenthesis in the legend.
`variance` should have a coordinate `num`.
fill_kw: dict
Keyword arguments passed to `ax.fill_between`.
It is possible to pass a dictionary of keywords for each component (uncertainty coordinates).
line_kw: dict
Keyword arguments passed to `ax.plot` for the lines in between the components.
The default is {color="k", lw=2}.
fig_kw: dict
Keyword arguments passed to `plt.subplots`.
legend_kw: dict
Keyword arguments passed to `ax.legend`.
Returns
-------
mpl.axes.Axes
"""
fill_kw = empty_dict(fill_kw)
line_kw = empty_dict(line_kw)
fig_kw = empty_dict(fig_kw)
legend_kw = empty_dict(legend_kw)

# select data to plot
if isinstance(data, xr.DataArray):
data = data.squeeze()
elif isinstance(data, xr.Dataset): # in case, it was save to disk before plotting.
if len(data.data_vars) > 1:
warnings.warn(
"data is xr.Dataset; only the first variable will be used in plot"
)
data = data[list(data.keys())[0]].squeeze()
else:
raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")

if ax is None:
fig, ax = plt.subplots(**fig_kw)

# Select data from reference year onward
if start_year:
data = data.sel(time=slice(start_year, None))

# Lead time coordinate
time = _add_lead_time_coord(data, start_year)
ax.set_xlabel(f"Lead time [years from {start_year}]")
else:
time = data.time.dt.year

# fill_kw that are direct (not with uncertainty as key)
fk_direct = {k: v for k, v in fill_kw.items() if (k not in data.uncertainty.values)}

# Draw areas
past_y = 0
black_lines = []
for u in data.uncertainty.values:
if u not in ["total", "variability"]:
present_y = past_y + data.sel(uncertainty=u)
num = len(data.sel(uncertainty=u).elements.values.tolist())
label = f"{u} ({num})" if show_num else u
ax.fill_between(
time, past_y, present_y, label=label, **fill_kw.get(u, fk_direct)
)
black_lines.append(present_y)
past_y = present_y
ax.fill_between(
time, past_y, 100, label="variability", **fill_kw.get("variability", fk_direct)
)

# Draw black lines
line_kw.setdefault("color", "k")
line_kw.setdefault("lw", 2)
ax.plot(time, np.array(black_lines).T, **line_kw)

# TODO: think if this needs to be accessible
ax.xaxis.set_major_locator(matplotlib.ticker.MultipleLocator(20))
ax.xaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=5))

ax.yaxis.set_major_locator(matplotlib.ticker.MultipleLocator(10))
ax.yaxis.set_minor_locator(matplotlib.ticker.AutoMinorLocator(n=2))

ax.set_ylabel(f"{data.attrs['long_name']} ({data.attrs['units']})") #

ax.set_ylim(0, 100)
ax.legend(**legend_kw)

return ax

0 comments on commit 6b53d52

Please sign in to comment.