Skip to content

Commit

Permalink
Make raster_equal accept False mask values of masked arrays (#468)
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Feb 4, 2024
1 parent cde8ea6 commit 4411f1b
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 15 deletions.
23 changes: 17 additions & 6 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ def __setitem__(self, index: Mask | NDArrayBool | Any, assign: NDArrayNum | Numb
self._data[:, ind] = assign # type: ignore
return None

def raster_equal(self, other: object) -> bool:
def raster_equal(self, other: RasterType) -> bool:
"""
Check if two rasters are equal.
Expand All @@ -986,12 +986,23 @@ def raster_equal(self, other: object) -> bool:
- The raster's transform, crs and nodata values.
"""

# If the mask is just "False", it is equivalent to being equal to an array of False
if isinstance(self.data.mask, np.bool_):
self_mask = np.zeros(np.shape(self.data), dtype=bool)
else:
self_mask = self.data.mask

if isinstance(other.data.mask, np.bool_):
other_mask = np.zeros(np.shape(other.data), dtype=bool)
else:
other_mask = other.data.mask

if not isinstance(other, Raster): # TODO: Possibly add equals to SatelliteImage?
raise NotImplementedError("Equality with other object than Raster not supported by raster_equal.")
return all(
[
np.array_equal(self.data.data, other.data.data, equal_nan=True),
np.array_equal(self.data.mask, other.data.mask),
np.array_equal(self_mask, other_mask),
self.data.fill_value == other.data.fill_value,
self.data.dtype == other.data.dtype,
self.transform == other.transform,
Expand Down Expand Up @@ -3585,14 +3596,14 @@ def __init__(
)
self._data = self.data[0, :, :]

# Convert masked array to boolean
self._data = self.data.astype(bool) # type: ignore
# Force dtypes
self._dtypes = (bool,)

# Fix nodata to None
self._nodata = None

# Define in dtypes
self._dtypes = (bool,)
# Convert masked array to boolean
self._data = self.data.astype(bool) # type: ignore

def __repr__(self) -> str:
"""Convert mask to string representation."""
Expand Down
28 changes: 19 additions & 9 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2967,12 +2967,15 @@ def test_reproject(self, mask: gu.Mask) -> None:
# Test 1: with a classic resampling (bilinear)

# Reproject mask - resample to 100 x 100 grid
mask_orig = mask.copy()
mask_reproj = mask.reproject(grid_size=(100, 100), force_source_nodata=2)

# Check instance is respected
assert isinstance(mask_reproj, gu.Mask)
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during reprojection
assert mask_orig.raster_equal(mask)

# Check inplace behaviour works
mask_tmp = mask.copy()
Expand All @@ -2998,6 +3001,8 @@ def test_reproject(self, mask: gu.Mask) -> None:
@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_crop(self, mask: gu.Mask) -> None:
# Test with same bounds -> should be the same #

mask_orig = mask.copy()
crop_geom = mask.bounds
mask_cropped = mask.crop(crop_geom)
assert mask_cropped.raster_equal(mask)
Expand All @@ -3006,6 +3011,8 @@ def test_crop(self, mask: gu.Mask) -> None:
assert isinstance(mask_cropped, gu.Mask)
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during cropping
assert mask_orig.raster_equal(mask)

# Check inplace behaviour works
mask_tmp = mask.copy()
Expand Down Expand Up @@ -3061,10 +3068,14 @@ def test_crop(self, mask: gu.Mask) -> None:

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_polygonize(self, mask: gu.Mask) -> None:

mask_orig = mask.copy()
# Run default
vect = mask.polygonize()
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during polygonizing
assert mask_orig.raster_equal(mask)

# Check the output is cast into a vector
assert isinstance(vect, gu.Vector)
Expand All @@ -3079,10 +3090,14 @@ def test_polygonize(self, mask: gu.Mask) -> None:

@pytest.mark.parametrize("mask", [mask_landsat_b4, mask_aster_dem, mask_everest]) # type: ignore
def test_proximity(self, mask: gu.Mask) -> None:

mask_orig = mask.copy()
# Run default
rast = mask.proximity()
# Check the dtype of the original mask was properly reconverted
assert mask.data.dtype == bool
# Check the original mask was not modified during reprojection
assert mask_orig.raster_equal(mask)

# Check that output is cast back into a raster
assert isinstance(rast, gu.Raster)
Expand All @@ -3101,17 +3116,12 @@ def test_save(self, mask: gu.Mask) -> None:
mask.save(temp_file)
saved = gu.Mask(temp_file)

# TODO: Generalize raster_equal for masks?
# A raster (or mask) in-memory has more information than on disk, we need to update it before checking equality
# The values in its .data.data that are masked in .data.mask are not necessarily equal to the nodata value
mask.data.data[mask.data.mask] = True # The default nodata 255 is converted to boolean True on masked values

# Check all attributes are equal
assert all(
[
np.ma.allequal(saved.data, mask.data),
saved.transform == mask.transform,
saved.crs == mask.crs,
saved.nodata == mask.nodata,
]
)
assert mask.raster_equal(saved)

# Clean up temporary folder - fails on Windows
try:
Expand Down

0 comments on commit 4411f1b

Please sign in to comment.