Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Error with Duplicated{<:Cholesky} #1250

Closed
devmotion opened this issue Jan 26, 2024 · 7 comments
Closed

Error with Duplicated{<:Cholesky} #1250

devmotion opened this issue Jan 26, 2024 · 7 comments

Comments

@devmotion
Copy link
Contributor

devmotion commented Jan 26, 2024

When running the tests in TuringLang/Turing.jl#1887 with Enzyme#main I noticed a few problems with the Cholesky support added in #1220:

Error with Duplicated{<:Cholesky}

julia> using Enzyme, LinearAlgebra

julia> C = cholesky((x -> x * x')(rand(2, 2)) + I)
Cholesky{Float64, Matrix{Float64}}
U factor:
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.31993  0.558389
         1.51658

julia> X = rand(2, 3)
2×3 Matrix{Float64}:
 0.820715  0.365977   0.527577
 0.424702  0.0866426  0.890884

julia> g(C, X) = sum(C \ X)
g (generic function with 1 method)

julia> X = rand(2, 2)
2×2 Matrix{Float64}:
 0.0265496  0.813724
 0.32185    0.344644

julia> g(C, X)
0.5603205601557636

julia> # very weird and "incorrect" but I didn't see any other way to work with `Duplicated`
       ∂g_∂C = Cholesky(zero(C.factors), 'U', 0)
Cholesky{Float64, Matrix{Float64}}
U factor:
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 0.0  0.0
     0.0

julia> ∂g_∂X = zero(X)
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> autodiff(Reverse, g, Active, Duplicated(C, ∂g_∂C), Duplicated(X, ∂g_∂X))
ERROR: Enzyme execution failed.
Enzyme cannot deduce type
Current scope:
; Function Attrs: mustprogress willreturn
define double @preprocess_julia_g_6909_inner.1({ {} addrspace(10)*, i32, i64 } %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) local_unnamed_addr #36 !dbg !971 {
entry:
  %2 = alloca { {} addrspace(10)*, i32, i64 }, align 8, !dbg !972, !enzyme_type !508
  %3 = addrspacecast { {} addrspace(10)*, i32, i64 }* %2 to { {} addrspace(10)*, i32, i64 } addrspace(11)*, !dbg !972
  %.fca.0.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 0, !dbg !972
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 0, !dbg !972
  store {} addrspace(10)* %.fca.0.extract, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !972, !noalias !973
  %.fca.1.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 1, !dbg !972
  %.fca.1.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 1, !dbg !972
  store i32 %.fca.1.extract, i32* %.fca.1.gep, align 8, !dbg !972, !noalias !973
  %.fca.2.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 2, !dbg !972
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 2, !dbg !972
  store i64 %.fca.2.extract, i64* %.fca.2.gep, align 8, !dbg !972, !noalias !973
  %4 = call {}*** @julia.get_pgcstack() #47
  %ptls_field.i3 = getelementptr inbounds {}**, {}*** %4, i64 2
  %5 = bitcast {}*** %ptls_field.i3 to i64***
  %ptls_load.i45 = load i64**, i64*** %5, align 8, !tbaa !46
  %6 = getelementptr inbounds i64*, i64** %ptls_load.i45, i64 2
  %safepoint.i = load i64*, i64** %6, align 8, !tbaa !50, !invariant.load !45
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #47, !dbg !976
  fence syncscope("singlethread") seq_cst
  %7 = call nonnull {} addrspace(10)* @julia___6921({ {} addrspace(10)*, i32, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %3, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) #47, !dbg !976
  %8 = call fastcc double @julia__mapreduce_6913({} addrspace(10)* nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %7) #47, !dbg !978
  ret double %8, !dbg !972
}

 Type analysis state:
<analysis>
{ {} addrspace(10)*, i32, i64 } %0: {[0]:Pointer, [0,0]:Pointer, [0,0,-1]:Float@double, [0,8]:Integer, [0,9]:Integer, [0,10]:Integer, [0,11]:Integer, [0,12]:Integer, [0,13]:Integer, [0,14]:Integer, [0,15]:Integer, [0,16]:Integer, [0,17]:Integer, [0,18]:Integer, [0,19]:Integer, [0,20]:Integer, [0,21]:Integer, [0,22]:Integer, [0,23]:Integer, [0,24]:Integer, [0,25]:Integer, [0,26]:Integer, [0,27]:Integer, [0,28]:Integer, [0,29]:Integer, [0,30]:Integer, [0,31]:Integer, [0,32]:Integer, [0,33]:Integer, [0,34]:Integer, [0,35]:Integer, [0,36]:Integer, [0,37]:Integer, [0,38]:Integer, [0,39]:Integer, [16]:Integer, [17]:Integer, [18]:Integer, [19]:Integer, [20]:Integer, [21]:Integer, [22]:Integer, [23]:Integer}, intvals: {}
{} addrspace(10)* %1: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
  %4 = call {}*** @julia.get_pgcstack() #47: {[-1]:Pointer, [-1,16]:Pointer}, intvals: {}
  %2 = alloca { {} addrspace(10)*, i32, i64 }, align 8, !dbg !46, !enzyme_type !47: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,-1]:Float@double, [-1,0,8]:Integer, [-1,0,9]:Integer, [-1,0,10]:Integer, [-1,0,11]:Integer, [-1,0,12]:Integer, [-1,0,13]:Integer, [-1,0,14]:Integer, [-1,0,15]:Integer, [-1,0,16]:Integer, [-1,0,17]:Integer, [-1,0,18]:Integer, [-1,0,19]:Integer, [-1,0,20]:Integer, [-1,0,21]:Integer, [-1,0,22]:Integer, [-1,0,23]:Integer, [-1,0,24]:Integer, [-1,0,25]:Integer, [-1,0,26]:Integer, [-1,0,27]:Integer, [-1,0,28]:Integer, [-1,0,29]:Integer, [-1,0,30]:Integer, [-1,0,31]:Integer, [-1,0,32]:Integer, [-1,0,33]:Integer, [-1,0,34]:Integer, [-1,0,35]:Integer, [-1,0,36]:Integer, [-1,0,37]:Integer, [-1,0,38]:Integer, [-1,0,39]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}, intvals: {}
  %ptls_field.i3 = getelementptr inbounds {}**, {}*** %4, i64 2: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %6 = getelementptr inbounds i64*, i64** %ptls_load.i45, i64 2: {[-1]:Pointer}, intvals: {}
  %8 = call fastcc double @julia__mapreduce_6913({} addrspace(10)* nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %7) #47, !dbg !64: {[-1]:Float@double}, intvals: {}
  %.fca.1.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 1, !dbg !46: {[-1]:Pointer}, intvals: {}
  %7 = call nonnull {} addrspace(10)* @julia___6921({ {} addrspace(10)*, i32, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %3, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) #47, !dbg !62: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 0, !dbg !46: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,-1]:Float@double, [-1,0,8]:Integer, [-1,0,9]:Integer, [-1,0,10]:Integer, [-1,0,11]:Integer, [-1,0,12]:Integer, [-1,0,13]:Integer, [-1,0,14]:Integer, [-1,0,15]:Integer, [-1,0,16]:Integer, [-1,0,17]:Integer, [-1,0,18]:Integer, [-1,0,19]:Integer, [-1,0,20]:Integer, [-1,0,21]:Integer, [-1,0,22]:Integer, [-1,0,23]:Integer, [-1,0,24]:Integer, [-1,0,25]:Integer, [-1,0,26]:Integer, [-1,0,27]:Integer, [-1,0,28]:Integer, [-1,0,29]:Integer, [-1,0,30]:Integer, [-1,0,31]:Integer, [-1,0,32]:Integer, [-1,0,33]:Integer, [-1,0,34]:Integer, [-1,0,35]:Integer, [-1,0,36]:Integer, [-1,0,37]:Integer, [-1,0,38]:Integer, [-1,0,39]:Integer}, intvals: {}
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 2, !dbg !46: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %ptls_load.i45 = load i64**, i64*** %5, align 8, !tbaa !56: {[-1]:Pointer}, intvals: {}
  %safepoint.i = load i64*, i64** %6, align 8, !tbaa !60, !invariant.load !45: {}, intvals: {}
  %5 = bitcast {}*** %ptls_field.i3 to i64***: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %3 = addrspacecast { {} addrspace(10)*, i32, i64 }* %2 to { {} addrspace(10)*, i32, i64 } addrspace(11)*, !dbg !46: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,-1]:Float@double, [-1,0,8]:Integer, [-1,0,9]:Integer, [-1,0,10]:Integer, [-1,0,11]:Integer, [-1,0,12]:Integer, [-1,0,13]:Integer, [-1,0,14]:Integer, [-1,0,15]:Integer, [-1,0,16]:Integer, [-1,0,17]:Integer, [-1,0,18]:Integer, [-1,0,19]:Integer, [-1,0,20]:Integer, [-1,0,21]:Integer, [-1,0,22]:Integer, [-1,0,23]:Integer, [-1,0,24]:Integer, [-1,0,25]:Integer, [-1,0,26]:Integer, [-1,0,27]:Integer, [-1,0,28]:Integer, [-1,0,29]:Integer, [-1,0,30]:Integer, [-1,0,31]:Integer, [-1,0,32]:Integer, [-1,0,33]:Integer, [-1,0,34]:Integer, [-1,0,35]:Integer, [-1,0,36]:Integer, [-1,0,37]:Integer, [-1,0,38]:Integer, [-1,0,39]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer}, intvals: {}
  %.fca.0.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 0, !dbg !46: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}, intvals: {}
  %.fca.1.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 1, !dbg !46: {}, intvals: {}
  %.fca.2.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 2, !dbg !46: {[-1]:Integer}, intvals: {}
</analysis>

oldFunc: ; Function Attrs: mustprogress willreturn
define double @preprocess_julia_g_6909_inner.1({ {} addrspace(10)*, i32, i64 } %0, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) local_unnamed_addr #36 !dbg !971 {
entry:
  %2 = alloca { {} addrspace(10)*, i32, i64 }, align 8, !dbg !972, !enzyme_type !508
  %3 = addrspacecast { {} addrspace(10)*, i32, i64 }* %2 to { {} addrspace(10)*, i32, i64 } addrspace(11)*, !dbg !972
  %.fca.0.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 0, !dbg !972
  %.fca.0.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 0, !dbg !972
  store {} addrspace(10)* %.fca.0.extract, {} addrspace(10)** %.fca.0.gep, align 8, !dbg !972, !noalias !973
  %.fca.1.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 1, !dbg !972
  %.fca.1.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 1, !dbg !972
  store i32 %.fca.1.extract, i32* %.fca.1.gep, align 8, !dbg !972, !noalias !973
  %.fca.2.extract = extractvalue { {} addrspace(10)*, i32, i64 } %0, 2, !dbg !972
  %.fca.2.gep = getelementptr inbounds { {} addrspace(10)*, i32, i64 }, { {} addrspace(10)*, i32, i64 }* %2, i64 0, i32 2, !dbg !972
  store i64 %.fca.2.extract, i64* %.fca.2.gep, align 8, !dbg !972, !noalias !973
  %4 = call {}*** @julia.get_pgcstack() #47
  %ptls_field.i3 = getelementptr inbounds {}**, {}*** %4, i64 2
  %5 = bitcast {}*** %ptls_field.i3 to i64***
  %ptls_load.i45 = load i64**, i64*** %5, align 8, !tbaa !46
  %6 = getelementptr inbounds i64*, i64** %ptls_load.i45, i64 2
  %safepoint.i = load i64*, i64** %6, align 8, !tbaa !50, !invariant.load !45
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint.i) #47, !dbg !976
  fence syncscope("singlethread") seq_cst
  %7 = call nonnull {} addrspace(10)* @julia___6921({ {} addrspace(10)*, i32, i64 } addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(24) %3, {} addrspace(10)* noundef nonnull align 16 dereferenceable(40) %1) #47, !dbg !976
  %8 = call fastcc double @julia__mapreduce_6913({} addrspace(10)* nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %7) #47, !dbg !978
  ret double %8, !dbg !972
}

Cannot deduce adding type of: { {} addrspace(10)*, i32, i64 } %0
 + idxs {i32 1,}



Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:1317
 [2] g
   @ ./REPL[55]:0 [inlined]
 [3] diffejulia_g_6909_inner_1wrap
   @ ./REPL[55]:0
 [4] macro expansion
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:5306 [inlined]
 [5] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Float64)
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4984
 [6] (::Enzyme.Compiler.CombinedAdjointThunk{Ptr{…}, Const{…}, Active{…}, Tuple{…}, Val{…}, Val{…}()})(::Const{typeof(g)}, ::Duplicated{Cholesky{…}}, ::Vararg{Any})
   @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4926
 [7] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(g)}, ::Type{Active}, ::Duplicated{Cholesky{Float64, Matrix{Float64}}}, ::Vararg{Any})
   @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:0
 [8] autodiff(::ReverseMode{false, FFIABI}, ::typeof(g), ::Type, ::Duplicated{Cholesky{Float64, Matrix{Float64}}}, ::Vararg{Any})
   @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:224
 [9] top-level scope
   @ REPL[60]:1
Some type information was truncated. Use `show(err)` to see complete types.

It seems Cholesky is only supported inside of a differentiated function. Continuing with the example:

julia> g2(A, X) = g(cholesky(A * A' + I), X)
g2 (generic function with 1 method)

julia> A = rand(2, 2);

julia> ∂g2_∂A = zero(A);

julia> ∂g2_∂X = zero(X);

julia> autodiff(Reverse, g2, Active, Duplicated(A, ∂g2_∂A), Duplicated(X, ∂g2_∂X))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn't implement memmove, using memcpy as fallback which can result in errors
((nothing, nothing),)

julia> ∂g2_∂A
2×2 Matrix{Float64}:
 0.0469133  -0.17766
 0.0138776   0.317032

julia> ∂g2_∂X
2×2 Matrix{Float64}:
 0.986821  0.986821
 0.936682  0.936682

Error with Const{<:Cholesky}

Edit: Fixed by #1251

julia> using Enzyme, LinearAlgebra

julia> C = cholesky((x -> x * x')(rand(2, 2)) + I)
Cholesky{Float64, Matrix{Float64}}
U factor:
2×2 UpperTriangular{Float64, Matrix{Float64}}:
 1.0538  0.0961016
        1.05759

julia> X = rand(2, 3)
2×3 Matrix{Float64}:
 0.639257  0.419136  0.202028
 0.236303  0.18246   0.0172878

julia> g(C, X) = sum(C \ X)
g (generic function with 1 method)

julia> X = rand(2, 2)
2×2 Matrix{Float64}:
 0.474658  0.631187
 0.949853  0.203402

julia> g(C, X)
1.8509308630411507

julia> ∂g_∂X = zero(X)
2×2 Matrix{Float64}:
 0.0  0.0
 0.0  0.0

julia> autodiff(Reverse, g, Active, Const(C), Duplicated(X, ∂g_∂X))
ERROR: type Const has no field dval
Stacktrace:
  [1] getproperty
    @ ./Base.jl:37 [inlined]
  [2] reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{…}, func::Const{…}, dret::Type, cache::Tuple{…}, fact::Const{…}, B::Duplicated{…}; kwargs::@Kwargs{})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/internal_rules.jl:805
  [3] reverse
    @ ~/.julia/packages/Enzyme/nQNPC/src/internal_rules.jl:791 [inlined]
  [4] g
    @ ./REPL[32]:1 [inlined]
  [5] g
    @ ./REPL[32]:0 [inlined]
  [6] diffejulia_g_6425_inner_1wrap
    @ ./REPL[32]:0
  [7] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:5306 [inlined]
  [8] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Const{…}, ::Duplicated{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4984
  [9] (::Enzyme.Compiler.CombinedAdjointThunk{Ptr{Nothing}, Const{typeof(g)}, Active{Float64}, Tuple{Const{…}, Duplicated{…}}, Val{1}, Val{false}()})(::Const{typeof(g)}, ::Const{Cholesky{Float64, Matrix{…}}}, ::Vararg{Any})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4926
 [10] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(g)}, ::Type{Active}, ::Const{Cholesky{Float64, Matrix{Float64}}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:0
 [11] autodiff(::ReverseMode{false, FFIABI}, ::typeof(g), ::Type, ::Const{Cholesky{Float64, Matrix{Float64}}}, ::Vararg{Any})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:224
 [12] top-level scope
    @ REPL[40]:1
Some type information was truncated. Use `show(err)` to see complete types.

Error with non-square matrices

Edit: Fixed by #1251

Based on the only working example above (g2):

julia> using Enzyme, LinearAlgebra

julia> g(C, X) = sum(C \ X)
g (generic function with 1 method)

julia> g2(A, X) = g(cholesky(A * A' + I), X)
g2 (generic function with 1 method)

julia> A = rand(2, 2)
2×2 Matrix{Float64}:
 0.6714    0.765626
 0.115545  0.08158

julia> X = rand(2, 3)
2×3 Matrix{Float64}:
 0.841836   0.62481   0.296653
 0.0793566  0.235749  0.0785191

julia> g2(A, X)
1.1167203142840973

julia> ∂g2_∂A = zero(A);

julia> ∂g2_∂X = zero(X);

julia> autodiff(Reverse, g2, Active, Duplicated(A, ∂g2_∂A), Duplicated(X, ∂g2_∂X))
┌ Warning: Using fallback BLAS replacements, performance may be degraded
└ @ Enzyme.Compiler ~/.julia/packages/GPUCompiler/U36Ed/src/utils.jl:59
warning: didn't implement memmove, using memcpy as fallback which can result in errors
ERROR: DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 2 and 3
Stacktrace:
  [1] _bcs1
    @ ./broadcast.jl:555 [inlined]
  [2] _bcs
    @ ./broadcast.jl:549 [inlined]
  [3] broadcast_shape
    @ ./broadcast.jl:543 [inlined]
  [4] combine_axes
    @ ./broadcast.jl:524 [inlined]
  [5] _axes
    @ ./broadcast.jl:236 [inlined]
  [6] axes
    @ ./broadcast.jl:234 [inlined]
  [7] check_broadcast_axes
    @ ./broadcast.jl:582 [inlined]
  [8] check_broadcast_axes
    @ ./broadcast.jl:586 [inlined]
  [9] instantiate
    @ ./broadcast.jl:309 [inlined]
 [10] materialize!
    @ ./broadcast.jl:914 [inlined]
 [11] materialize!
    @ ./broadcast.jl:911 [inlined]
 [12] reverse(config::EnzymeCore.EnzymeRules.ConfigWidth{…}, func::Const{…}, dret::Type, cache::Tuple{…}, fact::Duplicated{…}, B::Duplicated{…}; kwargs::@Kwargs{})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/internal_rules.jl:812
 [13] reverse
    @ ~/.julia/packages/Enzyme/nQNPC/src/internal_rules.jl:791 [inlined]
 [14] g
    @ ./REPL[74]:1 [inlined]
 [15] g2
    @ ./REPL[75]:1 [inlined]
 [16] diffejulia_g2_7886wrap
    @ ./REPL[75]:0
 [17] macro expansion
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:5306 [inlined]
 [18] enzyme_call(::Val{…}, ::Ptr{…}, ::Type{…}, ::Type{…}, ::Val{…}, ::Type{…}, ::Type{…}, ::Const{…}, ::Type{…}, ::Duplicated{…}, ::Duplicated{…}, ::Float64)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4984
 [19] (::Enzyme.Compiler.CombinedAdjointThunk{Ptr{Nothing}, Const{typeof(g2)}, Active{Float64}, Tuple{Duplicated{…}, Duplicated{…}}, Val{1}, Val{false}()})(::Const{typeof(g2)}, ::Duplicated{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/nQNPC/src/compiler.jl:4926
 [20] autodiff(::ReverseMode{false, FFIABI}, ::Const{typeof(g2)}, ::Type{Active}, ::Duplicated{Matrix{Float64}}, ::Vararg{Duplicated{Matrix{Float64}}})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:0
 [21] autodiff(::ReverseMode{false, FFIABI}, ::typeof(g2), ::Type, ::Duplicated{Matrix{Float64}}, ::Vararg{Duplicated{Matrix{Float64}}})
    @ Enzyme ~/.julia/packages/Enzyme/nQNPC/src/Enzyme.jl:224
 [22] top-level scope
    @ REPL[81]:1
Some type information was truncated. Use `show(err)` to see complete types.
@wsmoses
Copy link
Member

wsmoses commented Jan 26, 2024

@michel2323

@devmotion
Copy link
Contributor Author

As an additional comment: The issues with Const{<:Cholesky} and non-square matrices seem to be bugs (in

dfacts = EnzymeRules.width(config) == 1 ? (fact.dval,) : fact.dval
and
dfact.factors .+= -x .* buffer
) but I'm not completely sure how to view Duplicated{<:Cholesky}. A Duplicated with a second argument that is not even a correct instance (it can only be constructed using special constructors in LinearAlgebra) seems off. Maybe it should just not be supported? In this case, however, I think it would be good to directly throw an error when trying to construct a Duplicated{<:Cholesky}.

@wsmoses
Copy link
Member

wsmoses commented Jan 26, 2024

A duplicated variable is reasonable, but doesn't necessarily correspond to a cholesky decomposition.

For an arbitrary data structure (including cholesky), the shadow of an element (aka the corresponding element of the duplicated), represents the current partial derivative of the objective with respect to that element. If you view the data structure as just a list of numbers, it's just the partial wrt that number.

@wsmoses
Copy link
Member

wsmoses commented Jan 26, 2024

The const issue should be a relatively simple fix.

Basically if it is constant you should set dfacts to ntuple(width,nothing) and then in the update of dfact in the for loop, add an if statement checking if fact is not constant

Would you like to try adding that fix?

@wsmoses
Copy link
Member

wsmoses commented Jan 27, 2024

@devmotion I fixed the const issue here: #1251 . Possibly also fixed the non-square part?

In any case can you investigate and mark resolved and/or split remaining issues into distinct GH issues.

@devmotion devmotion changed the title Cholesky support broken: Bugs with Const{<:Cholesky}, Duplicated{<:Cholesky}, and non-square matrices Error with Duplicated{<:Cholesky} Jan 27, 2024
@wsmoses
Copy link
Member

wsmoses commented Jan 28, 2024

The type issue should also be fixed by #1256

@devmotion for the future, can you separate all such things into separate issues?

@devmotion
Copy link
Contributor Author

Sure, will do! Unfortunately, I did not have much time this weekend so I have also not opened issues yet for all the problems I had noticed (such as #1254 and a few problems already fixed in #1256 it seems).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants