Skip to content

Commit

Permalink
GridModule: use parallel r.patch as alternative backend when overlap=…
Browse files Browse the repository at this point in the history
…0 to potentially speed up patching tiles (#2249)
  • Loading branch information
petrasovaa committed Mar 28, 2022
1 parent 7d79d82 commit 1f48487
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 11 deletions.
55 changes: 44 additions & 11 deletions python/grass/pygrass/modules/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from grass.pygrass.utils import get_mapset_raster, findmaps

from grass.pygrass.modules.grid.split import split_region_tiles
from grass.pygrass.modules.grid.patch import rpatch_map
from grass.pygrass.modules.grid.patch import rpatch_map, rpatch_map_r_patch_backend


def select(parms, ptype):
Expand Down Expand Up @@ -424,11 +424,17 @@ class GridModule(object):
:type split: bool
:param mapset_prefix: if specified created mapsets start with this prefix
:type mapset_prefix: str
:param patch_backend: "r.patch", "RasterRow", or None for for default
:type patch_backend: None or str
:param run_: if False only instantiate the object
:type run_: bool
:param args: give all the parameters to the command
:param kargs: give all the parameters to the command
When patch_backend is None, the RasterRow method is used for patching the result.
When patch_backend is "r.patch", r.patch is used with nprocs=processes.
r.patch can only be used when overlap is 0.
>>> grd = GridModule('r.slope.aspect',
... width=500, height=500, overlap=2,
... processes=None, split=False,
Expand Down Expand Up @@ -458,6 +464,7 @@ def __init__(
start_col=0,
out_prefix="",
mapset_prefix=None,
patch_backend=None,
*args,
**kargs,
):
Expand All @@ -474,6 +481,20 @@ def __init__(
self.out_prefix = out_prefix
self.log = log
self.move = move
# by default RasterRow is used as previously
# if overlap > 0, r.patch won't work properly
if not patch_backend:
self.patch_backend = "RasterRow"
elif patch_backend not in ("r.patch", "RasterRow"):
raise RuntimeError(
_("Parameter patch_backend must be 'r.patch' or 'RasterRow'")
)
elif patch_backend == "r.patch" and self.overlap:
raise RuntimeError(
_("Patching backend 'r.patch' doesn't work for overlap > 0")
)
else:
self.patch_backend = patch_backend
self.gisrc_src = os.environ["GISRC"]
self.n_mset, self.gisrc_dst = None, None
if self.move:
Expand Down Expand Up @@ -665,16 +686,28 @@ def patch(self):
for otmap in self.module.outputs:
otm = self.module.outputs[otmap]
if otm.typedesc == "raster" and otm.value:
rpatch_map(
otm.value,
self.mset.name,
self.msetstr,
bboxes,
self.module.flags.overwrite,
self.start_row,
self.start_col,
self.out_prefix,
)
if self.patch_backend == "RasterRow":
rpatch_map(
raster=otm.value,
mapset=self.mset.name,
mset_str=self.msetstr,
bbox_list=bboxes,
overwrite=self.module.flags.overwrite,
start_row=self.start_row,
start_col=self.start_col,
prefix=self.out_prefix,
)
else:
rpatch_map_r_patch_backend(
raster=otm.value,
mset_str=self.msetstr,
bbox_list=bboxes,
overwrite=self.module.flags.overwrite,
start_row=self.start_row,
start_col=self.start_col,
prefix=self.out_prefix,
processes=self.processes,
)
noutputs += 1
if noutputs < 1:
msg = "No raster output option defined for <{}>".format(self.module.name)
Expand Down
45 changes: 45 additions & 0 deletions python/grass/pygrass/modules/grid/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from grass.pygrass.gis.region import Region
from grass.pygrass.raster import RasterRow
from grass.pygrass.utils import coor2pixel
from grass.pygrass.modules import Module


def get_start_end_index(bbox_list):
Expand Down Expand Up @@ -111,3 +112,47 @@ def rpatch_map(
del rst

rast.close()


def rpatch_map_r_patch_backend(
raster,
mset_str,
bbox_list,
overwrite=False,
start_row=0,
start_col=0,
prefix="",
processes=1,
):
"""Patch raster using a r.patch. Only use with overlap=0.
Will be faster than rpatch_map, since r.patch is parallelized.
:param raster: the name of output raster
:type raster: str
:param mset_str:
:type mset_str: str
:param bbox_list: a list of BBox object to convert
:type bbox_list: list of BBox object
:param overwrite: overwrite existing raster
:type overwrite: bool
:param start_row: the starting row of original raster
:type start_row: int
:param start_col: the starting column of original raster
:type start_col: int
:param prefix: the prefix of output raster
:type prefix: str
:param processes: number of parallel process for r.patch
:type processes: int
"""
rasts = []
for row, rbbox in enumerate(bbox_list):
for col in range(len(rbbox)):
mapset = mset_str % (start_row + row, start_col + col)
rasts.append(f"{raster}@{mapset}")
Module(
"r.patch",
input=rasts,
output=prefix + raster,
overwrite=overwrite,
nprocs=processes,
)
41 changes: 41 additions & 0 deletions python/grass/pygrass/modules/tests/grass_pygrass_grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,3 +171,44 @@ def run_grid_module():
prefixed += int(item.name.startswith(mapset_prefix))
if not clean:
assert prefixed, "Not even one prefixed mapset"


@pytest.mark.parametrize("patch_backend", [None, "r.patch", "RasterRow"])
def test_patching_backend(tmp_path, patch_backend):
"""Check patching backend works"""
location = "test"
gs.core._create_location_xy(tmp_path, location) # pylint: disable=protected-access
with grass_setup.init(tmp_path / location):
gs.run_command("g.region", s=0, n=50, w=0, e=50, res=1)

points = "points"
reference = "reference"
gs.run_command("v.random", output=points, npoints=100)
gs.run_command(
"v.to.rast", input=points, output=reference, type="point", use="cat"
)

def run_grid_module():
# modules/shortcuts calls get_commands which requires GISBASE.
# pylint: disable=import-outside-toplevel
from grass.pygrass.modules.grid import GridModule

grid = GridModule(
"v.to.rast",
width=10,
height=5,
overlap=0,
patch_backend=patch_backend,
processes=max_processes(),
input=points,
output="output",
type="point",
use="cat",
)
grid.run()

run_in_subprocess(run_grid_module)

mean_ref = float(gs.parse_command("r.univar", map=reference, flags="g")["mean"])
mean = float(gs.parse_command("r.univar", map="output", flags="g")["mean"])
assert abs(mean - mean_ref) < 0.0001

0 comments on commit 1f48487

Please sign in to comment.