Skip to content

Commit

Permalink
Merge pull request #659 from QuantEcon/fix-gridmake
Browse files Browse the repository at this point in the history
FIX: Fix dtype in `cartesian`
  • Loading branch information
oyamad committed Dec 1, 2022
2 parents 174234c + 1494995 commit fcde8c9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
10 changes: 5 additions & 5 deletions quantecon/gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ def cartesian(nodes, order='C'):
each line corresponds to one point of the product space
'''

nodes = [np.array(e) for e in nodes]
nodes = [np.asarray(e) for e in nodes]
shapes = [e.shape[0] for e in nodes]

dtype = nodes[0].dtype
dtype = np.result_type(*nodes)

n = len(nodes)
l = np.prod(shapes)
Expand Down Expand Up @@ -75,9 +75,9 @@ def mlinspace(a, b, nums, order='C'):
each line corresponds to one point of the product space
'''

a = np.array(a, dtype='float64')
b = np.array(b, dtype='float64')
nums = np.array(nums, dtype='int64')
a = np.asarray(a, dtype='float64')
b = np.asarray(b, dtype='float64')
nums = np.asarray(nums, dtype='int64')
nodes = [np.linspace(a[i], b[i], nums[i]) for i in range(len(nums))]

return cartesian(nodes, order=order)
Expand Down
13 changes: 13 additions & 0 deletions quantecon/tests/test_gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,19 @@ def test_cartesian_C_order_int_float():
assert_(abs(prod_int-prod_float).max() == 0)


def test_cartesian_C_order_int_float_mixed():
x_int = [0, 1]
x_float = [2.3, 4.5]
prod_expected = np.array(
[[0., 2.3],
[0., 4.5],
[1., 2.3],
[1., 4.5]]
)
prod_computed = cartesian([x_int, x_float])
assert_array_equal(prod_computed, prod_expected)


def test_cartesian_F_order():
x = np.linspace(0, 9, 10)

Expand Down

0 comments on commit fcde8c9

Please sign in to comment.