Skip to content

Commit

Permalink
MNT: improve image array argument checking in to_rgba. Closes #2499.
Browse files Browse the repository at this point in the history
  • Loading branch information
efiring committed Mar 7, 2016
1 parent 8433eec commit c2f91c5
Showing 1 changed file with 16 additions and 11 deletions.
27 changes: 16 additions & 11 deletions lib/matplotlib/cm.py
Expand Up @@ -221,6 +221,9 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
If *x* is an ndarray with 3 dimensions,
and the last dimension is either 3 or 4, then it will be
treated as an rgb or rgba array, and no mapping will be done.
The array can be uint8, or it can be floating point with
values in the 0-1 range; otherwise a ValueError will be raised.
If it is a masked array, the mask will be ignored.
If the last dimension is 3, the *alpha* kwarg (defaulting to 1)
will be used to fill in the transparency. If the last dimension
is 4, the *alpha* kwarg is ignored; it does not
Expand All @@ -232,12 +235,8 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
the returned rgba array will be uint8 in the 0 to 255 range.
If norm is False, no normalization of the input data is
performed, and it is assumed to already be in the range (0-1).
performed, and it is assumed to be in the range (0-1).
Note: this method assumes the input is well-behaved; it does
not check for anomalies such as *x* being a masked rgba
array, or being an integer type other than uint8, or being
a floating point rgba array with values outside the 0-1 range.
"""
# First check for special case, image input:
try:
Expand All @@ -255,12 +254,18 @@ def to_rgba(self, x, alpha=None, bytes=False, norm=True):
xx = x
else:
raise ValueError("third dimension must be 3 or 4")
if xx.dtype != np.uint8 and xx.max() > 1:
raise ValueError("image must be either uint8 or in the 0..1 range")
if bytes and xx.dtype != np.uint8:
xx = (xx * 255).astype(np.uint8)
if not bytes and xx.dtype == np.uint8:
xx = xx.astype(float) / 255
if xx.dtype.kind == 'f':
if xx.max() > 1 or xx.min() < 0:
raise ValueError("Floating point image RGB values "
"must be in the 0..1 range.")
if bytes:
xx = (xx * 255).astype(np.uint8)
elif xx.dtype == np.uint8:
if not bytes:
xx = xx.astype(float) / 255
else:
raise ValueError("Image RGB array must be uint8 or "
"floating point; found %s" % xx.dtype)
return xx
except AttributeError:
# e.g., x is not an ndarray; so try mapping it
Expand Down

0 comments on commit c2f91c5

Please sign in to comment.