diff --git a/test/extrapolation_tests.jl b/test/extrapolation_tests.jl index ad37fc0d..087d7ece 100644 --- a/test/extrapolation_tests.jl +++ b/test/extrapolation_tests.jl @@ -1,6 +1,7 @@ using DataInterpolations, Test using ForwardDiff using QuadGK +using Unitful function test_extrapolation(method, u, t) @testset "Extrapolation errors" begin @@ -43,6 +44,48 @@ function test_extrapolation(method, u, t) end end +@testset "Constant Interpolation with Unitful" begin + t_un = [1.0, 2.0]u"s" + u_un = [1.0, 2.0]u"m" + + for extrapolation_type in [ExtrapolationType.Constant, ExtrapolationType.Linear] + # Left extrapolation + A = ConstantInterpolation(u_un, t_un; extrapolation_left = extrapolation_type) + t_eval = 0.0u"s" + @test A(t_eval) == 1.0u"m" + + # Right extrapolation + A = ConstantInterpolation(u_un, t_un; extrapolation_right = extrapolation_type) + t_eval = 3.0u"s" + @test A(t_eval) == 2.0u"m" + end +end + +@testset "Linear Interpolation with Unitful" begin + t_un = [1.0, 2.0]u"s" + u_un = [1.0, 2.0]u"m" + + # Left constant extrapolation + A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Constant) + t_eval = 0.0u"s" + @test A(t_eval) == 1.0u"m" + + # Right constant extrapolation + A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Constant) + t_eval = 3.0u"s" + @test A(t_eval) == 2.0u"m" + + # Left linear extrapolation + A = LinearInterpolation(u_un, t_un; extrapolation_left = ExtrapolationType.Linear) + t_eval = 0.0u"s" + @test A(t_eval) == 0.0u"m" + + # Right constant extrapolation + A = LinearInterpolation(u_un, t_un; extrapolation_right = ExtrapolationType.Linear) + t_eval = 3.0u"s" + @test A(t_eval) == 3.0u"m" +end + @testset "Linear Interpolation" begin u = [1.0, 2.0] t = [1.0, 2.0]