Skip to content

Commit 5c1e64d

Browse files
authored
Merge pull request #6689 from story645/category
ENH: Str Categorical Axis Support
2 parents 94f2fb8 + 8a96281 commit 5c1e64d

File tree

10 files changed

+428
-566
lines changed

10 files changed

+428
-566
lines changed

lib/matplotlib/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1489,6 +1489,7 @@ def _jupyter_nbextension_paths():
14891489
'matplotlib.tests.test_backend_svg',
14901490
'matplotlib.tests.test_basic',
14911491
'matplotlib.tests.test_bbox_tight',
1492+
'matplotlib.tests.test_category',
14921493
'matplotlib.tests.test_cbook',
14931494
'matplotlib.tests.test_coding_standards',
14941495
'matplotlib.tests.test_collections',

lib/matplotlib/axes/_axes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import matplotlib.collections as mcoll
2222
import matplotlib.colors as mcolors
2323
import matplotlib.contour as mcontour
24+
import matplotlib.category as _ # <-registers a category unit converter
2425
import matplotlib.dates as _ # <-registers a date unit converter
2526
from matplotlib import docstring
2627
import matplotlib.image as mimage

lib/matplotlib/axis.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -662,6 +662,7 @@ def __init__(self, axes, pickradius=15):
662662
self.offsetText = self._get_offset_text()
663663
self.majorTicks = []
664664
self.minorTicks = []
665+
self.unit_data = []
665666
self.pickradius = pickradius
666667

667668
# Initialize here for testing; later add API
@@ -712,6 +713,17 @@ def _set_scale(self, value, **kwargs):
712713
def limit_range_for_scale(self, vmin, vmax):
713714
return self._scale.limit_range_for_scale(vmin, vmax, self.get_minpos())
714715

716+
@property
717+
def unit_data(self):
718+
"""Holds data that a ConversionInterface subclass relys on
719+
to convert between labels and indexes
720+
"""
721+
return self._unit_data
722+
723+
@unit_data.setter
724+
def unit_data(self, data):
725+
self._unit_data = data
726+
715727
def get_children(self):
716728
children = [self.label, self.offsetText]
717729
majorticks = self.get_major_ticks()

lib/matplotlib/category.py

Lines changed: 151 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,151 @@
1+
# -*- coding: utf-8 OA-*-za
2+
"""
3+
catch all for categorical functions
4+
"""
5+
from __future__ import (absolute_import, division, print_function,
6+
unicode_literals)
7+
8+
import six
9+
10+
import numpy as np
11+
12+
import matplotlib.units as units
13+
import matplotlib.ticker as ticker
14+
15+
16+
# pure hack for numpy 1.6 support
17+
from distutils.version import LooseVersion
18+
19+
NP_NEW = (LooseVersion(np.version.version) >= LooseVersion('1.7'))
20+
21+
22+
def to_array(data, maxlen=100):
23+
if NP_NEW:
24+
return np.array(data, dtype=np.unicode)
25+
try:
26+
vals = np.array(data, dtype=('|S', maxlen))
27+
except UnicodeEncodeError:
28+
# pure hack
29+
vals = np.array([convert_to_string(d) for d in data])
30+
return vals
31+
32+
33+
class StrCategoryConverter(units.ConversionInterface):
34+
@staticmethod
35+
def convert(value, unit, axis):
36+
"""Uses axis.unit_data map to encode
37+
data as floats
38+
"""
39+
vmap = dict(axis.unit_data)
40+
41+
if isinstance(value, six.string_types):
42+
return vmap[value]
43+
44+
vals = to_array(value)
45+
for lab, loc in axis.unit_data:
46+
vals[vals == lab] = loc
47+
48+
return vals.astype('float')
49+
50+
@staticmethod
51+
def axisinfo(unit, axis):
52+
seq, locs = zip(*axis.unit_data)
53+
majloc = StrCategoryLocator(locs)
54+
majfmt = StrCategoryFormatter(seq)
55+
return units.AxisInfo(majloc=majloc, majfmt=majfmt)
56+
57+
@staticmethod
58+
def default_units(data, axis):
59+
# the conversion call stack is:
60+
# default_units->axis_info->convert
61+
axis.unit_data = map_categories(data, axis.unit_data)
62+
return None
63+
64+
65+
class StrCategoryLocator(ticker.FixedLocator):
66+
def __init__(self, locs):
67+
super(StrCategoryLocator, self).__init__(locs, None)
68+
69+
70+
class StrCategoryFormatter(ticker.FixedFormatter):
71+
def __init__(self, seq):
72+
super(StrCategoryFormatter, self).__init__(seq)
73+
74+
75+
def convert_to_string(value):
76+
"""Helper function for numpy 1.6, can be replaced with
77+
np.array(...,dtype=unicode) for all later versions of numpy"""
78+
79+
if isinstance(value, six.string_types):
80+
return value
81+
if np.isfinite(value):
82+
value = np.asarray(value, dtype=str)[np.newaxis][0]
83+
elif np.isnan(value):
84+
value = 'nan'
85+
elif np.isposinf(value):
86+
value = 'inf'
87+
elif np.isneginf(value):
88+
value = '-inf'
89+
else:
90+
raise ValueError("Unconvertable {}".format(value))
91+
return value
92+
93+
94+
def map_categories(data, old_map=None):
95+
"""Create mapping between unique categorical
96+
values and numerical identifier.
97+
98+
Paramters
99+
---------
100+
data: iterable
101+
sequence of values
102+
old_map: list of tuple, optional
103+
if not `None`, than old_mapping will be updated with new values and
104+
previous mappings will remain unchanged)
105+
sort: bool, optional
106+
sort keys by ASCII value
107+
108+
Returns
109+
-------
110+
list of tuple
111+
[(label, ticklocation),...]
112+
113+
"""
114+
115+
# code typical missing data in the negative range because
116+
# everything else will always have positive encoding
117+
# question able if it even makes sense
118+
spdict = {'nan': -1.0, 'inf': -2.0, '-inf': -3.0}
119+
120+
if isinstance(data, six.string_types):
121+
data = [data]
122+
123+
# will update this post cbook/dict support
124+
strdata = to_array(data)
125+
uniq = np.unique(strdata)
126+
127+
if old_map:
128+
olabs, okeys = zip(*old_map)
129+
svalue = max(okeys) + 1
130+
else:
131+
old_map, olabs, okeys = [], [], []
132+
svalue = 0
133+
134+
category_map = old_map[:]
135+
136+
new_labs = [u for u in uniq if u not in olabs]
137+
missing = [nl for nl in new_labs if nl in spdict.keys()]
138+
139+
category_map.extend([(m, spdict[m]) for m in missing])
140+
141+
new_labs = [nl for nl in new_labs if nl not in missing]
142+
143+
new_locs = np.arange(svalue, svalue + len(new_labs), dtype='float')
144+
category_map.extend(list(zip(new_labs, new_locs)))
145+
return category_map
146+
147+
148+
# Connects the convertor to matplotlib
149+
units.registry[str] = StrCategoryConverter()
150+
units.registry[bytes] = StrCategoryConverter()
151+
units.registry[six.text_type] = StrCategoryConverter()
Binary file not shown.
Binary file not shown.

0 commit comments

Comments
 (0)