diff --git a/Project.toml b/Project.toml index 7280c0ee..187f3600 100644 --- a/Project.toml +++ b/Project.toml @@ -6,9 +6,32 @@ version = "0.13.8" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" Requires = "ae029012-a4dd-5104-9daa-d747884805df" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[weakdeps] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" + +[extensions] +ConstructionBaseExt = "ConstructionBase" +GPUArraysExt = "GPUArrays" +RecursiveArrayToolsExt = "RecursiveArrayTools" +ReverseDiffExt = "ReverseDiff" +SciMLBaseExt = "SciMLBase" +StaticArraysExt = "StaticArrays" [compat] ArrayInterface = "6, 7" @@ -18,7 +41,13 @@ StaticArrayInterface = "1" julia = "1.6" [extras] +ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" +GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" +RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd" +ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" +SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462" +StaticArrays = "90137ffa-7385-5640-81b9-e52037218182" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] diff --git a/ext/ConstructionBaseExt.jl b/ext/ConstructionBaseExt.jl new file mode 100644 index 00000000..f43f2435 --- /dev/null +++ b/ext/ConstructionBaseExt.jl @@ -0,0 +1,8 @@ +module ConstructionBaseExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using ConstructionBase) : (using ..ConstructionBase) + +ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) + +end diff --git a/src/compat/gpuarrays.jl b/ext/GPUArraysExt.jl similarity index 98% rename from src/compat/gpuarrays.jl rename to ext/GPUArraysExt.jl index ace99e6a..1072d7b0 100644 --- a/src/compat/gpuarrays.jl +++ b/ext/GPUArraysExt.jl @@ -1,3 +1,8 @@ +module GPUArraysExt + +using ComponentArrays, LinearAlgebra +isdefined(Base, :get_extension) ? (using GPUArrays) : (using ..GPUArrays) + const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax} const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax} const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax} @@ -271,3 +276,5 @@ function LinearAlgebra.mul!(C::GPUComponentVecorMat, }, a::Real, b::Real) return GPUArrays.generic_matmatmul!(C, A, B, a, b) end + +end diff --git a/src/compat/recursivearraytools.jl b/ext/RecursiveArrayToolsExt.jl similarity index 58% rename from src/compat/recursivearraytools.jl rename to ext/RecursiveArrayToolsExt.jl index dce7b3cd..3cff26ea 100644 --- a/src/compat/recursivearraytools.jl +++ b/ext/RecursiveArrayToolsExt.jl @@ -1,5 +1,12 @@ +module RecursiveArrayToolsExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using RecursiveArrayTools) : (using ..RecursiveArrayTools) + AVOA = RecursiveArrayTools.AbstractVectorOfArray function Base.Array(VA::AVOA{T,N,A}) where {T,N,A<:AbstractVector{<:ComponentVector}} return ComponentArray(reduce(hcat, VA.u), only(getaxes(VA.u[1])), FlatAxis()) -end \ No newline at end of file +end + +end diff --git a/src/compat/reversediff.jl b/ext/ReverseDiffExt.jl similarity index 89% rename from src/compat/reversediff.jl rename to ext/ReverseDiffExt.jl index 87275583..d891dfaa 100644 --- a/src/compat/reversediff.jl +++ b/ext/ReverseDiffExt.jl @@ -1,3 +1,8 @@ +module ReverseDiffExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using ReverseDiff) : (using ..ReverseDiff) + const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA} maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape) @@ -25,4 +30,6 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol) t = ReverseDiff.tape(tca) return maybe_tracked_array(val, der, t, (s,), tca) end -end \ No newline at end of file +end + +end diff --git a/src/compat/scimlbase.jl b/ext/SciMLBaseExt.jl similarity index 68% rename from src/compat/scimlbase.jl rename to ext/SciMLBaseExt.jl index b90f2e4c..675ebb1b 100644 --- a/src/compat/scimlbase.jl +++ b/ext/SciMLBaseExt.jl @@ -1,4 +1,9 @@ # Plotting stuff +module SciMLBaseExt + +using ComponentArrays +isdefined(Base, :get_extension) ? (using SciMLBase) : (using ..SciMLBase) + function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N,C<:AbstractVector{<:ComponentArray}} if SciMLBase.has_syms(sol.prob.f) return sol.prob.f.syms @@ -6,3 +11,5 @@ function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N return Symbol.(labels(sol.u[1])) end end + +end diff --git a/src/compat/staticarrays.jl b/ext/StaticArraysExt.jl similarity index 51% rename from src/compat/staticarrays.jl rename to ext/StaticArraysExt.jl index 618bce10..ab774605 100644 --- a/src/compat/staticarrays.jl +++ b/ext/StaticArraysExt.jl @@ -1,5 +1,9 @@ -ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} = - ComponentArray(similar(A), ax...) +module StaticArraysExt +using ComponentArrays +isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays) +ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} = + ComponentArray(similar(A), ax...) +end diff --git a/src/ComponentArrays.jl b/src/ComponentArrays.jl index 7777eeda..cd6f2802 100644 --- a/src/ComponentArrays.jl +++ b/src/ComponentArrays.jl @@ -4,7 +4,10 @@ import ChainRulesCore import StaticArrayInterface, ArrayInterface using LinearAlgebra -using Requires + +if !isdefined(Base, :get_extension) + using Requires +end const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}} const FlatOrColonIdx = Union{FlatIdx, Colon} @@ -49,16 +52,15 @@ export labels, label2index include("compat/chainrulescore.jl") - -required(filename) = include(joinpath("compat", filename)) - function __init__() - @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("constructionbase.jl") - @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("scimlbase.jl") - @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl") - @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl") - @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl") - @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl") + @static if !isdefined(Base, :get_extension) + @require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" include("../ext/ConstructionBaseExt.jl") + @require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" include("../ext/SciMLBaseExt.jl") + @require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" include("../ext/RecursiveArrayToolsExt.jl") + @require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" include("../ext/StaticArraysExt.jl") + @require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl") + @require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" include("../ext/GPUArraysExt.jl") + end end -end \ No newline at end of file +end diff --git a/src/compat/constructionbase.jl b/src/compat/constructionbase.jl deleted file mode 100644 index a42bd79c..00000000 --- a/src/compat/constructionbase.jl +++ /dev/null @@ -1 +0,0 @@ -ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...) \ No newline at end of file diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index ccc78377..b0b35c4a 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -75,4 +75,4 @@ Base.NamedTuple(x::ComponentVector) = _namedtuple(x) ## AbstractAxis conversion and promotion -Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax \ No newline at end of file +Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax