diff --git a/ultraplot/axes/geo.py b/ultraplot/axes/geo.py index ab95a661..896bc0a6 100644 --- a/ultraplot/axes/geo.py +++ b/ultraplot/axes/geo.py @@ -805,15 +805,18 @@ def _to_label_array(arg, lon=True): array[4] = True # possibly toggle geo spine labels elif not any(isinstance(_, str) for _ in array): if len(array) == 1: - array.append(False) # default is to label bottom or left + array.append(None) if len(array) == 2: - array = [False, False, *array] if lon else [*array, False, False] + array = [None, None, *array] if lon else [*array, None, None] if len(array) == 4: - b = any(array) if rc["grid.geolabels"] else False - array.append(b) # possibly toggle geo spine labels + b = ( + any(a for a in array if a is not None) + if rc["grid.geolabels"] + else None + ) + array.append(b) if len(array) != 5: raise ValueError(f"Invald boolean label array length {len(array)}.") - array = list(map(bool, array)) else: raise ValueError(f"Invalid {which}label spec: {arg}.") return array @@ -934,9 +937,13 @@ def format( # NOTE: Cartopy 0.18 and 0.19 inline labels require any of # top, bottom, left, or right to be toggled then ignores them. # Later versions of cartopy permit both or neither labels. - labels = _not_none(labels, rc.find("grid.labels", context=True)) - lonlabels = _not_none(lonlabels, labels) - latlabels = _not_none(latlabels, labels) + if lonlabels is None and latlabels is None: + labels = _not_none(labels, rc.find("grid.labels", context=True)) + lonlabels = labels + latlabels = labels + else: + lonlabels = _not_none(lonlabels, labels) + latlabels = _not_none(latlabels, labels) # Set the ticks self._toggle_ticks(lonlabels, "x") self._toggle_ticks(latlabels, "y") @@ -1464,8 +1471,9 @@ def _toggle_gridliner_labels( side_labels = _CartopyAxes._get_side_labels() togglers = (labelleft, labelright, labelbottom, labeltop) gl = self.gridlines_major + for toggle, side in zip(togglers, side_labels): - if getattr(gl, side) != toggle: + if toggle is not None: setattr(gl, side, toggle) if geo is not None: # only cartopy 0.20 supported but harmless setattr(gl, "geo_labels", geo) @@ -1760,6 +1768,7 @@ def _update_major_gridlines( for side, lon, lat in zip( "labelleft labelright labelbottom labeltop geo".split(), lonarray, latarray ): + sides[side] = None if lon and lat: sides[side] = True elif lon: diff --git a/ultraplot/gridspec.py b/ultraplot/gridspec.py index 0d642505..159cac2c 100644 --- a/ultraplot/gridspec.py +++ b/ultraplot/gridspec.py @@ -1536,42 +1536,48 @@ def __getitem__(self, key): >>> axs[1, 2] # the subplot in the second row, third column >>> axs[:, 0] # a SubplotGrid containing the subplots in the first column """ - if isinstance(key, tuple) and len(key) == 1: - key = key[0] - # List-style indexing - if isinstance(key, (Integral, slice)): - slices = isinstance(key, slice) - objs = list.__getitem__(self, key) - # Gridspec-style indexing - elif ( - isinstance(key, tuple) - and len(key) == 2 - and all(isinstance(ikey, (Integral, slice)) for ikey in key) - ): - # WARNING: Permit no-op slicing of empty grids here - slices = any(isinstance(ikey, slice) for ikey in key) - objs = [] - if self: - gs = self.gridspec - ss_key = gs._make_subplot_spec(key) # obfuscates panels - row1_key, col1_key = divmod(ss_key.num1, gs.ncols) - row2_key, col2_key = divmod(ss_key.num2, gs.ncols) - for ax in self: - ss = ax._get_topmost_axes().get_subplotspec().get_topmost_subplotspec() - row1, col1 = divmod(ss.num1, gs.ncols) - row2, col2 = divmod(ss.num2, gs.ncols) - inrow = row1_key <= row1 <= row2_key or row1_key <= row2 <= row2_key - incol = col1_key <= col1 <= col2_key or col1_key <= col2 <= col2_key - if inrow and incol: - objs.append(ax) - if not slices and len(objs) == 1: # accounts for overlapping subplots - objs = objs[0] - else: - raise IndexError(f"Invalid index {key!r}.") - if isinstance(objs, list): - return SubplotGrid(objs) - else: - return objs + # Allow 1D list-like indexing + if isinstance(key, int): + return list.__getitem__(self, key) + elif isinstance(key, slice): + return SubplotGrid(list.__getitem__(self, key)) + + # Allow 2D array-like indexing + # NOTE: We assume this is a 2D array of subplots, because this is + # how it is generated in the first place by ultraplot.figure(). + # But it is possible to append subplots manually. + gs = self.gridspec + if gs is None: + raise IndexError( + f"{self.__class__.__name__} has no gridspec, cannot index with {key!r}." + ) + # Build grid with None for empty slots + grid = np.full((gs.nrows_total, gs.ncols_total), None, dtype=object) + for ax in self: + spec = ax.get_subplotspec() + x1, x2, y1, y2 = spec._get_rows_columns(ncols=gs.ncols_total) + grid[x1 : x2 + 1, y1 : y2 + 1] = ax + + new_key = [] + for which, keyi in zip("hw", key): + try: + encoded_keyi = gs._encode_indices(keyi, which=which) + except: + raise IndexError( + f"Attempted to access {key=} for gridspec {grid.shape=}" + ) + new_key.append(encoded_keyi) + xs, ys = new_key + objs = grid[xs, ys] + if hasattr(objs, "flat"): + objs = [obj for obj in objs.flat if obj is not None] + elif not isinstance(objs, list): + objs = [objs] + + if len(objs) == 1: + return objs[0] + objs = [obj for obj in objs if obj is not None] + return SubplotGrid(objs) def __setitem__(self, key, value): """ diff --git a/ultraplot/tests/test_geographic.py b/ultraplot/tests/test_geographic.py index 4b95a938..35789a54 100644 --- a/ultraplot/tests/test_geographic.py +++ b/ultraplot/tests/test_geographic.py @@ -314,8 +314,8 @@ def test_toggle_gridliner_labels(): gl = ax[0].gridlines_major assert gl.left_labels == False - assert gl.right_labels == None # initially these are none - assert gl.top_labels == None + assert gl.right_labels == False + assert gl.top_labels == False assert gl.bottom_labels == False ax[0]._toggle_gridliner_labels(labeltop=True) assert gl.top_labels == True @@ -617,7 +617,7 @@ def test_cartesian_and_geo(rng): ax[0].pcolormesh(rng.random((10, 10))) ax[1].scatter(*rng.random((2, 100))) ax[0]._apply_axis_sharing() - assert mocked.call_count == 1 + assert mocked.call_count == 2 return fig @@ -895,3 +895,26 @@ def test_imshow_with_and_without_transform(rng): ax[2].imshow(data, transform=uplt.axes.geo.ccrs.PlateCarree()) ax.format(title=["LCC", "No transform", "PlateCarree"]) return fig + + +@pytest.mark.mpl_image_compare +def test_grid_indexing_formatting(rng): + """ + Check if subplotgrid is correctly selecting + the subplots based on non-shared axis formatting + """ + # See https://github.com/Ultraplot/UltraPlot/issues/356 + lon = np.arange(0, 360, 10) + lat = np.arange(-60, 60 + 1, 10) + data = rng.random((len(lat), len(lon))) + + fig, axs = uplt.subplots(nrows=3, ncols=2, proj="cyl", share=0) + axs.format(coast=True) + + for ax in axs: + m = ax.pcolor(lon, lat, data) + ax.colorbar(m) + + axs[-1, :].format(lonlabels=True) + axs[:, 0].format(latlabels=True) + return fig diff --git a/ultraplot/tests/test_subplots.py b/ultraplot/tests/test_subplots.py index dead27f3..e215a90e 100644 --- a/ultraplot/tests/test_subplots.py +++ b/ultraplot/tests/test_subplots.py @@ -314,3 +314,16 @@ def test_panel_sharing_top_right(layout): # The sharing axis is not showing any ticks assert ax[0]._is_ticklabel_on(dir) == False return fig + + +@pytest.mark.mpl_image_compare +def test_uneven_span_subplots(rng): + fig = uplt.figure(refwidth=1, refnum=5, span=False) + axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1]) + axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid") + axs[0].format(ec="black", fc="gray1", lw=1.4) + axs[1, 1:].format(fc="blush") + axs[1, :1].format(fc="sky blue") + axs[-1, -1].format(fc="gray4", grid=False) + axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2) + return fig