Skip to content

Commit bc955f9

Browse files
committed
axes_grid1: ImageGrid respect the aspect ratio of axes. This addresses matplotlib#955
1 parent 083c788 commit bc955f9

File tree

3 files changed

+79
-21
lines changed

3 files changed

+79
-21
lines changed
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
import matplotlib.pyplot as plt
2+
3+
from mpl_toolkits.axes_grid1 import ImageGrid
4+
fig = plt.figure(1)
5+
6+
grid1 = ImageGrid(fig, 121, (2,2), axes_pad=0.1,
7+
aspect=True, share_all=True)
8+
9+
for i in [0, 1]:
10+
grid1[i].set_aspect(2)
11+
12+
13+
grid2 = ImageGrid(fig, 122, (2,2), axes_pad=0.1,
14+
aspect=True, share_all=True)
15+
16+
17+
for i in [1, 3]:
18+
grid2[i].set_aspect(2)
19+
20+
plt.show()

lib/mpl_toolkits/axes_grid1/axes_grid.py

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,9 @@ def __init__(self, fig,
544544
else:
545545
axes_class, axes_class_args = axes_class
546546

547+
adjustable = axes_class_args.setdefault("adjustable", "box-forced")
548+
if adjustable != "box-forced":
549+
raise RuntimeError("adjustable parameter must not be set, or set to box-forced")
547550

548551

549552
self.axes_all = []
@@ -582,28 +585,31 @@ def __init__(self, fig,
582585
col, row = self._get_col_row(i)
583586

584587
if share_all:
585-
sharex = self._refax
586-
sharey = self._refax
588+
if self.axes_all:
589+
sharex = self.axes_all[0]
590+
sharey = self.axes_all[0]
591+
else:
592+
sharex = None
593+
sharey = None
587594
else:
588595
sharex = self._column_refax[col]
589596
sharey = self._row_refax[row]
590597

591598
ax = axes_class(fig, rect, sharex=sharex, sharey=sharey,
592599
**axes_class_args)
593600

594-
if share_all:
595-
if self._refax is None:
596-
self._refax = ax
597-
else:
598-
if sharex is None:
599-
self._column_refax[col] = ax
600-
if sharey is None:
601-
self._row_refax[row] = ax
602-
603601
self.axes_all.append(ax)
604602
self.axes_column[col].append(ax)
605603
self.axes_row[row].append(ax)
606604

605+
if share_all:
606+
if self._refax is None:
607+
self._refax = ax
608+
if sharex is None:
609+
self._column_refax[col] = ax
610+
if sharey is None:
611+
self._row_refax[row] = ax
612+
607613
cax = self._defaultCbarAxesClass(fig, rect,
608614
orientation=self._colorbar_location)
609615
self.cbar_axes.append(cax)
@@ -653,13 +659,14 @@ def _update_locators(self):
653659
self.cbar_axes[0].set_axes_locator(locator)
654660
self.cbar_axes[0].set_visible(True)
655661

656-
for col,ax in enumerate(self._column_refax):
662+
for col,ax in enumerate(self.axes_row[0]):
657663
if h: h.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))
658664

659665
if ax:
660-
sz = Size.AxesX(ax)
666+
sz = Size.AxesX(ax, aspect="axes", ref_ax=self.axes_all[0])
661667
else:
662-
sz = Size.AxesX(self.axes_llc)
668+
sz = Size.AxesX(self.axes_all[0],
669+
aspect="axes", ref_ax=self.axes_all[0])
663670

664671
if (self._colorbar_mode == "each" or
665672
(self._colorbar_mode == 'edge' and
@@ -682,13 +689,14 @@ def _update_locators(self):
682689

683690
v_ax_pos = []
684691
v_cb_pos = []
685-
for row,ax in enumerate(self._row_refax[::-1]):
692+
for row,ax in enumerate(self.axes_column[0][::-1]):
686693
if v: v.append(self._horiz_pad_size) #Size.Fixed(self._axes_pad))
687694

688695
if ax:
689-
sz = Size.AxesY(ax)
696+
sz = Size.AxesY(ax, aspect="axes", ref_ax=self.axes_all[0])
690697
else:
691-
sz = Size.AxesY(self.axes_llc)
698+
sz = Size.AxesY(self.axes_all[0],
699+
aspect="axes", ref_ax=self.axes_all[0])
692700

693701
if (self._colorbar_mode == "each" or
694702
(self._colorbar_mode == 'edge' and

lib/mpl_toolkits/axes_grid1/axes_size.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,39 @@ def get_size(self, renderer):
7373

7474
Scalable=Scaled
7575

76+
def _get_axes_aspect(ax):
77+
aspect = ax.get_aspect()
78+
# when aspec is "auto", consider it as 1.
79+
if aspect in ('normal', 'auto'):
80+
aspect = 1.
81+
elif aspect == "equal":
82+
aspect = 1
83+
else:
84+
aspect = float(aspect)
85+
86+
return aspect
7687

7788
class AxesX(_Base):
7889
"""
7990
Scaled size whose relative part corresponds to the data width
8091
of the *axes* multiplied by the *aspect*.
8192
"""
82-
def __init__(self, axes, aspect=1.):
93+
def __init__(self, axes, aspect=1., ref_ax=None):
8394
self._axes = axes
8495
self._aspect = aspect
96+
if aspect == "axes" and ref_ax is None:
97+
raise ValueError("ref_ax must be set when aspect='axes'")
98+
self._ref_ax = ref_ax
8599

86100
def get_size(self, renderer):
87101
l1, l2 = self._axes.get_xlim()
88-
rel_size = abs(l2-l1)*self._aspect
102+
if self._aspect == "axes":
103+
ref_aspect = _get_axes_aspect(self._ref_ax)
104+
aspect = ref_aspect/_get_axes_aspect(self._axes)
105+
else:
106+
aspect = self._aspect
107+
108+
rel_size = abs(l2-l1)*aspect
89109
abs_size = 0.
90110
return rel_size, abs_size
91111

@@ -94,13 +114,23 @@ class AxesY(_Base):
94114
Scaled size whose relative part corresponds to the data height
95115
of the *axes* multiplied by the *aspect*.
96116
"""
97-
def __init__(self, axes, aspect=1.):
117+
def __init__(self, axes, aspect=1., ref_ax=None):
98118
self._axes = axes
99119
self._aspect = aspect
120+
if aspect == "axes" and ref_ax is None:
121+
raise ValueError("ref_ax must be set when aspect='axes'")
122+
self._ref_ax = ref_ax
100123

101124
def get_size(self, renderer):
102125
l1, l2 = self._axes.get_ylim()
103-
rel_size = abs(l2-l1)*self._aspect
126+
127+
if self._aspect == "axes":
128+
ref_aspect = _get_axes_aspect(self._ref_ax)
129+
aspect = _get_axes_aspect(self._axes)
130+
else:
131+
aspect = self._aspect
132+
133+
rel_size = abs(l2-l1)*aspect
104134
abs_size = 0.
105135
return rel_size, abs_size
106136

0 commit comments

Comments
 (0)