Skip to content
27 changes: 18 additions & 9 deletions ultraplot/axes/geo.py
Original file line number Diff line number Diff line change
Expand Up @@ -805,15 +805,18 @@ def _to_label_array(arg, lon=True):
array[4] = True # possibly toggle geo spine labels
elif not any(isinstance(_, str) for _ in array):
if len(array) == 1:
array.append(False) # default is to label bottom or left
array.append(None)
if len(array) == 2:
array = [False, False, *array] if lon else [*array, False, False]
array = [None, None, *array] if lon else [*array, None, None]
if len(array) == 4:
b = any(array) if rc["grid.geolabels"] else False
array.append(b) # possibly toggle geo spine labels
b = (
any(a for a in array if a is not None)
if rc["grid.geolabels"]
else None
)
array.append(b)
if len(array) != 5:
raise ValueError(f"Invald boolean label array length {len(array)}.")
array = list(map(bool, array))
else:
raise ValueError(f"Invalid {which}label spec: {arg}.")
return array
Expand Down Expand Up @@ -934,9 +937,13 @@ def format(
# NOTE: Cartopy 0.18 and 0.19 inline labels require any of
# top, bottom, left, or right to be toggled then ignores them.
# Later versions of cartopy permit both or neither labels.
labels = _not_none(labels, rc.find("grid.labels", context=True))
lonlabels = _not_none(lonlabels, labels)
latlabels = _not_none(latlabels, labels)
if lonlabels is None and latlabels is None:
labels = _not_none(labels, rc.find("grid.labels", context=True))
lonlabels = labels
latlabels = labels
else:
lonlabels = _not_none(lonlabels, labels)
latlabels = _not_none(latlabels, labels)
# Set the ticks
self._toggle_ticks(lonlabels, "x")
self._toggle_ticks(latlabels, "y")
Expand Down Expand Up @@ -1464,8 +1471,9 @@ def _toggle_gridliner_labels(
side_labels = _CartopyAxes._get_side_labels()
togglers = (labelleft, labelright, labelbottom, labeltop)
gl = self.gridlines_major

for toggle, side in zip(togglers, side_labels):
if getattr(gl, side) != toggle:
if toggle is not None:
setattr(gl, side, toggle)
if geo is not None: # only cartopy 0.20 supported but harmless
setattr(gl, "geo_labels", geo)
Expand Down Expand Up @@ -1760,6 +1768,7 @@ def _update_major_gridlines(
for side, lon, lat in zip(
"labelleft labelright labelbottom labeltop geo".split(), lonarray, latarray
):
sides[side] = None
if lon and lat:
sides[side] = True
elif lon:
Expand Down
78 changes: 42 additions & 36 deletions ultraplot/gridspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1536,42 +1536,48 @@ def __getitem__(self, key):
>>> axs[1, 2] # the subplot in the second row, third column
>>> axs[:, 0] # a SubplotGrid containing the subplots in the first column
"""
if isinstance(key, tuple) and len(key) == 1:
key = key[0]
# List-style indexing
if isinstance(key, (Integral, slice)):
slices = isinstance(key, slice)
objs = list.__getitem__(self, key)
# Gridspec-style indexing
elif (
isinstance(key, tuple)
and len(key) == 2
and all(isinstance(ikey, (Integral, slice)) for ikey in key)
):
# WARNING: Permit no-op slicing of empty grids here
slices = any(isinstance(ikey, slice) for ikey in key)
objs = []
if self:
gs = self.gridspec
ss_key = gs._make_subplot_spec(key) # obfuscates panels
row1_key, col1_key = divmod(ss_key.num1, gs.ncols)
row2_key, col2_key = divmod(ss_key.num2, gs.ncols)
for ax in self:
ss = ax._get_topmost_axes().get_subplotspec().get_topmost_subplotspec()
row1, col1 = divmod(ss.num1, gs.ncols)
row2, col2 = divmod(ss.num2, gs.ncols)
inrow = row1_key <= row1 <= row2_key or row1_key <= row2 <= row2_key
incol = col1_key <= col1 <= col2_key or col1_key <= col2 <= col2_key
if inrow and incol:
objs.append(ax)
if not slices and len(objs) == 1: # accounts for overlapping subplots
objs = objs[0]
else:
raise IndexError(f"Invalid index {key!r}.")
if isinstance(objs, list):
return SubplotGrid(objs)
else:
return objs
# Allow 1D list-like indexing
if isinstance(key, int):
return list.__getitem__(self, key)
elif isinstance(key, slice):
return SubplotGrid(list.__getitem__(self, key))

# Allow 2D array-like indexing
# NOTE: We assume this is a 2D array of subplots, because this is
# how it is generated in the first place by ultraplot.figure().
# But it is possible to append subplots manually.
gs = self.gridspec
if gs is None:
raise IndexError(
f"{self.__class__.__name__} has no gridspec, cannot index with {key!r}."
)
# Build grid with None for empty slots
grid = np.full((gs.nrows_total, gs.ncols_total), None, dtype=object)
for ax in self:
spec = ax.get_subplotspec()
x1, x2, y1, y2 = spec._get_rows_columns(ncols=gs.ncols_total)
grid[x1 : x2 + 1, y1 : y2 + 1] = ax

new_key = []
for which, keyi in zip("hw", key):
try:
encoded_keyi = gs._encode_indices(keyi, which=which)
except:
raise IndexError(
f"Attempted to access {key=} for gridspec {grid.shape=}"
)
new_key.append(encoded_keyi)
xs, ys = new_key
objs = grid[xs, ys]
if hasattr(objs, "flat"):
objs = [obj for obj in objs.flat if obj is not None]
elif not isinstance(objs, list):
objs = [objs]

if len(objs) == 1:
return objs[0]
objs = [obj for obj in objs if obj is not None]
return SubplotGrid(objs)

def __setitem__(self, key, value):
"""
Expand Down
29 changes: 26 additions & 3 deletions ultraplot/tests/test_geographic.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,8 +314,8 @@ def test_toggle_gridliner_labels():
gl = ax[0].gridlines_major

assert gl.left_labels == False
assert gl.right_labels == None # initially these are none
assert gl.top_labels == None
assert gl.right_labels == False
assert gl.top_labels == False
assert gl.bottom_labels == False
ax[0]._toggle_gridliner_labels(labeltop=True)
assert gl.top_labels == True
Expand Down Expand Up @@ -617,7 +617,7 @@ def test_cartesian_and_geo(rng):
ax[0].pcolormesh(rng.random((10, 10)))
ax[1].scatter(*rng.random((2, 100)))
ax[0]._apply_axis_sharing()
assert mocked.call_count == 1
assert mocked.call_count == 2
return fig


Expand Down Expand Up @@ -895,3 +895,26 @@ def test_imshow_with_and_without_transform(rng):
ax[2].imshow(data, transform=uplt.axes.geo.ccrs.PlateCarree())
ax.format(title=["LCC", "No transform", "PlateCarree"])
return fig


@pytest.mark.mpl_image_compare
def test_grid_indexing_formatting(rng):
"""
Check if subplotgrid is correctly selecting
the subplots based on non-shared axis formatting
"""
# See https://github.com/Ultraplot/UltraPlot/issues/356
lon = np.arange(0, 360, 10)
lat = np.arange(-60, 60 + 1, 10)
data = rng.random((len(lat), len(lon)))

fig, axs = uplt.subplots(nrows=3, ncols=2, proj="cyl", share=0)
axs.format(coast=True)

for ax in axs:
m = ax.pcolor(lon, lat, data)
ax.colorbar(m)

axs[-1, :].format(lonlabels=True)
axs[:, 0].format(latlabels=True)
return fig
13 changes: 13 additions & 0 deletions ultraplot/tests/test_subplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,3 +314,16 @@ def test_panel_sharing_top_right(layout):
# The sharing axis is not showing any ticks
assert ax[0]._is_ticklabel_on(dir) == False
return fig


@pytest.mark.mpl_image_compare
def test_uneven_span_subplots(rng):
fig = uplt.figure(refwidth=1, refnum=5, span=False)
axs = fig.subplots([[1, 1, 2], [3, 4, 2], [3, 4, 5]], hratios=[2.2, 1, 1])
axs.format(xlabel="xlabel", ylabel="ylabel", suptitle="Complex SubplotGrid")
axs[0].format(ec="black", fc="gray1", lw=1.4)
axs[1, 1:].format(fc="blush")
axs[1, :1].format(fc="sky blue")
axs[-1, -1].format(fc="gray4", grid=False)
axs[0].plot((rng.random((50, 10)) - 0.5).cumsum(axis=0), cycle="Grays_r", lw=2)
return fig