Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,32 @@ version = "0.13.8"
[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrayInterface = "0d7ed370-da01-4f52-bd93-41d350b8b718"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[weakdeps]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

[extensions]
ConstructionBaseExt = "ConstructionBase"
GPUArraysExt = "GPUArrays"
RecursiveArrayToolsExt = "RecursiveArrayTools"
ReverseDiffExt = "ReverseDiff"
SciMLBaseExt = "SciMLBase"
StaticArraysExt = "StaticArrays"

[compat]
ArrayInterface = "6, 7"
Expand All @@ -18,7 +41,13 @@ StaticArrayInterface = "1"
julia = "1.6"

[extras]
ConstructionBase = "187b0558-2788-49d3-abe0-74a17ed4e7c9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
GPUArrays = "0c68f7d7-f131-5f86-a1c3-88cf8149b2d7"
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SciMLBase = "0bca4576-84f4-4d90-8ffe-ffa030f20462"
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
Expand Down
8 changes: 8 additions & 0 deletions ext/ConstructionBaseExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
module ConstructionBaseExt

using ComponentArrays
isdefined(Base, :get_extension) ? (using ConstructionBase) : (using ..ConstructionBase)

ConstructionBase.setproperties(x::ComponentVector, patch::NamedTuple) = ComponentVector(x; patch...)

end
7 changes: 7 additions & 0 deletions src/compat/gpuarrays.jl → ext/GPUArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module GPUArraysExt

using ComponentArrays, LinearAlgebra
isdefined(Base, :get_extension) ? (using GPUArrays) : (using ..GPUArrays)

const GPUComponentArray = ComponentArray{T,N,<:GPUArrays.AbstractGPUArray,Ax} where {T,N,Ax}
const GPUComponentVector{T,Ax} = ComponentArray{T,1,<:GPUArrays.AbstractGPUVector,Ax}
const GPUComponentMatrix{T,Ax} = ComponentArray{T,2,<:GPUArrays.AbstractGPUMatrix,Ax}
Expand Down Expand Up @@ -271,3 +276,5 @@ function LinearAlgebra.mul!(C::GPUComponentVecorMat,
}, a::Real, b::Real)
return GPUArrays.generic_matmatmul!(C, A, B, a, b)
end

end
Original file line number Diff line number Diff line change
@@ -1,5 +1,12 @@
module RecursiveArrayToolsExt

using ComponentArrays
isdefined(Base, :get_extension) ? (using RecursiveArrayTools) : (using ..RecursiveArrayTools)

AVOA = RecursiveArrayTools.AbstractVectorOfArray

function Base.Array(VA::AVOA{T,N,A}) where {T,N,A<:AbstractVector{<:ComponentVector}}
return ComponentArray(reduce(hcat, VA.u), only(getaxes(VA.u[1])), FlatAxis())
end
end

end
9 changes: 8 additions & 1 deletion src/compat/reversediff.jl → ext/ReverseDiffExt.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,8 @@
module ReverseDiffExt

using ComponentArrays
isdefined(Base, :get_extension) ? (using ReverseDiff) : (using ..ReverseDiff)

const TrackedComponentArray{V, D, N, DA, A, Ax} = ReverseDiff.TrackedArray{V,D,N,ComponentArray{V,N,A,Ax},DA}

maybe_tracked_array(val::AbstractArray, der, tape, inds, origin) = ReverseDiff.TrackedArray(val, der, tape)
Expand Down Expand Up @@ -25,4 +30,6 @@ function Base.getproperty(tca::TrackedComponentArray, s::Symbol)
t = ReverseDiff.tape(tca)
return maybe_tracked_array(val, der, t, (s,), tca)
end
end
end

end
7 changes: 7 additions & 0 deletions src/compat/scimlbase.jl → ext/SciMLBaseExt.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
# Plotting stuff
module SciMLBaseExt

using ComponentArrays
isdefined(Base, :get_extension) ? (using SciMLBase) : (using ..SciMLBase)

function SciMLBase.getsyms(sol::SciMLBase.AbstractODESolution{T,N,C}) where {T,N,C<:AbstractVector{<:ComponentArray}}
if SciMLBase.has_syms(sol.prob.f)
return sol.prob.f.syms
else
return Symbol.(labels(sol.u[1]))
end
end

end
8 changes: 6 additions & 2 deletions src/compat/staticarrays.jl → ext/StaticArraysExt.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} =
ComponentArray(similar(A), ax...)
module StaticArraysExt

using ComponentArrays
isdefined(Base, :get_extension) ? (using StaticArrays) : (using ..StaticArrays)

ComponentArray{A}(::UndefInitializer, ax::Axes) where {A<:StaticArrays.StaticArray,Axes<:Tuple} =
ComponentArray(similar(A), ax...)

end
24 changes: 13 additions & 11 deletions src/ComponentArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ import ChainRulesCore
import StaticArrayInterface, ArrayInterface

using LinearAlgebra
using Requires

if !isdefined(Base, :get_extension)
using Requires
end

const FlatIdx = Union{Integer, CartesianIndex, CartesianIndices, AbstractArray{<:Integer}}
const FlatOrColonIdx = Union{FlatIdx, Colon}
Expand Down Expand Up @@ -49,16 +52,15 @@ export labels, label2index

include("compat/chainrulescore.jl")


required(filename) = include(joinpath("compat", filename))

function __init__()
@require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" required("constructionbase.jl")
@require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" required("scimlbase.jl")
@require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" required("recursivearraytools.jl")
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" required("staticarrays.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" required("reversediff.jl")
@require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" required("gpuarrays.jl")
@static if !isdefined(Base, :get_extension)
@require ConstructionBase="187b0558-2788-49d3-abe0-74a17ed4e7c9" include("../ext/ConstructionBaseExt.jl")
@require SciMLBase="0bca4576-84f4-4d90-8ffe-ffa030f20462" include("../ext/SciMLBaseExt.jl")
@require RecursiveArrayTools="731186ca-8d62-57ce-b412-fbd966d074cd" include("../ext/RecursiveArrayToolsExt.jl")
@require StaticArrays="90137ffa-7385-5640-81b9-e52037218182" include("../ext/StaticArraysExt.jl")
@require ReverseDiff="37e2e3b7-166d-5795-8a7a-e32c996b4267" include("../ext/ReverseDiffExt.jl")
@require GPUArrays="0c68f7d7-f131-5f86-a1c3-88cf8149b2d7" include("../ext/GPUArraysExt.jl")
end
end

end
end
1 change: 0 additions & 1 deletion src/compat/constructionbase.jl

This file was deleted.

2 changes: 1 addition & 1 deletion src/similar_convert_copy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -75,4 +75,4 @@ Base.NamedTuple(x::ComponentVector) = _namedtuple(x)


## AbstractAxis conversion and promotion
Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax
Base.convert(::Type{Ax}, ax::AbstractAxis) where {Ax<:AbstractAxis} = ax