Skip to content

Commit

Permalink
Allow avoiding PathCollection
Browse files Browse the repository at this point in the history
  • Loading branch information
Phlya committed Mar 6, 2019
1 parent f6f0066 commit a522ef7
Showing 1 changed file with 45 additions and 1 deletion.
46 changes: 45 additions & 1 deletion adjustText/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,51 @@
from itertools import product
import numpy as np
from operator import itemgetter
from matplotlib.path import get_path_collection_extents

if sys.version_info >= (3, 0):
xrange = range

def get_bboxes_pathcollection(sc, ax):
"""Function to return a list of bounding boxes in data coordinates
for a scatter plot
Thank you to ImportanceOfBeingErnest
https://stackoverflow.com/a/55007838/1304161"""
# ax.figure.canvas.draw() # need to draw before the transforms are set.
transform = sc.get_transform()
transOffset = sc.get_offset_transform()
offsets = sc._offsets
paths = sc.get_paths()
transforms = sc.get_transforms()

if not transform.is_affine:
paths = [transform.transform_path_non_affine(p) for p in paths]
transform = transform.get_affine()
if not transOffset.is_affine:
offsets = transOffset.transform_non_affine(offsets)
transOffset = transOffset.get_affine()

if isinstance(offsets, np.ma.MaskedArray):
offsets = offsets.filled(np.nan)

bboxes = []

if len(paths) and len(offsets):
if len(paths) < len(offsets):
# for usual scatters you have one path, but several offsets
paths = [paths[0]]*len(offsets)
if len(transforms) < len(offsets):
# often you may have a single scatter size, but several offsets
transforms = [transforms[0]]*len(offsets)

for p, o, t in zip(paths, offsets, transforms):
result = get_path_collection_extents(
transform.frozen(), [p], [t],
[o], transOffset.frozen())
bboxes.append(result.inverse_transformed(ax.transData))

return bboxes

def get_text_position(text, ax=None):
ax = ax or plt.gca()
x, y = text.get_position()
Expand All @@ -17,8 +58,11 @@ def get_text_position(text, ax=None):
def get_bboxes(objs, r, expand, ax):
if ax is None:
ax = plt.gca()
return [i.get_window_extent(r).expanded(*expand).transformed(ax.\
try:
return [i.get_window_extent(r).expanded(*expand).transformed(ax.\
transData.inverted()) for i in objs]
except TypeError:
return get_bboxes_pathcollection(objs, ax)

def get_midpoint(bbox):
cx = (bbox.x0+bbox.x1)/2
Expand Down

0 comments on commit a522ef7

Please sign in to comment.