diff --git a/solaris/tile/raster_tile.py b/solaris/tile/raster_tile.py index 56363603..73f643db 100644 --- a/solaris/tile/raster_tile.py +++ b/solaris/tile/raster_tile.py @@ -1,6 +1,6 @@ import os import rasterio -from rasterio.warp import transform_bounds, Resampling +from rasterio.warp import Resampling, calculate_default_transform from rasterio.io import DatasetReader from rasterio.vrt import WarpedVRT from rasterio.crs import CRS @@ -26,7 +26,7 @@ class RasterTiler(object): Path to save output files to. If not specified here, this must be provided when ``Tiler.tile_generator()`` is called. src_tile_size : `tuple` of `int`s, optional - The size of the output tiles in ``(y, x)`` coordinates. By default, + The size of the input tiles in ``(y, x)`` coordinates. By default, this is in pixel units; this can be changed to metric units using the `src_metric_size` argument. src_metric_size : bool, optional @@ -55,9 +55,9 @@ class RasterTiler(object): initialization or when an image is loaded, the image bounds will be used; if provided, this value will override image metadata. tile_bounds : `list`-like - A `list`-like of ``[left, bottom, right, top]`` coordinates defining - the boundaries of the tiles to create. If not provided, they will be - generated from the `aoi_bounds` based on `src_tile_size`. + A `list`-like of ``[left, bottom, right, top]`` lists of coordinates + defining the boundaries of the tiles to create. If not provided, they + will be generated from the `aoi_bounds` based on `src_tile_size`. verbose : bool, optional Verbose text output. By default, verbose text is not printed. @@ -275,28 +275,48 @@ def tile_generator(self, src, dest_dir=None, channel_idxs=None, for tb in self.tile_bounds: # removing the following line until COG functionality implemented if True: # not self.is_cog or self.force_load_cog: - vrt = self.load_src_vrt() - window = vrt.window(*tb) + window = rasterio.windows.from_bounds( + *tb, transform=self.src.transform, + width=self.src_tile_size[1], + height=self.src_tile_size[0]) + if self.src.count != 1: - tile_data = vrt.read(window=window, - resampling=getattr(Resampling, - self.resampling), - indexes=channel_idxs) + src_data = self.src.read( + window=window, + resampling=getattr(Resampling, + self.resampling), + indexes=channel_idxs, boundless=True) else: - tile_data = vrt.read(window=window, - resampling=getattr(Resampling, - self.resampling)) - # get the affine xform between src and dest for the tile - aff_xform = transform.from_bounds(*tb, - self.dest_tile_size[1], - self.dest_tile_size[0]) + src_data = self.src.read( + window=window, + resampling=getattr(Resampling, + self.resampling), + boundless=True) + + dst_transform, width, height = calculate_default_transform( + self.src.crs, CRS.from_epsg(self.dest_crs), + self.src.width, self.src.height, *tb, + dst_height=self.dest_tile_size[0], + dst_width=self.dest_tile_size[1]) + + tile_data = np.zeros(shape=(src_data.shape[0], height, width), + dtype=src_data.dtype) + rasterio.warp.reproject( + source=src_data, + destination=tile_data, + src_transform=self.src.window_transform(window), + src_crs=self.src.crs, + dst_transform=dst_transform, + dst_crs=CRS.from_epsg(self.dest_crs), + resampling=getattr(Resampling, self.resampling)) + if self.nodata: mask = np.all(tile_data != nodata, axis=0).astype(np.uint8) * 255 elif self.alpha: - mask = vrt.read(self.alpha, window=window, - resampling=getattr(Resampling, - self.resampling)) + mask = self.src.read(self.alpha, window=window, + resampling=getattr(Resampling, + self.resampling)) else: mask = None # placeholder @@ -313,7 +333,7 @@ def tile_generator(self, src, dest_dir=None, channel_idxs=None, profile.update(width=self.dest_tile_size[1], height=self.dest_tile_size[0], crs=CRS.from_epsg(self.dest_crs), - transform=aff_xform) + transform=dst_transform) if len(tile_data.shape) == 2: # if there's no channel band profile.update(count=1) else: