Skip to content

Commit

Permalink
FutureWarnings based on Dataset.dims in stm.py fixed.
Browse files Browse the repository at this point in the history
  • Loading branch information
thijsvl committed Jun 20, 2024
1 parent 6784d9e commit 68de6ab
Showing 1 changed file with 13 additions and 13 deletions.
26 changes: 13 additions & 13 deletions stmtools/stm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def regulate_dims(self, space_label=None, time_label=None):
if (
(space_label is None)
and (time_label is None)
and all([k not in self._obj.dims.keys() for k in ["space", "time"]])
and all([k not in self._obj.sizes.keys() for k in ["space", "time"]])
):
raise ValueError(
'No "space" nor "time" dimension found. \
Expand All @@ -86,7 +86,7 @@ def regulate_dims(self, space_label=None, time_label=None):
# Check time dimension
ds_reg = self._obj
for key, label in zip(["space", "time"], [space_label, time_label], strict=True):
if key not in self._obj.dims.keys():
if key not in self._obj.sizes.keys():
if label is None:
ds_reg = ds_reg.expand_dims({key: 1})
elif isinstance(label, str):
Expand All @@ -100,7 +100,7 @@ def regulate_dims(self, space_label=None, time_label=None):
# Squeeze the time dimension for all point attibutes, if exists
pnt_vars = [var for var in ds_reg.data_vars.keys() if var.startswith("pnt_")]
for var in pnt_vars:
if "time" in ds_reg[var].dims:
if "time" in ds_reg[var].sizes:
ds_reg[var] = ds_reg[var].squeeze(dim="time")

return ds_reg
Expand Down Expand Up @@ -467,12 +467,12 @@ def enrich_from_dataset(self,
f'Coordinate label "{coord_label}" was not found in the input dataset.'
)

# check if dataset is point or raster if 'space' in dataset.dims:
if "space" in dataset.dims:
# check if dataset is point or raster if 'space' in dataset.sizes:
if "space" in dataset.sizes:
approch = "point"
elif "lat" in dataset.dims and "lon" in dataset.dims:
elif "lat" in dataset.sizes and "lon" in dataset.sizes:
approch = "raster"
elif "y" in dataset.dims and "x" in dataset.dims:
elif "y" in dataset.sizes and "x" in dataset.sizes:
approch = "raster"
else:
raise ValueError(
Expand All @@ -483,7 +483,7 @@ def enrich_from_dataset(self,
)

# check if dataset has time dimensions
if "time" not in dataset.dims:
if "time" not in dataset.sizes:
raise ValueError('Missing dimension: "time" in the input dataset.')

# check if dtype of time is the same
Expand Down Expand Up @@ -518,7 +518,7 @@ def num_points(self):
Number of space entry.
"""
return self._obj.dims["space"]
return self._obj.sizes["space"]

@property
def num_epochs(self):
Expand All @@ -530,7 +530,7 @@ def num_epochs(self):
Number of epochs.
"""
return self._obj.dims["time"]
return self._obj.sizes["time"]


def _in_polygon_block(mask, polygon, xlabel, ylabel, type_polygon):
Expand Down Expand Up @@ -715,7 +715,7 @@ def _enrich_from_raster_block(ds, dataraster, fields, method):

# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(interpolated[field].data, dims=ds.dims, coords=ds.coords)
ds[field] = xr.DataArray(interpolated[field].data, dims=ds.sizes, coords=ds.coords)
return ds


Expand Down Expand Up @@ -747,7 +747,7 @@ def _enrich_from_points_block(ds, datapoints, fields):
for dim in ["space", "time"]:
if dim not in datapoints.coords:
indexer[dim]= [
coord for coord in datapoints.coords if dim in datapoints[coord].dims
coord for coord in datapoints.coords if dim in datapoints[coord].sizes
]
else:
indexer[dim] = [dim]
Expand Down Expand Up @@ -778,7 +778,7 @@ def _enrich_from_points_block(ds, datapoints, fields):
# Assign these values to the corresponding points in ds
for field in fields:
ds[field] = xr.DataArray(
selections[field].data, dims=ds.dims, coords=ds.coords
selections[field].data, dims=ds.sizes, coords=ds.coords
)

return ds

0 comments on commit 68de6ab

Please sign in to comment.