Skip to content

Commit

Permalink
Make shift not inplace
Browse files Browse the repository at this point in the history
  • Loading branch information
rhugonnet committed Jan 27, 2024
1 parent 8a425d9 commit 4f07d9e
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
54 changes: 50 additions & 4 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -2364,17 +2364,55 @@ def reproject(
else:
return self.from_array(data, transformed, crs, nodata)

@overload
def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: Literal[False] = False,
) -> RasterType:
...

@overload
def shift(
self, xoff: float, yoff: float, distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced"
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: Literal[True],
) -> None:
...

@overload
def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
*,
inplace: bool = False,
) -> RasterType | None:
...

def shift(
self: RasterType,
xoff: float,
yoff: float,
distance_unit: Literal["georeferenced"] | Literal["pixel"] = "georeferenced",
inplace: bool = False,
) -> RasterType | None:
"""
Shift the raster in-place by a (x,y) offset.
Shift a raster by a (x,y) offset.
The shifting only updates the geotransform (no resampling is performed).
:param xoff: Translation x offset.
:param yoff: Translation y offset.
:param distance_unit: Distance unit, either 'georeferenced' (default) or 'pixel'.
:param inplace: Whether to modify the raster in-place.
"""
if distance_unit not in ["georeferenced", "pixel"]:
raise ValueError("Argument 'distance_unit' should be either 'pixel' or 'georeferenced'.")
Expand All @@ -2387,8 +2425,16 @@ def shift(
xoff *= self.res[0]
yoff *= self.res[1]

# Overwrite transform by shifted transform
self.transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff)
shifted_transform = rio.transform.Affine(dx, b, xmin + xoff, d, dy, ymax + yoff)

if inplace:
# Overwrite transform by shifted transform
self.transform = shifted_transform
return None
else:
raster_copy = self.copy()
raster_copy.transform = shifted_transform
return raster_copy

def save(
self,
Expand Down
6 changes: 3 additions & 3 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -950,7 +950,7 @@ def test_getitem_setitem(self, example: str) -> None:
rst[arr.astype("uint8")]
rst[arr.astype("uint8")] = 1
# An error when the georeferencing of the Mask does not match
mask.shift(1, 1)
mask.shift(1, 1, inplace=True)
with pytest.raises(
ValueError, match="Indexing a raster with a mask requires the two being on the same georeferenced grid."
):
Expand Down Expand Up @@ -1189,7 +1189,7 @@ def test_shift(self, example: str) -> None:
orig_bounds = r.bounds

# Shift raster by georeferenced units (default)
r.shift(xoff=1, yoff=1)
r.shift(xoff=1, yoff=1, inplace=True)

# Only bounds should change
assert orig_transform.c + 1 == r.transform.c
Expand All @@ -1206,7 +1206,7 @@ def test_shift(self, example: str) -> None:
orig_transform = r.transform
orig_bounds = r.bounds
orig_res = r.res
r.shift(xoff=1, yoff=1, distance_unit="pixel")
r.shift(xoff=1, yoff=1, distance_unit="pixel", inplace=True)

# Only bounds should change
assert orig_transform.c + 1 * orig_res[0] == r.transform.c
Expand Down

0 comments on commit 4f07d9e

Please sign in to comment.