Skip to content

Commit

Permalink
Merge a270b0d into d0afadc
Browse files Browse the repository at this point in the history
  • Loading branch information
burnpanck committed Jun 18, 2016
2 parents d0afadc + a270b0d commit 4f80670
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
18 changes: 11 additions & 7 deletions numtraits.py
Expand Up @@ -25,7 +25,9 @@

from __future__ import print_function

from traitlets import TraitType, TraitError
import warnings

from traitlets import TraitType, TraitError, Undefined

import numpy as np

Expand All @@ -38,16 +40,18 @@
class NumericalTrait(TraitType):
info_text = 'a numerical trait, either a scalar or a vector'
def __init__(self, ndim=None, shape=None, domain=None,
default=None, convertible_to=None):
super(NumericalTrait, self).__init__()
default_value=Undefined, convertible_to=None, default=Undefined):
if default is not Undefined:
if default_value is not Undefined:
raise TypeError('Cannot set default and default_value simultaneously')
warnings.warn(DeprecationWarning('`default` has been renamed to `default_value`'))
default_value = default
super(NumericalTrait, self).__init__(default_value=default_value)

# Just store all the construction arguments.
self.ndim = ndim
self.shape = shape
self.domain = domain
# TODO: traitlets supports a `default` argument in __init__(), we should
# probably link them together once we start using this.
self.default = default
self.target_unit = convertible_to

if self.target_unit is not None:
Expand Down Expand Up @@ -97,7 +101,7 @@ def validate(self, obj, value):
if self.ndim is not None:

if self.ndim == 0:
if not is_scalar:
if not is_scalar and num_value.ndim:
raise TraitError("{0} should be a scalar value".format(self.name))

if self.ndim > 0:
Expand Down
10 changes: 10 additions & 0 deletions test_numtraits.py
Expand Up @@ -13,6 +13,7 @@ class ScalarProperties(HasTraits):
d = NumericalTrait(ndim=0, domain='negative')
e = NumericalTrait(ndim=0, domain='strictly-negative')
f = NumericalTrait(ndim=0, domain=(3, 4))
g = NumericalTrait(ndim=0, default_value=2)

class TestScalar(object):

Expand Down Expand Up @@ -79,6 +80,14 @@ def test_range(self):
self.sp.f = 7
assert exc.value.args[0] == "f should be in the range [3:4]"

def test_scalar_quantities(self):
""" Tests for issue #14.
"""
quantities = pytest.importorskip("quantities")
self.sp.a = 1*quantities.m

def test_default_value(self):
assert self.sp.g == 2

class ArrayProperties(HasTraits):

Expand All @@ -90,6 +99,7 @@ class ArrayProperties(HasTraits):
f = NumericalTrait(domain=(3, 4), ndim=1)
g = NumericalTrait(shape=(3, 4))


class TestArray(object):

def setup_method(self, method):
Expand Down

0 comments on commit 4f80670

Please sign in to comment.