Skip to content

Commit

Permalink
Merge pull request #46 from tclose/fix-tests
Browse files Browse the repository at this point in the history
Fixes unittest fail due to PyYAML upgrade
  • Loading branch information
apdavison committed Apr 25, 2020
2 parents 9e1ab15 + 5a1eda1 commit 062a2ac
Show file tree
Hide file tree
Showing 18 changed files with 119 additions and 147 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,5 @@ cover
__pycache__
*.swp
/build/
.pytest_cache
.history
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sudo: false
matrix:
include:
- python: 2.7
- python: 3.4
- python: 3.5
- python: 3.6
addons:
apt:
Expand Down
39 changes: 2 additions & 37 deletions nineml/abstraction/expressions/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from sympy.printing import ccode
from sympy.logic.boolalg import BooleanTrue, BooleanFalse
from sympy.functions.elementary.piecewise import ExprCondPair
from sympy.utilities.lambdify import lambdify
import re
from nineml.utils import validate_identifier
# import math_namespace
Expand Down Expand Up @@ -180,43 +181,7 @@ def rhs_atoms(self):
def rhs_as_python_func(self):
""" Returns a python callable which evaluates the expression in
namespace and returns the result """
def nineml_expression(**kwargs):
if isinstance(self.rhs, (bool, int, float, BooleanTrue,
BooleanFalse)):
val = self.rhs
else:
if self.rhs.is_Boolean:
try:
val = self.rhs.subs(kwargs)
except Exception:
raise NineMLUsageError(
"Incorrect arguments provided to expression ('{}')"
": '{}'\n".format(
"', '".join(self.rhs_symbol_names),
"', '".join(list(kwargs.keys()))))
else:
try:
val = self.rhs.evalf(subs=kwargs)
except Exception:
raise NineMLUsageError(
"Incorrect arguments provided to expression '{}'"
": '{}' (expected '{}')\n".format(
self.rhs,
"', '".join(list(kwargs.keys())),
"', '".join(self.rhs_symbol_names)))
try:
val = float(val)
except TypeError:
try:
locals_dict = deepcopy(kwargs)
locals_dict.update(str_to_npfunc_map)
val = eval(str(val), {}, locals_dict)
except Exception:
raise NineMLUsageError(
"Could not evaluate expression: {}"
.format(self.rhs_str))
return val
return nineml_expression
return lambdify(list(self.rhs_symbol_names), self.rhs, 'numpy')

def rhs_suffixed(self, suffix='', prefix='', excludes=[]):
"""
Expand Down
6 changes: 5 additions & 1 deletion nineml/abstraction/expressions/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,13 @@ def _parse_expr(self, expr):
if self._logic_relation_re.search(expr):
expr = self._parse_relationals(expr)
self.escaped_names = set()
# This is a work around for a PY2 bug in Sympy where the inspect module
# can't get the argument spec of a Callable object (instead of a func.)
def parser(tokens, local_dict, global_dict):
return self(tokens, local_dict, global_dict)
try:
expr = sympy_parse(
expr, transformations=([self] + self._sympy_transforms),
expr, transformations=([parser] + self._sympy_transforms),
local_dict=self.inline_randoms_dict)
except Exception as e:
raise NineMLMathParseError(
Expand Down
5 changes: 3 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@ lxml>=3.7.3
pyyaml>=3.1
h5py>=2.7.0
future>=0.16.0
sympy>=1.1
numpydoc >= 0.7.0
sympy>=1.5
numpy>=1.11.0
numpydoc>=0.7.0
6 changes: 3 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
'Programming Language :: Python :: 2',
'Programming Language :: Python :: 2.7',
'Programming Language :: Python :: 3',
'Programming Language :: Python :: 3.4',
'Programming Language :: Python :: 3.5',
'Programming Language :: Python :: 3.6',
'Programming Language :: Python :: 3.7',
Expand All @@ -45,7 +44,8 @@
'future>=0.16.0',
'h5py>=2.7.0',
'PyYAML>=3.1',
'sympy>=1.2'],
'sympy>=1.5.1',
'numpy>=1.11.0'],
python_requires='>=2.7, !=3.0.*, !=3.1.*, !=3.2.*, <4',
tests_require=['nose', 'numpy']
tests_require=['nose']
)
124 changes: 62 additions & 62 deletions test/unittests/abstraction_test/dynamics_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def test_event_ports(self):
]
),
)
self.assertEquals(len(list(c.event_ports)), 2)
self.assertEqual(len(list(c.event_ports)), 2)

# Check inference of output event ports:
c = Dynamics(
Expand All @@ -262,7 +262,7 @@ def test_event_ports(self):
On('V < b', do=OutputEvent('ev_port3'))])
]
)
self.assertEquals(len(list(c.event_ports)), 3)
self.assertEqual(len(list(c.event_ports)), 3)

# Check inference of output event ports:
c = Dynamics(
Expand All @@ -281,7 +281,7 @@ def test_event_ports(self):
to='r1')])
]
)
self.assertEquals(len(list(c.event_ports)), 5)
self.assertEqual(len(list(c.event_ports)), 5)

def test_parameters(self):
# Signature: name
Expand Down Expand Up @@ -525,24 +525,24 @@ def test_transitions(self):
transitions=On('X>X1', do=['X=X0'],
to=None))])

self.assertEquals(len(list(c.all_transitions())), 6)
self.assertEqual(len(list(c.all_transitions())), 6)

r1 = c.regime('r1')
r2 = c.regime('r2')
r3 = c.regime('r3')
r4 = c.regime('r4')

self.assertEquals(len(list(r1.transitions)), 2)
self.assertEquals(len(list(r2.transitions)), 1)
self.assertEquals(len(list(r3.transitions)), 2)
self.assertEquals(len(list(r4.transitions)), 1)
self.assertEqual(len(list(r1.transitions)), 2)
self.assertEqual(len(list(r2.transitions)), 1)
self.assertEqual(len(list(r3.transitions)), 2)
self.assertEqual(len(list(r4.transitions)), 1)

def target_regimes(regime):
return unique_by_id(t.target_regime for t in regime.transitions)
self.assertEquals(target_regimes(r1), [r2, r3])
self.assertEquals(target_regimes(r2), [r3])
self.assertEquals(target_regimes(r3), [r3, r4])
self.assertEquals(target_regimes(r4), [r4])
self.assertEqual(target_regimes(r1), [r2, r3])
self.assertEqual(target_regimes(r2), [r3])
self.assertEqual(target_regimes(r3), [r3, r4])
self.assertEqual(target_regimes(r4), [r4])

def test_all_expressions(self):
a = Dynamics(
Expand Down Expand Up @@ -593,46 +593,46 @@ def test_On(self):

# Test that we are correctly inferring OnEvents and OnConditions.

self.assertEquals(type(On('V>0')), OnCondition)
self.assertEquals(type(On('V<0')), OnCondition)
self.assertEquals(type(On('(V<0) & (K>0)')), OnCondition)
self.assertEquals(type(On('V==0')), OnCondition)
self.assertEqual(type(On('V>0')), OnCondition)
self.assertEqual(type(On('V<0')), OnCondition)
self.assertEqual(type(On('(V<0) & (K>0)')), OnCondition)
self.assertEqual(type(On('V==0')), OnCondition)

self.assertEquals(
self.assertEqual(
type(On("q > 1 / (( 1 + mg_conc * eta * exp ( -1 * gamma*V)))")),
OnCondition)

self.assertEquals(type(On('SP0')), OnEvent)
self.assertEquals(type(On('SP1')), OnEvent)
self.assertEqual(type(On('SP0')), OnEvent)
self.assertEqual(type(On('SP1')), OnEvent)

# Check we can use 'do' with single and multiple values
tr = On('V>0')
self.assertEquals(len(list(tr.output_events)), 0)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 0)
self.assertEqual(len(list(tr.state_assignments)), 0)
tr = On('SP0')
self.assertEquals(len(list(tr.output_events)), 0)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 0)
self.assertEqual(len(list(tr.state_assignments)), 0)

tr = On('V>0', do=OutputEvent('spike'))
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 0)
tr = On('SP0', do=OutputEvent('spike'))
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 0)

tr = On('V>0', do=[OutputEvent('spike')])
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 0)
tr = On('SP0', do=[OutputEvent('spike')])
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 0)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 0)

tr = On('V>0', do=['y=2', OutputEvent('spike'), 'x=1'])
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 2)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 2)
tr = On('SP0', do=['y=2', OutputEvent('spike'), 'x=1'])
self.assertEquals(len(list(tr.output_events)), 1)
self.assertEquals(len(list(tr.state_assignments)), 2)
self.assertEqual(len(list(tr.output_events)), 1)
self.assertEqual(len(list(tr.state_assignments)), 2)


class OnCondition_test(unittest.TestCase):
Expand All @@ -643,7 +643,7 @@ def test_trigger(self):
'V < (V+10',
'V (< V+10)',
'V (< V+10)',
'1 / ( 1 + mg_conc * eta * exp (( -1 * gamma*V))'
'1 / ( 1 + mg_conc * eta * exp(-1 * gamma*V))'
'1..0'
'..0']
for tr in invalid_triggers:
Expand Down Expand Up @@ -710,7 +710,7 @@ def test_trigger(self):

python_func = c.trigger.rhs_as_python_func
param_dict = dict([(v, namespace[v]) for v in expt_vars])
self.assertEquals(return_values[i], python_func(**param_dict))
self.assertEqual(return_values[i], python_func(**param_dict))

def test_trigger_crossing_time_expr(self):
self.assertEqual(Trigger('t > t_next').crossing_time_expr.rhs,
Expand Down Expand Up @@ -741,8 +741,8 @@ def test_src_port_name(self):
self.assertRaises(NineMLUsageError, OnEvent, 'MyEvent1 2')
self.assertRaises(NineMLUsageError, OnEvent, 'MyEvent1* ')

self.assertEquals(OnEvent(' MyEvent1 ').src_port_name, 'MyEvent1')
self.assertEquals(OnEvent(' MyEvent2').src_port_name, 'MyEvent2')
self.assertEqual(OnEvent(' MyEvent1 ').src_port_name, 'MyEvent1')
self.assertEqual(OnEvent(' MyEvent2').src_port_name, 'MyEvent2')


class Regime_test(unittest.TestCase):
Expand All @@ -762,11 +762,11 @@ def test_add_on_condition(self):
# The source regime for this transition will be set as this regime.

r = Regime(name='R1')
self.assertEquals(unique_by_id(r.on_conditions), [])
self.assertEqual(unique_by_id(r.on_conditions), [])
r.add(OnCondition('sp1>0'))
self.assertEquals(len(unique_by_id(r.on_conditions)), 1)
self.assertEquals(len(unique_by_id(r.on_events)), 0)
self.assertEquals(len(unique_by_id(r.transitions)), 1)
self.assertEqual(len(unique_by_id(r.on_conditions)), 1)
self.assertEqual(len(unique_by_id(r.on_events)), 0)
self.assertEqual(len(unique_by_id(r.transitions)), 1)

def test_add_on_event(self):
# Signature: name(self, on_event)
Expand All @@ -779,11 +779,11 @@ def test_add_on_event(self):
# The source regime for this transition will be set as this regime.
# from nineml.abstraction.component.dynamics import Regime
r = Regime(name='R1')
self.assertEquals(unique_by_id(r.on_events), [])
self.assertEqual(unique_by_id(r.on_events), [])
r.add(OnEvent('sp'))
self.assertEquals(len(unique_by_id(r.on_events)), 1)
self.assertEquals(len(unique_by_id(r.on_conditions)), 0)
self.assertEquals(len(unique_by_id(r.transitions)), 1)
self.assertEqual(len(unique_by_id(r.on_events)), 1)
self.assertEqual(len(unique_by_id(r.on_conditions)), 0)
self.assertEqual(len(unique_by_id(r.transitions)), 1)

def test_get_next_name(self):
# Signature: name(cls)
Expand Down Expand Up @@ -817,7 +817,7 @@ def test_time_derivatives(self):
'dX2/dt=0',
name='r1')

self.assertEquals(
self.assertEqual(
set([td.variable for td in r.time_derivatives]),
set(['X1', 'X2']))

Expand Down Expand Up @@ -864,12 +864,12 @@ def test_event_send_receive_ports(self):
]
),
)
self.assertEquals(len(list(c.event_receive_ports)), 1)
self.assertEquals((list(list(c.event_receive_ports))[0]).name,
self.assertEqual(len(list(c.event_receive_ports)), 1)
self.assertEqual((list(list(c.event_receive_ports))[0]).name,
'in_ev1')

self.assertEquals(len(list(c.event_send_ports)), 2)
self.assertEquals(set(c.event_send_port_names),
self.assertEqual(len(list(c.event_send_ports)), 2)
self.assertEqual(set(c.event_send_port_names),
set(['ev_port1', 'ev_port2']))

# Check inference of output event ports:
Expand All @@ -887,12 +887,12 @@ def test_event_send_receive_ports(self):
On('in_ev2', do=OutputEvent('ev_port3'))])
]
)
self.assertEquals(len(list(c.event_receive_ports)), 2)
self.assertEquals(set(c.event_receive_port_names),
self.assertEqual(len(list(c.event_receive_ports)), 2)
self.assertEqual(set(c.event_receive_port_names),
set(['in_ev1', 'in_ev2']))

self.assertEquals(len(list(c.event_send_ports)), 3)
self.assertEquals(set(c.event_send_port_names),
self.assertEqual(len(list(c.event_send_ports)), 3)
self.assertEqual(set(c.event_send_port_names),
set(['ev_port1', 'ev_port2', 'ev_port3']))

# Check inference of output event ports:
Expand All @@ -913,12 +913,12 @@ def test_event_send_receive_ports(self):
to='r1')])
]
)
self.assertEquals(len(list(c.event_receive_ports)), 3)
self.assertEquals(set(c.event_receive_port_names),
self.assertEqual(len(list(c.event_receive_ports)), 3)
self.assertEqual(set(c.event_receive_port_names),
set(['spikeinput1', 'spikeinput2', 'spikeinput3']))

self.assertEquals(len(list(c.event_send_ports)), 3)
self.assertEquals(set(c.event_send_port_names),
self.assertEqual(len(list(c.event_send_ports)), 3)
self.assertEqual(set(c.event_send_port_names),
set(['ev_port1', 'ev_port2', 'ev_port3']))

def test_ports(self):
Expand Down Expand Up @@ -951,8 +951,8 @@ def test_ports(self):
ports = list(list(c.ports))
port_names = [p.name for p in ports]

self.assertEquals(len(port_names), 8)
self.assertEquals(set(port_names),
self.assertEqual(len(port_names), 8)
self.assertEqual(set(port_names),
set(['A1', 'B', 'C', 'spikeinput1', 'spikeinput2',
'spikeinput3', 'ev_port2', 'ev_port3'])
)
Expand Down
Loading

0 comments on commit 062a2ac

Please sign in to comment.