Skip to content

Commit

Permalink
Merge pull request #12 from TUW-GEO/fix-distance-output
Browse files Browse the repository at this point in the history
add support for lists and numpy arrays and probably other iterables
  • Loading branch information
cpaulik committed Apr 16, 2015
2 parents 5e7780f + 4a0b756 commit a041dc1
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 28 deletions.
81 changes: 57 additions & 24 deletions pygeogrids/grids.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,9 +200,9 @@ def __init__(self, lon, lat, gpis=None, subset=None, setup_kdTree=True,

self.kdTree = None
if setup_kdTree:
self._setup_kdTree()
self._setup_kdtree()

def _setup_kdTree(self):
def _setup_kdtree(self):
"""
setup kdTree
"""
Expand Down Expand Up @@ -336,9 +336,9 @@ def find_nearest_gpi(self, lon, lat, max_dist=np.Inf):
Parameters
----------
lon : float
lon : float or iterable
longitude of point
lat : float
lat : float or iterable
latitude of point
Returns
Expand All @@ -349,24 +349,34 @@ def find_nearest_gpi(self, lon, lat, max_dist=np.Inf):
distance of gpi to given lon, lat
At the moment not on a great circle but in spherical cartesian coordinates
"""
# check if input is iterable
try:
lon[0]
iterable = True
except TypeError:
iterable = False

if self.kdTree is None:
self.kdTree = NN.findGeoNN(self.activearrlon, self.activearrlat)
self._setup_kdtree()

d, ind = self.kdTree.find_nearest_index(lon, lat, max_dist=max_dist)

if not iterable:
d = d[0]
ind = ind[0]

if self.gpidirect and self.allpoints:
return ind[0], d
return ind, d

return self.activegpis[ind[0]], d
return self.activegpis[ind], d

def gpi2lonlat(self, gpi):
"""
Longitude and Latitude for given GPI.
Parameters
----------
gpi : int32
gpi : int32 or iterable
Grid Point Index.
Returns
Expand Down Expand Up @@ -400,14 +410,28 @@ def gpi2rowcol(self, gpi):
col : int
column in 2D array
"""
# check if iterable
try:
gpi[0]
iterable = True
except TypeError:
iterable = False
gpi = np.atleast_1d(gpi)
if len(self.shape) == 2:
if self.gpidirect:
index = gpi
else:
index = np.where(self.gpis == gpi)[0][0]
# get the indices that would sort the gpis
gpisorted = np.argsort(self.gpis)
# find the position where the gpis fit in the sorted array
pos = np.searchsorted(self.gpis[gpisorted], gpi)
index = gpisorted[pos]

index_lat = int(index / len(self.londim))
index_lat = (index / len(self.londim)).astype(np.int)
index_lon = index % len(self.londim)
if not iterable:
index_lat = index_lat[0]
index_lon = index_lon[0]
return index_lat, index_lon

else:
Expand Down Expand Up @@ -437,10 +461,10 @@ def calc_lut(self, other, max_dist=np.Inf, into_subset=False):
"""

if self.kdTree is None:
self._setup_kdTree()
self._setup_kdtree()

if other.kdTree is None:
other._setup_kdTree()
other._setup_kdtree()

if self.kdTree.kdtree is not None and other.kdTree.kdtree is not None:
dist, index = other.kdTree.find_nearest_index(
Expand Down Expand Up @@ -626,26 +650,35 @@ def gpi2cell(self, gpi):
Parameters
----------
gpi : int32
gpi : int32 or iterable
Grid Point Index.
Returns
-------
cell : int
cell : int or iterable
Cell number of GPI.
Raises
------
IndexError
if gpi is not found
"""
# check if iterable
try:
gpi[0]
iterable = True
except TypeError:
iterable = False
gpi = np.atleast_1d(gpi)
if self.gpidirect:
return self.arrcell[gpi]
cell = self.arrcell[gpi]
else:
index = np.where(self.activegpis == gpi)[0]
if index.size == 0:
raise IndexError('Not a valid gpi')
return self.activearrcell[index[0]]
# get the indices that would sort the gpis
gpisorted = np.argsort(self.gpis)
# find the position where the gpis fit in the sorted array
pos = np.searchsorted(self.gpis[gpisorted], gpi)
index = gpisorted[pos]
cell = self.activearrcell[index]

if not iterable:
cell = cell[0]

return cell

def get_cells(self):
"""
Expand Down
124 changes: 120 additions & 4 deletions tests/test_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,95 @@ def testlonlat2cell_hist(self):
nptest.assert_allclose(hist, np.zeros_like(hist) + 72)


class TestFindNearestNeighbor(unittest.TestCase):

def setUp(self):
self.grid = grids.genreg_grid(1, 1)

def test_nearest_neighbor(self):
gpi, dist = self.grid.find_nearest_gpi(14.3, 18.5)
assert gpi == 25754
assert len([dist]) == 1
lon, lat = self.grid.gpi2lonlat(gpi)
assert lon == 14.5
assert lat == 18.5

def test_nearest_neighbor_list(self):
gpi, dist = self.grid.find_nearest_gpi([145.1, 90.2], [45.8, -16.3])
assert len(gpi) == 2
assert len(dist) == 2
assert gpi[0] == 16165
assert gpi[1] == 38430
lon, lat = self.grid.gpi2lonlat(gpi)
assert lon[0] == 145.5
assert lon[1] == 90.5
assert lat[0] == 45.5
assert lat[1] == -16.5

def test_nearest_neighbor_ndarray(self):
gpi, dist = self.grid.find_nearest_gpi(
np.array([145.1, 90.2]), np.array([45.8, -16.3]))
assert len(gpi) == 2
assert len(dist) == 2
assert gpi[0] == 16165
assert gpi[1] == 38430
lon, lat = self.grid.gpi2lonlat(gpi)
assert lon[0] == 145.5
assert lon[1] == 90.5
assert lat[0] == 45.5
assert lat[1] == -16.5


class TestCellGrid(unittest.TestCase):

"""
setup simple 2D grid 2.5 degree global grid (144x72)
which starts at the North Western corner of 90 -180
Test for cell specific features
"""

def setUp(self):
self.latdim = np.arange(90, -90, -2.5)
self.londim = np.arange(-180, 180, 2.5)
self.lon, self.lat = np.meshgrid(self.londim, self.latdim)
self.grid = grids.BasicGrid(
self.lon.flatten(), self.lat.flatten(), shape=(len(self.londim), len(self.latdim)))
self.cellgrid = self.grid.to_cell_grid()

def test_gpi2cell(self):
"""
test if gpi to row column lookup works correctly
"""
gpi = 200
cell = self.cellgrid.gpi2cell(gpi)
assert cell == 1043

def test_gpi2cell_iterable(self):
"""
test if gpi to row column lookup works correctly
"""
gpi = [200, 255]
cell = self.cellgrid.gpi2cell(gpi)
assert np.all(cell == [1043, 2015])

def test_gpi2cell_custom_gpis(self):
"""
test if gpi to row column lookup works correctly
"""
self.custom_gpi_grid = grids.BasicGrid(self.lon.flatten(),
self.lat.flatten(),
shape=(len(self.londim),
len(self.latdim)),
gpis=np.arange(len(self.lat.flatten()))[::-1])
self.custom_gpi_cell_grid = self.custom_gpi_grid.to_cell_grid()
gpi = [200, 255]
cell = self.custom_gpi_cell_grid.gpi2cell(gpi)
assert np.all(cell == [1549, 577])
gpi = 200
cell = self.custom_gpi_cell_grid.gpi2cell(gpi)
assert cell == 1549


class Test_2Dgrid(unittest.TestCase):

"""
Expand All @@ -38,11 +127,11 @@ class Test_2Dgrid(unittest.TestCase):
"""

def setUp(self):
lat = np.arange(90, -90, -2.5)
lon = np.arange(-180, 180, 2.5)
self.lon, self.lat = np.meshgrid(lon, lat)
self.latdim = np.arange(90, -90, -2.5)
self.londim = np.arange(-180, 180, 2.5)
self.lon, self.lat = np.meshgrid(self.londim, self.latdim)
self.grid = grids.BasicGrid(
self.lon.flatten(), self.lat.flatten(), shape=(len(lon), len(lat)))
self.lon.flatten(), self.lat.flatten(), shape=(len(self.londim), len(self.latdim)))

def test_gpi2rowcol(self):
"""
Expand All @@ -55,6 +144,33 @@ def test_gpi2rowcol(self):
assert row == row_should
assert column == column_should

def test_gpi2rowcol_iterable(self):
"""
test if gpi to row column lookup works correctly
"""
gpi = [200, 255]
row_should = [1, 1]
column_should = [200 - 144, 255 - 144]
row, column = self.grid.gpi2rowcol(gpi)
assert np.all(row == row_should)
assert np.all(column == column_should)

def test_gpi2rowcol_custom_gpis(self):
"""
test if gpi to row column lookup works correctly
"""
self.custom_gpi_grid = grids.BasicGrid(self.lon.flatten(),
self.lat.flatten(),
shape=(len(self.londim),
len(self.latdim)),
gpis=np.arange(len(self.lat.flatten()))[::-1])
gpi = [200, 255]
row_should = [70, 70]
column_should = [87, 32]
row, column = self.custom_gpi_grid.gpi2rowcol(gpi)
assert np.all(row == row_should)
assert np.all(column == column_should)

def test_gpi2lonlat(self):
"""
test if gpi to longitude latitude lookup works correctly
Expand Down

0 comments on commit a041dc1

Please sign in to comment.