diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5cf149a..2aec138 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,7 @@ Changelog 0.4.0 (unreleased) ------------------ -Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Marco Braun (:user:`vindelico`), Pascal Bourgault (:user:`aulemahal`), Sarah-Claude Bourdeau-Goulet (:user:`Sarahclaude`) +Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Marco Braun (:user:`vindelico`), Pascal Bourgault (:user:`aulemahal`), Sarah-Claude Bourdeau-Goulet (:user:`Sarahclaude`), Éric Dupuis (:user:`coxipi`) New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -13,6 +13,7 @@ New features and enhancements * Added style sheet ``transparent.mplstyle`` (:issue:`183`, :pull:`185`) * Fix ``NaN`` issues, extreme values in sizes legend and added edgecolors in ``fg.matplotlib.scattermap`` (:pull:`184`). * New function ``fg.data`` for fetching package data and defined `matplotlib` style definitions. (:pull:`211`). +* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:issue:`208`, :pull:`219`). Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/docs/notebooks/figanos_multiplots.ipynb b/docs/notebooks/figanos_multiplots.ipynb index 09a94ac..c04e44b 100644 --- a/docs/notebooks/figanos_multiplots.ipynb +++ b/docs/notebooks/figanos_multiplots.ipynb @@ -92,7 +92,7 @@ "metadata": {}, "source": [ "## Maps\n", - "Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the key row `row` and `col` in the argument `plot_kw`." + "Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the keys `row` and `col` in the argument `plot_kw`." ] }, { @@ -199,6 +199,31 @@ "im.fig.suptitle(\"Multiple hatchmaps\", y=1.08)\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Heatmaps\n", + "\n", + "The keys `row` and `col` in the argument `plot_kw` can also be used to create a grid of heatmaps. This is done by wrapping Seaborn's [heatmap](https://seaborn.pydata.org/generated/seaborn.heatmap.html) and [FacetGrid](https://seaborn.pydata.org/generated/seaborn.FacetGrid.html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_space = opened[['tx_max_p50']].isel(time=[0, 1, 2]).sel(lat=slice(40,65), lon=slice(-90,-55))\n", + "\n", + "# spatial subbomain\n", + "sl = slice(100,100+5)\n", + "da = ds_space.isel(lat=sl, lon=sl).drop(\"horizon\").tx_max_p50\n", + "da[\"lon\"] = np.round(da.lon,2)\n", + "da[\"lat\"] = np.round(da.lat,2)\n", + "fg.heatmap(da, plot_kw = {\"col\": \"time\"})" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -367,7 +392,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 858f888..dce5870 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -5,6 +5,7 @@ import math import warnings from collections.abc import Iterable +from inspect import signature from pathlib import Path from typing import Any @@ -1356,8 +1357,25 @@ def heatmap( raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup fig, axis - if ax is None: + if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()): fig, ax = plt.subplots(**fig_kw) + elif ax is not None and ("col" in plot_kw or "row" in plot_kw): + raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") + elif ax is None: + if any([k != "figsize" for k in fig_kw.keys()]): + warnings.warn( + "Only figsize arguments can be passed to fig_kw when using facetgrid." + ) + plot_kw.setdefault("col", None) + plot_kw.setdefault("row", None) + plot_kw.setdefault("margin_titles", True) + heatmap_dims = list( + set(da.dims) + - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None} + ) + if da.name is None: + da = da.to_dataset(name="data").data + da_name = da.name # create cbar label if ( @@ -1389,11 +1407,16 @@ def heatmap( ) # convert data to DataFrame - if len(da.coords) != 2: - raise ValueError("DataArray must have exactly two dimensions") if transpose: da = da.transpose() - df = da.to_pandas() + if "col" not in plot_kw and "row" not in plot_kw: + if len(da.dims) != 2: + raise ValueError("DataArray must have exactly two dimensions") + df = da.to_pandas() + else: + if len(heatmap_dims) != 2: + raise ValueError("DataArray must have exactly two dimensions") + df = da.to_dataframe().reset_index() # set defaults if divergent is not False: @@ -1409,21 +1432,60 @@ def heatmap( plot_kw.setdefault("cmap", cmap) # plot - sns.heatmap(df, ax=ax, **plot_kw) - - # format - plt.xticks(rotation=45, ha="right", rotation_mode="anchor") - ax.tick_params(axis="both", direction="out") + def draw_heatmap(*args, **kwargs): + data = kwargs.pop("data") + d = ( + data + if len(args) == 0 + # Any sorting should be performed before sending a DataArray in `fg.heatmap` + else data.pivot_table( + index=args[1], columns=args[0], values=args[2], sort=False + ) + ) + ax = sns.heatmap(d, **kwargs) + ax.set_xticklabels( + ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" + ) + ax.tick_params(axis="both", direction="out") + set_plot_attrs( + use_attrs, + da, + ax, + title_loc="center", + wrap_kw={"min_line_len": 35, "max_line_len": 44}, + ) + return ax - set_plot_attrs( - use_attrs, - da, - ax, - title_loc="center", - wrap_kw={"min_line_len": 35, "max_line_len": 44}, - ) + if ax is not None: + ax = draw_heatmap(data=df, ax=ax, **plot_kw) + return ax + elif "col" in plot_kw or "row" in plot_kw: + # When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function + # With Seaborn, we need to be more careful and separate keywords. + plot_kw_hm = { + k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters + } + plot_kw_fg = { + k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters + } + unused_keys = ( + set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys()) + ) + if unused_keys != set(): + raise ValueError( + f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords " + "allowed in `sns.heatmap` or `sns.FacetGrid`. " + ) - return ax + g = sns.FacetGrid(df, **plot_kw_fg) + cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9]) + g.map_dataframe( + draw_heatmap, *heatmap_dims, da_name, **plot_kw_hm, cbar=True, cbar_ax=cax + ) + g.fig.subplots_adjust(right=0.9) + if "figsize" in fig_kw.keys(): + g.fig.set_size_inches(*fig_kw["figsize"]) + return g def scattermap(