Skip to content

Commit

Permalink
add more tests for stm
Browse files Browse the repository at this point in the history
  • Loading branch information
rogerkuou committed Nov 8, 2023
1 parent f88087f commit cd071a1
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 6 deletions.
3 changes: 0 additions & 3 deletions stmtools/stm.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
"""space-time matrix module."""

import logging
import math
from collections.abc import Iterable
from pathlib import Path

import affine
import dask.array as da
import geopandas as gpd
import numpy as np
import xarray as xr
from rasterio import features
from shapely.geometry import Point
from shapely.strtree import STRtree

Expand Down
4 changes: 4 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def stmat():
["space", "time"],
da.arange(npoints * ntime).reshape((npoints, ntime)),
),
pnt_height=(
["space"],
da.arange(npoints),
),
),
coords=dict(
lon=(["space"], da.arange(npoints)),
Expand Down
Binary file added tests/data/multi_polygon.gpkg
Binary file not shown.
37 changes: 34 additions & 3 deletions tests/test_stm.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from pathlib import Path

import dask.array as da
import geopandas as gpd
import numpy as np
import pytest
import xarray as xr
from shapely import geometry

path_multi_polygon = Path(__file__).parent / "./data/multi_polygon.gpkg"

@pytest.fixture
def stmat_rd():
Expand Down Expand Up @@ -42,6 +45,18 @@ def stmat_only_point():
coords=dict(lon=(["space"], da.arange(npoints)), lat=(["space"], da.arange(npoints))),
).unify_chunks()

@pytest.fixture
def stmat_wrong_space_label():
npoints = 10
return xr.Dataset(
data_vars=dict(
amplitude=(["space2"], da.arange(npoints)),
phase=(["space2"], da.arange(npoints)),
pnt_height=(["space2"], da.arange(npoints)),
),
coords=dict(lon=(["space2"], da.arange(npoints)), lat=(["space2"], da.arange(npoints))),
).unify_chunks()


@pytest.fixture
def polygon():
Expand Down Expand Up @@ -95,6 +110,11 @@ def test_time_dim_exists(self, stmat_only_point):
def test_time_dim_size_one(self, stmat_only_point):
stm_reg = stmat_only_point.stm.regulate_dims()
assert stm_reg.dims["time"] == 1

def test_time_dim_customed_label(self, stmat_wrong_space_label):
stm_reg = stmat_wrong_space_label.stm.regulate_dims(space_label="space2")
assert stm_reg.dims["time"] == 1
assert stm_reg.dims["space"] == 10

def test_pnt_time_dim_nonexists(self, stmat_only_point):
"""
Expand Down Expand Up @@ -122,6 +142,11 @@ def test_check_missing_dimension(self, stmat_only_point):
with pytest.raises(KeyError):
stmat_only_point.stm.subset(method="threshold", var="pnt_height", threshold=">5")

def test_check_missing_value(self, stmat):
with pytest.raises(ValueError):
stmat.stm.subset(method="threshold", var="pnt_height", threshold=">")
stmat.stm.subset(method="threshold", var="pnt_height", threshold="<")

def test_method_not_implemented(self, stmat):
with pytest.raises(NotImplementedError):
stmat.stm.subset(method="something_else")
Expand All @@ -132,7 +157,7 @@ def test_subset_with_threshold(self, stmat):
v_thres = np.ones(
stmat.space.shape,
)
v_thres[0:3] = 2
v_thres[0:3] = 3
stmat = stmat.assign(
{
"thres": (
Expand All @@ -141,8 +166,10 @@ def test_subset_with_threshold(self, stmat):
)
}
)
stmat_subset = stmat.stm.subset(method="threshold", var="thres", threshold=">1")
assert stmat_subset.equals(stmat.sel(space=[0, 1, 2]))
stmat_subset_larger = stmat.stm.subset(method="threshold", var="thres", threshold=">2")
stmat_subset_lower = stmat.stm.subset(method="threshold", var="thres", threshold="<2")
assert stmat_subset_larger.equals(stmat.sel(space=[0, 1, 2]))
assert stmat_subset_lower.equals(stmat.sel(space=range(3, 10, 1)))

def test_subset_with_polygons(self, stmat, polygon):
stmat_subset = stmat.stm.subset(method="polygon", polygon=polygon)
Expand All @@ -157,6 +184,10 @@ def test_subset_with_polygons_rd(self, stmat_rd, polygon):
def test_subset_with_multi_polygons(self, stmat, multi_polygon):
stmat_subset = stmat.stm.subset(method="polygon", polygon=multi_polygon)
assert stmat_subset.equals(stmat.sel(space=[2, 6]))

def test_subset_with_multi_polygons_file(self, stmat):
stmat_subset = stmat.stm.subset(method="polygon", polygon=path_multi_polygon)
assert stmat_subset.equals(stmat.sel(space=[2, 6]))


class TestEnrichment:
Expand Down

0 comments on commit cd071a1

Please sign in to comment.