diff --git a/tests/test_krige.py b/tests/test_krige.py index 6d4d935d..af7984b4 100644 --- a/tests/test_krige.py +++ b/tests/test_krige.py @@ -5,16 +5,20 @@ import numpy as np import unittest -from gstools import Gaussian, Exponential, Spherical, krige, SRF +import gstools as gs def trend(*xyz): return xyz[0] +def mean_func(*xyz): + return 2 * xyz[0] + + class TestKrige(unittest.TestCase): def setUp(self): - self.cov_models = [Gaussian, Exponential, Spherical] + self.cov_models = [gs.Gaussian, gs.Exponential, gs.Spherical] self.dims = range(1, 4) self.data = np.array( [ @@ -58,7 +62,7 @@ def test_simple(self): anis=[0.9, 0.8], angles=[2, 1, 0.5], ) - simple = krige.Simple( + simple = gs.krige.Simple( model, self.cond_pos[:dim], self.cond_val, self.mean ) field_1, __ = simple.unstructured(self.grids[dim - 1]) @@ -83,7 +87,7 @@ def test_ordinary(self): anis=[0.9, 0.8], angles=[2, 1, 0.5], ) - ordinary = krige.Ordinary( + ordinary = gs.krige.Ordinary( model, self.cond_pos[:dim], self.cond_val, @@ -112,7 +116,7 @@ def test_universal(self): anis=[0.9, 0.8], angles=[2, 1, 0.5], ) - universal = krige.Universal( + universal = gs.krige.Universal( model, self.cond_pos[:dim], self.cond_val, drift ) field_1, __ = universal.unstructured(self.grids[dim - 1]) @@ -137,7 +141,7 @@ def test_detrended(self): anis=[0.5, 0.2], angles=[0.4, 0.2, 0.1], ) - detrended = krige.Detrended( + detrended = gs.krige.Detrended( model, self.cond_pos[:dim], self.cond_val, trend ) field_1, __ = detrended.unstructured(self.grids[dim - 1]) @@ -157,14 +161,14 @@ def test_extdrift(self): cond_drift = [] for i, grid in enumerate(self.grids): dim = i + 1 - model = Exponential( + model = gs.Exponential( dim=dim, var=2, len_scale=10, anis=[0.9, 0.8], angles=[2, 1, 0.5], ) - srf = SRF(model) + srf = gs.SRF(model) field = srf(grid) ext_drift.append(field) field = field.reshape(self.grid_shape[:dim]) @@ -179,7 +183,7 @@ def test_extdrift(self): anis=[0.5, 0.2], angles=[0.4, 0.2, 0.1], ) - extdrift = krige.ExtDrift( + extdrift = gs.krige.ExtDrift( model, self.cond_pos[:dim], self.cond_val, @@ -202,7 +206,6 @@ def test_extdrift(self): ) def test_pseudo(self): - for Model in self.cov_models: for dim in self.dims: model = Model( @@ -213,7 +216,7 @@ def test_pseudo(self): angles=[0.4, 0.2, 0.1], ) for meth in self.p_meth: - krig = krige.Krige( + krig = gs.krige.Krige( model, self.p_data[:dim], self.p_vals, unbiased=False ) field, __ = krig([0, 0, 0][:dim]) @@ -224,7 +227,6 @@ def test_pseudo(self): ) def test_error(self): - for Model in self.cov_models: for dim in self.dims: model = Model( @@ -235,7 +237,7 @@ def test_error(self): anis=[0.9, 0.8], angles=[2, 1, 0.5], ) - ordinary = krige.Ordinary( + ordinary = gs.krige.Ordinary( model, self.cond_pos[:dim], self.cond_val, @@ -248,6 +250,42 @@ def test_error(self): self.assertAlmostEqual(err[1], model.nugget, places=2) self.assertAlmostEqual(err[4], model.nugget, places=2) + def test_raise(self): + # no cond_pos/cond_val given + self.assertRaises(ValueError, gs.krige.Krige, gs.Stable(), None, None) + + def test_krige_mean(self): + # check for constant mean (simple kriging) + krige = gs.krige.Simple(gs.Gaussian(), self.cond_pos, self.cond_val) + mean_f = krige.structured(self.pos, only_mean=True) + self.assertTrue(np.all(np.isclose(mean_f, 0))) + krige = gs.krige.Simple( + gs.Gaussian(), + self.cond_pos, + self.cond_val, + mean=mean_func, + normalizer=gs.normalizer.YeoJohnson, + trend=trend, + ) + # check applying mean, norm, trend + mean_f1 = krige.structured(self.pos, only_mean=True) + mean_f2 = gs.normalizer.tools.apply_mean_norm_trend( + self.pos, + np.zeros(tuple(map(len, self.pos))), + mean=mean_func, + normalizer=gs.normalizer.YeoJohnson, + trend=trend, + mesh_type="structured", + ) + self.assertTrue(np.all(np.isclose(mean_f1, mean_f2))) + krige = gs.krige.Simple(gs.Gaussian(), self.cond_pos, self.cond_val) + mean_f = krige.structured(self.pos, only_mean=True) + self.assertTrue(np.all(np.isclose(mean_f, 0))) + # check for constant mean (ordinary kriging) + krige = gs.krige.Ordinary(gs.Gaussian(), self.cond_pos, self.cond_val) + mean_f = krige.structured(self.pos, only_mean=True) + self.assertTrue(np.all(np.isclose(mean_f, krige.get_mean()))) + if __name__ == "__main__": unittest.main()