Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error of custom type interaction with StaticArrays (type unstable update of an immutable variable) #1263

Closed
just-walk opened this issue Jan 31, 2024 · 1 comment

Comments

@just-walk
Copy link
Contributor

Enzyme fails when using a custom type that is based on SMatrix and abstract type labels. The error is discussed here, where @wsmoses suggested that it is "not-yet-implemented" with a workaround to type-stabilize the code.

MWE and its output are below:

using Enzyme
using Random 
using StaticArrays

abstract type AbstractBasisType end

struct Contravariant <: AbstractBasisType end

struct CurvilinearBasisVectors{N, T, C, B, V <: AbstractBasisType} <: StaticMatrix{N, N, T}
    __x::Union{SMatrix{N, N, T}}
    function CurvilinearBasisVectors{N, T, C, B, V}(b::AbstractMatrix) where {N, T, C, B, V}
        return new{N, T, C, B, V}(SMatrix{N, N, T}(b))
    end
end

Base.@propagate_inbounds function Base.getindex(v::CurvilinearBasisVectors{N, T, C, B, V}, i::Int) where {N, T, C, B, V}
    return view(getfield(v, :__x), i)[]
end

basis_labels = (:∇x, :∇y, :∇z);
coord_labels = (:x, :y, :z);

a = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    rand(3,3),
);
da = CurvilinearBasisVectors{3,Float64,basis_labels,coord_labels,Contravariant}(
    zeros(3,3),
);

@show a,da

function f(x, y)
    sum(sum(y .* x))
end

@show f(a, 2.0)

@show autodiff(Reverse, f, Active, Duplicated(a, da), Active(2.0))
@show da
(a, da) = ([0.03087747827664955 0.2528585283866567 0.2994280949524749; 0.3884839505154861 0.6683809551000912 0.9092201295064689; 0.564563029791595 0.06953550568414368 0.8692541334330144], [0.0 0.0 0.0; 0.0 0.0 0.0; 0.0 0.0 0.0])
f(a, 2.0) = 8.105203611293161
ERROR: LoadError: setfield!: immutable struct of type SArray cannot be changed
Stacktrace:
  [1] rt_jl_getfield_rev(::SMatrix{3, 3, Float64, 9}, ::Base.RefValue{NTuple{9, Float64}}, ::Type{Val{:data}}, ::Val{false})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/rules/typeunstablerules.jl:257
  [2] getindex
    @ ~/.julia/packages/StaticArrays/eGKzB/src/SArray.jl:62 [inlined]
  [3] view
    @ ~/.julia/packages/StaticArrays/eGKzB/src/abstractarray.jl:291 [inlined]
  [4] getindex
    @ ~/software/julia/enzyme/test-basis.jl:17 [inlined]
  [5] macro expansion
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:135 [inlined]
  [6] __broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:123 [inlined]
  [7] _broadcast
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:119 [inlined]
  [8] copy
    @ ~/.julia/packages/StaticArrays/eGKzB/src/broadcast.jl:60 [inlined]
  [9] materialize
    @ ./broadcast.jl:903 [inlined]
 [10] f
    @ ~/software/julia/enzyme/test-basis.jl:33 [inlined]
 [11] f
    @ ~/software/julia/enzyme/test-basis.jl:0 [inlined]
 [12] diffejulia_f_5401_inner_1wrap
    @ ~/software/julia/enzyme/test-basis.jl:0
 [13] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:5306 [inlined]
 [14] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Active{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4984
 [15] (::Enzyme.Compiler.CombinedAdjointThunk{…})(::Const{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/jOGYG/src/compiler.jl:4926
 [16] autodiff(::ReverseMode{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Vararg{…})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:215
 [17] autodiff(::ReverseMode{false, FFIABI}, ::typeof(f), ::Type, ::Duplicated{CurvilinearBasisVectors{…}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/jOGYG/src/Enzyme.jl:224
 [18] macro expansion
    @ show.jl:1181 [inlined]
 [19] top-level scope
    @ ~/software/julia/enzyme/test-basis.jl:38
 [20] include(fname::String)
    @ Base.MainInclude ./client.jl:489
 [21] top-level scope
    @ REPL[4]:1
@wsmoses
Copy link
Member

wsmoses commented Feb 10, 2024

Marking this a duplicate of #970 [tho this is an easier MWE]

@wsmoses wsmoses closed this as completed Feb 10, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants