Skip to content

Commit

Permalink
EHN: Implement cartesian_nearest_index (#660)
Browse files Browse the repository at this point in the history
* EHN: Implement `cartesian_nearest_index`

* Fixes
  • Loading branch information
oyamad committed Dec 4, 2022
1 parent fcde8c9 commit b83799f
Show file tree
Hide file tree
Showing 3 changed files with 189 additions and 3 deletions.
4 changes: 3 additions & 1 deletion quantecon/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@
from .estspec import smooth, periodogram, ar_periodogram
# from .game_theory import <objects-here> #Place Holder if we wish to promote any general objects to the qe namespace.
from .graph_tools import DiGraph, random_tournament_graph
from .gridtools import cartesian, mlinspace, simplex_grid, simplex_index
from .gridtools import (
cartesian, mlinspace, cartesian_nearest_index, simplex_grid, simplex_index
)
from .inequality import lorenz_curve, gini_coefficient, shorrocks_index, \
rank_size
from .kalman import Kalman
Expand Down
139 changes: 139 additions & 0 deletions quantecon/gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,145 @@ def _repeat_1d(x, K, out):
out[ind] = val


def cartesian_nearest_index(x, nodes, order='C'):
"""
Return the index of the point closest to `x` within the cartesian
product generated by `nodes`. Each array in `nodes` must be sorted
in ascending order.
Parameters
----------
x : array_like(ndim=1 or 2)
Point(s) to search the closest point(s) for.
nodes : array_like(array_like(ndim=1))
Array of sorted arrays.
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated.
Returns
-------
scalar(int) or ndarray(int, ndim=1)
Index (indices) of the closest point(s) to `x`.
Examples
--------
>>> nodes = (np.arange(3), np.arange(2))
>>> prod = qe.cartesian(nodes)
>>> print(prod)
[[0 0]
[0 1]
[1 0]
[1 1]
[2 0]
[2 1]]
Among the 6 points in the cartesian product `prod`, the closest to
the point (0.6, 0.4) is `prod[2]`:
>>> x = (0.6, 0.4)
>>> qe.cartesian_nearest_index(x, nodes) # Pass `nodes`, not `prod`
2
The closest to (-0.1, 1.2) and (2, 0) are `prod[1]` and `prod[4]`,
respectively:
>>> x = [(-0.1, 1.2), (2, 0)]
>>> qe.cartesian_nearest_index(x, nodes)
array([1, 4])
Internally, the index in each dimension is searched by binary search
and then the index in the cartesian product is calculated (*not* by
constructing the cartesian product and then searching linearly over
it).
"""
x = np.asarray(x)
is_1d = False
shape = x.shape
if len(shape) == 1:
is_1d = True
x = x[np.newaxis]
types = [type(e[0]) for e in nodes]
dtype = np.result_type(*types)
nodes = tuple(np.asarray(e, dtype=dtype) for e in nodes)

n = shape[1-is_1d]
if len(nodes) != n:
msg = 'point `x`' if is_1d else 'points in `x`'
msg += ' must have same length as `nodes`'
raise ValueError(msg)

out = _cartesian_nearest_indices(x, nodes, order=order)
if is_1d:
return out[0]
return out


@njit(cache=True)
def _cartesian_nearest_indices(X, nodes, order='C'):
"""
The main body of `cartesian_nearest_index`, jit-complied by Numba.
Note that `X` must be a 2-dim ndarray, and a Python list is not
accepted for `nodes`.
Parameters
----------
X : ndarray(ndim=2)
Points to search the closest points for.
nodes : tuple(ndarray(ndim=1))
Tuple of sorted ndarrays of same dtype.
order : str, optional(default='C')
('C' or 'F') order in which the product is enumerated.
Returns
-------
ndarray(int, ndim=1)
Indices of the closest points to the points in `X`.
"""
m, n = X.shape # m vectors of length n
nums_grids = np.empty(n, dtype=np.intp)
for i in range(n):
nums_grids[i] = len(nodes[i])

ind = np.empty(n, dtype=np.intp)
out = np.empty(m, dtype=np.intp)

step = -1 if order == 'F' else 1
slice_ = slice(None, None, step)

for t in range(m):
for i in range(n):
if X[t, i] <= nodes[i][0]:
ind[i] = 0
elif X[t, i] >= nodes[i][-1]:
ind[i] = nums_grids[i] - 1
else:
k = np.searchsorted(nodes[i], X[t, i])
ind[i] = (
k if nodes[i][k] - X[t, i] < X[t, i] - nodes[i][k-1]
else k - 1
)
out[t] = _cartesian_index(ind[slice_], nums_grids[slice_])

return out


@njit(cache=True)
def _cartesian_index(indices, nums_grids):
n = len(indices)
idx = 0
de_cumprod = 1
for i in range(1,n+1):
idx += de_cumprod * indices[n-i]
de_cumprod *= nums_grids[n-i]
return idx


_msg_max_size_exceeded = 'Maximum allowed size exceeded'


Expand Down
49 changes: 47 additions & 2 deletions quantecon/tests/test_gridtools.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,13 @@
import numpy as np
import time
import pytest
from numpy.testing import assert_array_equal, assert_, assert_raises
from numpy.testing import (
assert_array_equal, assert_equal, assert_, assert_raises
)

from quantecon.gridtools import (
cartesian, mlinspace, _repeat_1d, simplex_grid, simplex_index,
num_compositions, num_compositions_jit
num_compositions, num_compositions_jit, cartesian_nearest_index
)


Expand Down Expand Up @@ -165,6 +167,49 @@ def test_repeat():
assert_(abs(t_numpy-t_repeat).max())


class TestCartesianNearestIndex:
def setup_method(self):
nums = (5, 6)
self.nodes = [list(range(nums[0])), np.linspace(0, 1, nums[1])]
self.orders = ['C', 'F']
self.prod_dict = \
{order:cartesian(self.nodes, order=order) for order in self.orders}

def linear_search(self, x, order='C'):
x = np.asarray(x)
return ((self.prod_dict[order] - x)**2).sum(1).argmin()

def test_1d(self):
x = (1.2, 0.3)
for order in self.orders:
ind_expected = self.linear_search(x, order)
ind_computed = cartesian_nearest_index(x, self.nodes, order)
assert_equal(ind_computed, ind_expected)

assert_raises(
ValueError, cartesian_nearest_index, x, self.prod_dict['C']
)

def test_2d(self):
T = 10
rng = np.random.default_rng(1234)
X = np.column_stack((
rng.uniform(self.nodes[0][0]-1, self.nodes[0][-1]+1, size=T),
rng.standard_normal(T) + 0.5
))
ind_expected = np.empty(T, dtype=np.intp)

for order in self.orders:
for t in range(T):
ind_expected[t] = self.linear_search(X[t], order)
ind_computed = cartesian_nearest_index(X, self.nodes, order)
assert_array_equal(ind_computed, ind_expected)

assert_raises(
ValueError, cartesian_nearest_index, X, self.prod_dict['C']
)


class TestSimplexGrid:
def setup_method(self):
self.simplex_grid_3_4 = np.array([[0, 0, 4],
Expand Down

0 comments on commit b83799f

Please sign in to comment.