Skip to content

Commit

Permalink
grass.pygrass: GridModule: fix no-data rows at tile borders (#2736)
Browse files Browse the repository at this point in the history
Fixes #2678 by avoiding unnecessary nsew <-> row,col conversions
  • Loading branch information
petrasovaa committed Feb 16, 2023
1 parent 77f6459 commit 1bb3943
Show file tree
Hide file tree
Showing 4 changed files with 97 additions and 28 deletions.
7 changes: 5 additions & 2 deletions python/grass/pygrass/modules/grid/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@
from grass.pygrass.modules import Module
from grass.pygrass.utils import get_mapset_raster, findmaps

from grass.pygrass.modules.grid.split import split_region_tiles
from grass.pygrass.modules.grid.split import (
split_region_tiles,
split_region_in_overlapping_tiles,
)
from grass.pygrass.modules.grid.patch import rpatch_map, rpatch_map_r_patch_backend


Expand Down Expand Up @@ -515,7 +518,7 @@ def __init__(
groups = [g for g in select(self.module.inputs, "group")]
if groups:
copy_groups(groups, self.gisrc_src, self.gisrc_dst, region=self.region)
self.bboxes = split_region_tiles(
self.bboxes = split_region_in_overlapping_tiles(
region=region, width=self.width, height=self.height, overlap=overlap
)
if mapset_prefix:
Expand Down
26 changes: 4 additions & 22 deletions python/grass/pygrass/modules/grid/patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,23 +18,6 @@
from grass.pygrass.modules import Module


def get_start_end_index(bbox_list):
"""Convert a Bounding Box to a list of the index of
column start, end, row start and end
:param bbox_list: a list of BBox object to convert
:type bbox_list: list of BBox object
"""
ss_list = []
reg = Region()
for bbox in bbox_list:
r_start, c_start = coor2pixel((bbox.west, bbox.north), reg)
r_end, c_end = coor2pixel((bbox.east, bbox.south), reg)
ss_list.append((int(r_start), int(r_end), int(c_start), int(c_end)))
return ss_list


def rpatch_row(rast, rasts, bboxes):
"""Patch a row of bound boxes.
Expand All @@ -45,16 +28,15 @@ def rpatch_row(rast, rasts, bboxes):
:param bboxes: a list of BBox object
:type bboxes: list of BBox object
"""
sei = get_start_end_index(bboxes)
# instantiate two buffer
buff = rasts[0][0]
rbuff = rasts[0][0]
r_start, r_end, c_start, c_end = sei[0]
for row in range(r_start, r_end):
r_start, r_end, c_start, c_end = bboxes[0]
for row in range(r_start, r_end + 1):
for col, ras in enumerate(rasts):
r_start, r_end, c_start, c_end = sei[col]
r_start, r_end, c_start, c_end = bboxes[col]
buff = ras.get_row(row, buff)
rbuff[c_start:c_end] = buff[c_start:c_end]
rbuff[c_start : c_end + 1] = buff[c_start : c_end + 1]
rast.put_row(rbuff)


Expand Down
53 changes: 49 additions & 4 deletions python/grass/pygrass/modules/grid/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,31 @@ def get_bbox(reg, row, col, width, height, overlap):
)


def split_region_tiles(region=None, width=100, height=100, overlap=0):
"""Spit a region into a list of Bbox.
def get_tile_start_end_row_col(reg, row, col, width, height):
"""Return a tile's starting and ending row and col
:param reg: a Region object to split
:type reg: Region object
:param row: the number of tiles in a row
:type row: int
:param col: the number of tiles in a col
:type col: int
:param width: the width of tiles
:type width: int
:param height: the width of tiles
:type height: int
"""
start_row = row * height
end_row = (row + 1) * height - 1
start_col = col * width
end_col = (col + 1) * width - 1
end_row = reg.rows - 1 if end_row >= reg.rows else end_row
end_col = reg.cols - 1 if end_col >= reg.cols else end_col
return (start_row, end_row, start_col, end_col)


def split_region_in_overlapping_tiles(region=None, width=100, height=100, overlap=0):
"""Split a region into a list of overlapping tiles defined as (N, S, E, W).
:param region: a Region object to split
:type region: Region object
Expand All @@ -67,10 +90,10 @@ def split_region_tiles(region=None, width=100, height=100, overlap=0):
1500
>>> reg.rows
1350
>>> split_region_tiles(region=reg, width=1000, height=700, overlap=0) # doctest: +NORMALIZE_WHITESPACE
>>> split_region_in_overlapping_tiles(region=reg, width=1000, height=700, overlap=0) # doctest: +NORMALIZE_WHITESPACE
[[Bbox(1350.0, 650.0, 1000.0, 0.0), Bbox(1350.0, 650.0, 1500.0, 1000.0)],
[Bbox(650.0, 0.0, 1000.0, 0.0), Bbox(650.0, 0.0, 1500.0, 1000.0)]]
>>> split_region_tiles(region=reg, width=1000, height=700, overlap=10) # doctest: +NORMALIZE_WHITESPACE
>>> split_region_in_overlapping_tiles(region=reg, width=1000, height=700, overlap=10) # doctest: +NORMALIZE_WHITESPACE
[[Bbox(1350.0, 640.0, 1010.0, 0.0), Bbox(1350.0, 640.0, 1500.0, 990.0)],
[Bbox(660.0, 0.0, 1010.0, 0.0), Bbox(660.0, 0.0, 1500.0, 990.0)]]
"""
Expand All @@ -88,6 +111,28 @@ def split_region_tiles(region=None, width=100, height=100, overlap=0):
return box_list


def split_region_tiles(region=None, width=100, height=100):
"""Split a region into a list of tiles defined as (start_row, end_row, start_col, end_col).
:param region: a Region object to split
:type region: Region object
:param width: the width of tiles
:type width: int
:param height: the width of tiles
:type height: int
"""
reg = region if region else Region()
ncols = (reg.cols + width - 1) // width
nrows = (reg.rows + height - 1) // height
box_list = []
for row in range(nrows):
row_list = []
for col in range(ncols):
row_list.append(get_tile_start_end_row_col(reg, row, col, width, height))
box_list.append(row_list)
return box_list


def get_overlap_region_tiles(region=None, width=100, height=100, overlap=0):
"""Get the Bbox of the overlapped region.
Expand Down
39 changes: 39 additions & 0 deletions python/grass/pygrass/modules/tests/grass_pygrass_grid_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,3 +253,42 @@ def run_grid_module():

info = gs.raster_info("slope")
assert info["min"] > 0


@pytest.mark.parametrize(
"processes, backend",
[
(1, "RasterRow"),
(9, "RasterRow"),
(9, "r.patch"),
(10, "RasterRow"),
(10, "r.patch"),
],
)
def test_patching_error(tmp_path, processes, backend):
"""Check auto adjusted tile size based on processes"""
location = "test"
gs.core._create_location_xy(tmp_path, location) # pylint: disable=protected-access
with gs.setup.init(tmp_path / location):
gs.run_command("g.region", s=0, n=10, w=0, e=10, res=0.1)
surface = "fractal"

def run_grid_module():
# modules/shortcuts calls get_commands which requires GISBASE.
# pylint: disable=import-outside-toplevel
from grass.pygrass.modules.grid import GridModule

grid = GridModule(
"r.surf.fractal",
overlap=0,
processes=processes,
output=surface,
patch_backend=backend,
debug=True,
)
grid.run()

run_in_subprocess(run_grid_module)

info = gs.parse_command("r.univar", flags="g", map=surface)
assert int(info["null_cells"]) == 0

0 comments on commit 1bb3943

Please sign in to comment.