Skip to content

Commit

Permalink
Replace indexes to bands for all functions, duplicate bands propertie…
Browse files Browse the repository at this point in the history
…s into indexes for inter-operability
  • Loading branch information
rhugonnet committed Jan 26, 2024
1 parent 30ebf1b commit 5523d8b
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 42 deletions.
54 changes: 31 additions & 23 deletions geoutils/raster/raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,15 +409,15 @@ def __init__(
self._transform: affine.Affine | None = None
self._crs: CRS | None = None
self._nodata: int | float | None = nodata
self._indexes = bands
self._indexes_loaded: int | tuple[int, ...] | None = None
self._bands = bands
self._bands_loaded: int | tuple[int, ...] | None = None
self._masked = masked
self._out_count: int | None = None
self._out_shape: tuple[int, int] | None = None
self._disk_hash: int | None = None
self._is_modified = True
self._disk_shape: tuple[int, int, int] | None = None
self._disk_indexes: tuple[int] | None = None
self._disk_bands: tuple[int] | None = None
self._disk_dtypes: tuple[str] | None = None

# This is for Raster.from_array to work.
Expand Down Expand Up @@ -467,7 +467,7 @@ def __init__(
self.tags.update(ds.tags())

self._disk_shape = (ds.count, ds.height, ds.width)
self._disk_indexes = ds.indexes
self._disk_bands = ds.indexes
self._disk_dtypes = ds.dtypes

# Check number of bands to be loaded
Expand Down Expand Up @@ -599,26 +599,34 @@ def dtypes(self) -> tuple[str, ...]:
return (str(self.data.dtype),) * self.count

@property
def indexes_on_disk(self) -> None | tuple[int, ...]:
"""Indexes of bands on disk if it exists."""
if self._disk_indexes is not None:
return self._disk_indexes
def bands_on_disk(self) -> None | tuple[int, ...]:
"""Band indexes on disk if a file exists."""
if self._disk_bands is not None:
return self._disk_bands
return None

@property
def indexes(self) -> tuple[int, ...]:
"""Indexes of bands loaded in memory if they are, otherwise on disk."""
if self._indexes is not None and not self.is_loaded:
if isinstance(self._indexes, int):
return (self._indexes,)
return tuple(self._indexes)
def bands(self) -> tuple[int, ...]:
"""Band indexes loaded in memory if they are, otherwise on disk."""
if self._bands is not None and not self.is_loaded:
if isinstance(self._bands, int):
return (self._bands,)
return tuple(self._bands)
# if self._indexes_loaded is not None:
# if isinstance(self._indexes_loaded, int):
# return (self._indexes_loaded, )
# return tuple(self._indexes_loaded)
if self.is_loaded:
return tuple(range(1, self.count + 1))
return self.indexes_on_disk # type: ignore
return self.bands_on_disk # type: ignore

@property
def indexes(self) -> tuple[int, ...]:
"""
Band indexes (duplicate of .bands attribute, mirroring Rasterio naming "indexes").
Loaded in memory if they are, otherwise on disk.
"""
return self.bands

@property
def name(self) -> str | None:
Expand Down Expand Up @@ -650,26 +658,26 @@ def load(self, bands: int | list[int] | None = None, **kwargs: Any) -> None:

# If no index is passed, use all of them
if bands is None:
valid_indexes = self.indexes
valid_bands = self.bands
# If a new index was pass, redefine out_shape
elif isinstance(bands, (int, list)):
# Rewrite properly as a tuple
if isinstance(bands, int):
valid_indexes = (bands,)
valid_bands = (bands,)
else:
valid_indexes = tuple(bands)
valid_bands = tuple(bands)
# Update out_count if out_shape exists (when a downsampling has been passed)
if self._out_shape is not None:
self._out_count = len(valid_indexes)
self._out_count = len(valid_bands)

# Save which indexes are loaded
self._indexes_loaded = valid_indexes
# Save which bands are loaded
self._bands_loaded = valid_bands

# If a downsampled out_shape was defined during instantiation
with rio.open(self.filename) as dataset:
self.data = _load_rio(
dataset,
indexes=list(valid_indexes),
indexes=list(valid_bands),
masked=self._masked,
transform=self.transform,
shape=self.shape,
Expand Down Expand Up @@ -1991,7 +1999,7 @@ def crop(
else:
with rio.open(self.filename) as raster:
crop_img = raster.read(
indexes=self._indexes,
indexes=self._bands,
masked=self._masked,
window=final_window,
)
Expand Down
38 changes: 19 additions & 19 deletions tests/test_raster.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,8 +229,8 @@ def test_load(self) -> None:
assert r.shape == (r.height, r.width)
assert r.count == 1
assert r.count_on_disk == 1
assert r.indexes == (1,)
assert r.indexes_on_disk == (1,)
assert r.bands == (1,)
assert r.bands_on_disk == (1,)
assert np.array_equal(r.dtypes, ["uint8"])
assert r.transform == rio.transform.Affine(30.0, 0.0, 478000.0, 0.0, -30.0, 3108140.0)
assert np.array_equal(r.res, [30.0, 30.0])
Expand All @@ -247,8 +247,8 @@ def test_load(self) -> None:
assert r2.shape == (r2.height, r2.width)
assert r2.count == 1
assert r.count_on_disk == 1
assert r.indexes == (1,)
assert r.indexes_on_disk == (1,)
assert r.bands == (1,)
assert r.bands_on_disk == (1,)
assert np.array_equal(r2.dtypes, ["float32"])
assert r2.transform == rio.transform.Affine(30.0, 0.0, 627175.0, 0.0, -30.0, 4852085.0)
assert np.array_equal(r2.res, [30.0, 30.0])
Expand All @@ -260,29 +260,29 @@ def test_load(self) -> None:
assert r.is_loaded
assert r.count == 1
assert r.count_on_disk == 1
assert r.indexes == (1,)
assert r.indexes_on_disk == (1,)
assert r.bands == (1,)
assert r.bands_on_disk == (1,)
assert r.data.shape == (r.height, r.width)

# Test 3 - single band, loading data
r = gu.Raster(self.landsat_b4_path, load_data=True)
assert r.is_loaded
assert r.count == 1
assert r.count_on_disk == 1
assert r.indexes == (1,)
assert r.indexes_on_disk == (1,)
assert r.bands == (1,)
assert r.bands_on_disk == (1,)
assert r.data.shape == (r.height, r.width)

# Test 4 - multiple bands, load all bands
r = gu.Raster(self.landsat_rgb_path, load_data=True)
assert r.count == 3
assert r.count_on_disk == 3
assert r.indexes == (
assert r.bands == (
1,
2,
3,
)
assert r.indexes_on_disk == (
assert r.bands_on_disk == (
1,
2,
3,
Expand All @@ -293,34 +293,34 @@ def test_load(self) -> None:
r = gu.Raster(self.landsat_rgb_path, load_data=True, bands=1)
assert r.count == 1
assert r.count_on_disk == 3
assert r.indexes == (1,)
assert r.indexes_on_disk == (1, 2, 3)
assert r.bands == (1,)
assert r.bands_on_disk == (1, 2, 3)
assert r.data.shape == (r.height, r.width)

# Test 6 - multiple bands, load a list of bands
r = gu.Raster(self.landsat_rgb_path, load_data=True, bands=[2, 3])
assert r.count == 2
assert r.count_on_disk == 3
assert r.indexes == (1, 2)
assert r.indexes_on_disk == (1, 2, 3)
assert r.bands == (1, 2)
assert r.bands_on_disk == (1, 2, 3)
assert r.data.shape == (r.count, r.height, r.width)

# Test 7 - load a single band a posteriori calling load()
r = gu.Raster(self.landsat_rgb_path)
r.load(bands=1)
assert r.count == 1
assert r.count_on_disk == 3
assert r.indexes == (1,)
assert r.indexes_on_disk == (1, 2, 3)
assert r.bands == (1,)
assert r.bands_on_disk == (1, 2, 3)
assert r.data.shape == (r.height, r.width)

# Test 8 - load a list of band a posteriori calling load()
r = gu.Raster(self.landsat_rgb_path)
r.load(bands=[2, 3])
assert r.count == 2
assert r.count_on_disk == 3
assert r.indexes == (1, 2)
assert r.indexes_on_disk == (1, 2, 3)
assert r.bands == (1, 2)
assert r.bands_on_disk == (1, 2, 3)
assert r.data.shape == (r.count, r.height, r.width)

# Check that errors are raised when appropriate
Expand Down Expand Up @@ -1776,7 +1776,7 @@ def test_value_at_coords(self) -> None:
assert z_val == z_val_2

# 2/ Band argument
# Get the indexes for the multi-band Raster
# Get the band indexes for the multi-band Raster
r_multi = gu.Raster(self.landsat_rgb_path)
itest, jtest = r_multi.xy2ij(xtest0, ytest0)
itest = int(itest[0])
Expand Down

0 comments on commit 5523d8b

Please sign in to comment.