Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 140 additions & 0 deletions tests/unit/mathutils/test_function_grid.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Unit tests for Function.from_grid() method and grid interpolation."""

import warnings

import numpy as np
import pytest

Expand Down Expand Up @@ -137,3 +139,141 @@ def test_from_grid_backward_compatibility():
# Test callable function
func3 = Function(lambda x: x**2)
assert func3(2) == 4


def test_shepard_fallback_warning():
"""Test that shepard_fallback is triggered and emits a warning.

When linear_grid interpolation is set but no grid interpolator is available,
the Function class should fall back to shepard interpolation and emit a warning.
"""
# Create a 2D function with scattered points (not structured grid)
source = [(0, 0, 0), (1, 0, 1), (0, 1, 2), (1, 1, 3)]
func = Function(
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
)

# Now manually change interpolation to linear_grid without setting up the grid
# This simulates the fallback scenario
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
func.set_interpolation("linear_grid")

# Check that a warning was issued
assert len(w) == 1
assert "falling back to shepard interpolation" in str(w[0].message)


def test_shepard_fallback_2d_interpolation():
"""Test that shepard_fallback produces correct interpolation for 2D data.

This test verifies the fallback interpolation works correctly when
linear_grid is set without a grid interpolator.
"""
# Create a 2D function: z = x + y
source = [
(0, 0, 0), # f(0, 0) = 0
(1, 0, 1), # f(1, 0) = 1
(0, 1, 1), # f(0, 1) = 1
(1, 1, 2), # f(1, 1) = 2
]

# First, create with shepard to get baseline results
func_shepard = Function(
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
)

# Create another function and trigger the fallback
func_fallback = Function(
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
)

# Trigger fallback
with warnings.catch_warnings():
warnings.simplefilter("ignore") # Suppress warnings for this test
func_fallback.set_interpolation("linear_grid")

# Test that both produce the same results at exact points
assert func_fallback(0, 0) == func_shepard(0, 0)
assert func_fallback(1, 1) == func_shepard(1, 1)

# Test interpolation at an intermediate point
result_fallback = func_fallback(0.5, 0.5)
result_shepard = func_shepard(0.5, 0.5)
assert np.isclose(result_fallback, result_shepard, atol=1e-6)


def test_shepard_fallback_3d_interpolation():
"""Test that shepard_fallback produces correct interpolation for 3D data.

This test verifies the fallback interpolation works correctly for
3-dimensional input data.
"""
# Create a 3D function: w = x + y + z
source = [
(0, 0, 0, 0), # f(0, 0, 0) = 0
(1, 0, 0, 1), # f(1, 0, 0) = 1
(0, 1, 0, 1), # f(0, 1, 0) = 1
(0, 0, 1, 1), # f(0, 0, 1) = 1
(1, 1, 1, 3), # f(1, 1, 1) = 3
]

# Create with shepard to get baseline results
func_shepard = Function(
source=source,
inputs=["x", "y", "z"],
outputs="w",
interpolation="shepard",
)

# Create another function and trigger the fallback
func_fallback = Function(
source=source,
inputs=["x", "y", "z"],
outputs="w",
interpolation="shepard",
)

# Trigger fallback
with warnings.catch_warnings():
warnings.simplefilter("ignore")
func_fallback.set_interpolation("linear_grid")

# Test that both produce the same results at exact points
assert func_fallback(0, 0, 0) == func_shepard(0, 0, 0)
assert func_fallback(1, 1, 1) == func_shepard(1, 1, 1)

# Test interpolation at an intermediate point
result_fallback = func_fallback(0.5, 0.5, 0.5)
result_shepard = func_shepard(0.5, 0.5, 0.5)
assert np.isclose(result_fallback, result_shepard, atol=1e-6)


def test_shepard_fallback_at_exact_data_points():
"""Test that shepard_fallback returns exact values at data points.

When querying at exact data points, the fallback should return the
exact value stored at that point.
"""
# Create a 2D function
source = [
(0, 0, 10),
(1, 0, 20),
(0, 1, 30),
(1, 1, 40),
]

func = Function(
source=source, inputs=["x", "y"], outputs="z", interpolation="shepard"
)

# Trigger fallback
with warnings.catch_warnings():
warnings.simplefilter("ignore")
func.set_interpolation("linear_grid")

# Test exact data points - should return exact values
assert func(0, 0) == 10
assert func(1, 0) == 20
assert func(0, 1) == 30
assert func(1, 1) == 40