From e752e74f02ee1b5dbf979db46830166154beb44f Mon Sep 17 00:00:00 2001 From: Shambles Date: Wed, 20 Jul 2022 23:25:59 -0600 Subject: [PATCH 01/19] Update chainrulescore.jl --- src/compat/chainrulescore.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/compat/chainrulescore.jl b/src/compat/chainrulescore.jl index 7709ef7d..f9406591 100644 --- a/src/compat/chainrulescore.jl +++ b/src/compat/chainrulescore.jl @@ -1,6 +1,6 @@ function ChainRulesCore.rrule(::typeof(getproperty), x::ComponentArray, s::Union{Symbol, Val}) function getproperty_adjoint(Δ) - zero_x = ComponentArray(zeros(eltype(Δ), size(x)), getaxes(x)) + zero_x = zero(similar(x, eltype(Δ))) setproperty!(zero_x, s, Δ) return (ChainRulesCore.NoTangent(), zero_x, ChainRulesCore.NoTangent()) end From d87e85c684529e3066d52ac13c9d5cf33cd5068f Mon Sep 17 00:00:00 2001 From: Shambles Date: Wed, 20 Jul 2022 23:38:38 -0600 Subject: [PATCH 02/19] Update Project.toml --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 400c08d1..9102fd3b 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.12.3" +version = "0.12.4" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" From d548520da84758ecb8f1f2e31f521451a001047d Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Thu, 21 Jul 2022 20:15:31 -0600 Subject: [PATCH 03/19] fix scalar indexing using ForwardDiff --- Project.toml | 2 +- src/similar_convert_copy.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/Project.toml b/Project.toml index 9102fd3b..9a10dd56 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ComponentArrays" uuid = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" authors = ["Jonnie Diegelman <47193959+jonniedie@users.noreply.github.com>"] -version = "0.12.4" +version = "0.12.5" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 39027a92..3e0e9bd4 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -33,7 +33,7 @@ Base.copy(x::ComponentArray) = ComponentArray(copy(getdata(x)), getaxes(x)) Base.copyto!(dest::AbstractArray, src::ComponentArray) = copyto!(dest, getdata(src)) function Base.copyto!(dest::ComponentArray, src::AbstractArray) - copyto!(getdata(dest), src) + copyto!(getdata(dest), collect(src)) return dest end function Base.copyto!(dest::ComponentArray, src::ComponentArray) @@ -73,4 +73,4 @@ Base.NamedTuple(x::ComponentVector) = _namedtuple(x) ## AbstractAxis conversion and promotion -Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax \ No newline at end of file +Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax From 422bdaea135c47c3cfd4b2cf96a91411ef0496c6 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Fri, 22 Jul 2022 13:38:19 -0600 Subject: [PATCH 04/19] a better way --- src/ComponentArrays.jl | 3 ++- src/compat/forwarddiff.jl | 4 ++++ src/similar_convert_copy.jl | 3 ++- 3 files changed, 8 insertions(+), 2 deletions(-) create mode 100644 src/compat/forwarddiff.jl diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 5e95c1df..970b60f4 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -59,6 +59,7 @@ function __init__() @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl") @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl") @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl") + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("forwarddiff.jl") end -end \ No newline at end of file +end diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl new file mode 100644 index 00000000..fe45aea6 --- /dev/null +++ b/src/compat/forwarddiff.jl @@ -0,0 +1,4 @@ +function Base.copyto!(dest::ComponentArray, src::ForwardDiff.Partials{N,V}) where {N,V} + copyto!(getdata(dest), collect(src)) + return dest +end diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 3e0e9bd4..a97f4573 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -33,9 +33,10 @@ Base.copy(x::ComponentArray) = ComponentArray(copy(getdata(x)), getaxes(x)) Base.copyto!(dest::AbstractArray, src::ComponentArray) = copyto!(dest, getdata(src)) function Base.copyto!(dest::ComponentArray, src::AbstractArray) - copyto!(getdata(dest), collect(src)) + copyto!(getdata(dest), src) return dest end + function Base.copyto!(dest::ComponentArray, src::ComponentArray) copyto!(getdata(dest), getdata(src)) return dest From e672fde46122c6347301a6741e22260309186a41 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Fri, 29 Jul 2022 11:52:18 -0600 Subject: [PATCH 05/19] GPUComponentArray --- src/compat/forwarddiff.jl | 2 +- src/similar_convert_copy.jl | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl index fe45aea6..59f2f291 100644 --- a/src/compat/forwarddiff.jl +++ b/src/compat/forwarddiff.jl @@ -1,4 +1,4 @@ -function Base.copyto!(dest::ComponentArray, src::ForwardDiff.Partials{N,V}) where {N,V} +function Base.copyto!(dest::GPUComponentArray, src::ForwardDiff.Partials{N,V}) where {N,V} copyto!(getdata(dest), collect(src)) return dest end diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index a97f4573..2c5c2910 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -36,7 +36,6 @@ function Base.copyto!(dest::ComponentArray, src::AbstractArray) copyto!(getdata(dest), src) return dest end - function Base.copyto!(dest::ComponentArray, src::ComponentArray) copyto!(getdata(dest), getdata(src)) return dest From 9471bef983d0a4c893d49b404f404a2b821cd71e Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Fri, 29 Jul 2022 12:10:08 -0600 Subject: [PATCH 06/19] test --- test/gpu_tests.jl | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 5a1a67f9..c61a731b 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -1,4 +1,5 @@ using JLArrays +using ForwardDiff JLArrays.allowscalar(false) @@ -7,7 +8,7 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4)) @testset "Broadcasting" begin @test identity.(jlca + jla) ./ 2 == jlca - + @test getdata(map(identity, jlca)) isa JLArray @test all(==(0), map(-, jlca, jla)) @test all(map(-, jlca, jlca) .== 0) @@ -15,10 +16,14 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4)) @test any(==(1), jlca) @test count(>(2), jlca) == 2 - + # Make sure mapreducing multiple arrays works @test mapreduce(==, +, jlca, jla) == 4 @test mapreduce(abs2, +, jlca) == 30 @test all(map(sin, jlca) .== sin.(jlca) .== sin.(jla) .≈ sin.(1:4)) end + +@testset "ForwardDiff" begin + ForwardDiff.gradient(sum, jlca) == one.(jlca) +end From fade2b7a97313636724491dd54014a811cb333fc Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Fri, 29 Jul 2022 13:45:56 -0600 Subject: [PATCH 07/19] fix --- src/ComponentArrays.jl | 4 +++- src/compat/{forwarddiff.jl => forwarddiff_gpu.jl} | 2 +- src/compat/gpuarrays.jl | 4 ++-- test/gpu_tests.jl | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) rename src/compat/{forwarddiff.jl => forwarddiff_gpu.jl} (82%) diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 0185d0bc..ff096a16 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -59,8 +59,10 @@ function __init__() @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl") @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl") @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl") - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("forwarddiff.jl") @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl") + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin + @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("forwarddiff_gpu.jl") + end end end diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff_gpu.jl similarity index 82% rename from src/compat/forwarddiff.jl rename to src/compat/forwarddiff_gpu.jl index 59f2f291..3d6720ba 100644 --- a/src/compat/forwarddiff.jl +++ b/src/compat/forwarddiff_gpu.jl @@ -1,4 +1,4 @@ -function Base.copyto!(dest::GPUComponentArray, src::ForwardDiff.Partials{N,V}) where {N,V} +function Base.copyto!(dest::GPUComponentArray, src::ForwardDiff.Partials) copyto!(getdata(dest), collect(src)) return dest end diff --git a/src/compat/gpuarrays.jl b/src/compat/gpuarrays.jl index c95561a7..da2e2edf 100644 --- a/src/compat/gpuarrays.jl +++ b/src/compat/gpuarrays.jl @@ -1,4 +1,4 @@ -const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax} +const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax<:Tuple{Vararg{AbstractAxis}}} GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) @@ -46,4 +46,4 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), Base.$(fname!)(f::Function, r::GPUComponentArray, A::GPUComponentArray{T}) where T = GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T)) end -end \ No newline at end of file +end diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index c61a731b..2f1ea6e7 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -1,5 +1,5 @@ using JLArrays -using ForwardDiff +import ForwardDiff JLArrays.allowscalar(false) From b4443d6187daddd8f005374d2a93caa086170cf3 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 30 Jul 2022 01:04:31 -0600 Subject: [PATCH 08/19] jacobian --- src/ComponentArrays.jl | 4 +--- src/compat/forwarddiff.jl | 1 + src/compat/forwarddiff_gpu.jl | 4 ---- test/gpu_tests.jl | 2 +- 4 files changed, 3 insertions(+), 8 deletions(-) create mode 100644 src/compat/forwarddiff.jl delete mode 100644 src/compat/forwarddiff_gpu.jl diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index ff096a16..91a781f3 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -60,9 +60,7 @@ function __init__() @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl") @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl") @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl") - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" begin - @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("forwarddiff_gpu.jl") - end + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("forwarddiff.jl") end end diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl new file mode 100644 index 00000000..22d1782f --- /dev/null +++ b/src/compat/forwarddiff.jl @@ -0,0 +1 @@ +ForwardDiff.jacobian(f, x::ComponentArray, args...) = ForwardDiff.jacobian(f, getdata(x), args...) diff --git a/src/compat/forwarddiff_gpu.jl b/src/compat/forwarddiff_gpu.jl deleted file mode 100644 index 3d6720ba..00000000 --- a/src/compat/forwarddiff_gpu.jl +++ /dev/null @@ -1,4 +0,0 @@ -function Base.copyto!(dest::GPUComponentArray, src::ForwardDiff.Partials) - copyto!(getdata(dest), collect(src)) - return dest -end diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index 2f1ea6e7..034262c9 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -25,5 +25,5 @@ jlca = ComponentArray(jla, Axis(a=1:2, b=3:4)) end @testset "ForwardDiff" begin - ForwardDiff.gradient(sum, jlca) == one.(jlca) + @test ForwardDiff.jacobian(identity, jlca) == ForwardDiff.jacobian(identity, jla) end From f258f82167157730cd9f6f15804813868e6cbb79 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 12:42:48 -0700 Subject: [PATCH 09/19] Use extension --- Project.toml | 31 +++++++++++++++++++ ext/ConstructionBaseExt.jl | 8 +++++ ext/ForwardDiffExt.jl | 8 +++++ .../gpuarrays.jl => ext/GPUArraysExt.jl | 7 +++++ .../RecursiveArrayToolsExt.jl | 9 +++++- .../reversediff.jl => ext/ReverseDiffExt.jl | 9 +++++- .../scimlbase.jl => ext/SciMLBaseExt.jl | 7 +++++ .../staticarrays.jl => ext/StaticArraysExt.jl | 8 +++-- src/ComponentArrays.jl | 21 ++++++++----- src/compat/constructionbase.jl | 1 - src/compat/forwarddiff.jl | 1 - 11 files changed, 96 insertions(+), 14 deletions(-) create mode 100644 ext/ConstructionBaseExt.jl create mode 100644 ext/ForwardDiffExt.jl rename src/compat/gpuarrays.jl => ext/GPUArraysExt.jl (94%) rename src/compat/recursivearraytools.jl => ext/RecursiveArrayToolsExt.jl (58%) rename src/compat/reversediff.jl => ext/ReverseDiffExt.jl (89%) rename src/compat/scimlbase.jl => ext/SciMLBaseExt.jl (68%) rename src/compat/staticarrays.jl => ext/StaticArraysExt.jl (51%) delete mode 100644 src/compat/constructionbase.jl delete mode 100644 src/compat/forwarddiff.jl diff --git a/Project.toml b/Project.toml index 8bf355c3..a2b44326 100644 --- a/Project.toml +++ b/Project.toml @@ -6,8 +6,33 @@ version = "0.13.1" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" + +[extensions] +ConstructionBaseExt = "ConstructionBase" +SciMLBaseExt = "SciMLBase" +RecursiveArrayToolsExt = "RecursiveArrayTools" +StaticArraysExt = "StaticArrays" +ReverseDiffExt = "ReverseDiff" +GPUArraysExt = "GPUArrays" +ForwardDiffExt = "ForwardDiff" [compat] ArrayInterface = "6" @@ -18,6 +43,12 @@ julia = "1.6" [extras] ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" [targets] test = ["Test"] diff --git a/ext/ConstructionBaseExt.jl b/ext/ConstructionBaseExt.jl new file mode 100644 index 00000000..f43f2435 --- /dev/null +++ b/ext/ConstructionBaseExt.jl @@ -0,0 +1,8 @@ +module ConstructionBaseExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using ConstructionBase) : (using ..ConstructionBase) + +ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) + +end diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl new file mode 100644 index 00000000..b28a36a8 --- /dev/null +++ b/ext/ForwardDiffExt.jl @@ -0,0 +1,8 @@ +module ForwardDiffExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) + +ForwardDiff.jacobian(f, x::ComponentArray, args...) = ForwardDiff.jacobian(f, getdata(x), args...) + +end diff --git a/src/compat/gpuarrays.jl b/ext/GPUArraysExt.jl similarity index 94% rename from src/compat/gpuarrays.jl rename to ext/GPUArraysExt.jl index da2e2edf..8b4eead2 100644 --- a/src/compat/gpuarrays.jl +++ b/ext/GPUArraysExt.jl @@ -1,3 +1,8 @@ +module GPUArraysExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using GPUArrays) : (using ..GPUArrays) + const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax<:Tuple{Vararg{AbstractAxis}}} GPUArrays.backend(x::ComponentArray) = GPUArrays.backend(getdata(x)) @@ -47,3 +52,5 @@ for (fname, op) in [(:sum, :(Base.add_sum)), (:prod, :(Base.mul_prod)), GPUArrays.mapreducedim!(f, $(op), getdata(r), getdata(A); init=neutral_element($(op), T)) end end + +end diff --git a/src/compat/recursivearraytools.jl b/ext/RecursiveArrayToolsExt.jl similarity index 58% rename from src/compat/recursivearraytools.jl rename to ext/RecursiveArrayToolsExt.jl index dce7b3cd..616dc6eb 100644 --- a/src/compat/recursivearraytools.jl +++ b/ext/RecursiveArrayToolsExt.jl @@ -1,5 +1,12 @@ +module RecursiveArrayToolsExt + +using CompoentArrays +isdefined(Base, :get_extension) ? (using RecursiveArrayTools) : (using ..RecursiveArrayTools) + AVOA = RecursiveArrayTools.AbstractVectorOfArray function Base.Array(VA::AVOA{T,N,A}) where {T,N,A<:AbstractVector{<:ComponentVector}} return ComponentArray(reduce(hcat, VA.u), only(getaxes(VA.u[1])), FlatAxis()) -end \ No newline at end of file +end + +end diff --git a/src/compat/reversediff.jl b/ext/ReverseDiffExt.jl similarity index 89% rename from src/compat/reversediff.jl rename to ext/ReverseDiffExt.jl index 87275583..d891dfaa 100644 --- a/src/compat/reversediff.jl +++ b/ext/ReverseDiffExt.jl @@ -1,3 +1,8 @@ +module ReverseDiffExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using ReverseDiff) : (using ..ReverseDiff) + const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape) @@ -25,4 +30,6 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol) t = ReverseDiff.tape(tca) return maybe_tracked_array(val, der, t, (s,), tca) end -end \ No newline at end of file +end + +end diff --git a/src/compat/scimlbase.jl b/ext/SciMLBaseExt.jl similarity index 68% rename from src/compat/scimlbase.jl rename to ext/SciMLBaseExt.jl index b90f2e4c..675ebb1b 100644 --- a/src/compat/scimlbase.jl +++ b/ext/SciMLBaseExt.jl @@ -1,4 +1,9 @@ # Plotting stuff +module SciMLBaseExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using SciMLBase) : (using ..SciMLBase) + function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N,C<:AbstractVector{<:ComponentArray}} if SciMLBase.has_syms(sol.prob.f) return sol.prob.f.syms @@ -6,3 +11,5 @@ function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N return Symbol.(labels(sol.u[1])) end end + +end diff --git a/src/compat/staticarrays.jl b/ext/StaticArraysExt.jl similarity index 51% rename from src/compat/staticarrays.jl rename to ext/StaticArraysExt.jl index 618bce10..ab774605 100644 --- a/src/compat/staticarrays.jl +++ b/ext/StaticArraysExt.jl @@ -1,5 +1,9 @@ -ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} = - ComponentArray(similar(A), ax...) +module StaticArraysExt +using ComponentArrays +isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays) +ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} = + ComponentArray(similar(A), ax...) +end diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 96d65736..a0fefbce 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -5,7 +5,10 @@ import ArrayInterface import ArrayInterface.ArrayInterfaceCore using LinearAlgebra -using Requires + +if !isdefined(Base, :get_extension) + using Requires +end const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}} const FlatOrColonIdx = Union{FlatIdx, Colon} @@ -54,13 +57,15 @@ include("compat/chainrulescore.jl") required(filename) = include(joinpath("compat", filename)) function __init__() - @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("constructionbase.jl") - @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("scimlbase.jl") - @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl") - @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl") - @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl") - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("forwarddiff.jl") + @static if !isdefined(Base, :get_extension) + @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("../ext/ConstructionBaseExt.jl") + @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("../ext/SciMLBaseExt.jl") + @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("../ext/RecursiveArrayToolsExt.jl") + @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("../ext/StaticArraysExt.jl") + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("../ext/ReverseDiffExt.jl") + @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("../ext/GPUArraysExt.jl") + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("../ext/ForwardDiffExt.jl") + end end end diff --git a/src/compat/constructionbase.jl b/src/compat/constructionbase.jl deleted file mode 100644 index a42bd79c..00000000 --- a/src/compat/constructionbase.jl +++ /dev/null @@ -1 +0,0 @@ -ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) \ No newline at end of file diff --git a/src/compat/forwarddiff.jl b/src/compat/forwarddiff.jl deleted file mode 100644 index 22d1782f..00000000 --- a/src/compat/forwarddiff.jl +++ /dev/null @@ -1 +0,0 @@ -ForwardDiff.jacobian(f, x::ComponentArray, args...) = ForwardDiff.jacobian(f, getdata(x), args...) From 054f3b8b3b7fbb71531c2d5b8cba24f0b513c977 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 12:51:23 -0700 Subject: [PATCH 10/19] Update RecursiveArrayToolsExt.jl --- ext/RecursiveArrayToolsExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/RecursiveArrayToolsExt.jl b/ext/RecursiveArrayToolsExt.jl index 616dc6eb..3cff26ea 100644 --- a/ext/RecursiveArrayToolsExt.jl +++ b/ext/RecursiveArrayToolsExt.jl @@ -1,6 +1,6 @@ module RecursiveArrayToolsExt -using CompoentArrays +using ComponentArrays isdefined(Base, :get_extension) ? (using RecursiveArrayTools) : (using ..RecursiveArrayTools) AVOA = RecursiveArrayTools.AbstractVectorOfArray From f5b5e893e822f5d43b0f56ff3ab693c6d3dd0005 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:23:51 -0700 Subject: [PATCH 11/19] Update gpu_tests.jl --- test/gpu_tests.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/test/gpu_tests.jl b/test/gpu_tests.jl index a34b6bdf..4a345c47 100644 --- a/test/gpu_tests.jl +++ b/test/gpu_tests.jl @@ -1,5 +1,4 @@ using JLArrays -import ForwardDiff JLArrays.allowscalar(false) @@ -68,7 +67,3 @@ end @test_nowarn mul!(deepcopy(A), A', transpose(A), 1, 2); end end - -@testset "ForwardDiff" begin - @test ForwardDiff.jacobian(identity, jlca) == ForwardDiff.jacobian(identity, jla) -end From fe16327b936737c017222f70193e0e7e0fbb0a33 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:28:54 -0700 Subject: [PATCH 12/19] Update ComponentArrays.jl --- src/ComponentArrays.jl | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 56fc216c..52841bde 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -53,17 +53,16 @@ export labels, label2index include("compat/chainrulescore.jl") -required(filename) = include(joinpath("compat", filename)) function __init__() @static if !isdefined(Base, :get_extension) - @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("../ext/ConstructionBaseExt.jl") - @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("../ext/SciMLBaseExt.jl") - @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("../ext/RecursiveArrayToolsExt.jl") - @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("../ext/StaticArraysExt.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("../ext/ReverseDiffExt.jl") - @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("../ext/GPUArraysExt.jl") - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" required("../ext/ForwardDiffExt.jl") + @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" include("../ext/ConstructionBaseExt.jl") + @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" include("../ext/SciMLBaseExt.jl") + @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" include("../ext/RecursiveArrayToolsExt.jl") + @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" include("../ext/StaticArraysExt.jl") + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl") + @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" include("../ext/GPUArraysExt.jl") + @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/ForwardDiffExt.jl") end end From ed323a0a9197d8c849b6c0fc4317abb09eefb55e Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:40:16 -0700 Subject: [PATCH 13/19] Update ComponentArrays.jl --- src/ComponentArrays.jl | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 52841bde..cdb21d70 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -52,8 +52,6 @@ export labels, label2index include("compat/chainrulescore.jl") - - function __init__() @static if !isdefined(Base, :get_extension) @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" include("../ext/ConstructionBaseExt.jl") From b393ed53593eb9586a42b80a41adf24d74409909 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:40:17 -0700 Subject: [PATCH 14/19] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index d8419efb..8c662c18 100644 --- a/Project.toml +++ b/Project.toml @@ -25,6 +25,7 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [extensions] ConstructionBaseExt = "ConstructionBase" From 4ada8d6456b51721cfb4db47759849ae832dcb3f Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:43:23 -0700 Subject: [PATCH 15/19] Update GPUArraysExt.jl --- ext/GPUArraysExt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ext/GPUArraysExt.jl b/ext/GPUArraysExt.jl index d1a18e3c..1072d7b0 100644 --- a/ext/GPUArraysExt.jl +++ b/ext/GPUArraysExt.jl @@ -1,6 +1,6 @@ module GPUArraysExt -using ComponentArrays +using ComponentArrays, LinearAlgebra isdefined(Base, :get_extension) ? (using GPUArrays) : (using ..GPUArrays) const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax} From da66653dfe74865280c30110ac370cbd309aef42 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:43:26 -0700 Subject: [PATCH 16/19] Update Project.toml --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 8c662c18..b3d37a8d 100644 --- a/Project.toml +++ b/Project.toml @@ -52,6 +52,7 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" [targets] test = ["Test"] From 23f680f629b3eba29cec77bdceb2a3c321833da7 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 13:47:12 -0700 Subject: [PATCH 17/19] Update Project.toml --- Project.toml | 30 ++++++++++++++---------------- 1 file changed, 14 insertions(+), 16 deletions(-) diff --git a/Project.toml b/Project.toml index b3d37a8d..7ba9b538 100644 --- a/Project.toml +++ b/Project.toml @@ -14,27 +14,26 @@ RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Requires = "ae029012-a4dd-5104-9daa-d747884805df" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] ConstructionBaseExt = "ConstructionBase" -SciMLBaseExt = "SciMLBase" +ForwardDiffExt = "ForwardDiff" +GPUArraysExt = "GPUArrays" RecursiveArrayToolsExt = "RecursiveArrayTools" -StaticArraysExt = "StaticArrays" ReverseDiffExt = "ReverseDiff" -GPUArraysExt = "GPUArrays" -ForwardDiffExt = "ForwardDiff" +SciMLBaseExt = "SciMLBase" +StaticArraysExt = "StaticArrays" [compat] ArrayInterface = "6, 7" @@ -44,15 +43,14 @@ StaticArrayInterface = "1" julia = "1.6" [extras] -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" -StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" -GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" -LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test"] From 6f9baf468b88aedca2a541c63c9b5685d67014e0 Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 14:01:53 -0700 Subject: [PATCH 18/19] remove forwarddiffext --- Project.toml | 2 -- ext/ForwardDiffExt.jl | 8 -------- 2 files changed, 10 deletions(-) delete mode 100644 ext/ForwardDiffExt.jl diff --git a/Project.toml b/Project.toml index 7ba9b538..187f3600 100644 --- a/Project.toml +++ b/Project.toml @@ -19,7 +19,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [weakdeps] ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" -ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" @@ -28,7 +27,6 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" [extensions] ConstructionBaseExt = "ConstructionBase" -ForwardDiffExt = "ForwardDiff" GPUArraysExt = "GPUArrays" RecursiveArrayToolsExt = "RecursiveArrayTools" ReverseDiffExt = "ReverseDiff" diff --git a/ext/ForwardDiffExt.jl b/ext/ForwardDiffExt.jl deleted file mode 100644 index b28a36a8..00000000 --- a/ext/ForwardDiffExt.jl +++ /dev/null @@ -1,8 +0,0 @@ -module ForwardDiffExt - -using ComponentArrays -isdefined(Base, :get_extension) ? (using ForwardDiff) : (using ..ForwardDiff) - -ForwardDiff.jacobian(f, x::ComponentArray, args...) = ForwardDiff.jacobian(f, getdata(x), args...) - -end From c216454552d0d397aa06a9ddcb9148f57b8b0bcb Mon Sep 17 00:00:00 2001 From: MilkshakeForReal Date: Sat, 11 Mar 2023 14:09:31 -0700 Subject: [PATCH 19/19] Update ComponentArrays.jl --- src/ComponentArrays.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index cdb21d70..cd6f2802 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -60,7 +60,6 @@ function __init__() @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" include("../ext/StaticArraysExt.jl") @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl") @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" include("../ext/GPUArraysExt.jl") - @require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("../ext/ForwardDiffExt.jl") end end