From a64bb06430887f29502b4553445429fc2d3b23ed Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Tue, 12 Mar 2024 15:21:36 +0100 Subject: [PATCH 01/33] draft implementation of querying data --- stmtools/stm.py | 101 ++++++++++++++++++++++++++++++++++++++++++++++ tests/test_stm.py | 76 ++++++++++++++++++++++++++++++++++ 2 files changed, 177 insertions(+) diff --git a/stmtools/stm.py b/stmtools/stm.py index ea0b8e1..fe35df2 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -400,6 +400,99 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): self._obj = self._obj.sortby(self._obj.order) return self._obj + def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | Iterable, method="linear") -> xr.Dataset: + """Enrich the SpaceTimeMatrix from one or more fields of a dataset. + + scipy is required. Each field will be assigned as a data variable to the + STM using interpolation in time and space. + + Parameters + ---------- + dataset : xarray.Dataset | xarray.DataArray + Input data for enrichment + fields : str or list of str + Field name(s) in the dataset for enrichment + method : str, optional + Method of interpolation, by default "linear", see + https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + + Returns + ------- + xarray.Dataset + Enriched STM. + """ + # Check if fields is a Iterable or a str + if isinstance(fields, str): + fields = [fields] + elif not isinstance(fields, Iterable): + raise ValueError("fields need to be a Iterable or a string") + + # if dataset is a DataArray, convert it to a Dataset + if isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset() + + ds = self._obj + + # TODO: add utility to preprocess the dataset + # check if dataset has space and time dimensions + if "space" not in dataset.dims: + raise ValueError('Missing dimension: "space" in the input dataset.') + if "time" not in dataset.dims: + raise ValueError('Missing dimension: "time" in the input dataset.') + + # check if dtype of time is the same + if dataset.time.dtype != ds.time.dtype: + raise ValueError("The input dataset and the STM have different time dtype.") + + # check if dataset and ds has the same space and time shapes, required + # for interpolation + if dataset.space.shape != ds.space.shape: + raise ValueError("The input dataset and the STM have different space shapes.") + if dataset.time.shape != ds.time.shape: + raise ValueError("The input dataset and the STM have different time shapes.") + + # check if the keys of dataset coordinates are the same as the STM + for key in ds.coords.keys(): + if key not in dataset.coords.keys(): + raise ValueError(f'Coordinate label "{key}" was not found in the input dataset.') + + chunks = (ds.chunksizes["space"][0], ds.chunksizes["time"][0]) + for field in fields: + + # check if dataset has the fields + if field not in dataset.data_vars.keys(): + raise ValueError(f'Field "{field}" not found in the the input dataset') + + # check STM has the filed already + if field in ds.data_vars.keys(): + logger.warning( + f'"{field}" was found in the data variables of the STM. ' + f'"We will proceed with the data variable from the input dataset as "{field}_other".' + ) + field = f"{field}_other" + + ds = ds.assign( + { + field: ( + ["space", "time"], + da.from_array(np.full(ds.space.shape + ds.time.shape, None), chunks=chunks), + ) + } + ) + # spatial interpolation and map_blocks does not work if coordinates are not same + # ds = xr.map_blocks( + # _enrich_from_dataset_block, + # ds, + # args=(dataset, fields, method), + # template=ds, + # ) + _ds = ds.copy(deep=True) + for field in fields: + _ds[field].data = dataset[field].interp_like(ds, method=method) + ds = _ds + + return ds + @property def num_points(self): """Get number of space entry of the stm. @@ -572,3 +665,11 @@ def _compute_morton_code(xx, yy): """ code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)] return code + + +def _enrich_from_dataset_block(ds, dataset, fields, method): + """Block-wise function for "enrich_from_dataset".""" + _ds = ds.copy(deep=True) + for field in fields: + _ds[field].data = dataset[field].interp_like(ds, method=method) + return _ds diff --git a/tests/test_stm.py b/tests/test_stm.py index 11f89f9..e207ff8 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -253,6 +253,22 @@ def stmat_lonlat_morton(): ), ).unify_chunks() +@pytest.fixture +def meteo(): + lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]) + lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25]) + + return xr.Dataset( + data_vars=dict( + temperature=(["space", "time"], da.arange(10 * 5).reshape((10, 5))), + humidity=(["space", "time"], da.arange(10 * 5).reshape((10, 5))), + ), + coords=dict( + lon=(["space"], lon_values), + lat=(["space"], lat_values), + time=(["time"], np.arange(5)), + ), + ).unify_chunks() class TestRegulateDims: def test_time_dim_exists(self, stmat_only_point): @@ -460,3 +476,63 @@ def test_reorder_lonlat(self, stmat_lonlat, stmat_lonlat_morton): assert not stmat_naive.range.equals(stmat_lonlat_morton.range) assert stmat.azimuth.equals(stmat_lonlat_morton.azimuth) assert stmat.range.equals(stmat_lonlat_morton.range) + + +class TestEnrichmentFromDataset: + def test_enrich_from_dataset_one_filed(self, stmat, meteo): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo, "temperature") + assert "temperature" in stmat_enriched.data_vars + + # check if the linear interpolation is correct + assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] + + # check if coordinates are correct + assert stmat_enriched.lon.equals(stmat.lon) + + def test_enrich_from_dataset_multi_filed(self, stmat, meteo): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo, ["temperature", "humidity"]) + assert "temperature" in stmat_enriched.data_vars + assert "humidity" in stmat_enriched.data_vars + + # check if the linear interpolation is correct + assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] + assert stmat_enriched.humidity[0, 0] == meteo.humidity[0, 0] + + def test_enrich_from_dataset_exceptions(self, stmat, meteo): + # valid fileds + with pytest.raises(ValueError) as excinfo: + field = "non_exist_field" + stmat.stm.enrich_from_dataset(meteo, field) + assert f'Field "{field}" not found' in str(excinfo.value) + + # valid dtype of "time" + meteo["time"] = meteo["time"].astype("float64") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(meteo, "temperature") + assert "different time dtype" in str(excinfo.value) + + # "time" dimension should exist in the meteo + meteo = meteo.drop_vars("time") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(meteo, "temperature") + assert 'Missing dimension: "time"' in str(excinfo.value) + + # shapes of "space" and "time" should be the same + meteo = meteo.sel(space=range(5)) + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(meteo, "temperature") + assert "different space shapes" in str(excinfo.value) + + # keys of coordinates should be the same + meteo = meteo.rename({"lon": "long"}) + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(meteo, "temperature") + assert 'Coordinate label "long" was not found' in str(excinfo.value) + + + def test_enrich_from_dataarray_one_filed(self, stmat, meteo): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo.temperature, "temperature") + assert "temperature" in stmat_enriched.data_vars + + # check if the linear interpolation is correct + assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] From 9a5feb46c597450dc428e64d6a0ccf1946d32404 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Tue, 12 Mar 2024 15:22:20 +0100 Subject: [PATCH 02/33] add scipy and xarray io to dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 272aa68..58eb8d9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ version = "0.1.1" requires-python = ">=3.10" dependencies = [ "dask[complete]", - "xarray", + "xarray[io]", "numpy", "rasterio", "geopandas", @@ -16,6 +16,7 @@ dependencies = [ "zarr", "distributed", "pymorton", + "scipy", # required for `enrich_from_dataset` method ] description = "space-time matrix for PS-InSAR application" readme = "README.md" From ba411743fec448145ca870d17d5118df74f922a6 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 15 Mar 2024 16:58:52 +0100 Subject: [PATCH 03/33] refactor enrich_from_dataset to two approaches of point and raster input data --- stmtools/stm.py | 173 +++++++++++++++++++++++++++++++++++------------- 1 file changed, 126 insertions(+), 47 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index fe35df2..bc41538 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -11,6 +11,7 @@ import xarray as xr from shapely.geometry import Point from shapely.strtree import STRtree +from scipy.spatial import cKDTree from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -400,20 +401,25 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): self._obj = self._obj.sortby(self._obj.order) return self._obj - def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | Iterable, method="linear") -> xr.Dataset: + def enrich_from_dataset(self, + dataset: xr.Dataset | xr.DataArray, + fields: str | Iterable, + method="nearest") -> xr.Dataset: """Enrich the SpaceTimeMatrix from one or more fields of a dataset. - scipy is required. Each field will be assigned as a data variable to the - STM using interpolation in time and space. + scipy is required. if dataset is raster, it uses + _enrich_from_raster_block to do interpolation using method. if dataset + is point, it uses _enrich_from_points_block to find the nearest points + in space and time using Euclidean distance. Parameters ---------- - dataset : xarray.Dataset | xarray.DataArray + dataset : xarray.Dataset | xarray.DataArray Input data for enrichment fields : str or list of str Field name(s) in the dataset for enrichment method : str, optional - Method of interpolation, by default "linear", see + Method of interpolation, by default "nearest", see https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like Returns @@ -432,11 +438,27 @@ def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | dataset = dataset.to_dataset() ds = self._obj + # check if both dataset and ds have coords_labels keys + for coord_label in ds.coords.keys(): + if coord_label not in dataset.coords.keys(): + raise ValueError( + f'Coordinate label "{coord_label}" was not found in the input dataset.' + ) + + # check if dataset is point or raster if 'space' in dataset.dims: + if "space" in dataset.dims: + approch = "point" + elif "lat" in dataset.dims and "lon" in dataset.dims: + approch = "raster" + elif "y" in dataset.dims and "x" in dataset.dims: + approch = "raster" + else: + raise ValueError( + "The input dataset is not a point or raster dataset." + "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." + ) - # TODO: add utility to preprocess the dataset - # check if dataset has space and time dimensions - if "space" not in dataset.dims: - raise ValueError('Missing dimension: "space" in the input dataset.') + # check if dataset has time dimensions if "time" not in dataset.dims: raise ValueError('Missing dimension: "time" in the input dataset.') @@ -444,20 +466,9 @@ def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | if dataset.time.dtype != ds.time.dtype: raise ValueError("The input dataset and the STM have different time dtype.") - # check if dataset and ds has the same space and time shapes, required - # for interpolation - if dataset.space.shape != ds.space.shape: - raise ValueError("The input dataset and the STM have different space shapes.") - if dataset.time.shape != ds.time.shape: - raise ValueError("The input dataset and the STM have different time shapes.") - - # check if the keys of dataset coordinates are the same as the STM - for key in ds.coords.keys(): - if key not in dataset.coords.keys(): - raise ValueError(f'Coordinate label "{key}" was not found in the input dataset.') + # TODO: check if both ds and dataset has same coordinate system - chunks = (ds.chunksizes["space"][0], ds.chunksizes["time"][0]) - for field in fields: + for i, field in enumerate(fields): # check if dataset has the fields if field not in dataset.data_vars.keys(): @@ -469,29 +480,22 @@ def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | f'"{field}" was found in the data variables of the STM. ' f'"We will proceed with the data variable from the input dataset as "{field}_other".' ) - field = f"{field}_other" - - ds = ds.assign( - { - field: ( - ["space", "time"], - da.from_array(np.full(ds.space.shape + ds.time.shape, None), chunks=chunks), - ) - } + fields[i] = f"{field}_other" + + if approch == "raster": + return xr.map_blocks( + _enrich_from_raster_block, + ds, + args=(fields, method), + kwargs={"dataset": dataset}, #TODD: block still not working, refactor + ) + elif approch == "point": + return xr.map_blocks( + _enrich_from_points_block, + ds, + args=(fields), + kwargs={"dataset": dataset}, ) - # spatial interpolation and map_blocks does not work if coordinates are not same - # ds = xr.map_blocks( - # _enrich_from_dataset_block, - # ds, - # args=(dataset, fields, method), - # template=ds, - # ) - _ds = ds.copy(deep=True) - for field in fields: - _ds[field].data = dataset[field].interp_like(ds, method=method) - ds = _ds - - return ds @property def num_points(self): @@ -667,9 +671,84 @@ def _compute_morton_code(xx, yy): return code -def _enrich_from_dataset_block(ds, dataset, fields, method): - """Block-wise function for "enrich_from_dataset".""" +def _enrich_from_raster_block(ds, dataraster, fields, method): + """Enrich the ds (SpaceTimeMatrix) from one or more fields of a raster dataset. + + scipy is required. It uses xarray.Dataset.interp_like to interpolate the + raster dataset to the coordinates of ds. + https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + + Parameters + ---------- + ds : xarray.Dataset + + dataset : xarray.Dataset | xarray.DataArray + Input data for enrichment + fields : str or list of str + Field name(s) in the dataset for enrichment + method : str, optional + Method of interpolation, by default "linear", see + + Returns + ------- + xarray.Dataset + """ + # interpolate the raster dataset to the coordinates of ds + interpolated = dataraster.interp(ds.coords, method=method) + + # Assign these values to the corresponding points in ds _ds = ds.copy(deep=True) for field in fields: - _ds[field].data = dataset[field].interp_like(ds, method=method) + _ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords) + return _ds + + +def _enrich_from_points_block(ds, datapoints, fields): + """Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset. + + scipy is required. It uses cKDTree to find the nearest points in space and + time using Euclidean distance. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html#scipy-spatial-ckdtree + + Parameters + ---------- + ds : xarray.Dataset + + datapoints : xarray.Dataset | xarray.DataArray + Input data for enrichment + fields : str or list of str + Field name(s) in the dataset for enrichment + + Returns + ------- + xarray.Dataset + """ + _ds = ds.copy(deep=True) + + # create tuple of spatial coordinates + spatial_coords = list(_ds.coords.keys())[:-1] # assuming the last coordinate is time + ds_coords = np.column_stack([_ds[coord].values.flatten() for coord in spatial_coords]) + + spatial_coords = list(datapoints.coords.keys())[:-1] # assuming the last coordinate is time + dataset_points_coords = np.column_stack([datapoints[coord].values.flatten() for coord in spatial_coords]) + + # Create a cKDTree object for the spatial coordinates of datapoints + # Find the indices of the nearest points in space in datapoints for each point in _ds + # it uses Euclidean distance + tree = cKDTree(dataset_points_coords) + _, indices_space = tree.query(ds_coords) + + # Create a cKDTree object for the temporal coordinates of datapoints + # Find the indices of the nearest points in time in datapoints for each point in _ds + datapoints_times = datapoints.time.values.reshape(-1, 1) + ds_times = _ds.time.values.reshape(-1, 1) + tree = cKDTree(datapoints_times) + _, indices_time = tree.query(ds_times) + + selections = datapoints.isel(time=indices_time, space=indices_space) + + # Assign these values to the corresponding points in _ds + for field in fields: + _ds[field] = xr.DataArray(selections[field].data, dims=ds.dims, coords=ds.coords) + return _ds From e133f5fa4b7be6b7cbb68966e1ad824afce28e56 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 18 Mar 2024 14:58:39 +0100 Subject: [PATCH 04/33] use KDTree instead of cKDTree --- stmtools/stm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index bc41538..531936a 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -11,7 +11,7 @@ import xarray as xr from shapely.geometry import Point from shapely.strtree import STRtree -from scipy.spatial import cKDTree +from scipy.spatial import KDTree from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -735,14 +735,14 @@ def _enrich_from_points_block(ds, datapoints, fields): # Create a cKDTree object for the spatial coordinates of datapoints # Find the indices of the nearest points in space in datapoints for each point in _ds # it uses Euclidean distance - tree = cKDTree(dataset_points_coords) + tree = KDTree(dataset_points_coords) _, indices_space = tree.query(ds_coords) # Create a cKDTree object for the temporal coordinates of datapoints # Find the indices of the nearest points in time in datapoints for each point in _ds datapoints_times = datapoints.time.values.reshape(-1, 1) ds_times = _ds.time.values.reshape(-1, 1) - tree = cKDTree(datapoints_times) + tree = KDTree(datapoints_times) _, indices_time = tree.query(ds_times) selections = datapoints.isel(time=indices_time, space=indices_space) From 90ef61d8feafb68ad7ff89a0208e7764637eac95 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 18 Mar 2024 17:29:24 +0100 Subject: [PATCH 05/33] replace KDTree with sel method of xarray, fix a bug in fields --- stmtools/stm.py | 69 +++++++++++++++++++------------------------------ 1 file changed, 26 insertions(+), 43 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 531936a..19cacae 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -11,7 +11,6 @@ import xarray as xr from shapely.geometry import Point from shapely.strtree import STRtree -from scipy.spatial import KDTree from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -468,7 +467,7 @@ def enrich_from_dataset(self, # TODO: check if both ds and dataset has same coordinate system - for i, field in enumerate(fields): + for field in fields: # check if dataset has the fields if field not in dataset.data_vars.keys(): @@ -476,26 +475,26 @@ def enrich_from_dataset(self, # check STM has the filed already if field in ds.data_vars.keys(): - logger.warning( - f'"{field}" was found in the data variables of the STM. ' - f'"We will proceed with the data variable from the input dataset as "{field}_other".' - ) - fields[i] = f"{field}_other" + raise ValueError(f'Field "{field}" already exists in the STM.') + + # if dataset is a dask collection, compute it first if approch == "raster": - return xr.map_blocks( - _enrich_from_raster_block, - ds, - args=(fields, method), - kwargs={"dataset": dataset}, #TODD: block still not working, refactor - ) + return _enrich_from_raster_block(ds, dataset, fields, method) + # return xr.map_blocks( + # _enrich_from_raster_block, + # ds, + # args=(fields, method), + # kwargs={"dataset": dataset}, #TODD: block still not working, refactor + # ) elif approch == "point": - return xr.map_blocks( - _enrich_from_points_block, - ds, - args=(fields), - kwargs={"dataset": dataset}, - ) + return _enrich_from_points_block(ds, dataset, fields) + # return xr.map_blocks( + # _enrich_from_points_block, + # ds, + # args=(fields), + # kwargs={"dataset": dataset}, + # ) @property def num_points(self): @@ -706,9 +705,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): def _enrich_from_points_block(ds, datapoints, fields): """Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset. - scipy is required. It uses cKDTree to find the nearest points in space and - time using Euclidean distance. - https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.cKDTree.html#scipy-spatial-ckdtree + https://docs.xarray.dev/en/latest/generated/xarray.DataArray.sel.html#xarray.DataArray.sel Parameters ---------- @@ -725,30 +722,16 @@ def _enrich_from_points_block(ds, datapoints, fields): """ _ds = ds.copy(deep=True) - # create tuple of spatial coordinates - spatial_coords = list(_ds.coords.keys())[:-1] # assuming the last coordinate is time - ds_coords = np.column_stack([_ds[coord].values.flatten() for coord in spatial_coords]) - - spatial_coords = list(datapoints.coords.keys())[:-1] # assuming the last coordinate is time - dataset_points_coords = np.column_stack([datapoints[coord].values.flatten() for coord in spatial_coords]) - - # Create a cKDTree object for the spatial coordinates of datapoints - # Find the indices of the nearest points in space in datapoints for each point in _ds - # it uses Euclidean distance - tree = KDTree(dataset_points_coords) - _, indices_space = tree.query(ds_coords) - - # Create a cKDTree object for the temporal coordinates of datapoints - # Find the indices of the nearest points in time in datapoints for each point in _ds - datapoints_times = datapoints.time.values.reshape(-1, 1) - ds_times = _ds.time.values.reshape(-1, 1) - tree = KDTree(datapoints_times) - _, indices_time = tree.query(ds_times) + # add spatial coordinates to dims + datapoints_coords = list(datapoints.coords.keys()) + datapoints = datapoints.set_index(space=datapoints_coords[:-1]) # assuming the last coordinate is time + datapoints = datapoints.unstack("space") # after this, the order of coordinates changes, so we use transpose later - selections = datapoints.isel(time=indices_time, space=indices_space) + indexers = {coord: _ds[coord] for coord in datapoints_coords} + selections = datapoints.sel(indexers, method="nearest") # Assign these values to the corresponding points in _ds for field in fields: - _ds[field] = xr.DataArray(selections[field].data, dims=ds.dims, coords=ds.coords) + _ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) return _ds From 99844c8506b7ec640173299b46a03d6564735017 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 18 Mar 2024 17:29:37 +0100 Subject: [PATCH 06/33] fix tests --- tests/test_stm.py | 188 ++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 157 insertions(+), 31 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index e207ff8..3f98c32 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -5,6 +5,7 @@ import numpy as np import pytest import xarray as xr +import pandas as pd from shapely import geometry from stmtools.stm import _validate_coords @@ -254,22 +255,69 @@ def stmat_lonlat_morton(): ).unify_chunks() @pytest.fixture -def meteo(): - lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5, 5.5, 6.5, 7.5, 8.5, 9.5]) - lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25, 5.25, 6.25, 7.25, 8.25, 9.25]) +def meteo_points(): + lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) + lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25]) + time_values = pd.date_range(start='2021-01-01', periods=6) return xr.Dataset( data_vars=dict( - temperature=(["space", "time"], da.arange(10 * 5).reshape((10, 5))), - humidity=(["space", "time"], da.arange(10 * 5).reshape((10, 5))), + temperature=(["space", "time"], da.arange(5 * 6).reshape((5, 6))), + humidity=(["space", "time"], da.arange(5 * 6).reshape((5, 6))), ), coords=dict( lon=(["space"], lon_values), lat=(["space"], lat_values), - time=(["time"], np.arange(5)), + time=(["time"], time_values), ), ).unify_chunks() +@pytest.fixture +def meteo_raster(): + # create a raster with 5x5 grid + lon_values = np.array([0, 1, 2, 3, 4]) + lat_values = np.array([0, 1, 2, 3, 4]) + time_values = pd.date_range(start='2021-01-01', periods=6) + + return xr.Dataset( + data_vars=dict( + temperature=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))), + humidity=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))), + ), + coords=dict( + lon=(["lon"], lon_values), + lat=(["lat"], lat_values), + time=(["time"], time_values), + ), + ).unify_chunks() + +@pytest.fixture +def stmat(): + npoints = 10 + ntime = 5 + return xr.Dataset( + data_vars=dict( + amplitude=( + ["space", "time"], + da.arange(npoints * ntime).reshape((npoints, ntime)), + ), + phase=( + ["space", "time"], + da.arange(npoints * ntime).reshape((npoints, ntime)), + ), + pnt_height=( + ["space"], + da.arange(npoints), + ), + ), + coords=dict( + lon=(["space"], da.arange(npoints)), + lat=(["space"], da.arange(npoints)), + time=(["time"], pd.date_range(start='2021-01-02', periods=ntime)), + ), + ).unify_chunks() + + class TestRegulateDims: def test_time_dim_exists(self, stmat_only_point): stm_reg = stmat_only_point.stm.regulate_dims() @@ -478,61 +526,139 @@ def test_reorder_lonlat(self, stmat_lonlat, stmat_lonlat_morton): assert stmat.range.equals(stmat_lonlat_morton.range) -class TestEnrichmentFromDataset: - def test_enrich_from_dataset_one_filed(self, stmat, meteo): - stmat_enriched = stmat.stm.enrich_from_dataset(meteo, "temperature") +class TestEnrichmentFromPointDataset: + def test_enrich_from_dataset_one_filed(self, stmat, meteo_points): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") assert "temperature" in stmat_enriched.data_vars - # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] + # check if the nearest method is correct + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] # check if coordinates are correct assert stmat_enriched.lon.equals(stmat.lon) - def test_enrich_from_dataset_multi_filed(self, stmat, meteo): - stmat_enriched = stmat.stm.enrich_from_dataset(meteo, ["temperature", "humidity"]) + def test_enrich_from_dataset_multi_filed(self, stmat, meteo_points): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, ["temperature", "humidity"]) assert "temperature" in stmat_enriched.data_vars assert "humidity" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] - assert stmat_enriched.humidity[0, 0] == meteo.humidity[0, 0] + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + assert stmat_enriched.humidity[0, 0] == meteo_points.humidity[0, 1] - def test_enrich_from_dataset_exceptions(self, stmat, meteo): + def test_enrich_from_dataset_exceptions(self, stmat, meteo_points): # valid fileds with pytest.raises(ValueError) as excinfo: field = "non_exist_field" - stmat.stm.enrich_from_dataset(meteo, field) + stmat.stm.enrich_from_dataset(meteo_points, field) assert f'Field "{field}" not found' in str(excinfo.value) # valid dtype of "time" - meteo["time"] = meteo["time"].astype("float64") + another_meteo_points = meteo_points.copy(deep=True) + another_meteo_points["time"] = another_meteo_points["time"].astype("float64") with pytest.raises(ValueError) as excinfo: - stmat.stm.enrich_from_dataset(meteo, "temperature") + stmat.stm.enrich_from_dataset(another_meteo_points, "temperature") assert "different time dtype" in str(excinfo.value) - # "time" dimension should exist in the meteo - meteo = meteo.drop_vars("time") + # "time" dimension should exist in the meteo_points + another_meteo_points = meteo_points.copy(deep=True) + another_meteo_points = another_meteo_points.drop_vars("time") with pytest.raises(ValueError) as excinfo: - stmat.stm.enrich_from_dataset(meteo, "temperature") + stmat.stm.enrich_from_dataset(another_meteo_points, "temperature") assert 'Missing dimension: "time"' in str(excinfo.value) - # shapes of "space" and "time" should be the same - meteo = meteo.sel(space=range(5)) + # keys of coordinates should be the same + another_meteo_points = meteo_points.copy(deep=True) + another_meteo_points = another_meteo_points.rename({"lon": "long"}) + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(another_meteo_points, "temperature") + assert 'Coordinate label "long" was not found' in str(excinfo.value) + + # dimensions either space or lon/lat should exist + another_meteo_points = meteo_points.copy(deep=True) + another_meteo_points = another_meteo_points.drop_dims("space") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(another_meteo_points, "temperature") + assert 'Missing dimension: "space" or "lon" and "lat"' in str(excinfo.value) + + # field already exists + another_stmat = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + with pytest.raises(ValueError) as excinfo: + another_stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert 'Field "temperature" already exists' in str(excinfo.value) + + def test_enrich_from_dataarray_one_filed(self, stmat, meteo_points): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points.temperature, "temperature") + assert "temperature" in stmat_enriched.data_vars + + # check if the linear interpolation is correct + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + + +class TestEnrichmentFromRasterDataset: + def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, "temperature") + assert "temperature" in stmat_enriched.data_vars + + # check if the nearest method is correct + assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + + # check if coordinates are correct + assert stmat_enriched.lon.equals(stmat.lon) + + def test_enrich_from_dataset_multi_filed(self, stmat, meteo_raster): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, ["temperature", "humidity"]) + assert "temperature" in stmat_enriched.data_vars + assert "humidity" in stmat_enriched.data_vars + + # check if the linear interpolation is correct + assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + assert stmat_enriched.humidity[0, 0] == meteo_raster.humidity[0, 0, 1] + + def test_enrich_from_dataset_exceptions(self, stmat, meteo_raster): + # valid fileds with pytest.raises(ValueError) as excinfo: - stmat.stm.enrich_from_dataset(meteo, "temperature") - assert "different space shapes" in str(excinfo.value) + field = "non_exist_field" + stmat.stm.enrich_from_dataset(meteo_raster, field) + assert f'Field "{field}" not found' in str(excinfo.value) + + # valid dtype of "time" + another_meteo_raster = meteo_raster.copy(deep=True) + another_meteo_raster["time"] = another_meteo_raster["time"].astype("float64") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature") + assert "different time dtype" in str(excinfo.value) + + # "time" dimension should exist in the meteo_raster + another_meteo_raster = meteo_raster.copy(deep=True) + another_meteo_raster = another_meteo_raster.drop_vars("time") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature") + assert 'Missing dimension: "time"' in str(excinfo.value) # keys of coordinates should be the same - meteo = meteo.rename({"lon": "long"}) + another_meteo_raster = meteo_raster.copy(deep=True) + another_meteo_raster = another_meteo_raster.rename({"lon": "long"}) with pytest.raises(ValueError) as excinfo: - stmat.stm.enrich_from_dataset(meteo, "temperature") + stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature") assert 'Coordinate label "long" was not found' in str(excinfo.value) + # dimensions either space or lon/lat should exist + another_meteo_raster = meteo_raster.copy(deep=True) + another_meteo_raster = another_meteo_raster.drop_dims("lat") + with pytest.raises(ValueError) as excinfo: + stmat.stm.enrich_from_dataset(another_meteo_raster, "temperature") + assert 'Missing dimension: "space" or "lon" and "lat"' in str(excinfo.value) + + # field already exists + another_stmat = stmat.stm.enrich_from_dataset(meteo_raster, "temperature") + with pytest.raises(ValueError) as excinfo: + another_stmat.stm.enrich_from_dataset(meteo_raster, "temperature") + assert 'Field "temperature" already exists' in str(excinfo.value) - def test_enrich_from_dataarray_one_filed(self, stmat, meteo): - stmat_enriched = stmat.stm.enrich_from_dataset(meteo.temperature, "temperature") + def test_enrich_from_dataarray_one_filed(self, stmat, meteo_raster): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster.temperature, "temperature") assert "temperature" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo.temperature[0, 0] + assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] \ No newline at end of file From 6926a4bcf496972b4739fdc35f4fbadfb783f042 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Wed, 20 Mar 2024 09:42:18 +0100 Subject: [PATCH 07/33] remove ds copy --- stmtools/stm.py | 14 ++++++-------- 1 file changed, 6 insertions(+), 8 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 19cacae..fad5815 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -696,10 +696,9 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): interpolated = dataraster.interp(ds.coords, method=method) # Assign these values to the corresponding points in ds - _ds = ds.copy(deep=True) for field in fields: - _ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords) - return _ds + ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords) + return ds def _enrich_from_points_block(ds, datapoints, fields): @@ -720,18 +719,17 @@ def _enrich_from_points_block(ds, datapoints, fields): ------- xarray.Dataset """ - _ds = ds.copy(deep=True) # add spatial coordinates to dims datapoints_coords = list(datapoints.coords.keys()) datapoints = datapoints.set_index(space=datapoints_coords[:-1]) # assuming the last coordinate is time datapoints = datapoints.unstack("space") # after this, the order of coordinates changes, so we use transpose later - indexers = {coord: _ds[coord] for coord in datapoints_coords} + indexers = {coord: ds[coord] for coord in datapoints_coords} selections = datapoints.sel(indexers, method="nearest") - # Assign these values to the corresponding points in _ds + # Assign these values to the corresponding points in ds for field in fields: - _ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) + ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) - return _ds + return ds From f61f355366bdc8984f76815ef960fbd0a81fa458 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Wed, 20 Mar 2024 11:43:24 +0100 Subject: [PATCH 08/33] add test if operations are lazy --- tests/test_stm.py | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index 3f98c32..8ba5a95 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -534,8 +534,13 @@ def test_enrich_from_dataset_one_filed(self, stmat, meteo_points): # check if the nearest method is correct assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + # check dimensions of stmat_enriched are the same as stmat + assert stmat_enriched.dims == stmat.dims + # check if coordinates are correct assert stmat_enriched.lon.equals(stmat.lon) + assert stmat_enriched.lat.equals(stmat.lat) + assert stmat_enriched.time.equals(stmat.time) def test_enrich_from_dataset_multi_filed(self, stmat, meteo_points): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, ["temperature", "humidity"]) @@ -594,6 +599,12 @@ def test_enrich_from_dataarray_one_filed(self, stmat, meteo_points): # check if the linear interpolation is correct assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + def test_all_operations_lazy(self, stmat, meteo_points): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature.data.dask is not None + assert stmat_enriched.lon.data.dask is not None + assert stmat_enriched.lat.data.dask is not None + # dont check time because it is not a dask array class TestEnrichmentFromRasterDataset: def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): @@ -603,8 +614,13 @@ def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): # check if the nearest method is correct assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + # check dimensions of stmat_enriched are the same as stmat + assert stmat_enriched.dims == stmat.dims + # check if coordinates are correct assert stmat_enriched.lon.equals(stmat.lon) + assert stmat_enriched.lat.equals(stmat.lat) + assert stmat_enriched.time.equals(stmat.time) def test_enrich_from_dataset_multi_filed(self, stmat, meteo_raster): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, ["temperature", "humidity"]) @@ -661,4 +677,11 @@ def test_enrich_from_dataarray_one_filed(self, stmat, meteo_raster): assert "temperature" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] \ No newline at end of file + assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + + def test_all_operations_lazy(self, stmat, meteo_raster): + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, "temperature") + assert stmat_enriched.temperature.data.dask is not None + assert stmat_enriched.lon.data.dask is not None + assert stmat_enriched.lat.data.dask is not None + # dont check time because it is not a dask array \ No newline at end of file From b3101c83b5ddc474193c7945ccb732046223db73 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Thu, 21 Mar 2024 13:17:24 +0100 Subject: [PATCH 09/33] add util functions for cropping and unstack operations --- stmtools/utils.py | 49 +++++++++++++++++++++ tests/test_util.py | 103 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) diff --git a/stmtools/utils.py b/stmtools/utils.py index 6ef682f..e39e93f 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -1,3 +1,4 @@ +import xarray as xr from collections.abc import Iterable @@ -8,3 +9,51 @@ def _has_property(ds, keys: str | Iterable): return set(keys).issubset(ds.data_vars.keys()) else: raise ValueError(f"Invalid type of keys: {type(keys)}.") + + +def crop(ds, other, buffer): + """Crop the other to a given buffer around ds. + + Parameters + ---------- + ds : xarray.Dataset | xarray.DataArray + Dataset to crop to. + other : xarray.Dataset | xarray.DataArray + Dataset to crop. + buffer : dict + A dictionary with the buffer values for each dimension. + + Returns + ------- + xarray.Dataset + Cropped dataset. + """ + if isinstance(ds, xr.DataArray): + ds = ds.to_dataset() + + if isinstance(other, xr.DataArray): + other = other.to_dataset() + + if not isinstance(buffer, dict): + raise ValueError(f"Invalid type of buffer: {type(buffer)}.") + for coord in buffer.keys(): + if coord not in ds.coords.keys(): + raise ValueError(f"coordinate '{coord}' not found in ds.") + if coord not in other.coords.keys(): + raise ValueError(f"coordinate '{coord}' not found in other.") + + other = unstack(other) + for coord in buffer.keys(): + coord_min = ds[coord].min() - buffer[coord] + coord_max = ds[coord].max() + buffer[coord] + other = other.sel({coord: slice(coord_min, coord_max)}) + return other + + +def unstack(ds): + for dim in ds.dims: + if dim not in ds.coords: + indexer = {dim: [coord for coord in ds.coords if dim in ds[coord].dims]} + ds = ds.set_index(indexer) + ds = ds.unstack(dim) + return ds \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index c78ce84..81b2a66 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,4 +1,8 @@ import pytest +import dask.array as da +import numpy as np +import pandas as pd +import xarray as xr from stmtools import utils @@ -19,3 +23,102 @@ def test_has_not(self, stmat): def test_incorrect_type(self, stmat): with pytest.raises(ValueError): utils._has_property(stmat, 1) + +@pytest.fixture +def meteo_points(): + n_times = 20 + n_locations = 50 + lon_values = np.arange(n_locations) + lat_values = np.arange(n_locations) + time_values = pd.date_range(start='2021-01-01', periods=n_times) + + return xr.Dataset( + data_vars=dict( + temperature=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), + ), + coords=dict( + lon=(["space"], lon_values), + lat=(["space"], lat_values), + time=(["time"], time_values), + ), + ).unify_chunks() + + + +@pytest.fixture +def meteo_raster(): + n_times = 20 + n_locations = 50 + lon_values = np.arange(n_locations) + lat_values = np.arange(n_locations) + time_values = pd.date_range(start='2021-01-01', periods=n_times) + # add x and y values + x_values = np.arange(n_locations) + y_values = np.arange(n_locations) + + return xr.Dataset( + data_vars=dict( + temperature=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), + ), + coords=dict( + lon=(["lon"], lon_values), + lat=(["lat"], lat_values), + x=(["lon"], x_values), + y=(["lat"], y_values), + time=(["time"], time_values), + ), + ).unify_chunks() + +@pytest.fixture +def stmat(): + npoints = 10 + ntime = 5 + return xr.Dataset( + data_vars=dict( + amplitude=( + ["space", "time"], + da.arange(npoints * ntime).reshape((npoints, ntime)), + ), + phase=( + ["space", "time"], + da.arange(npoints * ntime).reshape((npoints, ntime)), + ), + pnt_height=( + ["space"], + da.arange(npoints), + ), + ), + coords=dict( + lon=(["space"], da.arange(npoints)), + lat=(["space"], da.arange(npoints)), + time=(["time"], pd.date_range(start='2021-01-02', periods=ntime)), + ), + ).unify_chunks() + +class TestCrop: + def test_crop_points(self, stmat, meteo_points): + buffer = {"lon": 1, "lat": 1, "time": 1} + cropped = utils.crop(stmat, meteo_points, buffer) + # check min and max values of coordinates + assert cropped.lon.min() == 0 + assert cropped.lon.max() == 10 + assert cropped.lat.min() == 0 + assert cropped.lat.max() == 10 + assert cropped.time.min() == pd.Timestamp("2021-01-02") + assert cropped.time.max() == pd.Timestamp("2021-01-06") + + def test_crop_raster(self, stmat, meteo_raster): + buffer = {"lon": 1, "lat": 1, "time": 1} + cropped = utils.crop(stmat, meteo_raster, buffer) + # check min and max values of coordinates + assert cropped.lon.min() == 0 + assert cropped.lon.max() == 10 + assert cropped.lat.min() == 0 + assert cropped.lat.max() == 10 + assert cropped.time.min() == pd.Timestamp("2021-01-02") + assert cropped.time.max() == pd.Timestamp("2021-01-06") + + def test_all_operations_lazy(self, stmat, meteo_raster): + buffer = {"lon": 1, "lat": 1, "time": 1} + cropped = utils.crop(stmat, meteo_raster, buffer) + assert isinstance(cropped.temperature.data, da.Array) From 605db7e82e7606b3483b5d076edda773abe9f7a5 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 22 Mar 2024 17:50:38 +0100 Subject: [PATCH 10/33] fix stm enrich function --- stmtools/stm.py | 42 ++++++++++++++++------------------------- tests/test_stm.py | 48 +++++++++++++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 38 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index fad5815..3ae0346 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -12,6 +12,7 @@ from shapely.geometry import Point from shapely.strtree import STRtree +from stmtools import utils from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -403,7 +404,8 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): def enrich_from_dataset(self, dataset: xr.Dataset | xr.DataArray, fields: str | Iterable, - method="nearest") -> xr.Dataset: + method="nearest", + ) -> xr.Dataset: """Enrich the SpaceTimeMatrix from one or more fields of a dataset. scipy is required. if dataset is raster, it uses @@ -420,7 +422,6 @@ def enrich_from_dataset(self, method : str, optional Method of interpolation, by default "nearest", see https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like - Returns ------- xarray.Dataset @@ -454,7 +455,7 @@ def enrich_from_dataset(self, else: raise ValueError( "The input dataset is not a point or raster dataset." - "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." + "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." # give help on renaming ) # check if dataset has time dimensions @@ -476,25 +477,12 @@ def enrich_from_dataset(self, # check STM has the filed already if field in ds.data_vars.keys(): raise ValueError(f'Field "{field}" already exists in the STM.') - - # if dataset is a dask collection, compute it first + # TODO: overwrite the field in the STM if approch == "raster": return _enrich_from_raster_block(ds, dataset, fields, method) - # return xr.map_blocks( - # _enrich_from_raster_block, - # ds, - # args=(fields, method), - # kwargs={"dataset": dataset}, #TODD: block still not working, refactor - # ) elif approch == "point": return _enrich_from_points_block(ds, dataset, fields) - # return xr.map_blocks( - # _enrich_from_points_block, - # ds, - # args=(fields), - # kwargs={"dataset": dataset}, - # ) @property def num_points(self): @@ -675,7 +663,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): scipy is required. It uses xarray.Dataset.interp_like to interpolate the raster dataset to the coordinates of ds. - https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html Parameters ---------- @@ -686,7 +674,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): fields : str or list of str Field name(s) in the dataset for enrichment method : str, optional - Method of interpolation, by default "linear", see + Method of interpolation, by default "nearest", see Returns ------- @@ -719,13 +707,15 @@ def _enrich_from_points_block(ds, datapoints, fields): ------- xarray.Dataset """ - - # add spatial coordinates to dims - datapoints_coords = list(datapoints.coords.keys()) - datapoints = datapoints.set_index(space=datapoints_coords[:-1]) # assuming the last coordinate is time - datapoints = datapoints.unstack("space") # after this, the order of coordinates changes, so we use transpose later - - indexers = {coord: ds[coord] for coord in datapoints_coords} + # unstak the dimensions + for dim in datapoints.dims: + if dim not in datapoints.coords: + indexer = {dim: [coord for coord in datapoints.coords if dim in datapoints[coord].dims]} + datapoints = datapoints.set_index(indexer) + datapoints = datapoints.unstack(dim) + + # do selection + indexers = {coord: ds[coord] for coord in list(datapoints.coords.keys())} selections = datapoints.sel(indexers, method="nearest") # Assign these values to the corresponding points in ds diff --git a/tests/test_stm.py b/tests/test_stm.py index 8ba5a95..1322f8c 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -9,6 +9,7 @@ from shapely import geometry from stmtools.stm import _validate_coords +from stmtools.utils import crop path_multi_polygon = Path(__file__).parent / "./data/multi_polygon.gpkg" @@ -255,15 +256,18 @@ def stmat_lonlat_morton(): ).unify_chunks() @pytest.fixture + def meteo_points(): - lon_values = np.array([0.5, 1.5, 2.5, 3.5, 4.5]) - lat_values = np.array([0.25, 1.25, 2.25, 3.25, 4.25]) - time_values = pd.date_range(start='2021-01-01', periods=6) + n_times = 20 + n_locations = 50 + lon_values = np.arange(0, n_locations/2, 0.5) + lat_values = np.arange(0, n_locations/2, 0.5) + time_values = pd.date_range(start='2021-01-01', periods=n_times) return xr.Dataset( data_vars=dict( - temperature=(["space", "time"], da.arange(5 * 6).reshape((5, 6))), - humidity=(["space", "time"], da.arange(5 * 6).reshape((5, 6))), + temperature=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), + humidity=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), ), coords=dict( lon=(["space"], lon_values), @@ -274,19 +278,25 @@ def meteo_points(): @pytest.fixture def meteo_raster(): - # create a raster with 5x5 grid - lon_values = np.array([0, 1, 2, 3, 4]) - lat_values = np.array([0, 1, 2, 3, 4]) - time_values = pd.date_range(start='2021-01-01', periods=6) + n_times = 20 + n_locations = 50 + lon_values = np.arange(n_locations) + lat_values = np.arange(n_locations) + time_values = pd.date_range(start='2021-01-01', periods=n_times) + # add x and y values + x_values = np.arange(n_locations) + y_values = np.arange(n_locations) return xr.Dataset( data_vars=dict( - temperature=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))), - humidity=(["lon", "lat", "time"], da.arange(5 * 5 * 6).reshape((5, 5, 6))), + temperature=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), + humidity=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), ), coords=dict( lon=(["lon"], lon_values), lat=(["lat"], lat_values), + x=(["lon"], x_values), + y=(["lat"], y_values), time=(["time"], time_values), ), ).unify_chunks() @@ -606,6 +616,14 @@ def test_all_operations_lazy(self, stmat, meteo_points): assert stmat_enriched.lat.data.dask is not None # dont check time because it is not a dask array + def test_enrich_from_point_cropped(self, stmat, meteo_points): + buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} + print(stmat.lat.values) + meteo_points_cropped = crop(stmat, meteo_points, buffer) + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points_cropped, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_points_cropped.temperature[0, 1] + + class TestEnrichmentFromRasterDataset: def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, "temperature") @@ -684,4 +702,10 @@ def test_all_operations_lazy(self, stmat, meteo_raster): assert stmat_enriched.temperature.data.dask is not None assert stmat_enriched.lon.data.dask is not None assert stmat_enriched.lat.data.dask is not None - # dont check time because it is not a dask array \ No newline at end of file + # dont check time because it is not a dask array + + def test_enrich_from_raste_cropped(self, stmat, meteo_raster): + buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} + meteo_raster_cropped = crop(stmat, meteo_raster, buffer) + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster_cropped, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_raster_cropped.temperature[0, 0, 1] From dc33ecb675da8def086cb45255cd7a10436f54cc Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 22 Mar 2024 17:51:08 +0100 Subject: [PATCH 11/33] fix and refactor util function for cropping --- stmtools/utils.py | 32 ++++++++++++++++++++++---------- tests/test_util.py | 18 ++++++++++++++---- 2 files changed, 36 insertions(+), 14 deletions(-) diff --git a/stmtools/utils.py b/stmtools/utils.py index e39e93f..95bdf3f 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -42,18 +42,30 @@ def crop(ds, other, buffer): if coord not in other.coords.keys(): raise ValueError(f"coordinate '{coord}' not found in other.") - other = unstack(other) + original_dims_order = other.dims + + # for dims that are not in coords, unstack the data + indexer = {} + for dim in other.dims: + if dim not in other.coords.keys(): + indexer = {dim: [coord for coord in other.coords.keys() if dim in other.coords[coord].dims]} + other = other.set_index(indexer) + other = other.unstack(indexer) + + # do the slicing for coord in buffer.keys(): coord_min = ds[coord].min() - buffer[coord] coord_max = ds[coord].max() + buffer[coord] - other = other.sel({coord: slice(coord_min, coord_max)}) - return other + if coord in other.dims: + other = other.sel({coord: slice(coord_min, coord_max)}) + # stack back + for dim, coords in indexer.items(): + for coord in coords: + coord_value = xr.DataArray(other.coords[coord].values, dims=dim) + other = other.sel({coord: coord_value}) -def unstack(ds): - for dim in ds.dims: - if dim not in ds.coords: - indexer = {dim: [coord for coord in ds.coords if dim in ds[coord].dims]} - ds = ds.set_index(indexer) - ds = ds.unstack(dim) - return ds \ No newline at end of file + # transpose the dimensions back to the original order + other = other.transpose(*original_dims_order) + + return other diff --git a/tests/test_util.py b/tests/test_util.py index 81b2a66..cf0e845 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -44,7 +44,6 @@ def meteo_points(): ).unify_chunks() - @pytest.fixture def meteo_raster(): n_times = 20 @@ -97,15 +96,19 @@ def stmat(): class TestCrop: def test_crop_points(self, stmat, meteo_points): - buffer = {"lon": 1, "lat": 1, "time": 1} + buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} cropped = utils.crop(stmat, meteo_points, buffer) # check min and max values of coordinates assert cropped.lon.min() == 0 assert cropped.lon.max() == 10 assert cropped.lat.min() == 0 assert cropped.lat.max() == 10 - assert cropped.time.min() == pd.Timestamp("2021-01-02") - assert cropped.time.max() == pd.Timestamp("2021-01-06") + assert cropped.time.min() == pd.Timestamp("2021-01-01") + assert cropped.time.max() == pd.Timestamp("2021-01-07") + assert tuple(cropped.temperature.dims) == ("space", "time") + assert "lon" in cropped.coords + assert "lat" in cropped.coords + assert "time" in cropped.coords def test_crop_raster(self, stmat, meteo_raster): buffer = {"lon": 1, "lat": 1, "time": 1} @@ -117,6 +120,13 @@ def test_crop_raster(self, stmat, meteo_raster): assert cropped.lat.max() == 10 assert cropped.time.min() == pd.Timestamp("2021-01-02") assert cropped.time.max() == pd.Timestamp("2021-01-06") + assert tuple(cropped.dims) == ("lon", "lat", "time") + assert "lon" in cropped.coords + assert "lat" in cropped.coords + assert "time" in cropped.coords + assert "x" in cropped.coords + assert "y" in cropped.coords + def test_all_operations_lazy(self, stmat, meteo_raster): buffer = {"lon": 1, "lat": 1, "time": 1} From 200285401c2ed9ce7856d4fbf9e2e28847c3e6e2 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 22 Mar 2024 17:52:59 +0100 Subject: [PATCH 12/33] fix an error msg --- stmtools/stm.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/stmtools/stm.py b/stmtools/stm.py index 3ae0346..70462f4 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -456,6 +456,8 @@ def enrich_from_dataset(self, raise ValueError( "The input dataset is not a point or raster dataset." "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." # give help on renaming + "Consider renaming using " + "https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename" ) # check if dataset has time dimensions From 9113b0237f7bf6119699abd36d624f5a429ad851 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 22 Mar 2024 17:56:43 +0100 Subject: [PATCH 13/33] fix linter errors --- stmtools/stm.py | 37 ++++++++++++++++++++++++++++++------- stmtools/utils.py | 10 ++++++++-- 2 files changed, 38 insertions(+), 9 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 70462f4..8c42cae 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -12,7 +12,6 @@ from shapely.geometry import Point from shapely.strtree import STRtree -from stmtools import utils from stmtools.metadata import DataVarTypes, STMMetaData from stmtools.utils import _has_property @@ -39,6 +38,7 @@ def add_metadata(self, metadata): ------- xarray.Dataset STM with assigned attributes. + """ self._obj = self._obj.assign_attrs(metadata) return self._obj @@ -70,6 +70,7 @@ def regulate_dims(self, space_label=None, time_label=None): ------- xarray.Dataset Regulated STM. + """ if ( (space_label is None) @@ -129,6 +130,7 @@ def subset(self, method: str, **kwargs): ------- xarray.Dataset A subset of the original STM. + """ # Check if both "space" and "time" dimension exists for dim in ["space", "time"]: @@ -204,6 +206,7 @@ def enrich_from_polygon(self, polygon, fields, xlabel="lon", ylabel="lat"): ------- xarray.Dataset Enriched STM. + """ _ = _validate_coords(self._obj, xlabel, ylabel) @@ -267,6 +270,7 @@ def _in_polygon(self, polygon, xlabel="lon", ylabel="lat"): ------- Dask.array A boolean Dask array. True where a space entry is inside the (multi-)polygon. + """ # Check if coords exists _ = _validate_coords(self._obj, xlabel, ylabel) @@ -312,6 +316,7 @@ def register_metadata(self, dict_meta: STMMetaData): ------- xarray.Dataset STM with registered metadata. + """ ds_updated = self._obj.assign_attrs(dict_meta) @@ -331,6 +336,7 @@ def register_datatype(self, keys: str | Iterable, datatype: DataVarTypes): ------- xarray.Dataset STM with registered metadata. + """ ds_updated = self._obj @@ -364,6 +370,7 @@ def get_order(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): Scaling multiplier to the x coordinates before truncating them to integer values. yscale : float Scaling multiplier to the y coordinates before truncating them to integer values. + """ meta_arr = np.array((), dtype=np.int64) order = da.apply_gufunc( @@ -396,6 +403,7 @@ def reorder(self, xlabel="azimuth", ylabel="range", xscale=1.0, yscale=1.0): Scaling multiplier to the x coordinates before truncating them to integer values. yscale : float Scaling multiplier to the y coordinates before truncating them to integer values. + """ self._obj = self.get_order(xlabel, ylabel, xscale, yscale) self._obj = self._obj.sortby(self._obj.order) @@ -422,10 +430,12 @@ def enrich_from_dataset(self, method : str, optional Method of interpolation, by default "nearest", see https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + Returns ------- xarray.Dataset Enriched STM. + """ # Check if fields is a Iterable or a str if isinstance(fields, str): @@ -455,7 +465,7 @@ def enrich_from_dataset(self, else: raise ValueError( "The input dataset is not a point or raster dataset." - "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." # give help on renaming + "The dataset should have either 'space' or 'lat/y' and 'lon/x' dimensions." "Consider renaming using " "https://docs.xarray.dev/en/latest/generated/xarray.Dataset.rename.html#xarray-dataset-rename" ) @@ -494,6 +504,7 @@ def num_points(self): ------- int Number of space entry. + """ return self._obj.dims["space"] @@ -505,6 +516,7 @@ def num_epochs(self): ------- int Number of epochs. + """ return self._obj.dims["time"] @@ -558,6 +570,7 @@ def _ml_str_query(xx, yy, polygon, type_polygon): An array with two columns. The first column is the positional index into the list of polygons being used to query the tree. The second column is the positional index into the list of space entries for which the tree was constructed. + """ # Crop the polygon to the bounding box of the block xmin, ymin, xmax, ymax = [ @@ -623,6 +636,7 @@ def _validate_coords(ds, xlabel, ylabel): ------ ValueError If xlabel or ylabel neither exists in coordinates, raise ValueError + """ for clabel in [xlabel, ylabel]: if clabel not in ds.coords.keys(): @@ -655,6 +669,7 @@ def _compute_morton_code(xx, yy): ------- array_like An array with Morton codes per coordinate pair. + """ code = [pm.interleave(int(xi), int(yi)) for xi, yi in zip(xx, yy, strict=True)] return code @@ -670,8 +685,8 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): Parameters ---------- ds : xarray.Dataset - - dataset : xarray.Dataset | xarray.DataArray + SpaceTimeMatrix to enrich + dataraster : xarray.Dataset | xarray.DataArray Input data for enrichment fields : str or list of str Field name(s) in the dataset for enrichment @@ -681,6 +696,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): Returns ------- xarray.Dataset + """ # interpolate the raster dataset to the coordinates of ds interpolated = dataraster.interp(ds.coords, method=method) @@ -699,7 +715,7 @@ def _enrich_from_points_block(ds, datapoints, fields): Parameters ---------- ds : xarray.Dataset - + SpaceTimeMatrix to enrich datapoints : xarray.Dataset | xarray.DataArray Input data for enrichment fields : str or list of str @@ -708,11 +724,16 @@ def _enrich_from_points_block(ds, datapoints, fields): Returns ------- xarray.Dataset + """ # unstak the dimensions for dim in datapoints.dims: if dim not in datapoints.coords: - indexer = {dim: [coord for coord in datapoints.coords if dim in datapoints[coord].dims]} + indexer = { + dim: [ + coord for coord in datapoints.coords if dim in datapoints[coord].dims + ] + } datapoints = datapoints.set_index(indexer) datapoints = datapoints.unstack(dim) @@ -722,6 +743,8 @@ def _enrich_from_points_block(ds, datapoints, fields): # Assign these values to the corresponding points in ds for field in fields: - ds[field] = xr.DataArray(selections[field].data.transpose(), dims=ds.dims, coords=ds.coords) + ds[field] = xr.DataArray( + selections[field].data.transpose(), dims=ds.dims, coords=ds.coords + ) return ds diff --git a/stmtools/utils.py b/stmtools/utils.py index 95bdf3f..c41e261 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -1,6 +1,7 @@ -import xarray as xr from collections.abc import Iterable +import xarray as xr + def _has_property(ds, keys: str | Iterable): if isinstance(keys, str): @@ -27,6 +28,7 @@ def crop(ds, other, buffer): ------- xarray.Dataset Cropped dataset. + """ if isinstance(ds, xr.DataArray): ds = ds.to_dataset() @@ -48,7 +50,11 @@ def crop(ds, other, buffer): indexer = {} for dim in other.dims: if dim not in other.coords.keys(): - indexer = {dim: [coord for coord in other.coords.keys() if dim in other.coords[coord].dims]} + indexer = { + dim: [ + coord for coord in other.coords.keys() if dim in other.coords[coord].dims + ] + } other = other.set_index(indexer) other = other.unstack(indexer) From 86d7b6d864be496a9676f5976c0730b13831eead Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 25 Mar 2024 09:12:40 +0100 Subject: [PATCH 14/33] fix linters --- tests/test_stm.py | 15 ++++++++++----- tests/test_util.py | 10 +++++++--- 2 files changed, 17 insertions(+), 8 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index 1322f8c..5e8911d 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -3,9 +3,9 @@ import dask.array as da import geopandas as gpd import numpy as np +import pandas as pd import pytest import xarray as xr -import pandas as pd from shapely import geometry from stmtools.stm import _validate_coords @@ -263,11 +263,12 @@ def meteo_points(): lon_values = np.arange(0, n_locations/2, 0.5) lat_values = np.arange(0, n_locations/2, 0.5) time_values = pd.date_range(start='2021-01-01', periods=n_times) + data = da.arange(n_locations * n_times).reshape((n_locations, n_times)) return xr.Dataset( data_vars=dict( - temperature=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), - humidity=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), + temperature=(["space", "time"], data), + humidity=(["space", "time"], data), ), coords=dict( lon=(["space"], lon_values), @@ -287,10 +288,14 @@ def meteo_raster(): x_values = np.arange(n_locations) y_values = np.arange(n_locations) + data = da.arange(n_locations * n_locations * n_times).reshape( + (n_locations, n_locations, n_times) + ) + return xr.Dataset( data_vars=dict( - temperature=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), - humidity=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), + temperature=(["lon", "lat", "time"], data), + humidity=(["lon", "lat", "time"], data), ), coords=dict( lon=(["lon"], lon_values), diff --git a/tests/test_util.py b/tests/test_util.py index cf0e845..c914545 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -1,7 +1,7 @@ -import pytest import dask.array as da import numpy as np import pandas as pd +import pytest import xarray as xr from stmtools import utils @@ -31,10 +31,11 @@ def meteo_points(): lon_values = np.arange(n_locations) lat_values = np.arange(n_locations) time_values = pd.date_range(start='2021-01-01', periods=n_times) + data = da.arange(n_locations * n_times).reshape((n_locations, n_times)) return xr.Dataset( data_vars=dict( - temperature=(["space", "time"], da.arange(n_locations * n_times).reshape((n_locations, n_times))), + temperature=(["space", "time"], data), ), coords=dict( lon=(["space"], lon_values), @@ -54,10 +55,13 @@ def meteo_raster(): # add x and y values x_values = np.arange(n_locations) y_values = np.arange(n_locations) + data = da.arange(n_locations * n_locations * n_times).reshape( + (n_locations, n_locations, n_times) + ) return xr.Dataset( data_vars=dict( - temperature=(["lon", "lat", "time"], da.arange(n_locations * n_locations * n_times).reshape((n_locations, n_locations, n_times))), + temperature=(["lon", "lat", "time"], data), ), coords=dict( lon=(["lon"], lon_values), From da020a4ec5c1f85d8b390ad7e8cc9ee413f98c02 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 25 Mar 2024 09:18:03 +0100 Subject: [PATCH 15/33] remove scipy because it is included in xarray io --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 58eb8d9..9b30193 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,7 +16,6 @@ dependencies = [ "zarr", "distributed", "pymorton", - "scipy", # required for `enrich_from_dataset` method ] description = "space-time matrix for PS-InSAR application" readme = "README.md" From 3dc17562823e18a9900ecf7fb3799334bd2d2222 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 25 Mar 2024 09:23:44 +0100 Subject: [PATCH 16/33] fix linter errors in _io --- stmtools/_io.py | 1 + 1 file changed, 1 insertion(+) diff --git a/stmtools/_io.py b/stmtools/_io.py index 2736b24..b0f5940 100644 --- a/stmtools/_io.py +++ b/stmtools/_io.py @@ -66,6 +66,7 @@ def from_csv( Returns: ------- xr.Dataset: Output STM instance + """ # Load csv as Dask DataFrame ddf = dd.read_csv(file, blocksize=blocksize) From 0691161e7cf327847f4704b7cd4a7cb3ee8d11d5 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 25 Mar 2024 09:36:09 +0100 Subject: [PATCH 17/33] fix minor things --- stmtools/stm.py | 2 +- tests/test_stm.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 8c42cae..5b0afb4 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -429,7 +429,7 @@ def enrich_from_dataset(self, Field name(s) in the dataset for enrichment method : str, optional Method of interpolation, by default "nearest", see - https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp_like.html#xarray-dataset-interp-like + https://docs.xarray.dev/en/stable/generated/xarray.Dataset.interp.html Returns ------- diff --git a/tests/test_stm.py b/tests/test_stm.py index 5e8911d..1ee7e6f 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -623,7 +623,6 @@ def test_all_operations_lazy(self, stmat, meteo_points): def test_enrich_from_point_cropped(self, stmat, meteo_points): buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} - print(stmat.lat.values) meteo_points_cropped = crop(stmat, meteo_points, buffer) stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points_cropped, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points_cropped.temperature[0, 1] From 8c473a7db15daf1e19c80590ac4768d8feeebd7b Mon Sep 17 00:00:00 2001 From: SarahAlidoost <55081872+SarahAlidoost@users.noreply.github.com> Date: Mon, 8 Apr 2024 11:27:06 +0200 Subject: [PATCH 18/33] Update stmtools/stm.py Co-authored-by: Ou Ku --- stmtools/stm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 5b0afb4..3c62ef1 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -726,7 +726,7 @@ def _enrich_from_points_block(ds, datapoints, fields): xarray.Dataset """ - # unstak the dimensions + # unstack the dimensions for dim in datapoints.dims: if dim not in datapoints.coords: indexer = { From 92d53016e43c990037cad934878a17f535ee45ed Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 8 Apr 2024 13:49:20 +0200 Subject: [PATCH 19/33] add two utils functions for checking coordinates --- stmtools/utils.py | 41 +++++++++++++++++++++++++++++++++++++++++ tests/test_util.py | 40 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 81 insertions(+) diff --git a/stmtools/utils.py b/stmtools/utils.py index c41e261..186cb6a 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -75,3 +75,44 @@ def crop(ds, other, buffer): other = other.transpose(*original_dims_order) return other + + +def monotonic_coords(ds, dim: str): + """Check if the dataset is monotonic in the given dimension. + + Parameters + ---------- + ds : xarray.Dataset + Dataset to check. + dim : str + Dimension to check. + + Returns + ------- + bool + True if the dataset is monotonic, False otherwise. + + """ + return bool( + ds[dim].to_index().is_monotonic_increasing + or ds[dim].to_index().is_monotonic_decreasing + ) + + +def unique_coords(ds, dim: str ): + """Check if the dataset has unique coordinates in the given dimension. + + Parameters + ---------- + ds : xarray.Dataset + Dataset to check. + dim : str + Dimension to check. + + Returns + ------- + bool + True if the dataset has unique coordinates, False otherwise. + + """ + return bool(ds[dim].to_index().is_unique) \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index c914545..821d7c3 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -136,3 +136,43 @@ def test_all_operations_lazy(self, stmat, meteo_raster): buffer = {"lon": 1, "lat": 1, "time": 1} cropped = utils.crop(stmat, meteo_raster, buffer) assert isinstance(cropped.temperature.data, da.Array) + + +class TestMonotonicCoords: + def test_monotonic_coords(self, stmat): + assert utils.monotonic_coords(stmat, "lon") + assert utils.monotonic_coords(stmat, "lat") + assert utils.monotonic_coords(stmat, "time") + + def test_non_monotonic_coords_lon(self, stmat): + stmat["lon"][0] = 100 + assert not utils.monotonic_coords(stmat, "lon") + + def test_non_monotonic_coords_lat(self, stmat): + stmat["lat"][0] = 100 + stmat["lat"][1] = 50 + assert not utils.monotonic_coords(stmat, "lat") + + def test_non_monotonic_coords_time(self, stmat): + stmat["time"].values[0] = '2022-01-02T00:00:00.000000000' + stmat["time"].values[1] = '2022-01-01T00:00:00.000000000' + assert not utils.monotonic_coords(stmat, "time") + + +class TestUniqueCoords: + def test_unique_coords(self, stmat): + assert utils.unique_coords(stmat, "lon") + assert utils.unique_coords(stmat, "lat") + assert utils.unique_coords(stmat, "time") + + def test_non_unique_coords_lon(self, stmat): + stmat["lon"][0] = 1 + assert not utils.unique_coords(stmat, "lon") + + def test_non_unique_coords_lat(self, stmat): + stmat["lon"][0] = 1 + assert not utils.unique_coords(stmat, "lat") + + def test_non_unique_coords_time(self, stmat): + stmat["time"].values[0] = '2021-01-03T00:00:00.000000000' + assert not utils.unique_coords(stmat, "time") From 70089d1aebd05cb40c3ea85b3aac1bb3a6317439 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 8 Apr 2024 13:54:25 +0200 Subject: [PATCH 20/33] fix test unique coords in test_util --- tests/test_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_util.py b/tests/test_util.py index 821d7c3..b76f82b 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -170,7 +170,7 @@ def test_non_unique_coords_lon(self, stmat): assert not utils.unique_coords(stmat, "lon") def test_non_unique_coords_lat(self, stmat): - stmat["lon"][0] = 1 + stmat["lat"][0] = 1 assert not utils.unique_coords(stmat, "lat") def test_non_unique_coords_time(self, stmat): From b85415aedd0a582b2b956c3a2410b24023c647bd Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Mon, 8 Apr 2024 13:57:11 +0200 Subject: [PATCH 21/33] add a check if coords are monotonic and unigue is stm --- stmtools/stm.py | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 3c62ef1..5841d7e 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -13,7 +13,7 @@ from shapely.strtree import STRtree from stmtools.metadata import DataVarTypes, STMMetaData -from stmtools.utils import _has_property +from stmtools.utils import _has_property, monotonic_coords, unique_coords logger = logging.getLogger(__name__) @@ -739,6 +739,14 @@ def _enrich_from_points_block(ds, datapoints, fields): # do selection indexers = {coord: ds[coord] for coord in list(datapoints.coords.keys())} + + # check if coords in indexers are monotonic and unique + for coord in indexers: + if not monotonic_coords(datapoints, coord): + raise ValueError(f"Coordinate {coord} is not monotonic.") + if not unique_coords(datapoints, coord): + raise ValueError(f"Coordinate {coord} is not unique.") + selections = datapoints.sel(indexers, method="nearest") # Assign these values to the corresponding points in ds From c030cf96589898896755f93b6591a4428045076c Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Wed, 10 Apr 2024 12:21:53 +0200 Subject: [PATCH 22/33] use scipy KDTree instead of xarray unstack and sel functions --- stmtools/stm.py | 53 ++++++++++++++++++++++++++++++------------------- 1 file changed, 33 insertions(+), 20 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 5841d7e..31c7ed3 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -9,6 +9,7 @@ import numpy as np import pymorton as pm import xarray as xr +from scipy.spatial import KDTree from shapely.geometry import Point from shapely.strtree import STRtree @@ -710,7 +711,8 @@ def _enrich_from_raster_block(ds, dataraster, fields, method): def _enrich_from_points_block(ds, datapoints, fields): """Enrich the ds (SpaceTimeMatrix) from one or more fields of a point dataset. - https://docs.xarray.dev/en/latest/generated/xarray.DataArray.sel.html#xarray.DataArray.sel + Assumption is that dimensions of data are space and time. + https://docs.scipy.org/doc/scipy/reference/generated/scipy.spatial.KDTree.html#scipy.spatial.KDTree Parameters ---------- @@ -726,33 +728,44 @@ def _enrich_from_points_block(ds, datapoints, fields): xarray.Dataset """ - # unstack the dimensions - for dim in datapoints.dims: + ## The reason that we use KDTRee instead of xarray.unstack is that the latter is slow for large datasets + # check the dimensions + indexer = {} + for dim in ["space", "time"]: if dim not in datapoints.coords: - indexer = { - dim: [ - coord for coord in datapoints.coords if dim in datapoints[coord].dims - ] - } - datapoints = datapoints.set_index(indexer) - datapoints = datapoints.unstack(dim) + indexer[dim]= [ + coord for coord in datapoints.coords if dim in datapoints[coord].dims + ] + else: + indexer[dim] = [dim] + + ## datapoints + indexes = [datapoints[coord] for coord in indexer["space"]] + dataset_points_coords = np.column_stack(indexes) + + # ds + indexes = [ds[coord] for coord in indexer["space"]] + ds_coords = np.column_stack(indexes) - # do selection - indexers = {coord: ds[coord] for coord in list(datapoints.coords.keys())} + # Create a KDTree object for the spatial coordinates of datapoints + # Find the indices of the nearest points in space in datapoints for each point in ds + # it uses Euclidean distance + tree = KDTree(dataset_points_coords) + _, indices_space = tree.query(ds_coords) - # check if coords in indexers are monotonic and unique - for coord in indexers: - if not monotonic_coords(datapoints, coord): - raise ValueError(f"Coordinate {coord} is not monotonic.") - if not unique_coords(datapoints, coord): - raise ValueError(f"Coordinate {coord} is not unique.") + # Create a KDTree object for the temporal coordinates of datapoints + # Find the indices of the nearest points in time in datapoints for each point in ds + datapoints_times = datapoints.time.values.reshape(-1, 1) + ds_times = ds.time.values.reshape(-1, 1) + tree = KDTree(datapoints_times) + _, indices_time = tree.query(ds_times) - selections = datapoints.sel(indexers, method="nearest") + selections = datapoints.isel(space=indices_space, time=indices_time) # Assign these values to the corresponding points in ds for field in fields: ds[field] = xr.DataArray( - selections[field].data.transpose(), dims=ds.dims, coords=ds.coords + selections[field].data, dims=ds.dims, coords=ds.coords ) return ds From f12333603500f26d657543e5bb979cedf20f6f62 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Wed, 10 Apr 2024 12:23:13 +0200 Subject: [PATCH 23/33] fix linter errors --- stmtools/stm.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index 31c7ed3..d13df47 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -14,7 +14,7 @@ from shapely.strtree import STRtree from stmtools.metadata import DataVarTypes, STMMetaData -from stmtools.utils import _has_property, monotonic_coords, unique_coords +from stmtools.utils import _has_property logger = logging.getLogger(__name__) @@ -728,7 +728,9 @@ def _enrich_from_points_block(ds, datapoints, fields): xarray.Dataset """ - ## The reason that we use KDTRee instead of xarray.unstack is that the latter is slow for large datasets + # The reason that we use KDTRee instead of xarray.unstack is that the latter + # is slow for large datasets + # check the dimensions indexer = {} for dim in ["space", "time"]: From 73acbbd99b692d79e70478e6a51fd7bd5ec0a14e Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Thu, 11 Apr 2024 15:43:14 +0200 Subject: [PATCH 24/33] add test for non monotonic an dduplicates coords --- tests/test_stm.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/tests/test_stm.py b/tests/test_stm.py index 1ee7e6f..4a6e1b5 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -627,6 +627,30 @@ def test_enrich_from_point_cropped(self, stmat, meteo_points): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points_cropped, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points_cropped.temperature[0, 1] + def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): + # make the coordinates non-monotonic + meteo_points["lon"][0] = 25 + meteo_points["lat"][0] = 25 + + stmat["lon"][0] = 25 + stmat["lat"][0] = 25 + + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + + def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): + # make the coordinates duplicates, + # now both locations 0 and 1 have the same coords but different + # temperature values + meteo_points["lon"][0] = 0.5 + meteo_points["lat"][0] = 0.5 + + stmat["lon"][0] = 0.5 + stmat["lat"][0] = 0.5 + + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] + class TestEnrichmentFromRasterDataset: def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): From 3d01481f1f6949b3dd9e7a2e6be96cc3cac4391f Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Thu, 11 Apr 2024 16:23:14 +0200 Subject: [PATCH 25/33] add a test for non monotonic time --- tests/test_stm.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/test_stm.py b/tests/test_stm.py index 4a6e1b5..3a95c66 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -651,6 +651,14 @@ def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] + def test_enrichfrom_point_nanmonotonic_times(self, stmat, meteo_points): + # make the time non-monotonic + meteo_points["time"].values[0] = pd.Timestamp("2022-01-01") + stmat["time"].values[0] = pd.Timestamp("2022-01-01") + + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 0] + class TestEnrichmentFromRasterDataset: def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): From 9b3ddd8791f3d253d3bf968bc71c06663b876766 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 10:08:21 +0200 Subject: [PATCH 26/33] add type to coordinates in tests --- tests/test_stm.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index 3a95c66..8e2681c 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -326,8 +326,8 @@ def stmat(): ), ), coords=dict( - lon=(["space"], da.arange(npoints)), - lat=(["space"], da.arange(npoints)), + lon=(["space"], da.arange(npoints).astype('float64')), + lat=(["space"], da.arange(npoints).astype('float64')), time=(["time"], pd.date_range(start='2021-01-02', periods=ntime)), ), ).unify_chunks() @@ -629,11 +629,11 @@ def test_enrich_from_point_cropped(self, stmat, meteo_points): def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): # make the coordinates non-monotonic - meteo_points["lon"][0] = 25 - meteo_points["lat"][0] = 25 + meteo_points["lon"][0] = 25.0 + meteo_points["lat"][0] = 25.0 - stmat["lon"][0] = 25 - stmat["lat"][0] = 25 + stmat["lon"][0] = 25.0 + stmat["lat"][0] = 25.0 stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] From 45e1900477e0202f3bf5fb86b50bb83996009802 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 10:38:09 +0200 Subject: [PATCH 27/33] fix a linter error --- stmtools/stm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/stmtools/stm.py b/stmtools/stm.py index d13df47..b0160df 100644 --- a/stmtools/stm.py +++ b/stmtools/stm.py @@ -24,7 +24,7 @@ class SpaceTimeMatrix: """Space-Time Matrix.""" def __init__(self, xarray_obj): - """init.""" + """Init.""" self._obj = xarray_obj def add_metadata(self, metadata): From 0412f10d2e00e49e9e0bdee040f222605c275325 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 11:05:07 +0200 Subject: [PATCH 28/33] debug: add debuging to pytest in workflow, and comment the test to check on macos --- .github/workflows/build.yml | 2 +- tests/test_stm.py | 24 ++++++++++++------------ 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index f1140ce..751dba7 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: - name: Build the package run: python -m build - name: Test with pytest - run: python -m pytest + run: python -m pytest -vv build_doc: runs-on: ubuntu-latest diff --git a/tests/test_stm.py b/tests/test_stm.py index 8e2681c..2dee808 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -632,24 +632,24 @@ def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): meteo_points["lon"][0] = 25.0 meteo_points["lat"][0] = 25.0 - stmat["lon"][0] = 25.0 - stmat["lat"][0] = 25.0 + stmat["lon"][0] = 25.5 + stmat["lat"][0] = 25.5 stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] - def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): - # make the coordinates duplicates, - # now both locations 0 and 1 have the same coords but different - # temperature values - meteo_points["lon"][0] = 0.5 - meteo_points["lat"][0] = 0.5 + # def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): + # # make the coordinates duplicates, + # # now both locations 0 and 1 have the same coords but different + # # temperature values + # meteo_points["lon"][0] = 0.5 + # meteo_points["lat"][0] = 0.5 - stmat["lon"][0] = 0.5 - stmat["lat"][0] = 0.5 + # stmat["lon"][0] = 0.5 + # stmat["lat"][0] = 0.5 - stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] + # stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + # assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] def test_enrichfrom_point_nanmonotonic_times(self, stmat, meteo_points): # make the time non-monotonic From ea580a24a26cf55787d4610ead87e83e1925eb7f Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 11:12:18 +0200 Subject: [PATCH 29/33] debug comment the test to check on macos --- tests/test_stm.py | 36 ++++++++++++++++++------------------ 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index 2dee808..d2e86de 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -627,29 +627,29 @@ def test_enrich_from_point_cropped(self, stmat, meteo_points): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points_cropped, "temperature") assert stmat_enriched.temperature[0, 0] == meteo_points_cropped.temperature[0, 1] - def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): - # make the coordinates non-monotonic - meteo_points["lon"][0] = 25.0 - meteo_points["lat"][0] = 25.0 + # def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): + # # make the coordinates non-monotonic + # meteo_points["lon"][0] = 25.0 + # meteo_points["lat"][0] = 25.0 - stmat["lon"][0] = 25.5 - stmat["lat"][0] = 25.5 + # stmat["lon"][0] = 25.5 + # stmat["lat"][0] = 25.5 - stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + # stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + # assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] - # def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): - # # make the coordinates duplicates, - # # now both locations 0 and 1 have the same coords but different - # # temperature values - # meteo_points["lon"][0] = 0.5 - # meteo_points["lat"][0] = 0.5 + def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): + # make the coordinates duplicates, + # now both locations 0 and 1 have the same coords but different + # temperature values + meteo_points["lon"][0] = 0.5 + meteo_points["lat"][0] = 0.5 - # stmat["lon"][0] = 0.5 - # stmat["lat"][0] = 0.5 + stmat["lon"][0] = 0.5 + stmat["lat"][0] = 0.5 - # stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - # assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] def test_enrichfrom_point_nanmonotonic_times(self, stmat, meteo_points): # make the time non-monotonic From 9dc7804e44e8e4f0a3a03a4845b458cf449ea356 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 11:27:02 +0200 Subject: [PATCH 30/33] fix tests comparing values instead of data arrays --- tests/test_stm.py | 46 ++++++++++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 20 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index d2e86de..bbb4d0e 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -547,7 +547,7 @@ def test_enrich_from_dataset_one_filed(self, stmat, meteo_points): assert "temperature" in stmat_enriched.data_vars # check if the nearest method is correct - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[0, 1].values # check dimensions of stmat_enriched are the same as stmat assert stmat_enriched.dims == stmat.dims @@ -563,8 +563,8 @@ def test_enrich_from_dataset_multi_filed(self, stmat, meteo_points): assert "humidity" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] - assert stmat_enriched.humidity[0, 0] == meteo_points.humidity[0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[0, 1].values + assert stmat_enriched.humidity[0, 0].values == meteo_points.humidity[0, 1].values def test_enrich_from_dataset_exceptions(self, stmat, meteo_points): # valid fileds @@ -612,7 +612,7 @@ def test_enrich_from_dataarray_one_filed(self, stmat, meteo_points): assert "temperature" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[0, 1].values def test_all_operations_lazy(self, stmat, meteo_points): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") @@ -625,18 +625,21 @@ def test_enrich_from_point_cropped(self, stmat, meteo_points): buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} meteo_points_cropped = crop(stmat, meteo_points, buffer) stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points_cropped, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_points_cropped.temperature[0, 1] + assert ( + stmat_enriched.temperature[0, 0].values + == meteo_points_cropped.temperature[0, 1].values + ) - # def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): - # # make the coordinates non-monotonic - # meteo_points["lon"][0] = 25.0 - # meteo_points["lat"][0] = 25.0 + def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): + # make the coordinates non-monotonic + meteo_points["lon"][0] = 25.0 + meteo_points["lat"][0] = 25.0 - # stmat["lon"][0] = 25.5 - # stmat["lat"][0] = 25.5 + stmat["lon"][0] = 25.5 + stmat["lat"][0] = 25.5 - # stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - # assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 1] + stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[0, 1].values def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): # make the coordinates duplicates, @@ -649,7 +652,7 @@ def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): stmat["lat"][0] = 0.5 stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[1, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[1, 1].values def test_enrichfrom_point_nanmonotonic_times(self, stmat, meteo_points): # make the time non-monotonic @@ -657,7 +660,7 @@ def test_enrichfrom_point_nanmonotonic_times(self, stmat, meteo_points): stmat["time"].values[0] = pd.Timestamp("2022-01-01") stmat_enriched = stmat.stm.enrich_from_dataset(meteo_points, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_points.temperature[0, 0] + assert stmat_enriched.temperature[0, 0].values == meteo_points.temperature[0, 0].values class TestEnrichmentFromRasterDataset: @@ -666,7 +669,7 @@ def test_enrich_from_dataset_one_filed(self, stmat, meteo_raster): assert "temperature" in stmat_enriched.data_vars # check if the nearest method is correct - assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_raster.temperature[0, 0, 1].values # check dimensions of stmat_enriched are the same as stmat assert stmat_enriched.dims == stmat.dims @@ -682,8 +685,8 @@ def test_enrich_from_dataset_multi_filed(self, stmat, meteo_raster): assert "humidity" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] - assert stmat_enriched.humidity[0, 0] == meteo_raster.humidity[0, 0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_raster.temperature[0, 0, 1].values + assert stmat_enriched.humidity[0, 0].values == meteo_raster.humidity[0, 0, 1].values def test_enrich_from_dataset_exceptions(self, stmat, meteo_raster): # valid fileds @@ -731,7 +734,7 @@ def test_enrich_from_dataarray_one_filed(self, stmat, meteo_raster): assert "temperature" in stmat_enriched.data_vars # check if the linear interpolation is correct - assert stmat_enriched.temperature[0, 0] == meteo_raster.temperature[0, 0, 1] + assert stmat_enriched.temperature[0, 0].values == meteo_raster.temperature[0, 0, 1].values def test_all_operations_lazy(self, stmat, meteo_raster): stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster, "temperature") @@ -744,4 +747,7 @@ def test_enrich_from_raste_cropped(self, stmat, meteo_raster): buffer = {"lon": 1, "lat": 1, "time": pd.Timedelta("1D")} meteo_raster_cropped = crop(stmat, meteo_raster, buffer) stmat_enriched = stmat.stm.enrich_from_dataset(meteo_raster_cropped, "temperature") - assert stmat_enriched.temperature[0, 0] == meteo_raster_cropped.temperature[0, 0, 1] + assert ( + stmat_enriched.temperature[0, 0].values + == meteo_raster_cropped.temperature[0, 0, 1].values + ) From 0666878d26e9717d8c3ce2158a37c855608735c7 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 13:18:13 +0200 Subject: [PATCH 31/33] remove util function for checking unique values --- stmtools/utils.py | 19 ------------------- tests/test_util.py | 19 ------------------- 2 files changed, 38 deletions(-) diff --git a/stmtools/utils.py b/stmtools/utils.py index 186cb6a..58d30e8 100644 --- a/stmtools/utils.py +++ b/stmtools/utils.py @@ -97,22 +97,3 @@ def monotonic_coords(ds, dim: str): ds[dim].to_index().is_monotonic_increasing or ds[dim].to_index().is_monotonic_decreasing ) - - -def unique_coords(ds, dim: str ): - """Check if the dataset has unique coordinates in the given dimension. - - Parameters - ---------- - ds : xarray.Dataset - Dataset to check. - dim : str - Dimension to check. - - Returns - ------- - bool - True if the dataset has unique coordinates, False otherwise. - - """ - return bool(ds[dim].to_index().is_unique) \ No newline at end of file diff --git a/tests/test_util.py b/tests/test_util.py index b76f82b..1464252 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -157,22 +157,3 @@ def test_non_monotonic_coords_time(self, stmat): stmat["time"].values[0] = '2022-01-02T00:00:00.000000000' stmat["time"].values[1] = '2022-01-01T00:00:00.000000000' assert not utils.monotonic_coords(stmat, "time") - - -class TestUniqueCoords: - def test_unique_coords(self, stmat): - assert utils.unique_coords(stmat, "lon") - assert utils.unique_coords(stmat, "lat") - assert utils.unique_coords(stmat, "time") - - def test_non_unique_coords_lon(self, stmat): - stmat["lon"][0] = 1 - assert not utils.unique_coords(stmat, "lon") - - def test_non_unique_coords_lat(self, stmat): - stmat["lat"][0] = 1 - assert not utils.unique_coords(stmat, "lat") - - def test_non_unique_coords_time(self, stmat): - stmat["time"].values[0] = '2021-01-03T00:00:00.000000000' - assert not utils.unique_coords(stmat, "time") From fc784c82dbdd5423465b7c8ee7277d6ac5c5ef72 Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 13:18:52 +0200 Subject: [PATCH 32/33] fix the test --- tests/test_stm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_stm.py b/tests/test_stm.py index bbb4d0e..b89a619 100644 --- a/tests/test_stm.py +++ b/tests/test_stm.py @@ -643,10 +643,10 @@ def test_enrich_from_point_nanmonotonic_coords(self, stmat, meteo_points): def test_enrich_from_point_duplicate_coords(self, stmat, meteo_points): # make the coordinates duplicates, - # now both locations 0 and 1 have the same coords but different - # temperature values + # now both locations 0 and 1 have the same coords and temperature values meteo_points["lon"][0] = 0.5 meteo_points["lat"][0] = 0.5 + meteo_points.temperature[1, :] = meteo_points.temperature[0, :] stmat["lon"][0] = 0.5 stmat["lat"][0] = 0.5 From 90ef7bec4ab432c8bf5578afafa913d01fe15cbf Mon Sep 17 00:00:00 2001 From: SarahAlidoost Date: Fri, 12 Apr 2024 13:44:41 +0200 Subject: [PATCH 33/33] remove -vv from action build --- .github/workflows/build.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 751dba7..f1140ce 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -33,7 +33,7 @@ jobs: - name: Build the package run: python -m build - name: Test with pytest - run: python -m pytest -vv + run: python -m pytest build_doc: runs-on: ubuntu-latest