Skip to content

Commit

Permalink
MarkovChain: Add get_index
Browse files Browse the repository at this point in the history
  • Loading branch information
oyamad committed Mar 28, 2016
1 parent ecfbc8f commit 5bc78d5
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 2 deletions.
38 changes: 38 additions & 0 deletions quantecon/markov/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,44 @@ def state_values(self, values):
)
self._state_values = values

def get_index(self, value):
"""
Return the index of the given value in state_values.
Parameters
----------
value
Value to get the index for.
Returns
-------
idx : int
Index of the value.
"""
error_msg = 'value {0} not found'.format(repr(value))

if self.state_values is None:
if isinstance(value, numbers.Integral) and (0 <= value < self.n):
return value
else:
raise ValueError(error_msg)

# if self.state_values is not None:
if self.state_values.ndim == 1:
try:
idx = np.where(self.state_values == value)[0][0]
return idx
except IndexError:
raise ValueError(error_msg)
else:
idx = 0
while idx < self.n:
if np.array_equal(self.state_values[idx], value):
return idx
idx += 1
raise ValueError(error_msg)

@property
def digraph(self):
if self._digraph is None:
Expand Down
20 changes: 18 additions & 2 deletions quantecon/markov/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,6 @@ def test_com_rec_classes(self):
sorted(classes, key=key)
)


def test_cyc_classes(self):
mc = self.mc_periodic_dict['mc']
cycs = self.mc_periodic_dict['cycs']
Expand All @@ -453,7 +452,6 @@ def test_cyc_classes(self):
sorted(classes, key=key)
)


def test_simulate(self):
# Deterministic mc
mc = self.mc_periodic_dict['mc']
Expand All @@ -478,6 +476,24 @@ def test_simulate(self):
assert_array_equal(X, X_expected)


def test_get_index():
P = [[0.4, 0.6], [0.2, 0.8]]
mc = MarkovChain(P)

eq_(mc.get_index(0), 0)
eq_(mc.get_index(1), 1)
assert_raises(ValueError, mc.get_index, 2)

mc.state_values = [1, 2]
eq_(mc.get_index(1), 0)
eq_(mc.get_index(2), 1)
assert_raises(ValueError, mc.get_index, 0)

mc.state_values = [[1, 2], [3, 4]]
eq_(mc.get_index([1, 2]), 0)
assert_raises(ValueError, mc.get_index, 1)


@raises(ValueError)
def test_raises_value_error_non_2dim():
"""Test with non 2dim input"""
Expand Down

0 comments on commit 5bc78d5

Please sign in to comment.