Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add smart axis label visibility to wcsaxes #6774

Merged
merged 8 commits into from Jan 23, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES.rst
Expand Up @@ -467,6 +467,10 @@ astropy.visualization
``WCSPixel2WorldTransform`` classes by setting ``has_inverse`` to ``True``.
In order to implement a unit test, also implement the equality comparison
operator for both classes. [#6531]
- Added automatic hiding of axes labels when no tick labels are drawn on that
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space was missing here...

Copy link
Member

@eteq eteq Jan 24, 2018

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI, I fixed this directly in master (2ece61c) since it's dead trivial and this got merged already

axis. This parameter can be configured with
``WCSAxes.coords[*].set_axislabel_visibility_rule`` so that labels are automatically
hidden when no ticks are drawn or always shown. [#6774]

astropy.wcs
^^^^^^^^^^^
Expand Down
50 changes: 38 additions & 12 deletions astropy/visualization/wcsaxes/axislabels.py
Expand Up @@ -19,6 +19,7 @@ def __init__(self, frame, minpad=1, *args, **kwargs):
self.set_ha('center')
self.set_va('center')
self._minpad = minpad
self._visibility_rule = 'labels'

def get_minpad(self, axis):
try:
Expand All @@ -38,7 +39,18 @@ def get_visible_axes(self):
def set_minpad(self, minpad):
self._minpad = minpad

def draw(self, renderer, bboxes, ticklabels_bbox_list, visible_ticks):
def set_visibility_rule(self, value):
allowed = ['always', 'labels', 'ticks']
if value not in allowed:
raise ValueError("Axis label visibility rule must be one of{}".format(' / '.join(allowed)))

self._visibility_rule = value

def get_visibility_rule(self):
return self._visibility_rule

def draw(self, renderer, bboxes, ticklabels_bbox,
coord_ticklabels_bbox, ticks_locs, visible_ticks):

if not self.get_visible():
return
Expand All @@ -47,6 +59,20 @@ def draw(self, renderer, bboxes, ticklabels_bbox_list, visible_ticks):

for axis in self.get_visible_axes():

# Flatten the bboxes for all coords and all axes
ticklabels_bbox_list = []
for bbcoord in ticklabels_bbox.values():
for bbaxis in bbcoord.values():
ticklabels_bbox_list += bbaxis

if self.get_visibility_rule() == 'ticks':
if not ticks_locs[axis]:
continue

elif self.get_visibility_rule() == 'labels':
if not coord_ticklabels_bbox:
continue

padding = text_size * self.get_minpad(axis)

# Find position of the axis label. For now we pick the mid-point
Expand Down Expand Up @@ -75,38 +101,38 @@ def draw(self, renderer, bboxes, ticklabels_bbox_list, visible_ticks):

if isinstance(self._frame, RectangularFrame):

if len(ticklabels_bbox_list) > 0:
ticklabels_bbox = mtransforms.Bbox.union(ticklabels_bbox_list)
if len(ticklabels_bbox_list) > 0 and ticklabels_bbox_list[0] is not None:
coord_ticklabels_bbox[axis] = [mtransforms.Bbox.union(ticklabels_bbox_list)]
else:
ticklabels_bbox = None
coord_ticklabels_bbox[axis] = [None]

if axis == 'l':
if axis in visible_ticks and ticklabels_bbox is not None:
left = ticklabels_bbox.xmin
if axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None:
left = coord_ticklabels_bbox[axis][0].xmin
else:
left = xcen
xpos = left - padding
self.set_position((xpos, ycen))

elif axis == 'r':
if axis in visible_ticks and ticklabels_bbox is not None:
right = ticklabels_bbox.x1
if axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None:
right = coord_ticklabels_bbox[axis][0].x1
else:
right = xcen
xpos = right + padding
self.set_position((xpos, ycen))

elif axis == 'b':
if axis in visible_ticks and ticklabels_bbox is not None:
bottom = ticklabels_bbox.ymin
if axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None:
bottom = coord_ticklabels_bbox[axis][0].ymin
else:
bottom = ycen
ypos = bottom - padding
self.set_position((xcen, ypos))

elif axis == 't':
if axis in visible_ticks and ticklabels_bbox is not None:
top = ticklabels_bbox.y1
if axis in visible_ticks and coord_ticklabels_bbox[axis][0] is not None:
top = coord_ticklabels_bbox[axis][0].y1
else:
top = ycen
ypos = top + padding
Expand Down
30 changes: 26 additions & 4 deletions astropy/visualization/wcsaxes/coordinate_helpers.py
Expand Up @@ -417,6 +417,26 @@ def set_axislabel_position(self, position):
"""
self.axislabels.set_visible_axes(position)

def set_axislabel_visibility_rule(self, rule):
"""
Set the rule used to determine when the axis label is drawn.

Parameters
----------
rule : str
If the rule is 'always' axis labels will always be drawn on the
axis. If the rule is 'ticks' the label will only be drawn if ticks
were drawn on that axis. If the rule is 'labels' the axis label
will only be drawn if tick labels were drawn on that axis.
"""
self.axislabels.set_visibility_rule(rule)

def get_axislabel_visibility_rule(self, rule):
"""
Get the rule used to determine when the axis label is drawn.
"""
return self.axislabels.get_visibility_rule()

@property
def locator(self):
return self._formatter_locator.locator
Expand Down Expand Up @@ -454,22 +474,24 @@ def _draw_grid(self, renderer):

renderer.close_group('grid lines')

def _draw_ticks(self, renderer, bboxes, ticklabels_bbox):
def _draw_ticks(self, renderer, bboxes, ticklabels_bbox, ticks_locs):

renderer.open_group('ticks')

self.ticks.draw(renderer)
self.ticks.draw(renderer, ticks_locs)
self.ticklabels.draw(renderer, bboxes=bboxes,
ticklabels_bbox=ticklabels_bbox)

renderer.close_group('ticks')

def _draw_axislabels(self, renderer, bboxes, ticklabels_bbox, visible_ticks):
def _draw_axislabels(self, renderer, bboxes, ticklabels_bbox, ticks_locs, visible_ticks):

renderer.open_group('axis labels')

self.axislabels.draw(renderer, bboxes=bboxes,
ticklabels_bbox_list=ticklabels_bbox,
ticklabels_bbox=ticklabels_bbox,
coord_ticklabels_bbox=ticklabels_bbox[self],
ticks_locs=ticks_locs,
visible_ticks=visible_ticks)

renderer.close_group('axis labels')
Expand Down
13 changes: 10 additions & 3 deletions astropy/visualization/wcsaxes/core.py
@@ -1,5 +1,7 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst

from functools import partial
from collections import defaultdict

import numpy as np

Expand Down Expand Up @@ -328,7 +330,10 @@ def draw(self, renderer, inframe=False):
# each coordinate axis. For now, just assume it covers the whole sky.

self._bboxes = []
self._ticklabels_bbox = []
# This generates a structure like [coords][axis] = [...]
ticklabels_bbox = defaultdict(partial(defaultdict, list))
ticks_locs = defaultdict(partial(defaultdict, list))

visible_ticks = []

for coords in self._all_coords:
Expand All @@ -341,14 +346,16 @@ def draw(self, renderer, inframe=False):

for coord in coords:
coord._draw_ticks(renderer, bboxes=self._bboxes,
ticklabels_bbox=self._ticklabels_bbox)
ticklabels_bbox=ticklabels_bbox[coord],
ticks_locs=ticks_locs[coord])
visible_ticks.extend(coord.ticklabels.get_visible_axes())

for coords in self._all_coords:

for coord in coords:
coord._draw_axislabels(renderer, bboxes=self._bboxes,
ticklabels_bbox=self._ticklabels_bbox,
ticklabels_bbox=ticklabels_bbox,
ticks_locs=ticks_locs[coord],
visible_ticks=visible_ticks)

self.coords.frame.draw(renderer)
Expand Down
60 changes: 60 additions & 0 deletions astropy/visualization/wcsaxes/tests/test_coordinate_helpers.py
@@ -1,7 +1,12 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
from ..core import WCSAxes
from .... import units as u
import matplotlib.pyplot as plt

from unittest.mock import patch

import pytest


def test_getaxislabel():

Expand All @@ -12,3 +17,58 @@ def test_getaxislabel():
ax.coords[1].set_axislabel("Y")
assert ax.coords[0].get_axislabel() == "X"
assert ax.coords[1].get_axislabel() == "Y"


@pytest.fixture
def ax():
fig = plt.figure()
ax = WCSAxes(fig, [0.1, 0.1, 0.8, 0.8], aspect='equal')
fig.add_axes(ax)

return ax


def assert_label_draw(ax, x_label, y_label):
ax.coords[0].set_axislabel("Label 1")
ax.coords[1].set_axislabel("Label 2")

with patch.object(ax.coords[0].axislabels, 'set_position') as pos1:
with patch.object(ax.coords[1].axislabels, 'set_position') as pos2:
ax.figure.canvas.draw()

assert pos1.call_count == x_label
assert pos2.call_count == y_label


def test_label_visibility_rules_default(ax):
assert_label_draw(ax, True, True)


def test_label_visibility_rules_label(ax):

ax.coords[0].set_ticklabel_visible(False)
ax.coords[1].set_ticks(values=[-9999]*u.deg)

assert_label_draw(ax, False, False)


def test_label_visibility_rules_ticks(ax):

ax.coords[0].set_axislabel_visibility_rule('ticks')
ax.coords[1].set_axislabel_visibility_rule('ticks')

ax.coords[0].set_ticklabel_visible(False)
ax.coords[1].set_ticks(values=[-9999]*u.deg)

assert_label_draw(ax, True, False)


def test_label_visibility_rules_always(ax):

ax.coords[0].set_axislabel_visibility_rule('always')
ax.coords[1].set_axislabel_visibility_rule('always')

ax.coords[0].set_ticklabel_visible(False)
ax.coords[1].set_ticks(values=[-9999]*u.deg)

assert_label_draw(ax, True, True)
1 change: 1 addition & 0 deletions astropy/visualization/wcsaxes/tests/test_images.py
Expand Up @@ -439,6 +439,7 @@ def test_axislabels_regression(self):
ax = fig.add_axes([0.25, 0.25, 0.5, 0.5], projection=wcs, aspect='auto')
ax.coords[0].set_axislabel("Label 1")
ax.coords[1].set_axislabel("Label 2")
ax.coords[1].set_axislabel_visibility_rule('always')
ax.coords[1].ticklabels.set_visible(False)
return fig

Expand Down
2 changes: 1 addition & 1 deletion astropy/visualization/wcsaxes/ticklabels.py
Expand Up @@ -207,4 +207,4 @@ def draw(self, renderer, bboxes, ticklabels_bbox):
if not self._exclude_overlapping or bb.count_overlaps(bboxes) == 0:
super().draw(renderer)
bboxes.append(bb)
ticklabels_bbox.append(bb)
ticklabels_bbox[axis].append(bb)
10 changes: 6 additions & 4 deletions astropy/visualization/wcsaxes/ticks.py
Expand Up @@ -128,7 +128,7 @@ def __len__(self):

_tickvert_path = Path([[0., 0.], [1., 0.]])

def draw(self, renderer):
def draw(self, renderer, ticks_locs):
"""
Draw the ticks.
"""
Expand All @@ -137,12 +137,12 @@ def draw(self, renderer):
return

offset = renderer.points_to_pixels(self.get_ticksize())
self._draw_ticks(renderer, self.pixel, self.angle, offset)
self._draw_ticks(renderer, self.pixel, self.angle, offset, ticks_locs)
if self._display_minor_ticks:
offset = offset * 0.5 # for minor ticksize
self._draw_ticks(renderer, self.minor_pixel, self.minor_angle, offset)
self._draw_ticks(renderer, self.minor_pixel, self.minor_angle, offset, ticks_locs)

def _draw_ticks(self, renderer, pixel_array, angle_array, offset):
def _draw_ticks(self, renderer, pixel_array, angle_array, offset, ticks_locs):
"""
Draw the minor ticks.
"""
Expand Down Expand Up @@ -177,4 +177,6 @@ def _draw_ticks(self, renderer, pixel_array, angle_array, offset):
# Reset the tick rotation before moving to the next tick
marker_rotation.clear()

ticks_locs[axis].append(locs)

gc.restore()