Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/Tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
group:
- "Core"
- "Downstream"
- "JET"
uses: "SciML/.github/.github/workflows/tests.yml@v1"
with:
group: "${{ matrix.group }}"
Expand Down
6 changes: 4 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "RecursiveArrayTools"
uuid = "731186ca-8d62-57ce-b412-fbd966d074cd"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]
version = "3.41.0"
authors = ["Chris Rackauckas <accounts@chrisrackauckas.com>"]

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -48,6 +48,7 @@ DocStringExtensions = "0.9.3"
FastBroadcast = "0.3.5"
ForwardDiff = "0.10.38, 1"
GPUArraysCore = "0.2"
JET = "0.9, 0.11"
KernelAbstractions = "0.9.36"
LinearAlgebra = "1.10"
Measurements = "2.11"
Expand Down Expand Up @@ -76,6 +77,7 @@ julia = "1.10"
Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
FastBroadcast = "7034ab61-46d4-4ed7-9d0f-46aef9175898"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
JET = "c3a54625-cd67-489e-a8e7-0a5a0ff4e31b"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
Measurements = "eff96d63-e80a-5855-80a2-b1b0885c5ab7"
MonteCarloMeasurements = "0987c9cc-fe09-11e8-30f0-b96dd679fdca"
Expand All @@ -93,4 +95,4 @@ Unitful = "1986cc42-f94f-5a68-af5c-568840ba703d"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
test = ["Aqua", "FastBroadcast", "ForwardDiff", "JET", "KernelAbstractions", "Measurements", "NLsolve", "Pkg", "Random", "SafeTestsets", "SciMLBase", "SparseArrays", "StaticArrays", "StructArrays", "Tables", "Test", "Unitful", "Zygote"]
8 changes: 4 additions & 4 deletions src/named_array_partition.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""
NamedArrayPartition(; kwargs...)
NamedArrayPartition(x::NamedTuple)
NamedArrayPartition(x::NamedTuple)

Similar to an `ArrayPartition` but the individual arrays can be accessed via the
constructor-specified names. However, unlike `ArrayPartition`, each individual array
Expand All @@ -22,7 +22,7 @@ function NamedArrayPartition(x::NamedTuple)
return NamedArrayPartition(ArrayPartition{T, S}(values(x)), names_to_indices)
end

# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# Note: overloading `getproperty` means we cannot access `NamedArrayPartition`
# fields except through `getfield` and accessor functions.
ArrayPartition(x::NamedArrayPartition) = getfield(x, :array_partition)

Expand Down Expand Up @@ -53,7 +53,7 @@ end
function Base.similar(
A::NamedArrayPartition, ::Type{T}, ::Type{S}, R::DataType...) where {T, S}
NamedArrayPartition(
similar(getfield(A, :array_partition), T, S, R), getfield(A, :names_to_indices))
similar(getfield(A, :array_partition), T, S, R...), getfield(A, :names_to_indices))
end

Base.Array(x::NamedArrayPartition) = Array(ArrayPartition(x))
Expand All @@ -68,7 +68,7 @@ function Base.getproperty(x::NamedArrayPartition, s::Symbol)
getindex(ArrayPartition(x).x, getproperty(getfield(x, :names_to_indices), s))
end

# this enables x.s = some_array.
# this enables x.s = some_array.
@inline function Base.setproperty!(x::NamedArrayPartition, s::Symbol, v)
index = getproperty(getfield(x, :names_to_indices), s)
ArrayPartition(x).x[index] .= v
Expand Down
6 changes: 4 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ function recursivefill!(b::AbstractArray{T, N},
a::T2) where {T <: StaticArraysCore.SArray,
T2 <: Union{Number, Bool}, N}
@inbounds for i in eachindex(b)
b[i] = fill(a, typeof(b[i]))
# Preserve static array shape while replacing all entries with the scalar
b[i] = map(_ -> a, b[i])
end
end

Expand All @@ -128,7 +129,8 @@ function recursivefill!(bs::AbstractVectorOfArray{T, N},
T2 <: Union{Number, Bool}, N}
@inbounds for b in bs, i in eachindex(b)

b[i] = fill(a, typeof(b[i]))
# Preserve static array shape while replacing all entries with the scalar
b[i] = map(_ -> a, b[i])
end
end

Expand Down
3 changes: 3 additions & 0 deletions src/vector_of_array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -924,6 +924,8 @@ function Base.view(A::AbstractVectorOfArray{T, N, <:AbstractVector{T}},
J = map(i -> Base.unalias(A, i), to_indices(A, I))
elseif length(I) == 2 && (I[1] == Colon() || I[1] == 1)
J = map(i -> Base.unalias(A, i), to_indices(A, Base.tail(I)))
else
J = map(i -> Base.unalias(A, i), to_indices(A, I))
end
@boundscheck checkbounds(A, J...)
SubArray(A, J)
Expand Down Expand Up @@ -1200,6 +1202,7 @@ end

struct VectorOfArrayStyle{N} <: Broadcast.AbstractArrayStyle{N} end # N is only used when voa sees other abstract arrays
VectorOfArrayStyle{N}(::Val{N}) where {N} = VectorOfArrayStyle{N}()
VectorOfArrayStyle(::Val{N}) where {N} = VectorOfArrayStyle{N}()

# The order is important here. We want to override Base.Broadcast.DefaultArrayStyle to return another Base.Broadcast.DefaultArrayStyle.
Broadcast.BroadcastStyle(a::VectorOfArrayStyle, ::Base.Broadcast.DefaultArrayStyle{0}) = a
Expand Down
14 changes: 14 additions & 0 deletions test/jet_tests.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using JET, Test, RecursiveArrayTools

# Get all reports first
result = JET.report_package(RecursiveArrayTools; target_modules = (RecursiveArrayTools,))
reports = JET.get_reports(result)

# Filter out similar_type inference errors from StaticArraysCore
filtered_reports = filter(reports) do report
s = string(report)
!(occursin("similar_type", s) && occursin("StaticArraysCore", s))
end

# Check if there are any non-filtered errors
@test isempty(filtered_reports)
4 changes: 4 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,8 @@ end
@time @safetestset "VectorOfArray GPU" include("gpu/vectorofarray_gpu.jl")
@time @safetestset "ArrayPartition GPU" include("gpu/arraypartition_gpu.jl")
end

if GROUP == "JET" || GROUP == "All"
@time @safetestset "JET Tests" include("jet_tests.jl")
end
end
Loading