Skip to content

Commit

Permalink
Merge pull request #360 from michaelbynum/value_bug
Browse files Browse the repository at this point in the history
Ensure variables and parameters keep their values when constraints are removed
  • Loading branch information
kaklise committed Aug 17, 2023
2 parents 040bdb8 + 8c6d4c2 commit 3d42408
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 0 deletions.
2 changes: 2 additions & 0 deletions wntr/sim/aml/aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def _decrement_var(self, var):
if self._refcounts[var] == 0:
cvar = self._var_cvar_map[var]
var._c_obj = None
var._value = cvar.value
del self._refcounts[var]
del self._var_cvar_map[var]
self._evaluator.remove_var(cvar)
Expand All @@ -162,6 +163,7 @@ def _decrement_param(self, p):
if self._refcounts[p] == 0:
cparam = self._param_cparam_map[p]
p._c_obj = None
p._value = cparam.value
del self._refcounts[p]
del self._param_cparam_map[p]
self._evaluator.remove_param(cparam)
Expand Down
25 changes: 25 additions & 0 deletions wntr/tests/test_aml.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import numpy as np
import wntr.sim.aml as aml
from wntr.sim.solvers import NewtonSolver, SolverStatus


def compare_evaluation(self, m, true_r, true_j):
Expand All @@ -17,6 +18,30 @@ def compare_evaluation(self, m, true_r, true_j):
self.assertAlmostEqual(true_j[c][v], j[c.index, v.index], 10)


class TestModel(unittest.TestCase):
def test_var_value_with_decrement(self):
m = aml.Model()
m.x = aml.Var()
m.p = aml.Param(val=1)
m.c = aml.Constraint(m.x - m.p)
m.set_structure()
opt = NewtonSolver({'TOL': 1e-8})
status, msg, num_iter = opt.solve(m)
self.assertEqual(status, SolverStatus.converged)
self.assertAlmostEqual(m.x.value, 1)
m.p.value = 2
status, msg, num_iter = opt.solve(m)
self.assertEqual(status, SolverStatus.converged)
self.assertAlmostEqual(m.x.value, 2)
del m.c
m.c = aml.Constraint(m.x**0.5 - m.p)
m.set_structure()
self.assertAlmostEqual(m.x.value, 2) # this is the real test here
status, msg, num_iter = opt.solve(m)
self.assertEqual(status, SolverStatus.converged)
self.assertAlmostEqual(m.x.value, 4)


class TestExpression(unittest.TestCase):
def test_add(self):
m = aml.Model()
Expand Down

0 comments on commit 3d42408

Please sign in to comment.