Skip to content

Commit

Permalink
prediction on a different extent finally working?
Browse files Browse the repository at this point in the history
  • Loading branch information
basaks committed Feb 9, 2023
1 parent a055c1d commit 4594183
Show file tree
Hide file tree
Showing 4 changed files with 51 additions and 29 deletions.
14 changes: 7 additions & 7 deletions configs/ref_rf.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,19 @@ features:
- path: configs/data/sirsam/k_15v5.tif
- path: configs/data/sirsam/relief_apsect.tif
transforms:
- standardise
- whiten:
keep_fraction: 0.8
# - standardise
# - whiten:
# keep_fraction: 0.8
imputation: none

preprocessing:
imputation: none
transforms:
- whiten:
keep_fraction: 0.8
# - whiten:
# keep_fraction: 0.8

targets:
file: configs/data/geochem_sites.shp
file: configs/data/geochem_sites_cropped.shp
property: K_ppm_imp
# group_targets:
# groups_eps: 0.09
Expand Down Expand Up @@ -71,7 +71,7 @@ learning:


prediction:
# prediction_template: configs/data/sirsam/dem_foc2.tif
prediction_template: configs/data/sirsam/dem_foc2.tif
quantiles: 0.95
outbands: 4

Expand Down
6 changes: 4 additions & 2 deletions uncoverml/features.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from typing import Optional
from collections import OrderedDict
import numpy as np
import pickle
Expand All @@ -16,10 +17,11 @@
log = logging.getLogger(__name__)


def extract_subchunks(image_source: RasterioImageSource, subchunk_index, n_subchunks, patchsize):
def extract_subchunks(image_source: RasterioImageSource, subchunk_index, n_subchunks, patchsize,
template_source: Optional[RasterioImageSource] = None):
equiv_chunks = n_subchunks * mpiops.chunks
equiv_chunk_index = mpiops.chunks*subchunk_index + mpiops.chunk_index
image = Image(image_source, equiv_chunk_index, equiv_chunks, patchsize)
image = Image(image_source, equiv_chunk_index, equiv_chunks, patchsize, template_source)
x = patch.all_patches(image, patchsize)
return x

Expand Down
22 changes: 13 additions & 9 deletions uncoverml/geoio.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,15 +79,12 @@ def crs(self):

class RasterioImageSource(ImageSource):

def __init__(self, filename, template_filename: Optional[str] = None):
def __init__(self, filename):

self.filename = filename
assert os.path.isfile(filename), '{} does not exist'.format(filename)

template_geotiff = rasterio.open(template_filename, 'r') if template_filename else None

with rasterio.open(self.filename, 'r') as geotiff:
tiff = geotiff if template_geotiff is None else template_geotiff

self._full_res = (geotiff.width, geotiff.height, geotiff.count)
self._nodata_value = geotiff.meta['nodata']
Expand All @@ -99,7 +96,7 @@ def __init__(self, filename, template_filename: Optional[str] = None):
self._dtype = np.dtype(geotiff.dtypes[0])
self._crs = geotiff.crs

A = tiff.transform
A = geotiff.transform
# No shearing or rotation allowed!!
if not ((A[1] == 0) and (A[3] == 0)):
raise RuntimeError("Transform to pixel coordinates"
Expand Down Expand Up @@ -297,7 +294,7 @@ def get_image_spec(model, config: Config):


def get_image_spec_from_nchannels(nchannels, config: Config):
if config.prediction_template:
if config.prediction_template and config.is_prediction:
imagelike = Path(config.prediction_template).absolute()
else:
imagelike = config.feature_sets[0].files[0]
Expand Down Expand Up @@ -447,15 +444,17 @@ def feature_names(config: Config):
return results


def _iterate_sources(f, config):
def _iterate_sources(f, config: Config):

results = []
template_tif = config.prediction_template if config.is_prediction else None
if config.is_prediction:
log.info(f"Using prediction template {config.prediction_template}")
for s in config.feature_sets:
extracted_chunks = {}
for tif in s.files:
name = os.path.abspath(tif)
image_source = RasterioImageSource(tif, template_filename=template_tif)
image_source = RasterioImageSource(tif)
x = f(image_source)
log_missing_percentage(name, x)
extracted_chunks[name] = x
Expand Down Expand Up @@ -494,7 +493,12 @@ def image_subchunks(subchunk_index, config: Config):
"""This is used in prediction only"""

def f(image_source: RasterioImageSource):
r = features.extract_subchunks(image_source, subchunk_index, config.n_subchunks, config.patchsize)
if config.is_prediction and config.prediction_template is not None:
template_source = RasterioImageSource(config.prediction_template)
else:
template_source = None
r = features.extract_subchunks(image_source, subchunk_index, config.n_subchunks, config.patchsize,
template_source=template_source)
return r
result = _iterate_sources(f, config)
return result
Expand Down
38 changes: 27 additions & 11 deletions uncoverml/image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Optional
import numpy as np
import logging

Expand Down Expand Up @@ -28,7 +29,10 @@ def construct_splits(npixels, nchunks, overlap=0):


class Image:
def __init__(self, source: RasterioImageSource, chunk_idx=0, nchunks=1, overlap=0):
def __init__(self, source: RasterioImageSource,
chunk_idx=0, nchunks=1, overlap=0,
t_source: Optional[RasterioImageSource] = None,
):
assert chunk_idx >= 0 and chunk_idx < nchunks

if nchunks == 1 and overlap != 0:
Expand All @@ -49,6 +53,10 @@ def __init__(self, source: RasterioImageSource, chunk_idx=0, nchunks=1, overlap=
self.pixsize_x = source.pixsize_x
self.pixsize_y = source.pixsize_y
self.crs = source.crs
log.debug(f"Image full resolution : {source.full_resolution}")
log.debug(f"Image origin longitude : {source.origin_longitude}")
log.debug(f"Image origin latitude : {source.origin_latitude}")

assert self.pixsize_x > 0
assert self.pixsize_y > 0

Expand All @@ -65,21 +73,29 @@ def __init__(self, source: RasterioImageSource, chunk_idx=0, nchunks=1, overlap=
self._pix_y_to_coords = dict(zip(pix_y, coords_y))

# exclusive y range of this chunk in full image
ymin, ymax = construct_splits(self._full_res[1], nchunks, overlap)[chunk_idx]
self._offset = np.array([0, ymin], dtype=int)
# exclusive x range of this chunk (same for all chunks)
xmin, xmax = 0, self._full_res[0]
if t_source:
self._t_full_res = t_source.full_resolution
self._t_start_lon = t_source.origin_longitude
self._t_start_lat = t_source.origin_latitude
self._t_pixsize_x = t_source.pixsize_x
self._t_pixsize_y = t_source.pixsize_y
tymin, tymax = construct_splits(self._t_full_res[1], nchunks, overlap)[chunk_idx]
t_x_offset = np.searchsorted(self._coords_x, self._t_start_lon)
t_y_offset = np.searchsorted(self._coords_y, self._t_start_lat)
xmin, xmax = t_x_offset, t_x_offset + self._t_full_res[0]
ymin, ymax = t_y_offset + tymin, t_y_offset + tymax
self.resolution = (xmax-xmin, ymax - ymin, self._t_full_res[2])
else:
# exclusive x range of this chunk (same for all chunks)
xmin, xmax = 0, self._full_res[0]
ymin, ymax = construct_splits(self._full_res[1], nchunks, overlap)[chunk_idx]
self.resolution = (xmax - xmin, ymax - ymin, self._full_res[2])

self._offset = np.array([xmin, ymin], dtype=int)
assert(xmin < xmax)
assert(ymin < ymax)

# get resolution of this chunk
xres = self._full_res[0]
yres = ymax - ymin

# Calculate the new values for resolution and bounding box
self.resolution = (xres, yres, self._full_res[2])

start_bound_x, start_bound_y = self._global_pix2lonlat(
np.array([[xmin, ymin]]))[0]
# one past the last pixel
Expand Down

0 comments on commit 4594183

Please sign in to comment.