Skip to content

Commit

Permalink
Merge pull request #66 from TUDelftGeodesy/fix_63
Browse files Browse the repository at this point in the history
Add a function to enrich STM using data from another dataset
  • Loading branch information
rogerkuou committed May 6, 2024
2 parents 72cf071 + 90ef7be commit 2c62a0f
Show file tree
Hide file tree
Showing 6 changed files with 720 additions and 2 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ version = "0.1.1"
requires-python = ">=3.10"
dependencies = [
"dask[complete]",
"xarray",
"xarray[io]",
"numpy",
"rasterio",
"geopandas",
Expand Down
1 change: 1 addition & 0 deletions stmtools/_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def from_csv(
Returns:
-------
xr.Dataset: Output STM instance
"""
# Load csv as Dask DataFrame
ddf = dd.read_csv(file, blocksize=blocksize)
Expand Down
201 changes: 200 additions & 1 deletion stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import numpy as np
import pymorton as pm
import xarray as xr
from scipy.spatial import KDTree
from shapely.geometry import Point
from shapely.strtree import STRtree

Expand All @@ -23,7 +24,7 @@ class SpaceTimeMatrix:
"""Space-Time Matrix."""

def __init__(self, xarray_obj):
"""init."""
"""Init."""
self._obj = xarray_obj

def add_metadata(self, metadata):
Expand All @@ -38,6 +39,7 @@ def add_metadata(self, metadata):
-------
xarray.Dataset
STM with assigned attributes.
"""
self._obj = self._obj.assign_attrs(metadata)
return self._obj
Expand Down Expand Up @@ -69,6 +71,7 @@ def regulate_dims(self, space_label=None, time_label=None):
-------
xarray.Dataset
Regulated STM.
"""
if (
(space_label is None)
Expand Down Expand Up @@ -128,6 +131,7 @@ def subset(self, method: str, **kwargs):
-------
xarray.Dataset
A subset of the original STM.
"""
# Check if both "space" and "time" dimension exists
for dim in ["space", "time"]:
Expand Down Expand Up @@ -203,6 +207,7 @@ def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"):
-------
xarray.Dataset
Enriched STM.
"""
_ = _validate_coords(self._obj, xlabel, ylabel)

Expand Down Expand Up @@ -266,6 +271,7 @@ def _in_polygon(self, polygon, xlabel="lon", ylabel="lat"):
-------
Dask.array
A boolean Dask array. True where a space entry is inside the (multi-)polygon.
"""
# Check if coords exists
_ = _validate_coords(self._obj, xlabel, ylabel)
Expand Down Expand Up @@ -311,6 +317,7 @@ def register_metadata(self, dict_meta: STMMetaData):
-------
xarray.Dataset
STM with registered metadata.
"""
ds_updated = self._obj.assign_attrs(dict_meta)

Expand All @@ -330,6 +337,7 @@ def register_datatype(self, keys: str | Iterable, datatype: DataVarTypes):
-------
xarray.Dataset
STM with registered metadata.
"""
ds_updated = self._obj

Expand Down Expand Up @@ -363,6 +371,7 @@ def get_order(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.
"""
meta_arr = np.array((), dtype=np.int64)
order = da.apply_gufunc(
Expand Down Expand Up @@ -395,6 +404,7 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):
Scaling multiplier to the x coordinates before truncating them to integer values.
yscale : float
Scaling multiplier to the y coordinates before truncating them to integer values.
"""
self._obj = self.get_order(xlabel, ylabel, xscale, yscale)

Expand All @@ -410,6 +420,93 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0):

return self._obj

def enrich_from_dataset(self,
dataset: xr.Dataset | xr.DataArray,
fields: str | Iterable,
method="nearest",
) -> xr.Dataset:
"""Enrich the SpaceTimeMatrix from one or more fields of a dataset.
scipy is required. if dataset is raster, it uses
_enrich_from_raster_block to do interpolation using method. if dataset
is point, it uses _enrich_from_points_block to find the nearest points
in space and time using Euclidean distance.
Parameters
----------
dataset : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "nearest", see
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html
Returns
-------
xarray.Dataset
Enriched STM.
"""
# Check if fields is a Iterable or a str
if isinstance(fields, str):
fields = [fields]
elif not isinstance(fields, Iterable):
raise ValueError("fields need to be a Iterable or a string")

# if dataset is a DataArray, convert it to a Dataset
if isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset()

ds = self._obj
# check if both dataset and ds have coords_labels keys
for coord_label in ds.coords.keys():
if coord_label not in dataset.coords.keys():
raise ValueError(
f'Coordinate label "{coord_label}" was not found in the input dataset.'
)

# check if dataset is point or raster if 'space' in dataset.dims:
if "space" in dataset.dims:
approch = "point"
elif "lat" in dataset.dims and "lon" in dataset.dims:
approch = "raster"
elif "y" in dataset.dims and "x" in dataset.dims:
approch = "raster"
else:
raise ValueError(
"The input dataset is not a point or raster dataset."
"The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions."
"Consider renaming using "
"https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename"
)

# check if dataset has time dimensions
if "time" not in dataset.dims:
raise ValueError('Missing dimension: "time" in the input dataset.')

# check if dtype of time is the same
if dataset.time.dtype != ds.time.dtype:
raise ValueError("The input dataset and the STM have different time dtype.")

# TODO: check if both ds and dataset has same coordinate system

for field in fields:

# check if dataset has the fields
if field not in dataset.data_vars.keys():
raise ValueError(f'Field "{field}" not found in the the input dataset')

# check STM has the filed already
if field in ds.data_vars.keys():
raise ValueError(f'Field "{field}" already exists in the STM.')
# TODO: overwrite the field in the STM

if approch == "raster":
return _enrich_from_raster_block(ds, dataset, fields, method)
elif approch == "point":
return _enrich_from_points_block(ds, dataset, fields)

@property
def num_points(self):
"""Get number of space entry of the stm.
Expand All @@ -418,6 +515,7 @@ def num_points(self):
-------
int
Number of space entry.
"""
return self._obj.dims["space"]

Expand All @@ -429,6 +527,7 @@ def num_epochs(self):
-------
int
Number of epochs.
"""
return self._obj.dims["time"]

Expand Down Expand Up @@ -482,6 +581,7 @@ def _ml_str_query(xx, yy, polygon, type_polygon):
An array with two columns. The first column is the positional index into the list of
polygons being used to query the tree. The second column is the positional index into
the list of space entries for which the tree was constructed.
"""
# Crop the polygon to the bounding box of the block
xmin, ymin, xmax, ymax = [
Expand Down Expand Up @@ -547,6 +647,7 @@ def _validate_coords(ds, xlabel, ylabel):
------
ValueError
If xlabel or ylabel neither exists in coordinates, raise ValueError
"""
for clabel in [xlabel, ylabel]:
if clabel not in ds.coords.keys():
Expand Down Expand Up @@ -579,6 +680,104 @@ def _compute_morton_code(xx, yy):
-------
array_like
An array with Morton codes per coordinate pair.
"""
code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)]
return code


def _enrich_from_raster_block(ds, dataraster, fields, method):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a raster dataset.
scipy is required. It uses xarray.Dataset.interp_like to interpolate the
raster dataset to the coordinates of ds.
https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html
Parameters
----------
ds : xarray.Dataset
SpaceTimeMatrix to enrich
dataraster : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
method : str, optional
Method of interpolation, by default "nearest", see
Returns
-------
xarray.Dataset
"""
# interpolate the raster dataset to the coordinates of ds
interpolated = dataraster.interp(ds.coords, method=method)

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords)
return ds


def _enrich_from_points_block(ds, datapoints, fields):
"""Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset.
Assumption is that dimensions of data are space and time.
https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html#scipy.spatial.KDTree
Parameters
----------
ds : xarray.Dataset
SpaceTimeMatrix to enrich
datapoints : xarray.Dataset | xarray.DataArray
Input data for enrichment
fields : str or list of str
Field name(s) in the dataset for enrichment
Returns
-------
xarray.Dataset
"""
# The reason that we use KDTRee instead of xarray.unstack is that the latter
# is slow for large datasets

# check the dimensions
indexer = {}
for dim in ["space", "time"]:
if dim not in datapoints.coords:
indexer[dim]= [
coord for coord in datapoints.coords if dim in datapoints[coord].dims
]
else:
indexer[dim] = [dim]

## datapoints
indexes = [datapoints[coord] for coord in indexer["space"]]
dataset_points_coords = np.column_stack(indexes)

# ds
indexes = [ds[coord] for coord in indexer["space"]]
ds_coords = np.column_stack(indexes)

# Create a KDTree object for the spatial coordinates of datapoints
# Find the indices of the nearest points in space in datapoints for each point in ds
# it uses Euclidean distance
tree = KDTree(dataset_points_coords)
_, indices_space = tree.query(ds_coords)

# Create a KDTree object for the temporal coordinates of datapoints
# Find the indices of the nearest points in time in datapoints for each point in ds
datapoints_times = datapoints.time.values.reshape(-1, 1)
ds_times = ds.time.values.reshape(-1, 1)
tree = KDTree(datapoints_times)
_, indices_time = tree.query(ds_times)

selections = datapoints.isel(space=indices_space, time=indices_time)

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(
selections[field].data, dims=ds.dims, coords=ds.coords
)

return ds

0 comments on commit 2c62a0f

Please sign in to comment.