Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move impl. of plt.subplots to Figure.add_subplots. #5146

Merged
merged 3 commits into from Nov 22, 2015
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
137 changes: 137 additions & 0 deletions lib/matplotlib/figure.py
Expand Up @@ -39,6 +39,7 @@

from matplotlib.axes import Axes, SubplotBase, subplot_class_factory
from matplotlib.blocking_input import BlockingMouseInput, BlockingKeyMouseInput
from matplotlib.gridspec import GridSpec
from matplotlib.legend import Legend
from matplotlib.patches import Rectangle
from matplotlib.projections import (get_projection_names,
Expand Down Expand Up @@ -1001,6 +1002,142 @@ def add_subplot(self, *args, **kwargs):
self.stale = True
return a

def add_subplots(self, nrows=1, ncols=1, sharex=False, sharey=False,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you leave the name as subplots?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a bit like Figure.suptitle: I don't like it because I don't see why it should be fig.add_subplot but fig.subplots, just like I don't see why it should be ax.set_title but fig.suptitle.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but, the name subplots has been around for a while now and it is better to not rename functionality if we don't have to. From the point of view of the user moving from pyplot -> OO you don't want to make them re-learn the names of things.

squeeze=True, subplot_kw=None, gridspec_kw=None):
"""
Add a set of subplots to this figure.

Parameters
----------
nrows : int, default: 1
Number of rows of the subplot grid.

ncols : int, default: 1
Number of columns of the subplot grid.

sharex : {"none", "all", "row", "col"} or bool, default: False
If *False*, or "none", each subplot has its own X axis.

If *True*, or "all", all subplots will share an X axis, and the x
tick labels on all but the last row of plots will be invisible.

If "col", each subplot column will share an X axis, and the x
tick labels on all but the last row of plots will be invisible.

If "row", each subplot row will share an X axis.

sharey : {"none", "all", "row", "col"} or bool, default: False
If *False*, or "none", each subplot has its own Y axis.

If *True*, or "all", all subplots will share an Y axis, and the x
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here and below: y tick labels, not x.

tick labels on all but the first column of plots will be invisible.

If "row", each subplot row will share an Y axis, and the x tick
labels on all but the first column of plots will be invisible.

If "col", each subplot column will share an Y axis.

squeeze : bool, default: True
If *True*, extra dimensions are squeezed out from the returned axes
array:

- if only one subplot is constructed (nrows=ncols=1), the resulting
single Axes object is returned as a scalar.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs build fails because this needs to be indented to be valid RST i.e.

- if only one subplot is constructed (nrows=ncols=1), the resulting
  single Axes object is returned as a scalar.


- for Nx1 or 1xN subplots, the returned object is a 1-d numpy
object array of Axes objects are returned as numpy 1-d arrays.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As above


- for NxM subplots with N>1 and M>1 are returned as a 2d array.

If *False*, no squeezing at all is done: the returned axes object
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Delete 'axes'; -> 'the returned object'

is always a 2-d array of Axes instances, even if it ends up being
1x1.

subplot_kw : dict, default: {}
Dict with keywords passed to the
:meth:`~matplotlib.figure.Figure.add_subplot` call used to create
each subplots.

gridspec_kw : dict, default: {}
Dict with keywords passed to the
:class:`~matplotlib.gridspec.GridSpec` constructor used to create
the grid the subplots are placed on.

Returns
-------
ax : single Axes object or array of Axes objects
The addes axes. The dimensions of the resulting array can be
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo -> 'adds'

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed

controlled with the squeeze keyword, see above.

See Also
--------
pyplot.subplots : pyplot API; docstring includes examples.
"""

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we throw a self.clf() in here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd rather not, IMO if you use the OO interface you know what you're doing and it's anyways easier for the caller to add a call the clf rather than to remove a call to clf.

# for backwards compatibility
if isinstance(sharex, bool):
sharex = "all" if sharex else "none"
if isinstance(sharey, bool):
sharey = "all" if sharey else "none"
share_values = ["all", "row", "col", "none"]
if sharex not in share_values:
# This check was added because it is very easy to type
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
# In most cases, no error will ever occur, but mysterious behavior
# will result because what was intended to be the subplot index is
# instead treated as a bool for sharex.
if isinstance(sharex, int):
warnings.warn(
"sharex argument to add_subplots() was an integer. "
"Did you intend to use add_subplot() (without 's')?")

raise ValueError("sharex [%s] must be one of %s" %
(sharex, share_values))
if sharey not in share_values:
raise ValueError("sharey [%s] must be one of %s" %
(sharey, share_values))
if subplot_kw is None:
subplot_kw = {}
if gridspec_kw is None:
gridspec_kw = {}

gs = GridSpec(nrows, ncols, **gridspec_kw)

# Create array to hold all axes.
axarr = np.empty((nrows, ncols), dtype=object)
for row in range(nrows):
for col in range(ncols):
shared_with = {"none": None, "all": axarr[0, 0],
"row": axarr[row, 0], "col": axarr[0, col]}
subplot_kw["sharex"] = shared_with[sharex]
subplot_kw["sharey"] = shared_with[sharey]
axarr[row, col] = self.add_subplot(gs[row, col], **subplot_kw)

# turn off redundant tick labeling
if sharex in ["col", "all"] and nrows > 1:
# turn off all but the bottom row
for ax in axarr[:-1, :].flat:
for label in ax.get_xticklabels():
label.set_visible(False)
ax.xaxis.offsetText.set_visible(False)

if sharey in ["row", "all"] and ncols > 1:
# turn off all but the first column
for ax in axarr[:, 1:].flat:
for label in ax.get_yticklabels():
label.set_visible(False)
ax.yaxis.offsetText.set_visible(False)

if squeeze:
# Reshape the array to have the final desired dimension (nrow,ncol),
# though discarding unneeded dimensions that equal 1. If we only have
# one subplot, just return it instead of a 1-element array.
return axarr.item() if axarr.size == 1 else axarr.squeeze()
else:
# returned axis array will be always 2-d, even if nrows=ncols=1
return axarr


def clf(self, keep_observers=False):
"""
Clear the figure.
Expand Down
103 changes: 4 additions & 99 deletions lib/matplotlib/pyplot.py
Expand Up @@ -1131,106 +1131,11 @@ def subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
# same as
plt.subplots(2, 2, sharex=True, sharey=True)
"""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we eliminate a lot of this docstring? Either with @_autogen_docstring or %(Figure.subplots)s or something?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure how to make this work, given that the APIs are subtly different (one returns just the new axes, the other returns a figure, axes pair). Open to suggestions...

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not super worried about the duplication of docstrings at this point. An interesting follow on to this PR would be to generate plt.subplots() via boilerplate.py.

# for backwards compatibility
if isinstance(sharex, bool):
if sharex:
sharex = "all"
else:
sharex = "none"
if isinstance(sharey, bool):
if sharey:
sharey = "all"
else:
sharey = "none"
share_values = ["all", "row", "col", "none"]
if sharex not in share_values:
# This check was added because it is very easy to type
# `subplots(1, 2, 1)` when `subplot(1, 2, 1)` was intended.
# In most cases, no error will ever occur, but mysterious behavior will
# result because what was intended to be the subplot index is instead
# treated as a bool for sharex.
if isinstance(sharex, int):
warnings.warn("sharex argument to subplots() was an integer."
" Did you intend to use subplot() (without 's')?")

raise ValueError("sharex [%s] must be one of %s" %
(sharex, share_values))
if sharey not in share_values:
raise ValueError("sharey [%s] must be one of %s" %
(sharey, share_values))
if subplot_kw is None:
subplot_kw = {}
if gridspec_kw is None:
gridspec_kw = {}

fig = figure(**fig_kw)
gs = GridSpec(nrows, ncols, **gridspec_kw)

# Create empty object array to hold all axes. It's easiest to make it 1-d
# so we can just append subplots upon creation, and then
nplots = nrows*ncols
axarr = np.empty(nplots, dtype=object)

# Create first subplot separately, so we can share it if requested
ax0 = fig.add_subplot(gs[0, 0], **subplot_kw)
axarr[0] = ax0

r, c = np.mgrid[:nrows, :ncols]
r = r.flatten() * ncols
c = c.flatten()
lookup = {
"none": np.arange(nplots),
"all": np.zeros(nplots, dtype=int),
"row": r,
"col": c,
}
sxs = lookup[sharex]
sys = lookup[sharey]

# Note off-by-one counting because add_subplot uses the MATLAB 1-based
# convention.
for i in range(1, nplots):
if sxs[i] == i:
subplot_kw['sharex'] = None
else:
subplot_kw['sharex'] = axarr[sxs[i]]
if sys[i] == i:
subplot_kw['sharey'] = None
else:
subplot_kw['sharey'] = axarr[sys[i]]
axarr[i] = fig.add_subplot(gs[i // ncols, i % ncols], **subplot_kw)

# returned axis array will be always 2-d, even if nrows=ncols=1
axarr = axarr.reshape(nrows, ncols)

# turn off redundant tick labeling
if sharex in ["col", "all"] and nrows > 1:
# turn off all but the bottom row
for ax in axarr[:-1, :].flat:
for label in ax.get_xticklabels():
label.set_visible(False)
ax.xaxis.offsetText.set_visible(False)

if sharey in ["row", "all"] and ncols > 1:
# turn off all but the first column
for ax in axarr[:, 1:].flat:
for label in ax.get_yticklabels():
label.set_visible(False)
ax.yaxis.offsetText.set_visible(False)

if squeeze:
# Reshape the array to have the final desired dimension (nrow,ncol),
# though discarding unneeded dimensions that equal 1. If we only have
# one subplot, just return it instead of a 1-element array.
if nplots == 1:
ret = fig, axarr[0, 0]
else:
ret = fig, axarr.squeeze()
else:
# returned axis array will be always 2-d, even if nrows=ncols=1
ret = fig, axarr.reshape(nrows, ncols)

return ret
axs = fig.add_subplots(
nrows=nrows, ncols=ncols, sharex=sharex, sharey=sharey, squeeze=squeeze,
subplot_kw=subplot_kw, gridspec_kw=gridspec_kw)
return fig, axs


def subplot2grid(shape, loc, rowspan=1, colspan=1, **kwargs):
Expand Down