From b94cb5e52c9554ab888f054cfff100767f5d9a6f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Mon, 10 Jun 2024 18:05:40 -0400 Subject: [PATCH 1/9] Allow FacetGrids use in heatmaps --- CHANGELOG.rst | 3 +- docs/notebooks/figanos_multiplots.ipynb | 29 +++++++++- src/figanos/matplotlib/plot.py | 74 +++++++++++++++++++------ 3 files changed, 85 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 5cf149ac..756478f1 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -4,7 +4,7 @@ Changelog 0.4.0 (unreleased) ------------------ -Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Marco Braun (:user:`vindelico`), Pascal Bourgault (:user:`aulemahal`), Sarah-Claude Bourdeau-Goulet (:user:`Sarahclaude`) +Contributors to this version: Trevor James Smith (:user:`Zeitsperre`), Marco Braun (:user:`vindelico`), Pascal Bourgault (:user:`aulemahal`), Sarah-Claude Bourdeau-Goulet (:user:`Sarahclaude`), Éric Dupuis (:user:`coxipi`) New features and enhancements ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -13,6 +13,7 @@ New features and enhancements * Added style sheet ``transparent.mplstyle`` (:issue:`183`, :pull:`185`) * Fix ``NaN`` issues, extreme values in sizes legend and added edgecolors in ``fg.matplotlib.scattermap`` (:pull:`184`). * New function ``fg.data`` for fetching package data and defined `matplotlib` style definitions. (:pull:`211`). +* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. Breaking changes ^^^^^^^^^^^^^^^^ diff --git a/docs/notebooks/figanos_multiplots.ipynb b/docs/notebooks/figanos_multiplots.ipynb index 09a94acf..c04e44b9 100644 --- a/docs/notebooks/figanos_multiplots.ipynb +++ b/docs/notebooks/figanos_multiplots.ipynb @@ -92,7 +92,7 @@ "metadata": {}, "source": [ "## Maps\n", - "Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the key row `row` and `col` in the argument `plot_kw`." + "Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the keys `row` and `col` in the argument `plot_kw`." ] }, { @@ -199,6 +199,31 @@ "im.fig.suptitle(\"Multiple hatchmaps\", y=1.08)\n" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Heatmaps\n", + "\n", + "The keys `row` and `col` in the argument `plot_kw` can also be used to create a grid of heatmaps. This is done by wrapping Seaborn's [heatmap](https://seaborn.pydata.org/generated/seaborn.heatmap.html) and [FacetGrid](https://seaborn.pydata.org/generated/seaborn.FacetGrid.html)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "ds_space = opened[['tx_max_p50']].isel(time=[0, 1, 2]).sel(lat=slice(40,65), lon=slice(-90,-55))\n", + "\n", + "# spatial subbomain\n", + "sl = slice(100,100+5)\n", + "da = ds_space.isel(lat=sl, lon=sl).drop(\"horizon\").tx_max_p50\n", + "da[\"lon\"] = np.round(da.lon,2)\n", + "da[\"lat\"] = np.round(da.lat,2)\n", + "fg.heatmap(da, plot_kw = {\"col\": \"time\"})" + ] + }, { "cell_type": "markdown", "metadata": {}, @@ -367,7 +392,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.0" + "version": "3.12.3" } }, "nbformat": 4, diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 858f888e..6514962c 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1356,8 +1356,20 @@ def heatmap( raise TypeError("`data` must contain a xr.DataArray or xr.Dataset") # setup fig, axis - if ax is None: + if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()): fig, ax = plt.subplots(**fig_kw) + elif ax is not None and ("col" in plot_kw or "row" in plot_kw): + raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") + elif ax is None: + plot_kw.setdefault("col", None) + plot_kw.setdefault("row", None) + heatmap_dims = list( + set(da.dims) + - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None} + ) + if da.name is None: + da = da.to_dataset(name="data").data + da_name = da.name # create cbar label if ( @@ -1389,11 +1401,16 @@ def heatmap( ) # convert data to DataFrame - if len(da.coords) != 2: - raise ValueError("DataArray must have exactly two dimensions") if transpose: da = da.transpose() - df = da.to_pandas() + if "col" not in plot_kw and "row" not in plot_kw: + if len(da.dims) != 2: + raise ValueError("DataArray must have exactly two dimensions") + df = da.to_pandas() + else: + if len(heatmap_dims) != 2: + raise ValueError("DataArray must have exactly two dimensions") + df = da.to_dataframe().reset_index() # set defaults if divergent is not False: @@ -1409,21 +1426,42 @@ def heatmap( plot_kw.setdefault("cmap", cmap) # plot - sns.heatmap(df, ax=ax, **plot_kw) - - # format - plt.xticks(rotation=45, ha="right", rotation_mode="anchor") - ax.tick_params(axis="both", direction="out") - - set_plot_attrs( - use_attrs, - da, - ax, - title_loc="center", - wrap_kw={"min_line_len": 35, "max_line_len": 44}, - ) + def draw_heatmap(*args, **kwargs): + data = kwargs.pop("data") + d = ( + data + if len(args) == 0 + else data.pivot(index=args[1], columns=args[0], values=args[2]) + ) + ax = sns.heatmap(d, **kwargs) + ax.set_xticklabels( + ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor" + ) + ax.tick_params(axis="both", direction="out") + set_plot_attrs( + use_attrs, + da, + ax, + title_loc="center", + wrap_kw={"min_line_len": 35, "max_line_len": 44}, + ) + return ax - return ax + if ax is not None: + ax = draw_heatmap(data=df, ax=ax, **plot_kw) + return ax + elif "col" in plot_kw or "row" in plot_kw: + g = sns.FacetGrid(df, col=plot_kw["col"], row=plot_kw["row"]) + plot_kw.pop("col") + plot_kw.pop("row") + cax = g.fig.add_axes([0.92, 0.12, 0.02, 0.8]) + g.map_dataframe( + draw_heatmap, *heatmap_dims, da_name, **plot_kw, cbar=True, cbar_ax=cax + ) + g.fig.subplots_adjust(right=0.9) + if "figsize" in fig_kw.keys(): + g.fig.set_size_inches(*fig_kw["figsize"]) + return g def scattermap( From 94eb74ed67e7d634fb9c8479c23f5fbf07529ac9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Mon, 10 Jun 2024 18:07:52 -0400 Subject: [PATCH 2/9] update CHANGELOG --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index 756478f1..f8f22b72 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,7 @@ New features and enhancements * Added style sheet ``transparent.mplstyle`` (:issue:`183`, :pull:`185`) * Fix ``NaN`` issues, extreme values in sizes legend and added edgecolors in ``fg.matplotlib.scattermap`` (:pull:`184`). * New function ``fg.data`` for fetching package data and defined `matplotlib` style definitions. (:pull:`211`). -* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. +* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:pull:`219`, :issue:`208`). Breaking changes ^^^^^^^^^^^^^^^^ From bb1fde59b28f0a4f7a96486db29f3c6753abdc59 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Mon, 10 Jun 2024 18:08:33 -0400 Subject: [PATCH 3/9] respect convention: issue first, pull after --- CHANGELOG.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CHANGELOG.rst b/CHANGELOG.rst index f8f22b72..2aec138a 100644 --- a/CHANGELOG.rst +++ b/CHANGELOG.rst @@ -13,7 +13,7 @@ New features and enhancements * Added style sheet ``transparent.mplstyle`` (:issue:`183`, :pull:`185`) * Fix ``NaN`` issues, extreme values in sizes legend and added edgecolors in ``fg.matplotlib.scattermap`` (:pull:`184`). * New function ``fg.data`` for fetching package data and defined `matplotlib` style definitions. (:pull:`211`). -* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:pull:`219`, :issue:`208`). +* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:issue:`208`, :pull:`219`). Breaking changes ^^^^^^^^^^^^^^^^ From 49afac7890b632b801964e201db7248692ccd8e8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Tue, 11 Jun 2024 11:42:04 -0400 Subject: [PATCH 4/9] warn, only use figsize with facetgrid --- src/figanos/matplotlib/plot.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 6514962c..c667f874 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1361,6 +1361,10 @@ def heatmap( elif ax is not None and ("col" in plot_kw or "row" in plot_kw): raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.") elif ax is None: + if any([k != "figsize" for k in fig_kw.keys()]): + warnings.warn( + "Only figsize arguments can be passed to fig_kw when using facetgrid." + ) plot_kw.setdefault("col", None) plot_kw.setdefault("row", None) heatmap_dims = list( From b9417de3af5bb1b017129066b64d7753c341d8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Wed, 12 Jun 2024 16:54:23 -0400 Subject: [PATCH 5/9] split `plot_kw` in kws for FacetGrid and heatmap --- src/figanos/matplotlib/plot.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index c667f874..163e36aa 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -5,6 +5,7 @@ import math import warnings from collections.abc import Iterable +from inspect import signature from pathlib import Path from typing import Any @@ -1455,12 +1456,19 @@ def draw_heatmap(*args, **kwargs): ax = draw_heatmap(data=df, ax=ax, **plot_kw) return ax elif "col" in plot_kw or "row" in plot_kw: - g = sns.FacetGrid(df, col=plot_kw["col"], row=plot_kw["row"]) - plot_kw.pop("col") - plot_kw.pop("row") + # When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function + # With Seaborn, we need to be more careful and separate keywords. + plot_kw_hm = { + k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters + } + plot_kw_fg = { + k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters + } + + g = sns.FacetGrid(df, **plot_kw_fg) cax = g.fig.add_axes([0.92, 0.12, 0.02, 0.8]) g.map_dataframe( - draw_heatmap, *heatmap_dims, da_name, **plot_kw, cbar=True, cbar_ax=cax + draw_heatmap, *heatmap_dims, da_name, **plot_kw_hm, cbar=True, cbar_ax=cax ) g.fig.subplots_adjust(right=0.9) if "figsize" in fig_kw.keys(): From f9a368a3f27af32456beac53c82d91b35defe7c8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Thu, 13 Jun 2024 11:56:53 -0400 Subject: [PATCH 6/9] warning if illegal plot keys are used --- src/figanos/matplotlib/plot.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 163e36aa..9af70e7a 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1464,6 +1464,14 @@ def draw_heatmap(*args, **kwargs): plot_kw_fg = { k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters } + unused_keys = ( + set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys()) + ) + if unused_keys != set(): + warnings.warn( + f"`plot_kw` containted extra keywords: {unused_keys} that can't be used with `sns.heatmap` or `sns.FacetGrid`. " + "These keywords will be ignored" + ) g = sns.FacetGrid(df, **plot_kw_fg) cax = g.fig.add_axes([0.92, 0.12, 0.02, 0.8]) From d4dad6715bcb9c1043df21652ee5ed8dffbc8948 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Thu, 13 Jun 2024 12:03:29 -0400 Subject: [PATCH 7/9] warning -> error --- src/figanos/matplotlib/plot.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 9af70e7a..a50e74dc 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1468,9 +1468,9 @@ def draw_heatmap(*args, **kwargs): set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys()) ) if unused_keys != set(): - warnings.warn( - f"`plot_kw` containted extra keywords: {unused_keys} that can't be used with `sns.heatmap` or `sns.FacetGrid`. " - "These keywords will be ignored" + raise ValueError( + f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords " + "allowed in `sns.heatmap` or `sns.FacetGrid`. " ) g = sns.FacetGrid(df, **plot_kw_fg) From 20834e383055bebfb017fd26e7193c6c971e974b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Thu, 13 Jun 2024 14:59:14 -0400 Subject: [PATCH 8/9] no sorting in the heatmap function --- src/figanos/matplotlib/plot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index a50e74dc..51676b93 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1368,6 +1368,7 @@ def heatmap( ) plot_kw.setdefault("col", None) plot_kw.setdefault("row", None) + plot_kw.setdefault("margin_titles", True) heatmap_dims = list( set(da.dims) - {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None} @@ -1436,7 +1437,10 @@ def draw_heatmap(*args, **kwargs): d = ( data if len(args) == 0 - else data.pivot(index=args[1], columns=args[0], values=args[2]) + # Any sorting should be performed before sending a DataArray in `fg.heatmap` + else data.pivot_table( + index=args[1], columns=args[0], values=args[2], sort=False + ) ) ax = sns.heatmap(d, **kwargs) ax.set_xticklabels( From f1812b819aa4552cb5994eeaafee36e92245fcb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=89ric=20Dupuis?= Date: Wed, 19 Jun 2024 09:34:57 -0400 Subject: [PATCH 9/9] change cax config --- src/figanos/matplotlib/plot.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/figanos/matplotlib/plot.py b/src/figanos/matplotlib/plot.py index 51676b93..dce58700 100644 --- a/src/figanos/matplotlib/plot.py +++ b/src/figanos/matplotlib/plot.py @@ -1478,7 +1478,7 @@ def draw_heatmap(*args, **kwargs): ) g = sns.FacetGrid(df, **plot_kw_fg) - cax = g.fig.add_axes([0.92, 0.12, 0.02, 0.8]) + cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9]) g.map_dataframe( draw_heatmap, *heatmap_dims, da_name, **plot_kw_hm, cbar=True, cbar_ax=cax )