Skip to content

Commit

Permalink
Merge cedd621 into 1ec32c0
Browse files Browse the repository at this point in the history
  • Loading branch information
coxipi committed Jun 19, 2024
2 parents 1ec32c0 + cedd621 commit d3df931
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 18 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ New features and enhancements
* New function ``fg.data`` for fetching package data and defined `matplotlib` style definitions. (:pull:`211`).
* ``fg.taylordiagram`` can now accept datasets with many dimensions (not only `taylor_params`), provided that they all share the same `ref_std` (e.g. normalized taylor diagrams) (:pull:`214`).
* A new optional way to organize points in a `fg.taylordiagram` with `colors_key`, `markers_key` : DataArrays with a common dimension value or a common attrtibute are grouped with the same color/marker (:pull:`214`).
* Heatmap (`fg.matplotlib.heatmap`) now supports `row,col` arguments in `plot_kw`, allowing to plot a grid of heatmaps. (:issue:`208`, :pull:`219`).

Breaking changes
^^^^^^^^^^^^^^^^
Expand Down
27 changes: 26 additions & 1 deletion docs/notebooks/figanos_multiplots.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@
"metadata": {},
"source": [
"## Maps\n",
"Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the key row `row` and `col` in the argument `plot_kw`."
"Create multiple maps plot with figanos wrapped around [xr.plot.facetgrid.FacetGrid](https://docs.xarray.dev/en/latest/generated/xarray.plot.FacetGrid.html) by passing the keys `row` and `col` in the argument `plot_kw`."
]
},
{
Expand Down Expand Up @@ -199,6 +199,31 @@
"im.fig.suptitle(\"Multiple hatchmaps\", y=1.08)\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Heatmaps\n",
"\n",
"The keys `row` and `col` in the argument `plot_kw` can also be used to create a grid of heatmaps. This is done by wrapping Seaborn's [heatmap](https://seaborn.pydata.org/generated/seaborn.heatmap.html) and [FacetGrid](https://seaborn.pydata.org/generated/seaborn.FacetGrid.html)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"ds_space = opened[['tx_max_p50']].isel(time=[0, 1, 2]).sel(lat=slice(40,65), lon=slice(-90,-55))\n",
"\n",
"# spatial subbomain\n",
"sl = slice(100,100+5)\n",
"da = ds_space.isel(lat=sl, lon=sl).drop(\"horizon\").tx_max_p50\n",
"da[\"lon\"] = np.round(da.lon,2)\n",
"da[\"lat\"] = np.round(da.lat,2)\n",
"fg.heatmap(da, plot_kw = {\"col\": \"time\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand Down
96 changes: 79 additions & 17 deletions src/figanos/matplotlib/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import math
import warnings
from collections.abc import Iterable
from inspect import signature
from pathlib import Path
from typing import Any

Expand Down Expand Up @@ -1356,8 +1357,25 @@ def heatmap(
raise TypeError("`data` must contain a xr.DataArray or xr.Dataset")

# setup fig, axis
if ax is None:
if ax is None and ("row" not in plot_kw.keys() and "col" not in plot_kw.keys()):
fig, ax = plt.subplots(**fig_kw)
elif ax is not None and ("col" in plot_kw or "row" in plot_kw):
raise ValueError("Cannot use 'ax' and 'col'/'row' at the same time.")
elif ax is None:
if any([k != "figsize" for k in fig_kw.keys()]):
warnings.warn(
"Only figsize arguments can be passed to fig_kw when using facetgrid."
)
plot_kw.setdefault("col", None)
plot_kw.setdefault("row", None)
plot_kw.setdefault("margin_titles", True)
heatmap_dims = list(
set(da.dims)
- {d for d in [plot_kw["col"], plot_kw["row"]] if d is not None}
)
if da.name is None:
da = da.to_dataset(name="data").data
da_name = da.name

# create cbar label
if (
Expand Down Expand Up @@ -1389,11 +1407,16 @@ def heatmap(
)

# convert data to DataFrame
if len(da.coords) != 2:
raise ValueError("DataArray must have exactly two dimensions")
if transpose:
da = da.transpose()
df = da.to_pandas()
if "col" not in plot_kw and "row" not in plot_kw:
if len(da.dims) != 2:
raise ValueError("DataArray must have exactly two dimensions")
df = da.to_pandas()
else:
if len(heatmap_dims) != 2:
raise ValueError("DataArray must have exactly two dimensions")
df = da.to_dataframe().reset_index()

# set defaults
if divergent is not False:
Expand All @@ -1409,21 +1432,60 @@ def heatmap(
plot_kw.setdefault("cmap", cmap)

# plot
sns.heatmap(df, ax=ax, **plot_kw)

# format
plt.xticks(rotation=45, ha="right", rotation_mode="anchor")
ax.tick_params(axis="both", direction="out")
def draw_heatmap(*args, **kwargs):
data = kwargs.pop("data")
d = (
data
if len(args) == 0
# Any sorting should be performed before sending a DataArray in `fg.heatmap`
else data.pivot_table(
index=args[1], columns=args[0], values=args[2], sort=False
)
)
ax = sns.heatmap(d, **kwargs)
ax.set_xticklabels(
ax.get_xticklabels(), rotation=45, ha="right", rotation_mode="anchor"
)
ax.tick_params(axis="both", direction="out")
set_plot_attrs(
use_attrs,
da,
ax,
title_loc="center",
wrap_kw={"min_line_len": 35, "max_line_len": 44},
)
return ax

set_plot_attrs(
use_attrs,
da,
ax,
title_loc="center",
wrap_kw={"min_line_len": 35, "max_line_len": 44},
)
if ax is not None:
ax = draw_heatmap(data=df, ax=ax, **plot_kw)
return ax
elif "col" in plot_kw or "row" in plot_kw:
# When using xarray's FacetGrid, `plot_kw` can be used in the FacetGrid and in the plotting function
# With Seaborn, we need to be more careful and separate keywords.
plot_kw_hm = {
k: v for k, v in plot_kw.items() if k in signature(sns.heatmap).parameters
}
plot_kw_fg = {
k: v for k, v in plot_kw.items() if k in signature(sns.FacetGrid).parameters
}
unused_keys = (
set(plot_kw.keys()) - set(plot_kw_fg.keys()) - set(plot_kw_hm.keys())
)
if unused_keys != set():
raise ValueError(
f"`heatmap` got unexpected keywords in `plot_kw`: {unused_keys}. Keywords in `plot_kw` should be keywords "
"allowed in `sns.heatmap` or `sns.FacetGrid`. "
)

return ax
g = sns.FacetGrid(df, **plot_kw_fg)
cax = g.fig.add_axes([0.95, 0.05, 0.02, 0.9])
g.map_dataframe(
draw_heatmap, *heatmap_dims, da_name, **plot_kw_hm, cbar=True, cbar_ax=cax
)
g.fig.subplots_adjust(right=0.9)
if "figsize" in fig_kw.keys():
g.fig.set_size_inches(*fig_kw["figsize"])
return g


def scattermap(
Expand Down

0 comments on commit d3df931

Please sign in to comment.