Skip to content

Commit

Permalink
Merge pull request #123 from OpenDataAnalytics/crop-rgb
Browse files Browse the repository at this point in the history
Crop rgb
  • Loading branch information
johnkit committed Mar 14, 2019
2 parents 6cbd268 + 7f0ec92 commit cabd474
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 66 deletions.
19 changes: 14 additions & 5 deletions gaia/geo/gdal_functions.py
Expand Up @@ -231,8 +231,9 @@ def world_to_pixel(geoMatrix, x, y):
ulX = geoMatrix[0]
ulY = geoMatrix[3]
xDist = geoMatrix[1]
yDist = geoMatrix[5]
pixel = int((x - ulX) / xDist)
line = int((ulY - y) / xDist)
line = int((y - ulY) / yDist)
return (pixel, line)

src_image = get_dataset(raster_input)
Expand Down Expand Up @@ -264,7 +265,10 @@ def world_to_pixel(geoMatrix, x, y):
px_width = int(lr_x - ul_x)
px_height = int(lr_y - ul_y)

clip = src_array[ul_y:lr_y, ul_x:lr_x]
if raster_input.RasterCount == 1:
clip = src_array[ul_y:lr_y, ul_x:lr_x]
else:
clip = src_array[:, ul_y:lr_y, ul_x:lr_x]

# create pixel offset to pass to new image Projection info
xoffset = ul_x
Expand Down Expand Up @@ -295,14 +299,19 @@ def world_to_pixel(geoMatrix, x, y):
mask = image_to_array(raster_poly)

# Clip the image using the mask
clip = gdalnumeric.numpy.choose(
mask, (clip, nodata_value)).astype(src_dtype)
if raster_input.RasterCount == 1:
clip = gdalnumeric.numpy.choose(
mask, (clip, nodata_value)).astype(src_dtype)
else:
for i in range(raster_input.RasterCount):
clip[i] = gdalnumeric.numpy.choose(
mask, (clip[i], nodata_value)).astype(src_dtype)

# create output raster
raster_band = raster_input.GetRasterBand(1)
output_driver = gdal.GetDriverByName('MEM')
output_dataset = output_driver.Create(
'', clip.shape[1], clip.shape[0],
'', clip.shape[-1], clip.shape[-2],
raster_input.RasterCount, raster_band.DataType)
output_dataset.SetGeoTransform(geo_trans)
output_dataset.SetProjection(raster_input.GetProjection())
Expand Down
41 changes: 40 additions & 1 deletion gaia/io/gdal_reader.py
Expand Up @@ -87,7 +87,46 @@ def read(self, format=formats.RASTER, epsg=None, as_numpy_array=False,
return o

def load_metadata(self, dataObject):
self.__read_internal(dataObject)
# self.__read_internal(dataObject)
data = dataObject.get_data()

# Get corner points
gt = data.GetGeoTransform()
if gt is None:
raise Exception(
'Cannot compute corners - dataset has no geo transform')
num_cols = data.RasterXSize
num_rows = data.RasterYSize
corners = list()
for px in [0, num_cols]:
for py in [0, num_rows]:
x = gt[0] + px*gt[1] + py*gt[2]
y = gt[3] + px*gt[4] + py*gt[5]
corners.append([x, y])

# if as_lonlat:
# spatial_ref = osr.SpatialReference()
# spatial_ref.ImportFromWkt(self.get_wkt_string())
# corners = self._convert_to_lonlat(corners, spatial_ref)

xvals = [c[0] for c in corners]
yvals = [c[1] for c in corners]
xmin = min(xvals)
ymin = min(yvals)
xmax = max(xvals)
ymax = max(yvals)
coords = [[
[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]
]]
metadata = {
'bounds': {
'coordinates': coords
},
'height': data.RasterYSize,
'width': data.RasterXSize
}
# print('metadata: {}'.format(metadata))
dataObject.set_metadata(metadata)

def load_data(self, dataObject):
self.__read_internal(dataObject)
Expand Down
63 changes: 41 additions & 22 deletions gaia/io/geojson_reader.py
Expand Up @@ -5,6 +5,7 @@

import re
from six import string_types
import geojson
import geopandas

from gaia.io.readers import GaiaReader
Expand All @@ -26,22 +27,29 @@ class GaiaGeoJSONReader(GaiaReader):
"""
epsgRegex = re.compile('epsg:([\d]+)')

def __init__(self, url, *args, **kwargs):
def __init__(self, data_source, *args, **kwargs):
super(GaiaGeoJSONReader, self).__init__(*args, **kwargs)

self.uri = url
self.ext = '.%s' % get_uri_extension(self.uri)
self.geojson_object = None
self.uri = None
self.ext = None

if isinstance(data_source, string_types):
self.uri = data_source
self.ext = '.%s' % get_uri_extension(self.uri)
elif isinstance(data_source, geojson.GeoJSON):
self.geojson_object = data_source

@staticmethod
def can_read(url, *args, **kwargs):
# Todo update for girder-hosted files
if not isinstance(url, string_types):
def can_read(data_source, *args, **kwargs):
if isinstance(data_source, string_types):
# Check string for a supported filename/url
extension = '.{}'.format(get_uri_extension(data_source))
if extension in formats.VECTOR:
return True
return False

extension = '.{}'.format(get_uri_extension(url))
if extension in formats.VECTOR:
elif isinstance(data_source, geojson.GeoJSON):
return True
return False

def read(self, format=None, epsg=None):
return super().read(format, epsg)
Expand All @@ -57,15 +65,28 @@ def __read_internal(self, dataObject):
# if not self.format:
# self.format = self.default_output

# FIXME: Should this check actually go into the can_read method?
if self.ext not in formats.VECTOR:
raise UnsupportedFormatException(
"Only the following vector formats are supported: {}".format(
','.join(formats.VECTOR)
)
)

data = geopandas.read_file(self.uri)
if self.uri:
if self.ext not in formats.VECTOR:
tpl = "Only the following vector formats are supported: {}"
msg = tpl.format(','.join(formats.VECTOR))
raise UnsupportedFormatException(msg)
data = geopandas.read_file(self.uri)

elif self.geojson_object:
if isinstance(self.geojson_object, geojson.geometry.Geometry):
feature = geojson.Feature(geometry=self.geojson_object)
features = geojson.FeatureCollection([feature])
elif isinstance(self.geojson_object, geojson.Feature):
features = geojson.FeatureCollection([self.geojson_object])
elif isinstance(self.geojson_object, geojson.FeatureCollection):
features = self.geojson_object
else:
raise UnsupportedFormatException(
'Unrecognized geojson object {}'.self.geojson_object)

# For now, hard code crs to lat-lon
data = geopandas.GeoDataFrame.from_features(
features, crs=dict(init='epsg:4326'))

# FIXME: still need to handle filtering
# if self.filters:
Expand All @@ -74,9 +95,6 @@ def __read_internal(self, dataObject):
# FIXME: skipped the transformation step for now
# return self.transform_data(format, epsg)

# do the actual reading and set both data and metadata
# on the dataObject parameter

# Initialize metadata
metadata = dict()

Expand All @@ -97,6 +115,7 @@ def __read_internal(self, dataObject):

dataObject.set_data(data)
epsgString = data.crs['init']

m = self.epsgRegex.search(epsgString)
if m:
dataObject._epsg = int(m.group(1))
Expand Down
43 changes: 5 additions & 38 deletions gaia/preprocess/gdal_processes.py
Expand Up @@ -8,6 +8,7 @@
from gaia.validators import validate_subset
from gaia.process_registry import register_process
from gaia.geo.gdal_functions import gdal_clip
from gaia.io.gdal_reader import GaiaGDALReader
import gaia.types


Expand Down Expand Up @@ -41,47 +42,13 @@ def compute_subset_gdal(inputs=[], args=[]):
# not to write the output dataset to a tiff file on disk
output_dataset = gdal_clip(raster_img, None, clip_json)

# Copy data to new GDALDataObject
outputDataObject = GDALDataObject()
outputDataObject.set_data(output_dataset)
outputDataObject._datatype = gaia.types.RASTER

# print('input meta: {}'.format(raster.get_metadata()))

# Get corner points
gt = output_dataset.GetGeoTransform()
if gt is None:
raise Exception('Cannot compute corners - dataset has no geo transform')
num_cols = output_dataset.RasterXSize
num_rows = output_dataset.RasterYSize
corners = list()
for px in [0, num_cols]:
for py in [0, num_rows]:
x = gt[0] + px*gt[1] + py*gt[2]
y = gt[3] + px*gt[4] + py*gt[5]
corners.append([x, y])

# if as_lonlat:
# spatial_ref = osr.SpatialReference()
# spatial_ref.ImportFromWkt(self.get_wkt_string())
# corners = self._convert_to_lonlat(corners, spatial_ref)

xvals = [c[0] for c in corners]
yvals = [c[1] for c in corners]
xmin = min(xvals)
ymin = min(yvals)
xmax = max(xvals)
ymax = max(yvals)
coords = [[
[xmin, ymin], [xmax, ymin], [xmax, ymax], [xmin, ymax]
]]
metadata = {
'bounds': {
'coordinates': coords
},
'height': output_dataset.RasterYSize,
'width': output_dataset.RasterXSize
}
print('metadata: {}'.format(metadata))
outputDataObject.set_metadata(metadata)
# Instantiate temporary reader to (only) parse metadata
reader = GaiaGDALReader('internal.tif')
reader.load_metadata(outputDataObject)

return outputDataObject
1 change: 1 addition & 0 deletions requirements.txt
Expand Up @@ -26,3 +26,4 @@ geoalchemy2>=0.2.6
rasterio>=0.36.0
pyOpenSSL>=17.0.0
girder-client>=2.4.0
geojson>=2.0.0
27 changes: 27 additions & 0 deletions tests/cases/test_processes.py
Expand Up @@ -20,6 +20,9 @@
import json
import unittest
from zipfile import ZipFile

import geojson

import gaia
from gaia.preprocess import crop
from gaia.io import readers
Expand Down Expand Up @@ -72,3 +75,27 @@ def test_crop_gdal(self):
testfile = os.path.join(testfile_path, '2states.geojson')
if os.path.exists(testfile):
os.remove(testfile)

def test_crop_rgb(self):
"""
Test cropping raster data with RGB bands
"""
input_path = os.path.join(testfile_path, 'simplergb.tif')
input_raster = gaia.create(input_path)

# Generate crop geometry from raster bounds
bounds = input_raster.get_metadata().get('bounds').get('coordinates')
bounds = bounds[0]
x = (bounds[0][0] + bounds[2][0]) / 2.0
y = (bounds[0][1] + bounds[2][1]) / 2.0

dx = 0.12 * (bounds[2][0] - bounds[0][0])
dy = 0.16 * (bounds[2][1] - bounds[0][1])
poly = [[
[x, y], [x+dx, y+dy], [x-dx, y+dy], [x-dx, y-dy], [x+dx, y-dy]
]]
geometry = geojson.Polygon(poly)
crop_geom = gaia.create(geometry)

cropped_raster = crop(input_raster, crop_geom)
self.assertIsNotNone(cropped_raster)
Binary file added tests/data/simplergb.tif
Binary file not shown.

0 comments on commit cabd474

Please sign in to comment.