diff --git a/Project.toml b/Project.toml index e182604f..802bdf69 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527" IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01" +Requires = "ae029012-a4dd-5104-9daa-d747884805df" StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5" @@ -30,12 +31,16 @@ FillArrays = "0.11, 0.12, 0.13" GPUArraysCore = "0.1" IteratorInterfaceExtensions = "1" RecipesBase = "0.7, 0.8, 1.0" +Requires = "1.0" StaticArraysCore = "1.1" SymbolicIndexingInterface = "0.1, 0.2" Tables = "1" ZygoteRules = "0.2" julia = "1.6" +[extensions] +RecursiveArrayToolsTrackerExt = "Tracker" + [extras] Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" @@ -44,11 +49,16 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [targets] -test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] +test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"] + +[weakdeps] +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" \ No newline at end of file diff --git a/ext/RecursiveArrayToolsTrackerExt.jl b/ext/RecursiveArrayToolsTrackerExt.jl new file mode 100644 index 00000000..43c7f4a7 --- /dev/null +++ b/ext/RecursiveArrayToolsTrackerExt.jl @@ -0,0 +1,18 @@ +module RecursiveArrayToolsTrackerExt + +import RecursiveArrayTools +isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker) + +function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N}, + a::AbstractArray{T2, N}) where { + T <: + Tracker.TrackedArray, + T2 <: + Tracker.TrackedArray, + N} + @inbounds for i in eachindex(a) + b[i] = copy(a[i]) + end +end + +end \ No newline at end of file diff --git a/src/RecursiveArrayTools.jl b/src/RecursiveArrayTools.jl index bcb269f3..c93ca011 100644 --- a/src/RecursiveArrayTools.jl +++ b/src/RecursiveArrayTools.jl @@ -42,6 +42,13 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray}, T(xs), ȳ -> (NoTangent(), ȳ) end +import Requires +@static if !isdefined(Base, :get_extension) + function __init__() + Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end + end +end + export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray, AllObserved, vecarr_to_vectors, tuples diff --git a/src/tabletraits.jl b/src/tabletraits.jl index df4ccef7..f8799aa5 100644 --- a/src/tabletraits.jl +++ b/src/tabletraits.jl @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray) N = length(A.u[1]) names = [ :timestamp, - (A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) : + (!(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? (states(A.sc)[i] for i in 1:N) : (Symbol("value", i) for i in 1:N))..., ] types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...] else - names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value] + names = [:timestamp, !(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? states(A.sc)[1] : :value] types = Type[eltype(A.t), VT] end return AbstractDiffEqArrayRows(names, types, A.t, A.u) @@ -31,8 +31,8 @@ struct AbstractDiffEqArrayRows{T, U} u::U end function AbstractDiffEqArrayRows(names, types, t, u) - AbstractDiffEqArrayRows(names, types, - Dict(nm => i for (i, nm) in enumerate(names)), t, u) + AbstractDiffEqArrayRows(Symbol.(names), types, + Dict(Symbol(nm) => i for (i, nm) in enumerate(names)), t, u) end Base.length(x::AbstractDiffEqArrayRows) = length(x.u) diff --git a/test/downstream/Project.toml b/test/downstream/Project.toml index 21059020..89471dae 100644 --- a/test/downstream/Project.toml +++ b/test/downstream/Project.toml @@ -1,7 +1,9 @@ [deps] ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" [compat] ModelingToolkit = "8.33" -OrdinaryDiffEq = "6.31" \ No newline at end of file +OrdinaryDiffEq = "6.31" +Tracker = "0.2" \ No newline at end of file diff --git a/test/downstream/TrackerExt.jl b/test/downstream/TrackerExt.jl new file mode 100644 index 00000000..275f2336 --- /dev/null +++ b/test/downstream/TrackerExt.jl @@ -0,0 +1,7 @@ +using RecursiveArrayTools, Tracker, Test + +x = [5.0] +a = [Tracker.TrackedArray(x)] +b = [Tracker.TrackedArray(copy([5.2]))] +RecursiveArrayTools.recursivecopy!(a,b) +@test a[1][1] == 5.2 \ No newline at end of file diff --git a/test/runtests.jl b/test/runtests.jl index 5d99e528..591ad998 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,8 @@ using Pkg using RecursiveArrayTools using Test using Aqua +using SafeTestsets + Aqua.test_all(RecursiveArrayTools, ambiguities = false) @test_broken isempty(Test.detect_ambiguities(RecursiveArrayTools)) const GROUP = get(ENV, "GROUP", "All") @@ -21,26 +23,27 @@ end @time begin if GROUP == "Core" || GROUP == "All" - @time @testset "Utils Tests" begin include("utils_test.jl") end - @time @testset "Partitions Tests" begin include("partitions_test.jl") end - @time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end - @time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end - @time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end - @time @testset "Table traits" begin include("tabletraits.jl") end - @time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end - @time @testset "Linear Algebra Tests" begin include("linalg.jl") end - @time @testset "Upstream Tests" begin include("upstream.jl") end - @time @testset "Adjoint Tests" begin include("adjoints.jl") end + @time @safetestset "Utils Tests" begin include("utils_test.jl") end + @time @safetestset "Partitions Tests" begin include("partitions_test.jl") end + @time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end + @time @safetestset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end + @time @safetestset "VecOfArr Interface Tests" begin include("interface_tests.jl") end + @time @safetestset "Table traits" begin include("tabletraits.jl") end + @time @safetestset "StaticArrays Tests" begin include("copy_static_array_test.jl") end + @time @safetestset "Linear Algebra Tests" begin include("linalg.jl") end + @time @safetestset "Upstream Tests" begin include("upstream.jl") end + @time @safetestset "Adjoint Tests" begin include("adjoints.jl") end end if !is_APPVEYOR && GROUP == "Downstream" activate_downstream_env() - @time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end - @time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end + @time @safetestset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end + @time @safetestset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end + @time @safetestset "TrackerExt" begin include("downstream/TrackerExt.jl") end end if !is_APPVEYOR && GROUP == "GPU" activate_gpu_env() - @time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end + @time @safetestset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end end end