From a33f79422605eee064a1f7e56504306435c1154a Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 26 Jun 2024 09:48:25 +0000 Subject: [PATCH 1/4] feat: preserve indices when copying tracked arrays --- src/similar_convert_copy.jl | 3 +++ test/runtests.jl | 5 +++++ 2 files changed, 8 insertions(+) diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index d8cecc02..22c500ed 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -55,6 +55,9 @@ end function Base.convert(::Type{ComponentArray{T1,N,A1,Ax1}}, x::ComponentArray{T2,N,A2,Ax2}) where {T1,T2,N,A1,A2,Ax1,Ax2} return T1.(x) end +function Base.convert(::Type{ComponentArray{T,N,A1,Ax1}}, x::ComponentArray{T,N,A2,Ax2}) where {T,N,A1,A2,Ax1,Ax2} + return x +end Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x)) Base.convert(::Type{Cholesky{T1,Matrix{T1}}}, x::Cholesky{T2,<:ComponentArray}) where {T1,T2} = Cholesky(Matrix{T1}(x.factors), x.uplo, x.info) diff --git a/test/runtests.jl b/test/runtests.jl index da0acf1d..b0665dba 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using ComponentArrays using BenchmarkTools using ForwardDiff +using Tracker using InvertedIndices using LabelledArrays using LinearAlgebra @@ -400,6 +401,10 @@ end @test convert(Array, ca) == getdata(ca) @test convert(Matrix{Float32}, cmat) isa Matrix{Float32} + + tr = Tracker.param(ca) + ca_ = convert(typeof(ca), tr) + @test ca_.x == ca.x end @testset "Broadcasting" begin From c4e33dfb0c52b01a3489dafc2b92593e1518655e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 26 Jun 2024 12:46:35 +0000 Subject: [PATCH 2/4] test: correct field name --- test/runtests.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/runtests.jl b/test/runtests.jl index b0665dba..8b9c6c6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -404,7 +404,7 @@ end tr = Tracker.param(ca) ca_ = convert(typeof(ca), tr) - @test ca_.x == ca.x + @test ca_.a == ca.a end @testset "Broadcasting" begin From 81df7b34c30f4a8904efffd9250a50a3730a2678 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 26 Jun 2024 19:41:31 +0000 Subject: [PATCH 3/4] chore: match axes in convert --- src/similar_convert_copy.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 22c500ed..6b31c1fb 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -55,7 +55,7 @@ end function Base.convert(::Type{ComponentArray{T1,N,A1,Ax1}}, x::ComponentArray{T2,N,A2,Ax2}) where {T1,T2,N,A1,A2,Ax1,Ax2} return T1.(x) end -function Base.convert(::Type{ComponentArray{T,N,A1,Ax1}}, x::ComponentArray{T,N,A2,Ax2}) where {T,N,A1,A2,Ax1,Ax2} +function Base.convert(::Type{ComponentArray{T,N,A1,Ax}}, x::ComponentArray{T,N,A2,Ax}) where {T,N,A1,A2,Ax} return x end Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x)) From 77b31e5109dae8be31df65e06cedec11e0bd8c91 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Wed, 26 Jun 2024 19:57:46 +0000 Subject: [PATCH 4/4] chore: define method to avoid ambiguity --- src/similar_convert_copy.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/similar_convert_copy.jl b/src/similar_convert_copy.jl index 6b31c1fb..1a43265d 100644 --- a/src/similar_convert_copy.jl +++ b/src/similar_convert_copy.jl @@ -58,6 +58,9 @@ end function Base.convert(::Type{ComponentArray{T,N,A1,Ax}}, x::ComponentArray{T,N,A2,Ax}) where {T,N,A1,A2,Ax} return x end +function Base.convert(::Type{ComponentArray{T,N,A,Ax}}, x::ComponentArray{T,N,A,Ax}) where {T,N,A,Ax} + return x +end Base.convert(T::Type{<:Array}, x::ComponentArray) = convert(T, getdata(x)) Base.convert(::Type{Cholesky{T1,Matrix{T1}}}, x::Cholesky{T2,<:ComponentArray}) where {T1,T2} = Cholesky(Matrix{T1}(x.factors), x.uplo, x.info)