Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 11 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
IteratorInterfaceExtensions = "82899510-4779-5014-852e-03e436cf321d"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecipesBase = "3cdcf5f2-1ef4-517c-9805-6587b60abb01"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StaticArraysCore = "1e83bf80-4336-4d27-bf5d-d5a4f845583c"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
SymbolicIndexingInterface = "2efcf032-c050-4f8e-a9bb-153293bab1f5"
Expand All @@ -30,12 +31,16 @@ FillArrays = "0.11, 0.12, 0.13"
GPUArraysCore = "0.1"
IteratorInterfaceExtensions = "1"
RecipesBase = "0.7, 0.8, 1.0"
Requires = "1.0"
StaticArraysCore = "1.1"
SymbolicIndexingInterface = "0.1, 0.2"
Tables = "1"
ZygoteRules = "0.2"
julia = "1.6"

[extensions]
RecursiveArrayToolsTrackerExt = "Tracker"

[extras]
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Expand All @@ -44,11 +49,16 @@ NLsolve = "2774e3e8-f4cf-5e23-947b-6d7e65073b56"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]
test = ["SafeTestsets", "Aqua", "ForwardDiff", "LabelledArrays", "NLsolve", "OrdinaryDiffEq", "Pkg", "Test", "Unitful", "Random", "StaticArrays", "StructArrays", "Zygote"]

[weakdeps]
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
18 changes: 18 additions & 0 deletions ext/RecursiveArrayToolsTrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
module RecursiveArrayToolsTrackerExt

import RecursiveArrayTools
isdefined(Base, :get_extension) ? (import Tracker) : (import ..Tracker)

function RecursiveArrayTools.recursivecopy!(b::AbstractArray{T, N},
a::AbstractArray{T2, N}) where {
T <:
Tracker.TrackedArray,
T2 <:
Tracker.TrackedArray,
N}
@inbounds for i in eachindex(a)
b[i] = copy(a[i])
end
end

end
7 changes: 7 additions & 0 deletions src/RecursiveArrayTools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ function ChainRulesCore.rrule(T::Type{<:GPUArraysCore.AbstractGPUArray},
T(xs), ȳ -> (NoTangent(), ȳ)
end

import Requires
@static if !isdefined(Base, :get_extension)
function __init__()
Requires.@require Tracker="9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" begin include("../ext/RecursiveArrayToolsTrackerExt.jl") end
end
end

export VectorOfArray, DiffEqArray, AbstractVectorOfArray, AbstractDiffEqArray,
AllObserved, vecarr_to_vectors, tuples

Expand Down
8 changes: 4 additions & 4 deletions src/tabletraits.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,12 @@ function Tables.rows(A::AbstractDiffEqArray)
N = length(A.u[1])
names = [
:timestamp,
(A.sc !== nothing && A.sc.syms !== nothing ? (A.sc.syms[i] for i in 1:N) :
(!(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? (states(A.sc)[i] for i in 1:N) :
(Symbol("value", i) for i in 1:N))...,
]
types = Type[eltype(A.t), (eltype(A.u[1]) for _ in 1:N)...]
else
names = [:timestamp, A.sc !== nothing && A.sc.syms !== nothing ? A.sc.syms[1] : :value]
names = [:timestamp, !(A.sc isa SymbolicIndexingInterface.SymbolCache{Nothing, Nothing, Nothing}) ? states(A.sc)[1] : :value]
types = Type[eltype(A.t), VT]
end
return AbstractDiffEqArrayRows(names, types, A.t, A.u)
Expand All @@ -31,8 +31,8 @@ struct AbstractDiffEqArrayRows{T, U}
u::U
end
function AbstractDiffEqArrayRows(names, types, t, u)
AbstractDiffEqArrayRows(names, types,
Dict(nm => i for (i, nm) in enumerate(names)), t, u)
AbstractDiffEqArrayRows(Symbol.(names), types,
Dict(Symbol(nm) => i for (i, nm) in enumerate(names)), t, u)
end

Base.length(x::AbstractDiffEqArrayRows) = length(x.u)
Expand Down
4 changes: 3 additions & 1 deletion test/downstream/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
[deps]
ModelingToolkit = "961ee093-0014-501f-94e3-6117800e7a78"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"

[compat]
ModelingToolkit = "8.33"
OrdinaryDiffEq = "6.31"
OrdinaryDiffEq = "6.31"
Tracker = "0.2"
7 changes: 7 additions & 0 deletions test/downstream/TrackerExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
using RecursiveArrayTools, Tracker, Test

x = [5.0]
a = [Tracker.TrackedArray(x)]
b = [Tracker.TrackedArray(copy([5.2]))]
RecursiveArrayTools.recursivecopy!(a,b)
@test a[1][1] == 5.2
29 changes: 16 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ using Pkg
using RecursiveArrayTools
using Test
using Aqua
using SafeTestsets

Aqua.test_all(RecursiveArrayTools, ambiguities = false)
@test_broken isempty(Test.detect_ambiguities(RecursiveArrayTools))
const GROUP = get(ENV, "GROUP", "All")
Expand All @@ -21,26 +23,27 @@ end

@time begin
if GROUP == "Core" || GROUP == "All"
@time @testset "Utils Tests" begin include("utils_test.jl") end
@time @testset "Partitions Tests" begin include("partitions_test.jl") end
@time @testset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@time @testset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
@time @testset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
@time @testset "Table traits" begin include("tabletraits.jl") end
@time @testset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
@time @testset "Linear Algebra Tests" begin include("linalg.jl") end
@time @testset "Upstream Tests" begin include("upstream.jl") end
@time @testset "Adjoint Tests" begin include("adjoints.jl") end
@time @safetestset "Utils Tests" begin include("utils_test.jl") end
@time @safetestset "Partitions Tests" begin include("partitions_test.jl") end
@time @safetestset "VecOfArr Indexing Tests" begin include("basic_indexing.jl") end
@time @safetestset "SymbolicIndexingInterface API test" begin include("symbolic_indexing_interface_test.jl") end
@time @safetestset "VecOfArr Interface Tests" begin include("interface_tests.jl") end
@time @safetestset "Table traits" begin include("tabletraits.jl") end
@time @safetestset "StaticArrays Tests" begin include("copy_static_array_test.jl") end
@time @safetestset "Linear Algebra Tests" begin include("linalg.jl") end
@time @safetestset "Upstream Tests" begin include("upstream.jl") end
@time @safetestset "Adjoint Tests" begin include("adjoints.jl") end
end

if !is_APPVEYOR && GROUP == "Downstream"
activate_downstream_env()
@time @testset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
@time @testset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
@time @safetestset "DiffEqArray Indexing Tests" begin include("downstream/symbol_indexing.jl") end
@time @safetestset "Event Tests with ArrayPartition" begin include("downstream/downstream_events.jl") end
@time @safetestset "TrackerExt" begin include("downstream/TrackerExt.jl") end
end

if !is_APPVEYOR && GROUP == "GPU"
activate_gpu_env()
@time @testset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
@time @safetestset "VectorOfArray GPU" begin include("gpu/vectorofarray_gpu.jl") end
end
end