Skip to content

Commit

Permalink
TEST: increase test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
ColmTalbot committed Jun 21, 2023
1 parent 3813282 commit 595e70d
Showing 1 changed file with 47 additions and 23 deletions.
70 changes: 47 additions & 23 deletions test/interpolate_test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from itertools import product

import numpy as np
import pytest
from scipy.interpolate import CubicSpline, interp1d

from cached_interpolate import RegularCachingInterpolant as CachingInterpolant
from cached_interpolate import CachingInterpolant, RegularCachingInterpolant

interpolants = [RegularCachingInterpolant, CachingInterpolant]


@pytest.mark.parametrize("bc_type", ["clamped", "natural", "not-a-knot", "periodic"])
Expand All @@ -11,7 +15,7 @@ def test_cubic_matches_scipy(bc_type):
y_values = np.random.uniform(-1, 1, 10)
if bc_type == "periodic":
y_values[0] = y_values[-1]
spl = CachingInterpolant(x_values, y_values, kind="cubic", bc_type=bc_type)
spl = RegularCachingInterpolant(x_values, y_values, kind="cubic", bc_type=bc_type)
test_points = np.random.uniform(0, 1, 10000)
max_diff = 0
for _ in range(100):
Expand All @@ -24,10 +28,19 @@ def test_cubic_matches_scipy(bc_type):
assert max_diff, 1e-10


def test_nearest_matches_scipy():
@pytest.mark.parametrize("interpolant", interpolants)
def test_caching_interpolant_bad_bc_type(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
with pytest.raises(NotImplementedError):
_ = interpolant(x_values, y_values, kind="cubic", bc_type="bad")


@pytest.mark.parametrize("interpolant", interpolants)
def test_nearest_matches_scipy(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="nearest")
spl = interpolant(x_values, y_values, kind="nearest")
test_points = np.random.uniform(0, 1, 10000)
max_diff = 0
for _ in range(100):
Expand All @@ -38,10 +51,11 @@ def test_nearest_matches_scipy():
assert max_diff < 1e-10


def test_linear_matches_numpy():
@pytest.mark.parametrize("interpolant", interpolants)
def test_linear_matches_numpy(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="linear")
spl = interpolant(x_values, y_values, kind="linear")
test_points = np.random.uniform(0, 1, 10000)
max_diff = 0
for _ in range(100):
Expand All @@ -52,79 +66,89 @@ def test_linear_matches_numpy():
assert max_diff < 1e-10


def test_single_input():
@pytest.mark.parametrize("interpolant", interpolants)
def test_single_input(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="cubic")
spl = interpolant(x_values, y_values, kind="cubic")
assert spl(0) == y_values[0]


def test_single_complex():
@pytest.mark.parametrize("interpolant", interpolants)
def test_single_complex(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
y_values = y_values + 1j * (1 - y_values)
spl = CachingInterpolant(x_values, y_values, kind="cubic")
assert spl(0) == y_values[0]


def test_interpolation_at_lower_bound():
@pytest.mark.parametrize("interpolant", interpolants)
def test_interpolation_at_lower_bound(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="cubic")
spl = interpolant(x_values, y_values, kind="cubic")
test_point = 0
assert abs(spl(test_point) - y_values[0]) < 1e-5


def test_interpolation_at_upper_bound():
@pytest.mark.parametrize("interpolant", interpolants)
def test_interpolation_at_upper_bound(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="cubic")
spl = interpolant(x_values, y_values, kind="cubic")
test_point = 1
assert abs(spl(test_point) - y_values[-1]) < 1e-5


def test_bad_interpolation_method_raises_error():
@pytest.mark.parametrize("interpolant", interpolants)
def test_bad_interpolation_method_raises_error(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
with pytest.raises(ValueError):
_ = CachingInterpolant(x_values, y_values, kind="bad method")
_ = interpolant(x_values, y_values, kind="bad method")


def test_running_without_new_y_values():
@pytest.mark.parametrize("interpolant", interpolants)
def test_running_without_new_y_values(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
spl = CachingInterpolant(x_values, y_values, kind="cubic")
spl = interpolant(x_values, y_values, kind="cubic")
old_values = spl._data
_ = spl(np.array([0, 1]), y=np.random.uniform(-1, 1, 10), use_cache=False)
assert np.max(old_values - spl._data) > 1e-5


def test_running_with_complex_input_linear():
@pytest.mark.parametrize("interpolant", interpolants)
def test_running_with_complex_input_linear(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
y_values = y_values * np.exp(1j * np.random.uniform(0, 2 * np.pi, 10))
spl = CachingInterpolant(x_values, y_values, kind="linear")
spl = interpolant(x_values, y_values, kind="linear")
scs = interp1d(x=x_values, y=y_values, kind="linear")
test_points = np.random.uniform(0, 1, 10)
scs_test = scs(test_points)
diffs = spl(test_points) - scs_test
assert np.max(diffs) < 1e-10


def test_running_with_complex_input_cubic():
@pytest.mark.parametrize("interpolant", interpolants)
def test_running_with_complex_input_cubic(interpolant):
x_values = np.linspace(0, 1, 10)
y_values = np.random.uniform(-1, 1, 10)
y_values = y_values * np.exp(1j * np.random.uniform(0, 2 * np.pi, 10))
spl = CachingInterpolant(x_values, y_values, kind="cubic", bc_type="natural")
spl = interpolant(x_values, y_values, kind="cubic", bc_type="natural")
scs = CubicSpline(x=x_values, y=y_values, bc_type="natural")
test_points = np.random.uniform(0, 1, 10)
scs_test = scs(test_points)
diffs = spl(test_points) - scs_test
assert np.max(diffs) < 1e-10


@pytest.mark.parametrize("kind", ["nearest", "linear", "cubic"])
def test_2d_input(kind):
@pytest.mark.parametrize(
"kind,interpolant", product(["nearest", "linear", "cubic"], interpolants)
)
def test_2d_input(kind, interpolant):
kwargs = dict(x=np.linspace(0, 1, 5), y=np.random.uniform(0, 1, 5), kind=kind)
test_values = np.random.uniform(0, 1, (2, 10000))
spl = CachingInterpolant(**kwargs)
Expand Down

0 comments on commit 595e70d

Please sign in to comment.