Skip to content

Commit

Permalink
Merge 55da804 into ae96994
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Mar 30, 2024
2 parents ae96994 + 55da804 commit 6a32b79
Show file tree
Hide file tree
Showing 3 changed files with 89 additions and 28 deletions.
76 changes: 52 additions & 24 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,14 +390,29 @@ def _cast_pixel_interpretation(
return area_or_point_out


def _cast_nodata(out_dtype: DTypeLike, nodata: int | float | None) -> int | float | None:
"""
Cast nodata value for output data type to default nodata if incompatible.
:param out_dtype: Dtype of output array.
:param nodata: Nodata value.
:return: Cast nodata value.
"""

if out_dtype == bool:
nodata = None
if nodata is not None and not rio.dtypes.can_cast_dtype(nodata, out_dtype):
nodata = _default_nodata(out_dtype)
else:
nodata = nodata

return nodata


def _cast_numeric_array_raster(
raster: RasterType, other: RasterType | NDArrayNum | Number, operation_name: str
) -> tuple[
MArrayNum,
MArrayNum | NDArrayNum | Number,
float | int | tuple[int, ...] | tuple[float, ...] | None,
Literal["Area", "Point"] | None,
]:
) -> tuple[MArrayNum, MArrayNum | NDArrayNum | Number, float | int | None, Literal["Area", "Point"] | None]:
"""
Cast a raster and another raster or array or number to arrays with proper metadata, or raise an error message.
Expand Down Expand Up @@ -562,10 +577,19 @@ def __init__(

# This is for Raster.from_array to work.
if isinstance(filename_or_dataset, dict):

# To have "area_or_point" user input go through checks of the set() function without shifting the transform
self.set_area_or_point(filename_or_dataset["area_or_point"], shift_area_or_point=False)
# Same things here, and also important to pass the nodata before the data setter, which uses it in turn
self._nodata = filename_or_dataset["nodata"]

# Need to set nodata before the data setter, which uses it
# We trick set_nodata into knowing the data type by setting self._disk_dtype, then unsetting it
# (as a raster created from an array doesn't have a disk dtype)
if np.dtype(filename_or_dataset["data"].dtype) != bool: # Exception for Mask class
self._disk_dtype = filename_or_dataset["data"].dtype
self.set_nodata(filename_or_dataset["nodata"], update_array=False, update_mask=False)
self._disk_dtype = None

# Then, we can set the data, transform and crs
self.data = filename_or_dataset["data"]
self.transform: rio.transform.Affine = filename_or_dataset["transform"]
self.crs: rio.crs.CRS = filename_or_dataset["crs"]
Expand Down Expand Up @@ -1010,9 +1034,10 @@ def from_array(
data: NDArrayNum | MArrayNum | NDArrayBool,
transform: tuple[float, ...] | Affine,
crs: CRS | int | None,
nodata: int | float | tuple[int, ...] | tuple[float, ...] | None = None,
nodata: int | float | None = None,
area_or_point: Literal["Area", "Point"] | None = None,
tags: dict[str, Any] = None,
cast_nodata: bool = True,
) -> RasterType:
"""Create a raster from a numpy array and the georeferencing information.
Expand All @@ -1021,11 +1046,12 @@ def from_array(
:param data: Input array, 2D for single band or 3D for multi-band (bands should be first axis).
:param transform: Affine 2D transform. Either a tuple(x_res, 0.0, top_left_x,
0.0, y_res, top_left_y) or an affine.Affine object.
:param crs: Coordinate reference system. Either a rasterio CRS,
or an EPSG integer.
:param crs: Coordinate reference system. Any CRS supported by Pyproj (e.g., CRS object, EPSG integer).
:param nodata: Nodata value.
:param area_or_point: Pixel interpretation of the raster, will be stored in AREA_OR_POINT metadata.
:param tags: Metadata stored in a dictionary.
:param cast_nodata: Automatically cast nodata value to the default nodata for the new array type if not
compatible. If False, will raise an error when incompatible.
:returns: Raster created from the provided array and georeferencing.
Expand All @@ -1038,20 +1064,14 @@ def from_array(
>>> transform = (30.0, 0.0, 478000.0, 0.0, -30.0, 3108140.0)
>>> myim = Raster.from_array(data, transform, 32645)
"""
if not isinstance(transform, Affine):
if isinstance(transform, tuple):
transform = Affine(*transform)
else:
raise TypeError("The transform argument needs to be Affine or tuple.")

# Enable shortcut to create CRS from an EPSG ID.
if isinstance(crs, int):
crs = CRS.from_epsg(crs)

# Define tags as empty dictionary if not defined
if tags is None:
tags = {}

# Cast nodata if the new array has incompatible type with the old nodata value
if cast_nodata:
nodata = _cast_nodata(data.dtype, nodata)

# If the data was transformed into boolean, re-initialize as a Mask subclass
# Typing: we can specify this behaviour in @overload once we add the NumPy plugin of MyPy
if data.dtype == bool:
Expand Down Expand Up @@ -1729,7 +1749,7 @@ def set_nodata(

if new_nodata is not None:
if not rio.dtypes.can_cast_dtype(new_nodata, self.dtype):
raise ValueError(f"nodata value {new_nodata} incompatible with self.dtype {self.dtype}")
raise ValueError(f"Nodata value {new_nodata} incompatible with self.dtype {self.dtype}.")

# If we update mask or array, get the masked array
if update_array or update_mask:
Expand Down Expand Up @@ -1787,7 +1807,10 @@ def set_nodata(

# Update the nodata value
self._nodata = new_nodata
self.data.fill_value = new_nodata

# Update the fill value only if the data is loaded
if self.is_loaded:
self.data.fill_value = new_nodata

@property
def data(self) -> MArrayNum:
Expand Down Expand Up @@ -2061,26 +2084,31 @@ def info(self, stats: bool = False, verbose: bool = True) -> None | str:
else:
return "".join(as_str)

def copy(self: RasterType, new_array: NDArrayNum | None = None) -> RasterType:
def copy(self: RasterType, new_array: NDArrayNum | None = None, cast_nodata: bool = True) -> RasterType:
"""
Copy the raster in-memory.
:param new_array: New array to use in the copied raster.
:param cast_nodata: Automatically cast nodata value to the default nodata for the new array type if not
compatible. If False, will raise an error when incompatible.
:return: Copy of the raster.
"""
# Define new array
if new_array is not None:
data = new_array
else:
data = self.data.copy()

# Send to from_array
cp = self.from_array(
data=data,
transform=self.transform,
crs=self.crs,
nodata=self.nodata,
area_or_point=self.area_or_point,
tags=self.tags,
cast_nodata=cast_nodata,
)

return cp
Expand Down
4 changes: 2 additions & 2 deletions geoutils/raster/satimg.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,8 +428,8 @@ def __parse_metadata_from_file(self, fn_meta: str | None) -> None:

return None

def copy(self, new_array: NDArrayNum | None = None) -> SatelliteImage:
new_satimg = super().copy(new_array=new_array) # type: ignore
def copy(self, new_array: NDArrayNum | None = None, cast_nodata: bool = False) -> SatelliteImage:
new_satimg = super().copy(new_array=new_array, cast_nodata=cast_nodata) # type: ignore
# all objects here are immutable so no need for a copy method (string and datetime)
# satimg_attrs = ['satellite', 'sensor', 'product', 'version', 'tile_name', 'datetime'] #taken outside of class
for attrs in satimg_attrs:
Expand Down
37 changes: 35 additions & 2 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -1061,14 +1061,32 @@ def test_copy(self, example: str) -> None:
assert r.raster_equal(r2)

# -- Fifth test: check that the new_array argument works when providing a new dtype ##
# For an integer dataset cast to float, or opposite (the exploradores dataset will cast from float to int)
if "int" in r.dtype:
new_dtype = "float32"
else:
new_dtype = "uint8"
r2 = r.copy(new_array=r_arr.astype(dtype=new_dtype))

# This should work all the types by default due to automatic casting
with warnings.catch_warnings():
warnings.filterwarnings("ignore", message="Unmasked values equal to the nodata value*")
r2 = r.copy(new_array=r_arr.astype(dtype=new_dtype))
assert r2.dtype == new_dtype

# However, the new nodata will differ if casting was done
if np.promote_types(r.dtype, new_dtype) != new_dtype:
assert r2.nodata != r.nodata
else:
assert r2.nodata == r.nodata

# The copy should fail if the data type is not compatible
if np.promote_types(r.dtype, new_dtype) != new_dtype:
with pytest.raises(ValueError, match="Nodata value *"):
r.copy(new_array=r_arr.astype(dtype=new_dtype), cast_nodata=False)
else:
r2 = r.copy(new_array=r_arr.astype(dtype=new_dtype), cast_nodata=False)
assert r2.dtype == new_dtype

@pytest.mark.parametrize("example", [landsat_b4_path, aster_dem_path]) # type: ignore
def test_is_modified(self, example: str) -> None:
"""
Expand Down Expand Up @@ -2559,7 +2577,7 @@ def test_set_nodata(self, example: str) -> None:
r.set_nodata(new_nodata="this_should_not_work") # type: ignore

# A ValueError if nodata value is incompatible with dtype
expected_message = r"nodata value .* incompatible with self.dtype .*"
expected_message = r"Nodata value .* incompatible with self.dtype .*"
if "int" in r.dtype:
with pytest.raises(ValueError, match=expected_message):
# Feed a floating numeric to an integer type
Expand Down Expand Up @@ -2927,6 +2945,21 @@ def test_from_array(self, example: str) -> None:
with pytest.raises(TypeError, match="The transform argument needs to be Affine or tuple."):
gu.Raster.from_array(data=img.data, transform="lol", crs=None, nodata=None) # type: ignore

def test_from_array__nodata_casting(self) -> None:
"""Check nodata casting of from_array that affects of all other functionalities (copy, etc)"""

rst = gu.Raster(self.landsat_b4_path)
warnings.filterwarnings("ignore", message="New nodata value cells already exist*")
rst.set_nodata(255)

# Check that a not-compatible nodata will raise an error if casting is not true
with pytest.raises(ValueError, match="Nodata value*"):
rst.from_array(data=rst.data, crs=rst.crs, transform=rst.transform, nodata=-99999, cast_nodata=False)

# Otherwise it is re-cast automatically
rst2 = rst.from_array(data=rst.data, crs=rst.crs, transform=rst.transform, nodata=-99999)
assert rst2.nodata == _default_nodata(rst.data.dtype)

def test_type_hints(self) -> None:
"""Test that pylint doesn't raise errors on valid code."""
# Create a temporary directory and a temporary filename
Expand Down

0 comments on commit 6a32b79

Please sign in to comment.