diff --git a/.gitignore b/.gitignore index bbd6bf10..0bf4ea4e 100644 --- a/.gitignore +++ b/.gitignore @@ -30,6 +30,7 @@ sources # Python extras .ipynb_checkpoints *.log +*.ipnyb *.pyc .*.pyc __pycache__ diff --git a/ultraplot/axes/cartesian.py b/ultraplot/axes/cartesian.py index 46685b5d..0678a2a8 100644 --- a/ultraplot/axes/cartesian.py +++ b/ultraplot/axes/cartesian.py @@ -386,9 +386,7 @@ def _apply_axis_sharing(self): # bottommost or to the *right* of the leftmost panel. But the sharing level # used for the leftmost and bottommost is the *figure* sharing level. - # Get border axes once for efficiency border_axes = self.figure._get_border_axes() - # Apply X axis sharing self._apply_axis_sharing_for_axis("x", border_axes) @@ -412,128 +410,31 @@ def _apply_axis_sharing_for_axis( """ if axis_name == "x": axis = self.xaxis - shared_axis = self._sharex - panel_group = self._panel_sharex_group + shared_axis = self._sharex # do we share the xaxis? + panel_group = self._panel_sharex_group # do we have a panel? sharing_level = self.figure._sharex - label_params = ["labeltop", "labelbottom"] - border_sides = ["top", "bottom"] else: # axis_name == 'y' axis = self.yaxis shared_axis = self._sharey panel_group = self._panel_sharey_group sharing_level = self.figure._sharey - label_params = ["labelleft", "labelright"] - border_sides = ["left", "right"] - if shared_axis is None or not axis.get_visible(): + if not axis.get_visible(): return level = 3 if panel_group else sharing_level # Handle axis label sharing (level > 0) - if level > 0: + # If we are a border axis, @shared_axis may be None + # We propagate this through the _determine_tick_label_visiblity() logic + if level > 0 and shared_axis: shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") labels._transfer_label(axis.label, shared_axis_obj.label) axis.label.set_visible(False) - # Handle tick label sharing (level > 2) - if level > 2: - label_visibility = self._determine_tick_label_visibility( - axis, - shared_axis, - axis_name, - label_params, - border_sides, - border_axes, - ) - axis.set_tick_params(which="both", **label_visibility) # Turn minor ticks off axis.set_minor_formatter(mticker.NullFormatter()) - def _determine_tick_label_visibility( - self, - axis: maxis.Axis, - shared_axis: maxis.Axis, - axis_name: str, - label_params: list[str], - border_sides: list[str], - border_axes: dict[str, list[plot.PlotAxes]], - ) -> dict[str, bool]: - """ - Determine which tick labels should be visible based on sharing rules and borders. - - Parameters - ---------- - axis : matplotlib axis - The current axis object - shared_axis : Axes - The axes this one shares with - axis_name : str - Either 'x' or 'y' - label_params : list - List of label parameter names (e.g., ['labeltop', 'labelbottom']) - border_sides : list - List of border side names (e.g., ['top', 'bottom']) - border_axes : dict - Dictionary from _get_border_axes() - - Returns - ------- - dict - Dictionary of label visibility parameters - """ - ticks = axis.get_tick_params() - shared_axis_obj = getattr(shared_axis, f"{axis_name}axis") - sharing_ticks = shared_axis_obj.get_tick_params() - - label_visibility = {} - - def _convert_label_param(label_param: str) -> str: - # Deal with logic not being consistent - # in prior mpl versions - if version.parse(str(_version_mpl)) <= version.parse("3.9"): - if label_param == "labeltop" and axis_name == "x": - label_param = "labelright" - elif label_param == "labelbottom" and axis_name == "x": - label_param = "labelleft" - return label_param - - for label_param, border_side in zip(label_params, border_sides): - # Check if user has explicitly set label location via format() - label_visibility[label_param] = False - has_panel = False - for panel in self._panel_dict[border_side]: - # Check if the panel is a colorbar - colorbars = [ - values - for key, values in self._colorbar_dict.items() - if border_side in key # key is tuple (side, top | center | lower) - ] - if not panel in colorbars: - # Skip colorbar as their - # yaxis is not shared - has_panel = True - break - # When we have a panel, let the panel have - # the labels and turn-off for this axis + side. - if has_panel: - continue - is_border = self in border_axes.get(border_side, []) - is_panel = ( - self in shared_axis._panel_dict[border_side] - and self == shared_axis._panel_dict[border_side][-1] - ) - # Use automatic border detection logic - # if we are a panel we "push" the labels outwards - label_param_trans = _convert_label_param(label_param) - is_this_tick_on = ticks[label_param_trans] - is_parent_tick_on = sharing_ticks[label_param_trans] - if is_panel: - label_visibility[label_param] = is_parent_tick_on - elif is_border: - label_visibility[label_param] = is_this_tick_on - return label_visibility - def _add_alt(self, sx, **kwargs): """ Add an alternate axes. diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index 896bc0a6..15c5f9a4 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -652,27 +652,16 @@ def _apply_axis_sharing(self): or to the *right* of the leftmost panel. But the sharing level used for the leftmost and bottommost is the *figure* sharing level. """ - # Handle X axis sharing - if self._sharex: - self._handle_axis_sharing( - source_axis=self._sharex._lonaxis, - target_axis=self._lonaxis, - ) - # Handle Y axis sharing - if self._sharey: - self._handle_axis_sharing( - source_axis=self._sharey._lataxis, - target_axis=self._lataxis, - ) + # Share interval x + if self._sharex and self.figure._sharex >= 2: + self._lonaxis.set_view_interval(*self._sharex._lonaxis.get_view_interval()) + self._lonaxis.set_minor_locator(self._sharex._lonaxis.get_minor_locator()) - # This block is apart of the draw sequence as the - # gridliner object is created late in the - # build chain. - if not self.stale: - return - if self.figure._get_sharing_level() == 0: - return + # Share interval y + if self._sharey and self.figure._sharey >= 2: + self._lataxis.set_view_interval(*self._sharey._lataxis.get_view_interval()) + self._lataxis.set_minor_locator(self._sharey._lataxis.get_minor_locator()) def _get_gridliner_labels( self, @@ -691,38 +680,36 @@ def _toggle_gridliner_labels( labelright=None, geo=None, ): - # For BasemapAxes the gridlines are dicts with key as the coordinate and keys the line and label - # We override the dict here assuming the labels are mut excl due to the N S E W extra chars + """ + Toggle visibility of gridliner labels for each direction. + + Parameters + ---------- + labeltop, labelbottom, labelleft, labelright : bool or None + Whether to show labels on each side. If None, do not change. + geo : optional + Not used in this method. + """ + # Ensure gridlines_major is fully initialized if any(i is None for i in self.gridlines_major): return + gridlabels = self._get_gridliner_labels( bottom=labelbottom, top=labeltop, left=labelleft, right=labelright ) - bools = [labelbottom, labeltop, labelleft, labelright] - directions = "bottom top left right".split() - for direction, toggle in zip(directions, bools): + + toggles = { + "bottom": labelbottom, + "top": labeltop, + "left": labelleft, + "right": labelright, + } + + for direction, toggle in toggles.items(): if toggle is None: continue for label in gridlabels.get(direction, []): - label.set_visible(toggle) - - def _handle_axis_sharing( - self, - source_axis: "GeoAxes", - target_axis: "GeoAxes", - ): - """ - Helper method to handle axis sharing for both X and Y axes. - - Args: - source_axis: The source axis to share from - target_axis: The target axis to apply sharing to - """ - # Copy view interval and minor locator from source to target - - if self.figure._get_sharing_level() >= 2: - target_axis.set_view_interval(*source_axis.get_view_interval()) - target_axis.set_minor_locator(source_axis.get_minor_locator()) + label.set_visible(bool(toggle) or toggle in ("x", "y")) @override def draw(self, renderer=None, *args, **kwargs): @@ -1441,6 +1428,7 @@ def _is_ticklabel_on(self, side: str) -> bool: """ # Deal with different cartopy versions left_labels, right_labels, bottom_labels, top_labels = self._get_side_labels() + if self.gridlines_major is None: return False elif side == "labelleft": diff --git a/ultraplot/axes/polar.py b/ultraplot/axes/polar.py index d66e3e2e..94950179 100644 --- a/ultraplot/axes/polar.py +++ b/ultraplot/axes/polar.py @@ -4,6 +4,11 @@ """ import inspect +try: + from typing import override +except: + from typing_extensions import override + import matplotlib.projections.polar as mpolar import numpy as np @@ -138,6 +143,11 @@ def __init__(self, *args, **kwargs): for axis in (self.xaxis, self.yaxis): axis.set_tick_params(which="both", size=0) + @override + def _apply_axis_sharing(self): + # Not implemented. Silently pass + return + def _update_formatter(self, x, *, formatter=None, formatter_kw=None): """ Update the gridline label formatter. diff --git a/ultraplot/figure.py b/ultraplot/figure.py index d44f31e6..27062ddd 100644 --- a/ultraplot/figure.py +++ b/ultraplot/figure.py @@ -6,6 +6,7 @@ import inspect import os from numbers import Integral +from packaging import version try: from typing import List @@ -20,6 +21,11 @@ import matplotlib.transforms as mtransforms import numpy as np +try: + from typing import override +except: + from typing_extensions import override + from . import axes as paxes from . import constructor from . import gridspec as pgridspec @@ -477,6 +483,21 @@ def _canvas_preprocess(self, *args, **kwargs): return canvas +def _clear_border_cache(func): + """ + Decorator that clears the border cache after function execution. + """ + + @functools.wraps(func) + def wrapper(self, *args, **kwargs): + result = func(self, *args, **kwargs) + if hasattr(self, "_cache_border_axes"): + delattr(self, "_cache_border_axes") + return result + + return wrapper + + class Figure(mfigure.Figure): """ The `~matplotlib.figure.Figure` subclass used by ultraplot. @@ -801,6 +822,145 @@ def __init__( # NOTE: This ignores user-input rc_mode. self.format(rc_kw=rc_kw, rc_mode=1, skip_axes=True, **kw_format) + @override + def draw(self, renderer): + # implement the tick sharing here + # should be shareable --> either all cartesian or all geographic + # but no mixing (panels can be mixed) + # check which ticks are on for x or y and push the labels to the + # outer most on a given column or row. + # we can use get_border_axes for the outermost plots and then collect their outermost panels that are not colorbars + self._share_ticklabels(axis="x") + self._share_ticklabels(axis="y") + super().draw(renderer) + + def _share_ticklabels(self, *, axis: str) -> None: + """ + Tick label sharing is determined at the figure level. While + each subplot controls the limits, we are dealing with the ticklabels + here as the complexity is easiier to deal with. + axis: str 'x' or 'y', row or columns to update + """ + if not self.stale: + return + + outer_axes = self._get_border_axes() + true_outer = {} + + sides = ("top", "bottom") if axis == "x" else ("left", "right") + # for panels + other_axis = "x" if axis == "y" else "y" + other_sides = ("left", "right") if axis == "x" else ("top", "bottom") + # Outer_axes contains the main grid but we need + # to add the panels that are on these axes potentially + tick_params = {} + + # Check if any of the ticks are set to on for @axis + subplot_types = set() + + from packaging import version + from .internals import _version_mpl + + mpl_version = version.parse(str(_version_mpl)) + use_new_labels = mpl_version >= version.parse("3.10") + + label_map = { + "labeltop": "labeltop" if use_new_labels else "labelright", + "labelbottom": "labelbottom" if use_new_labels else "labelleft", + "labelleft": "labelleft", + "labelright": "labelright", + } + + labelleft = label_map["labelleft"] + labelright = label_map["labelright"] + labeltop = label_map["labeltop"] + labelbottom = label_map["labelbottom"] + + for axi in self._iter_axes(panels=True, hidden=False): + if not type(axi) in ( + paxes.CartesianAxes, + paxes._CartopyAxes, + paxes._BasemapAxes, + ): + warnings._warn_ultraplot( + f"Tick label sharing not implemented for {type(axi)} subplots." + ) + return + if not axi._panel_side: + subplot_types.add(type(axi)) + match axis: + # Handle x + case "x" if isinstance(axi, paxes.CartesianAxes): + tmp = axi.xaxis.get_tick_params() + if tmp.get(labeltop): + tick_params[labeltop] = tmp[labeltop] + if tmp.get(labelbottom): + tick_params[labelbottom] = tmp[labelbottom] + + case "x" if isinstance(axi, paxes.GeoAxes): + if axi._is_ticklabel_on("labeltop"): + tick_params["labeltop"] = axi._is_ticklabel_on("labeltop") + if axi._is_ticklabel_on("labelbottom"): + tick_params["labelbottom"] = axi._is_ticklabel_on("labelbottom") + + # Handle y + case "y" if isinstance(axi, paxes.CartesianAxes): + tmp = axi.yaxis.get_tick_params() + if tmp.get(labelleft): + tick_params[labelleft] = tmp[labelleft] + if tmp.get(labelright): + tick_params[labelright] = tmp[labelright] + + case "y" if isinstance(axi, paxes.GeoAxes): + if axi._is_ticklabel_on("labelleft"): + tick_params["labelleft"] = axi._is_ticklabel_on("labelleft") + if axi._is_ticklabel_on("labelright"): + tick_params["labelright"] = axi._is_ticklabel_on("labelright") + + # We cannot mix types (yet) + if len(subplot_types) > 1: + warnings._warn_ultraplot( + "Tick label sharing not implemented for mixed subplot types." + ) + return + for axi in self._iter_axes(panels=True, hidden=False): + tmp = tick_params.copy() + # For sharing limits and or axis labels we + # can leave the ticks as found + for side in sides: + label = f"label{side}" + if isinstance(axi, paxes.CartesianAxes): + # Ignore for geo as it internally converts + label = label_map[label] + if axi not in outer_axes[side]: + tmp[label] = False + + # Determine sharing level + level = getattr(self, f"_share{axis}") + if axis == "y": + # For panels + if hasattr(axi, "_panel_sharey_group") and axi._panel_sharey_group: + level = 3 + elif axi._panel_side and axi._sharey: + level = 3 + else: # x-axis + # For panels + if hasattr(axi, "_panel_sharex_group") and axi._panel_sharex_group: + level = 3 + elif axi._panel_side and axi._sharex: + level = 3 + + if level < 3: + continue + if isinstance(axi, paxes.GeoAxes): + # TODO: move this to tick_params? + # Tick_params is independent of gridliner objects + # Depending on the backend tick params is useful or not + axi._toggle_gridliner_labels(**tmp) + elif tmp: + getattr(axi, f"{axis}axis").set_tick_params(**tmp) + self.stale = True + def _context_adjusting(self, cache=True): """ Prevent re-running auto layout steps due to draws triggered by figure @@ -928,8 +1088,9 @@ def _get_border_axes( if gs is None: return border_axes - # Skip colorbars or panels etc - all_axes = [axi for axi in self.axes if axi.number is not None] + all_axes = [] + for axi in self._iter_axes(panels=True): + all_axes.append(axi) # Handle empty cases nrows, ncols = gs.nrows, gs.ncols @@ -941,26 +1102,52 @@ def _get_border_axes( # Reconstruct the grid based on axis locations. Note that # spanning axes will fit into one of the boxes. Check # this with unittest to see how empty axes are handles - grid, grid_axis_type, seen_axis_type = _get_subplot_layout( - gs, - all_axes, - same_type=same_type, - ) + + gs = self.axes[0].get_gridspec() + shape = (gs.nrows_total, gs.ncols_total) + grid = np.zeros(shape, dtype=object) + grid.fill(None) + grid_axis_type = np.zeros(shape, dtype=int) + seen_axis_type = dict() + ax_type_mapping = dict() + for axi in self._iter_axes(panels=True, hidden=True): + gs = axi.get_subplotspec() + x, y = np.unravel_index(gs.num1, shape) + span = gs._get_rows_columns() + + xleft, xright, yleft, yright = span + xspan = xright - xleft + 1 + yspan = yright - yleft + 1 + number = axi.number + axis_type = type(axi) + if isinstance(axi, (paxes.GeoAxes)): + axis_type = axi.projection + if axis_type not in seen_axis_type: + seen_axis_type[axis_type] = len(seen_axis_type) + type_number = seen_axis_type[axis_type] + ax_type_mapping[axi] = type_number + if axi.get_visible(): + grid[x : x + xspan, y : y + yspan] = axi + grid_axis_type[x : x + xspan, y : y + yspan] = type_number # We check for all axes is they are a border or not # Note we could also write the crawler in a way where # it find the borders by moving around in the grid, without spawning on each axis point. We may change # this in the future for axi in all_axes: - axis_type = seen_axis_type.get(type(axi), 1) + axis_type = ax_type_mapping[axi] + number = axi.number + if axi.number is None: + number = -axi._panel_parent.number crawler = _Crawler( ax=axi, grid=grid, - target=axi.number, + target=number, axis_type=axis_type, grid_axis_type=grid_axis_type, ) for direction, is_border in crawler.find_edges(): - if is_border: + # print(">>", is_border, direction, axi.number) + if is_border and axi not in border_axes[direction]: border_axes[direction].append(axi) self._cached_border_axes = border_axes return border_axes @@ -1054,12 +1241,7 @@ def _get_renderer(self): renderer = canvas.get_renderer() return renderer - def _get_sharing_level(self): - """ - We take the average here as the sharex and sharey should be the same value. In case this changes in the future we can track down the error easily - """ - return 0.5 * (self.figure._sharex + self.figure._sharey) - + @_clear_border_cache def _add_axes_panel(self, ax, side=None, **kwargs): """ Add an axes panel. @@ -1096,6 +1278,23 @@ def _add_axes_panel(self, ax, side=None, **kwargs): pax = self.add_subplot(ss, **kwargs) pax._panel_side = side pax._panel_share = share + if share: + # When we are sharing we remove the ticks by default + # as we "push" the labels out. See Figure._share_ticklabels. + # If we add the labels here it is more difficult to control + # for some ticks being on. + from packaging import version + from .internals import _version_mpl + + params = {} + if version.parse(str(_version_mpl)) < version.parse("3.10"): + params = dict(labelleft=False, labelright=False) + pax.xaxis.set_tick_params(**params) + pax.yaxis.set_tick_params(**params) + else: + pax.xaxis.set_tick_params(labelbottom=False, labeltop=False) + pax.yaxis.set_tick_params(labelleft=False, labelright=False) + pax._panel_parent = ax ax._panel_dict[side].append(pax) ax._apply_auto_share() @@ -1104,6 +1303,7 @@ def _add_axes_panel(self, ax, side=None, **kwargs): axis.set_label_position(side) # set label position return pax + @_clear_border_cache def _add_figure_panel( self, side=None, span=None, row=None, col=None, rows=None, cols=None, **kwargs ): @@ -1138,6 +1338,7 @@ def _add_figure_panel( pax._panel_parent = None return pax + @_clear_border_cache def _add_subplot(self, *args, **kwargs): """ The driver function for adding single subplots. @@ -1246,9 +1447,6 @@ def _add_subplot(self, *args, **kwargs): if ax.number: self._subplot_dict[ax.number] = ax - # Invalidate border axes cache - if hasattr(self, "_cached_border_axes"): - delattr(self, "_cached_border_axes") return ax def _unshare_axes(self): @@ -1263,56 +1461,6 @@ def _unshare_axes(self): if isinstance(ax, paxes.GeoAxes) and hasattr(ax, "set_global"): ax.set_global() - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Only apply sharing of labels when we are - # actually sharing labels. - if self._get_sharing_level() == 0: - return - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - border_axes = self._get_border_axes() - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi._apply_axis_sharing() - def _toggle_axis_sharing( self, *, @@ -1728,6 +1876,7 @@ def _update_super_title(self, title, **kwargs): if title is not None: self._suptitle.set_text(title) + @_clear_border_cache @docstring._concatenate_inherited @docstring._snippet_manager def add_axes(self, rect, **kwargs): @@ -1822,7 +1971,6 @@ def _align_content(): # noqa: E306 # subsequent tight layout really weird. Have to resize twice. _draw_content() if not gs: - print("hello") return if aspect: gs._auto_layout_aspect() @@ -1968,12 +2116,6 @@ def format( } ax.format(rc_kw=rc_kw, rc_mode=rc_mode, skip_figure=True, **kw, **kwargs) ax.number = store_old_number - # When we apply formatting to all axes, we need - # to potentially adjust the labels. - - if len(axs) == len(self.axes) and self._get_sharing_level() > 0: - self._share_labels_with_others() - # Warn unused keyword argument(s) kw = { key: value @@ -1985,53 +2127,6 @@ def format( f"Ignoring unused projection-specific format() keyword argument(s): {kw}" # noqa: E501 ) - def _share_labels_with_others(self, *, which="both"): - """ - Helpers function to ensure the labels - are shared for rectilinear GeoAxes. - """ - # Turn all labels off - # Note: this action performs it for all the axes in - # the figure. We use the stale here to only perform - # it once as it is an expensive action. - border_axes = self._get_border_axes(same_type=False) - # Recode: - recoded = {} - for direction, axes in border_axes.items(): - for axi in axes: - recoded[axi] = recoded.get(axi, []) + [direction] - - # We turn off the tick labels when the scale and - # ticks are shared (level > 0) - are_ticks_on = False - default = dict( - labelleft=are_ticks_on, - labelright=are_ticks_on, - labeltop=are_ticks_on, - labelbottom=are_ticks_on, - ) - for axi in self._iter_axes(hidden=False, panels=False, children=False): - # Turn the ticks on or off depending on the position - sides = recoded.get(axi, []) - turn_on_or_off = default.copy() - # The axis will be a border if it is either - # (a) on the edge - # (b) not next to a subplot - # (c) not next to a subplot of the same kind - for side in sides: - sidelabel = f"label{side}" - is_label_on = axi._is_ticklabel_on(sidelabel) - if is_label_on: - # When we are a border an the labels are on - # we keep them on - assert sidelabel in turn_on_or_off - turn_on_or_off[sidelabel] = True - - if isinstance(axi, paxes.GeoAxes): - axi._toggle_gridliner_labels(**turn_on_or_off) - else: - axi.tick_params(which=which, **turn_on_or_off) - @docstring._concatenate_inherited @docstring._snippet_manager def colorbar( diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 159cac2c..029b61c1 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -195,7 +195,7 @@ def _get_rows_columns(self, ncols=None): row2, col2 = divmod(self.num2, ncols) return row1, row2, col1, col2 - def _get_grid_span(self, hidden=False) -> (int, int, int, int): + def _get_grid_span(self, hidden=True) -> (int, int, int, int): """ Retrieve the location of the subplot within the gridspec. When hidden is False we only consider @@ -203,11 +203,12 @@ def _get_grid_span(self, hidden=False) -> (int, int, int, int): """ gs = self.get_gridspec() nrows, ncols = gs.nrows_total, gs.ncols_total - if not hidden: + if hidden: + x, y = np.unravel_index(self.num1, (nrows, ncols)) + else: nrows, ncols = gs.nrows, gs.ncols - # Use num1 or num2 - decoded = gs._decode_indices(self.num1) - x, y = np.unravel_index(decoded, (nrows, ncols)) + decoded = gs._decode_indices(self.num1) + x, y = np.unravel_index(decoded, (nrows, ncols)) span = self._get_rows_columns() xspan = span[1] - span[0] + 1 # inclusive diff --git a/ultraplot/tests/conftest.py b/ultraplot/tests/conftest.py index e6848aba..db2482d9 100644 --- a/ultraplot/tests/conftest.py +++ b/ultraplot/tests/conftest.py @@ -3,7 +3,6 @@ import warnings, logging logging.getLogger("matplotlib").setLevel(logging.ERROR) - SEED = 51423 diff --git a/ultraplot/tests/test_2dplots.py b/ultraplot/tests/test_2dplots.py index 13f084c6..a2b75319 100644 --- a/ultraplot/tests/test_2dplots.py +++ b/ultraplot/tests/test_2dplots.py @@ -30,12 +30,12 @@ def test_auto_diverging1(rng): """ # Test with basic data fig = uplt.figure() - # fig.format(collabels=('Auto sequential', 'Auto diverging'), suptitle='Default') ax = fig.subplot(121) ax.pcolor(rng.random((10, 10)) * 5, colorbar="b") ax = fig.subplot(122) ax.pcolor(rng.random((10, 10)) * 5 - 3.5, colorbar="b") fig.format(toplabels=("Sequential", "Diverging")) + fig.canvas.draw() return fig diff --git a/ultraplot/tests/test_axes.py b/ultraplot/tests/test_axes.py index a04c2233..75ccb3aa 100644 --- a/ultraplot/tests/test_axes.py +++ b/ultraplot/tests/test_axes.py @@ -352,7 +352,7 @@ def test_sharing_labels_top_right(): [3, 4, 5], [3, 4, 0], ], - 3, # default sharing level + True, # default sharing level {"xticklabelloc": "t", "yticklabelloc": "r"}, [1, 3, 4], # y-axis labels visible indices [0, 1, 4], # x-axis labels visible indices @@ -405,6 +405,7 @@ def check_state(ax, numbers, state, which): # Format axes with the specified tick label locations ax.format(**tick_loc) + fig.canvas.draw() # needed for sharing labels # Calculate the indices where labels should be hidden all_indices = list(range(len(ax))) diff --git a/ultraplot/tests/test_figure.py b/ultraplot/tests/test_figure.py index 0e92f8f2..cffa3c7f 100644 --- a/ultraplot/tests/test_figure.py +++ b/ultraplot/tests/test_figure.py @@ -58,7 +58,17 @@ def test_unsharing_different_rectilinear(): """ with pytest.warns(uplt.internals.warnings.UltraPlotWarning): fig, ax = uplt.subplots(ncols=2, proj=("cyl", "merc"), share="all") - uplt.close(fig) + + +def test_get_renderer_basic(): + """ + Test that _get_renderer returns a renderer object. + """ + fig, ax = uplt.subplots() + renderer = fig._get_renderer() + # Renderer should not be None and should have draw_path method + assert renderer is not None + assert hasattr(renderer, "draw_path") def test_figure_sharing_toggle(): diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 35789a54..6eef28fd 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -296,6 +296,7 @@ def are_labels_on(ax, which=["top", "bottom", "right", "left"]) -> tuple[bool]: settings = dict(land=True, ocean=True, labels="both") fig, ax = uplt.subplots(layout, share="all", proj="cyl") ax.format(**settings) + fig.canvas.draw() # needed for sharing labels for axi in ax: state = are_labels_on(axi) expectation = expectations[axi.number - 1] @@ -491,7 +492,8 @@ def test_get_gridliner_labels_cartopy(): uplt.close(fig) -def test_sharing_levels(): +@pytest.mark.parametrize("level", [0, 1, 2, 3, 4]) +def test_sharing_levels(level): """ We can share limits or labels. We check if we can do both for the GeoAxes. @@ -515,7 +517,6 @@ def test_sharing_levels(): x = np.array([0, 10]) y = np.array([0, 10]) - sharing_levels = [0, 1, 2, 3, 4] lonlim = latlim = np.array((-10, 10)) def assert_views_are_sharing(ax): @@ -551,46 +552,42 @@ def assert_views_are_sharing(ax): l2 = np.linalg.norm( np.asarray(latview) - np.asarray(target_lat), ) - level = ax.figure._get_sharing_level() + level = ax.figure._sharex if level <= 1: share_x = share_y = False assert np.allclose(l1, 0) == share_x assert np.allclose(l2, 0) == share_y - for level in sharing_levels: - fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) - ax.format(labels="both") - for axi in ax: - axi.format( - lonlim=lonlim * axi.number, - latlim=latlim * axi.number, - ) + fig, ax = uplt.subplots(ncols=2, nrows=2, proj="cyl", share=level) + ax.format(labels="both") + for axi in ax: + axi.format( + lonlim=lonlim * axi.number, + latlim=latlim * axi.number, + ) - fig.canvas.draw() - for idx, axi in enumerate(ax): - axi.plot(x * (idx + 1), y * (idx + 1)) - - fig.canvas.draw() # need this to update the labels - # All the labels should be on - for axi in ax: - side_labels = axi._get_gridliner_labels( - left=True, - right=True, - top=True, - bottom=True, - ) - s = 0 - for dir, labels in side_labels.items(): - s += any([label.get_visible() for label in labels]) - - assert_views_are_sharing(axi) - # When we share the labels but not the limits, - # we expect all ticks to be on - if level == 0: - assert s == 4 - else: - assert s == 2 - uplt.close(fig) + fig.canvas.draw() + for idx, axi in enumerate(ax): + axi.plot(x * (idx + 1), y * (idx + 1)) + + # All the labels should be on + for axi in ax: + + s = sum( + [ + 1 if axi._is_ticklabel_on(side) else 0 + for side in "labeltop labelbottom labelleft labelright".split() + ] + ) + + assert_views_are_sharing(axi) + # When we share the labels but not the limits, + # we expect all ticks to be on + if level > 2: + assert s == 2 + else: + assert s == 4 + uplt.close(fig) @pytest.mark.mpl_image_compare @@ -616,8 +613,10 @@ def test_cartesian_and_geo(rng): ax.format(land=True, lonlim=(-10, 10), latlim=(-10, 10)) ax[0].pcolormesh(rng.random((10, 10))) ax[1].scatter(*rng.random((2, 100))) - ax[0]._apply_axis_sharing() - assert mocked.call_count == 2 + fig.canvas.draw() + assert ( + mocked.call_count > 2 + ) # needs to be called at least twice; one for each axis return fig @@ -678,19 +677,9 @@ def test_panels_geo(): ax.format(labels=True) for dir in "top bottom right left".split(): pax = ax.panel_axes(dir) - match dir: - case "top": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "bottom": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "left": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 - case "right": - assert len(pax.get_xticklabels()) > 0 - assert len(pax.get_yticklabels()) > 0 + fig.canvas.draw() # need this to update the ticks + assert len(pax.get_xticklabels()) > 0 + assert len(pax.get_yticklabels()) > 0 @pytest.mark.mpl_image_compare @@ -807,6 +796,7 @@ def are_labels_on(ax, which=("top", "bottom", "right", "left")) -> tuple[bool]: h = ax.imshow(data)[0] ax.format(land=True, labels="both") # need this otherwise no labels are printed fig.colorbar(h, loc="r") + fig.canvas.draw() # needed to invoke axis sharing expectations = ( [True, False, False, True], diff --git a/ultraplot/tests/test_sharing.py b/ultraplot/tests/test_sharing.py new file mode 100644 index 00000000..620e879f --- /dev/null +++ b/ultraplot/tests/test_sharing.py @@ -0,0 +1,98 @@ +import pytest, ultraplot as uplt + +""" +Sharing levels for subplots determine the visibility of the axis labels and tick labels. + +Axis labels are pushed to the border subplots when the sharing level is greater than 1. + +Ticks are visible only on the border plots when the sharing level is greater than 2. + +Or more verbosely: + sharey = 0: no sharing, all labels and ticks visible + sharey = 1: share axis labels, tick labels are still independent + sharey = 2: share data limits + sharey = 3 or True, share both ticks and labels +A similar story holds for sharex. +""" + + +@pytest.mark.parametrize("share_level", [0, "labels", "labs", 1, True]) +@pytest.mark.mpl_image_compare +def test_sharing_levels_y(share_level): + """ + Test sharing levels for y-axis: left and right ticks/labels. + """ + fig, axs = uplt.subplots(None, 2, 3, sharey=share_level) + axs.format(ylabel="Y") + axs.format(title=f"sharey = {share_level}") + fig.canvas.draw() # needed for checks + + if fig._sharey < 3: + border_axes = set(axs) + else: + # Reduce border_axes to a set of axes for left and right + border_axes = set() + for direction in ["left", "right"]: + axes = fig._get_border_axes().get(direction, []) + if isinstance(axes, (list, tuple, set)): + border_axes.update(axes) + else: + border_axes.add(axes) + for axi in axs: + tick_params = axi.yaxis.get_tick_params() + for direction in ["left", "right"]: + label_key = f"label{direction}" + visible = tick_params.get(label_key, False) + is_border = axi in fig._get_border_axes().get(direction, []) + if direction == "left" and (fig._sharey < 3 or is_border): + assert visible + else: + assert not visible + return fig + + +@pytest.mark.parametrize("share_level", [0, "labels", "labs", 1, True]) +@pytest.mark.mpl_image_compare +def test_sharing_levels_x(share_level): + """ + Test sharing levels for x-axis: top and bottom ticks/labels. + """ + fig, axs = uplt.subplots(None, 2, 3, sharex=share_level) + axs.format(xlabel="X") + axs.format(title=f"sharex = {share_level}") + fig.canvas.draw() # needed for checks + + # Get the border axes + if fig._sharex < 3: + border_axes = set(axs) + else: + # Reduce border_axes to a set of axes for top and bottom + border_axes = set() + for direction in ["top", "bottom"]: + axes = fig._get_border_axes().get(direction, []) + if isinstance(axes, (list, tuple, set)): + border_axes.update(axes) + else: + border_axes.add(axes) + + # Run tests + for axi in axs: + tick_params = axi.xaxis.get_tick_params() + # Get correct directions depending on mpl version + from ultraplot.internals.versions import _version_mpl + from packaging import version + + if version.parse(str(_version_mpl)) >= version.parse("3.10"): + direction_label_map = {"top": "labeltop", "bottom": "labelbottom"} + else: + direction_label_map = {"top": "labelright", "bottom": "labelleft"} + + for direction in ["top", "bottom"]: + label_key = direction_label_map[direction] + visible = tick_params.get(label_key, False) + is_border = axi in fig._get_border_axes().get(direction, []) + if direction == "bottom" and (fig._sharex < 3 or is_border): + assert visible + else: + assert not visible + return fig diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index e215a90e..067c0ee1 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -290,29 +290,42 @@ def test_panel_sharing_top_right(layout): for dir in "left right top bottom".split(): pax = ax[0].panel(dir) fig.canvas.draw() # force redraw tick labels - for dir, paxs in ax[0]._panel_dict.items(): - # Since we are sharing some of the ticks - # should be hidden depending on where the panel is - # in the grid - for pax in paxs: - match dir: - case "left": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") - case "top": - assert pax._is_ticklabel_on("labeltop") == False - assert pax._is_ticklabel_on("labelbottom") == False - assert pax._is_ticklabel_on("labelleft") - case "right": - print(pax._is_ticklabel_on("labelright")) - assert pax._is_ticklabel_on("labelright") == False - assert pax._is_ticklabel_on("labelbottom") - case "bottom": - assert pax._is_ticklabel_on("labelleft") - assert pax._is_ticklabel_on("labelbottom") == False - - # The sharing axis is not showing any ticks - assert ax[0]._is_ticklabel_on(dir) == False + border_axes = fig._get_border_axes() + + for axi in fig._iter_axes(panels=True): + assert ( + axi._is_ticklabel_on("labelleft") + if axi in border_axes["left"] + else not axi._is_ticklabel_on("labelleft") + ) + assert ( + axi._is_ticklabel_on("labeltop") + if axi in border_axes["top"] + else not axi._is_ticklabel_on("labeltop") + ) + assert ( + axi._is_ticklabel_on("labelright") + if axi in border_axes["right"] + else not axi._is_ticklabel_on("labelright") + ) + assert ( + axi._is_ticklabel_on("labelbottom") + if axi in border_axes["bottom"] + else not axi._is_ticklabel_on("labelbottom") + ) + return fig + + +@pytest.mark.mpl_image_compare +def test_uneven_span_subplots(rng): + fig = uplt.figure(refwidth=1, refnum=5, span=False) + axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1]) + axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid") + axs[0].format(ec="black", fc="gray1", lw=1.4) + axs[1, 1:].format(fc="blush") + axs[1, :1].format(fc="sky blue") + axs[-1, -1].format(fc="gray4", grid=False) + axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) return fig diff --git a/ultraplot/utils.py b/ultraplot/utils.py index 1b1b97a9..3ecd9597 100644 --- a/ultraplot/utils.py +++ b/ultraplot/utils.py @@ -918,7 +918,8 @@ def _get_subplot_layout( axis types. This function is used internally to determine the layout of axes in a GridSpec. """ - grid = np.zeros((gs.nrows, gs.ncols)) + grid = np.zeros((gs.nrows_total, gs.ncols_total), dtype=object) + grid.fill(None) grid_axis_type = np.zeros((gs.nrows, gs.ncols)) # Collect grouper based on kinds of axes. This # would allow us to share labels across types @@ -936,7 +937,7 @@ def _get_subplot_layout( grid[ slice(*rowspan), slice(*colspan), - ] = axi.number + ] = axi # Allow grouping of mixed types axis_type = 1 @@ -996,22 +997,28 @@ def find_edge_for( direction: str, d: tuple[int, int], ) -> tuple[str, bool]: - from itertools import product - """ Setup search for a specific direction. """ + from itertools import product + # Retrieve where the axis is in the grid spec = self.ax.get_subplotspec() - spans = spec._get_grid_span() + shape = (spec.get_gridspec().nrows_total, spec.get_gridspec().ncols_total) + x, y = np.unravel_index(spec.num1, shape) + spans = spec._get_rows_columns() rowspan = spans[:2] colspan = spans[-2:] - xs = range(*rowspan) - ys = range(*colspan) + + a = rowspan[1] - rowspan[0] + b = colspan[1] - colspan[0] + xs = range(x, x + a + 1) + ys = range(y, y + b + 1) + is_border = False - for x, y in product(xs, ys): - pos = (x, y) + for xl, yl in product(xs, ys): + pos = (xl, yl) if self.is_border(pos, d): is_border = True break @@ -1026,27 +1033,27 @@ def is_border( Recursively move over the grid by following the direction. """ x, y = pos - # Check if we are at an edge of the grid (out-of-bounds). - if x < 0: - return True - elif x > self.grid.shape[0] - 1: + # Edge of grid (out-of-bounds) + if not (0 <= x < self.grid.shape[0] and 0 <= y < self.grid.shape[1]): return True - if y < 0: - return True - elif y > self.grid.shape[1] - 1: - return True + cell = self.grid[x, y] + dx, dy = direction + if cell is None: + return self.is_border((x + dx, y + dy), direction) - if self.grid[x, y] == 0 or self.grid_axis_type[x, y] != self.axis_type: - return True + if hasattr(cell, "_panel_hidden") and cell._panel_hidden: + return self.is_border((x + dx, y + dy), direction) - # Check if we reached a plot or an internal edge - if self.grid[x, y] != self.target and self.grid[x, y] > 0: - return self._check_ranges(direction, other=self.grid[x, y]) + if self.grid_axis_type[x, y] != self.axis_type: + if cell in self.ax._panel_dict.get(cell._panel_side, []): + return self.is_border((x + dx, y + dy), direction) - dx, dy = direction - pos = (x + dx, y + dy) - return self.is_border(pos, direction) + # Internal edge or plot reached + if cell != self.ax: + return self._check_ranges(direction, other=cell) + + return self.is_border((x + dx, y + dy), direction) def _check_ranges( self, @@ -1065,14 +1072,15 @@ def _check_ranges( can share x. """ this_spec = self.ax.get_subplotspec() - other_spec = self.ax.figure._subplot_dict[other].get_subplotspec() + other_spec = other.get_subplotspec() # Get the row and column spans of both axes - this_span = this_spec._get_grid_span() + this_span = this_spec._get_rows_columns() this_rowspan = this_span[:2] this_colspan = this_span[-2:] other_span = other_spec._get_grid_span() + other_span = other_spec._get_rows_columns() other_rowspan = other_span[:2] other_colspan = other_span[-2:] @@ -1089,7 +1097,43 @@ def _check_ranges( other_start, other_stop = other_rowspan if this_start == other_start and this_stop == other_stop: - return False # not a border + # We may hit an internal border if we are at + # the interface with a panel that is not sharing + dmap = { + (-1, 0): "bottom", + (1, 0): "top", + (0, -1): "left", + (0, 1): "right", + } + side = dmap[direction] + if self.ax.number is None: # panel + parent = self.ax._panel_parent + + panels = parent._panel_dict.get(side, []) + # If we are a panel at the end we are a border + # only if we are not sharing axes + if side in ("left", "right"): + if self.ax._sharey is None: + return True + elif not self.ax._panel_sharey_group: + return True + elif side in ("top", "bottom"): + if self.ax._sharex is None: + return True + elif not self.ax._panel_sharex_group: + return True + + # Only consider when we are interfacing with a panel + # axes on the outside will also not share when they are in top + # or left + elif side in ("left", "right") and self.ax._sharey is None: + if other.number is None: + return True + elif side in ("bottom", "top") and self.ax._sharex is None: + if other.number is None: + return True + + return False return True