Skip to content

Commit 2c91aa6

Browse files
committed
Merge pull request matplotlib#1170 from leejjoon/fix-tight-layout-condition
Uses tight_layout.get_subplotspec_list to check if all axes are compatible w/ tight_layout
2 parents f8368d0 + 1d80e7f commit 2c91aa6

File tree

2 files changed

+48
-23
lines changed

2 files changed

+48
-23
lines changed

lib/matplotlib/figure.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -1487,17 +1487,20 @@ def tight_layout(self, renderer=None, pad=1.08, h_pad=None, w_pad=None, rect=Non
14871487
labels) will fit into. Default is (0, 0, 1, 1).
14881488
"""
14891489

1490-
from tight_layout import get_renderer, get_tight_layout_figure
1490+
from tight_layout import (get_renderer, get_tight_layout_figure,
1491+
get_subplotspec_list)
14911492

1492-
subplot_axes = [ax for ax in self.axes if isinstance(ax, SubplotBase)]
1493-
if len(subplot_axes) < len(self.axes):
1494-
warnings.warn("tight_layout can only process Axes that descend "
1495-
"from SubplotBase; results might be incorrect.")
1493+
subplotspec_list = get_subplotspec_list(self.axes)
1494+
if None in subplotspec_list:
1495+
warnings.warn("This figure includes Axes that are not "
1496+
"compatible with tight_layout, so its "
1497+
"results might be incorrect.")
14961498

14971499
if renderer is None:
14981500
renderer = get_renderer(self)
14991501

1500-
kwargs = get_tight_layout_figure(self, subplot_axes, renderer,
1502+
kwargs = get_tight_layout_figure(self, self.axes, subplotspec_list,
1503+
renderer,
15011504
pad=pad, h_pad=h_pad, w_pad=w_pad,
15021505
rect=rect)
15031506

lib/matplotlib/tight_layout.py

+39-17
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,33 @@ def get_renderer(fig):
209209
return renderer
210210

211211

212-
def get_tight_layout_figure(fig, axes_list, renderer,
212+
def get_subplotspec_list(axes_list):
213+
"""
214+
Return a list of subplotspec from the given list of axes. For an
215+
instance of axes that does not support subplotspec, None is
216+
inserted in the list.
217+
218+
"""
219+
subplotspec_list = []
220+
for ax in axes_list:
221+
axes_or_locator = ax.get_axes_locator()
222+
if axes_or_locator is None:
223+
axes_or_locator = ax
224+
225+
if hasattr(axes_or_locator, "get_subplotspec"):
226+
subplotspec = axes_or_locator.get_subplotspec()
227+
subplotspec = subplotspec.get_topmost_subplotspec()
228+
if subplotspec.get_gridspec().locally_modified_subplot_params():
229+
subplotspec = None
230+
else:
231+
subplotspec = None
232+
233+
subplotspec_list.append(subplotspec)
234+
235+
return subplotspec_list
236+
237+
238+
def get_tight_layout_figure(fig, axes_list, subplotspec_list, renderer,
213239
pad=1.08, h_pad=None, w_pad=None, rect=None):
214240
"""
215241
Return subplot parameters for tight-layouted-figure with specified
@@ -221,6 +247,9 @@ def get_tight_layout_figure(fig, axes_list, renderer,
221247
222248
*axes_list* : a list of axes
223249
250+
*subplotspec_list* : a list of subplotspec associated with each
251+
axes in axes_list
252+
224253
*renderer* : renderer instance
225254
226255
*pad* : float
@@ -238,27 +267,20 @@ def get_tight_layout_figure(fig, axes_list, renderer,
238267
"""
239268

240269

241-
subplotspec_list = []
242270
subplot_list = []
243271
nrows_list = []
244272
ncols_list = []
245273
ax_bbox_list = []
246274

247-
subplot_dict = {} # for axes_grid1, multiple axes can share
248-
# same subplot_interface. Thus we need to
249-
# join them together.
275+
subplot_dict = {} # multiple axes can share
276+
# same subplot_interface (e.g, axes_grid1). Thus
277+
# we need to join them together.
250278

251-
for ax in axes_list:
252-
locator = ax.get_axes_locator()
253-
if hasattr(locator, "get_subplotspec"):
254-
subplotspec = locator.get_subplotspec().get_topmost_subplotspec()
255-
elif hasattr(ax, "get_subplotspec"):
256-
subplotspec = ax.get_subplotspec().get_topmost_subplotspec()
257-
else:
258-
continue
279+
subplotspec_list2 = []
259280

260-
if (subplotspec is None) or \
261-
subplotspec.get_gridspec().locally_modified_subplot_params():
281+
for ax, subplotspec in zip(axes_list,
282+
subplotspec_list):
283+
if subplotspec is None:
262284
continue
263285

264286
subplots = subplot_dict.setdefault(subplotspec, [])
@@ -267,7 +289,7 @@ def get_tight_layout_figure(fig, axes_list, renderer,
267289
myrows, mycols, _, _ = subplotspec.get_geometry()
268290
nrows_list.append(myrows)
269291
ncols_list.append(mycols)
270-
subplotspec_list.append(subplotspec)
292+
subplotspec_list2.append(subplotspec)
271293
subplot_list.append(subplots)
272294
ax_bbox_list.append(subplotspec.get_position(fig))
273295

@@ -277,7 +299,7 @@ def get_tight_layout_figure(fig, axes_list, renderer,
277299
max_ncols = max(ncols_list)
278300

279301
num1num2_list = []
280-
for subplotspec in subplotspec_list:
302+
for subplotspec in subplotspec_list2:
281303
rows, cols, num1, num2 = subplotspec.get_geometry()
282304
div_row, mod_row = divmod(max_nrows, rows)
283305
div_col, mod_col = divmod(max_ncols, cols)

0 commit comments

Comments
 (0)