Skip to content

Commit

Permalink
Refactor STRtree implementation to store geoms/idx sequences instead …
Browse files Browse the repository at this point in the history
…of reverse mapping (#1177)
  • Loading branch information
jorisvandenbossche committed Aug 21, 2021
1 parent 9ff8c32 commit c047f49
Showing 1 changed file with 82 additions and 83 deletions.
165 changes: 82 additions & 83 deletions shapely/strtree.py
Expand Up @@ -104,44 +104,34 @@ def __init__(
ShapelyDeprecationWarning,
stacklevel=2,
)
self._tree = None
self.node_capacity = node_capacity
self._rev = {
item: geom
for geom, item in self._iterinitdata(geoms, items)
if not geom.is_empty
}
if self._rev:
self._init_tree(self._rev.items())

def _iterinitdata(
self,
geoms: Iterable[BaseGeometry], items: Optional[Iterable[BaseGeometry]],
) -> Iterator[Tuple[BaseGeometry, Any]]:
if items is not None:
for geom, item in zip(geoms, items):
if isinstance(geom, BaseGeometry):
yield (geom, item)

# Keep references to geoms
self._geoms = list(geoms)
# Default enumeration index to store in the tree
self._idxs = list(range(len(self._geoms)))

# handle items
self._has_custom_items = items is not None
if not self._has_custom_items:
items = self._idxs
self._items = items

# initialize GEOS STRtree
self._tree = lgeos.GEOSSTRtree_create(self.node_capacity)
i = 0
for idx, geom in zip(self._idxs, self._geoms):
# filter empty geometries out of the input
if geom is not None and not geom.is_empty:
lgeos.GEOSSTRtree_insert(self._tree, geom._geom, ctypes.py_object(idx))
i += 1
self._n_geoms = i

def __reduce__(self):
if self._has_custom_items:
return STRtree, (self._geoms, self._items)
else:
for enum_idx, geom in enumerate(geoms):
if isinstance(geom, BaseGeometry):
yield (geom, enum_idx)

def _init_tree(self, rev_initdata: ItemsView[Any, BaseGeometry]):
if rev_initdata:
self._tree = lgeos.GEOSSTRtree_create(self.node_capacity)
for item, geom in rev_initdata:
lgeos.GEOSSTRtree_insert(self._tree, geom._geom, ctypes.py_object(item))

def __getstate__(self):
state = self.__dict__.copy()
del state["_tree"]
return state

def __setstate__(self, state):
self.__dict__.update(state)
if self._rev:
self._init_tree(self._rev.items())
return STRtree, (self._geoms, )

def __del__(self):
if self._tree is not None:
Expand All @@ -152,6 +142,19 @@ def __del__(self):

self._tree = None

def _query(self, geom):
if self._n_geoms == 0:
return []

result = []

def callback(item, userdata):
idx = ctypes.cast(item, ctypes.py_object).value
result.append(idx)

lgeos.GEOSSTRtree_query(self._tree, geom._geom, lgeos.GEOSQueryCallback(callback), None)
return result

def query_items(self, geom: BaseGeometry) -> Sequence[Any]:
"""Query for nodes which intersect the geom's envelope to get
stored items.
Expand Down Expand Up @@ -197,19 +200,11 @@ def query_items(self, geom: BaseGeometry) -> Sequence[Any]:
['POINT (2 2)']
"""
if self._tree is None or not self._rev:
return []

result = []

def callback(item, userdata):
idx = ctypes.cast(item, ctypes.py_object).value
result.append(idx)

lgeos.GEOSSTRtree_query(
self._tree, geom._geom, lgeos.GEOSQueryCallback(callback), None
)
return result
result = self._query(geom)
if self._has_custom_items:
return [self._items[i] for i in result]
else:
return result

def query_geoms(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
"""Query for nodes which intersect the geom's envelope to get
Expand All @@ -225,8 +220,8 @@ def query_geoms(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
An array or list of geometry objects.
"""
items = self.query_items(geom)
return [self._rev[idx] for idx in items]
result = self._query(geom)
return [self._geoms[i] for i in result]

def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
"""Query for nodes which intersect the geom's envelope to get
Expand All @@ -247,6 +242,34 @@ def query(self, geom: BaseGeometry) -> Sequence[BaseGeometry]:
"""
return self.query_geoms(geom)

def _nearest(self, geom, exclusive):
envelope = geom.envelope

def callback(item1, item2, distance, userdata):
try:
callback_userdata = ctypes.cast(userdata, ctypes.py_object).value
idx = ctypes.cast(item1, ctypes.py_object).value
geom2 = ctypes.cast(item2, ctypes.py_object).value
dist = ctypes.cast(distance, ctypes.POINTER(ctypes.c_double))
if callback_userdata["exclusive"] and self._geoms[idx].equals(geom2):
dist[0] = sys.float_info.max
else:
lgeos.GEOSDistance(self._geoms[idx]._geom, geom2._geom, dist)

return 1
except Exception:
log.exception("Caught exception")
return 0

item = lgeos.GEOSSTRtree_nearest_generic(
self._tree,
ctypes.py_object(geom),
envelope._geom,
lgeos.GEOSDistanceCallback(callback),
ctypes.py_object({"exclusive": exclusive}),
)
return ctypes.cast(item, ctypes.py_object).value

def nearest_item(
self, geom: BaseGeometry, exclusive: bool = False
) -> Union[Any, None]:
Expand Down Expand Up @@ -285,35 +308,14 @@ def nearest_item(
'POINT (0 0)'
"""
if self._tree is None or not self._rev:
if self._n_geoms == 0:
return None

envelope = geom.envelope

def callback(item1, item2, distance, userdata):
try:
callback_userdata = ctypes.cast(userdata, ctypes.py_object).value
idx = ctypes.cast(item1, ctypes.py_object).value
geom2 = ctypes.cast(item2, ctypes.py_object).value
dist = ctypes.cast(distance, ctypes.POINTER(ctypes.c_double))
if callback_userdata["exclusive"] and self._rev[idx].equals(geom2):
dist[0] = sys.float_info.max
else:
lgeos.GEOSDistance(self._rev[idx]._geom, geom2._geom, dist)
return 1
except Exception:
log.exception("Caught exception")
return 0

item = lgeos.GEOSSTRtree_nearest_generic(
self._tree,
ctypes.py_object(geom),
envelope._geom,
lgeos.GEOSDistanceCallback(callback),
ctypes.py_object({"exclusive": exclusive}),
)
result = ctypes.cast(item, ctypes.py_object).value
return result
result = self._nearest(geom, exclusive)
if self._has_custom_items:
return self._items[result]
else:
return result

def nearest_geom(
self, geom: BaseGeometry, exclusive: bool = False
Expand All @@ -337,11 +339,8 @@ def nearest_geom(
version 2.0.
"""
item = self.nearest_item(geom, exclusive=exclusive)
if item is None:
return None
else:
return self._rev[item]
result = self._nearest(geom, exclusive)
return self._geoms[result]

def nearest(
self, geom: BaseGeometry, exclusive: bool = False
Expand Down

0 comments on commit c047f49

Please sign in to comment.