Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enzyme compilation failed for cmpxchg #655

Closed
jeremiedb opened this issue Feb 23, 2023 · 16 comments
Closed

Enzyme compilation failed for cmpxchg #655

jeremiedb opened this issue Feb 23, 2023 · 16 comments

Comments

@jeremiedb
Copy link

jeremiedb commented Feb 23, 2023

The following Julia code attempts to differentiate a convolution from NNlib but fails:

using Enzyme
Enzyme.API.runtimeActivity!(true)
using NNlib

w = randn(Float32, 3, 3, 5, 7);
dw = zero(w)
loss(w, x) = sum(conv(x, w))
x = randn(Float32, (3, 3, 5, 8));
grads = Enzyme.autodiff(loss, Duplicated(w, dw), Const(x));

Results in the following error trace (truncated for lisibility):

┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
ERROR: Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef i8 @preprocess_julia__trylock_7199({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr EnzymeAD/Enzyme#119 !dbg !6737 {
top:
  %2 = call {}*** @julia.get_pgcstack() EnzymeAD/Enzyme#110
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6738

...

in Mode: ReverseModePrimal
cannot handle unknown instruction
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !142, !tbaa !146
@vtjnash
Copy link

vtjnash commented Feb 24, 2023

Should trylock have a rule also (which is where it appears this come from)?

@wsmoses wsmoses transferred this issue from EnzymeAD/Enzyme Mar 6, 2023
@wsmoses
Copy link
Member

wsmoses commented Mar 6, 2023

Likely its reasonable to mark it inactive, but I'd want to see what's calling trylock since it might make more sense to have the outer code that calls trylock be inactive.

@jeremiedb if you can see where its being called?

@jeremiedb
Copy link
Author

jeremiedb commented Mar 7, 2023

The complete error trace for the above conv call is:

julia> grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
┌ Warning: active variables passed by value to jl_new_task are not yet supported
└ @ Enzyme.Compiler C:\Users\jerem\.julia\packages\GPUCompiler\kb6yJ\src\utils.jl:50
ERROR: Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef i8 @preprocess_julia__trylock_11344({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr #127 !dbg !6627 {
top:
  %2 = call {}*** @julia.get_pgcstack() #128
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6628
  %3 = bitcast {}*** %ptls_field8 to i32**, !dbg !6628
  %ptls_load910 = load i32*, i32** %3, align 8, !dbg !6628, !tbaa !2355
  %4 = getelementptr inbounds i32, i32* %ptls_load910, i64 8, !dbg !6628
  %5 = load i32, i32* %4, align 4, !dbg !6628
  %6 = add i32 %5, 1, !dbg !6628
  store i32 %6, i32* %4, align 4, !dbg !6628
  %7 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !6630
  %8 = addrspacecast i8 addrspace(10)* %7 to i8 addrspace(11)*, !dbg !6630
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 12, !dbg !6630
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !6630, !tbaa !266
  %11 = extractvalue { i8, i1 } %10, 1, !dbg !6630
  br i1 %11, label %L6, label %L9, !dbg !6631

common.ret:                                       ; preds = %L16, %L9, %L6
  %common.ret.op = phi i8 [ 1, %L6 ], [ 0, %L9 ], [ 0, %L16 ]
  ret i8 %common.ret.op, !dbg !6632

L6:                                               ; preds = %top
  %12 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !6633
  %13 = bitcast i8 addrspace(11)* %12 to i32 addrspace(11)*, !dbg !6633
  store i32 1, i32 addrspace(11)* %13, align 8, !dbg !6633, !tbaa !266
  %14 = bitcast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(10)*, !dbg !6635
  store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(10)* %14 release, align 8, !dbg !6635, !tbaa !266
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* nofree noundef nonnull %0, {} addrspace(10)* nofree nonnull %1) #129, !dbg !6635
  br label %common.ret

L9:                                               ; preds = %top
  %ptls_load41314 = load i32*, i32** %3, align 8, !dbg !6637, !tbaa !2355
  %15 = getelementptr inbounds i32, i32* %ptls_load41314, i64 8, !dbg !6637
  %16 = load i32, i32* %15, align 4, !dbg !6637
  %17 = add i32 %16, -1, !dbg !6637
  %18 = icmp eq i32 %16, 0, !dbg !6637
  %19 = select i1 %18, i32 0, i32 %17, !dbg !6637
  store i32 %19, i32* %15, align 4, !dbg !6637
  %20 = load atomic i32, i32* inttoptr (i64 140730518437480 to i32*) monotonic, align 8, !dbg !6639, !tbaa !600
  %.not = icmp eq i32 %20, 0, !dbg !6640
  br i1 %.not, label %common.ret, label %L16, !dbg !6639

L16:                                              ; preds = %L9
  call void @jl_gc_run_pending_finalizers(i64 noundef 0) #128, !dbg !6643
  br label %common.ret, !dbg !6643
}

; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef i8 @preprocess_julia__trylock_11344({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr #127 !dbg !6627 {
top:
  %2 = call {}*** @julia.get_pgcstack() #128
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6628
  %3 = bitcast {}*** %ptls_field8 to i32**, !dbg !6628
  %ptls_load910 = load i32*, i32** %3, align 8, !dbg !6628, !tbaa !2355
  %4 = getelementptr inbounds i32, i32* %ptls_load910, i64 8, !dbg !6628
  %5 = load i32, i32* %4, align 4, !dbg !6628
  %6 = add i32 %5, 1, !dbg !6628
  store i32 %6, i32* %4, align 4, !dbg !6628
  %7 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !6630
  %8 = addrspacecast i8 addrspace(10)* %7 to i8 addrspace(11)*, !dbg !6630
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 12, !dbg !6630
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !6630, !tbaa !266
  %11 = extractvalue { i8, i1 } %10, 1, !dbg !6630
  br i1 %11, label %L6, label %L9, !dbg !6631

common.ret:                                       ; preds = %L16, %L9, %L6
  %common.ret.op = phi i8 [ 1, %L6 ], [ 0, %L9 ], [ 0, %L16 ]
  ret i8 %common.ret.op, !dbg !6632

L6:                                               ; preds = %top
  %12 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !6633
  %13 = bitcast i8 addrspace(11)* %12 to i32 addrspace(11)*, !dbg !6633
  store i32 1, i32 addrspace(11)* %13, align 8, !dbg !6633, !tbaa !266
  %14 = bitcast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(10)*, !dbg !6635
  store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(10)* %14 release, align 8, !dbg !6635, !tbaa !266
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* nofree noundef nonnull %0, {} addrspace(10)* nofree nonnull %1) #129, !dbg !6635
  br label %common.ret

L9:                                               ; preds = %top
  %ptls_load41314 = load i32*, i32** %3, align 8, !dbg !6637, !tbaa !2355
  %15 = getelementptr inbounds i32, i32* %ptls_load41314, i64 8, !dbg !6637
  %16 = load i32, i32* %15, align 4, !dbg !6637
  %17 = add i32 %16, -1, !dbg !6637
  %18 = icmp eq i32 %16, 0, !dbg !6637
  %19 = select i1 %18, i32 0, i32 %17, !dbg !6637
  store i32 %19, i32* %15, align 4, !dbg !6637
  %20 = load atomic i32, i32* inttoptr (i64 140730518437480 to i32*) monotonic, align 8, !dbg !6639, !tbaa !600
  %.not = icmp eq i32 %20, 0, !dbg !6640
  br i1 %.not, label %common.ret, label %L16, !dbg !6639

L16:                                              ; preds = %L9
  call void @jl_gc_run_pending_finalizers(i64 noundef 0) #128, !dbg !6643
  br label %common.ret, !dbg !6643
}

; Function Attrs: mustprogress noinline uwtable willreturn
define internal fastcc noundef { {} addrspace(10)*, i8 } @fakeaugmented_julia__trylock_11344({} addrspace(10)* nonnull align 8 dereferenceable(56) %0, {} addrspace(10)* %"'", {} addrspace(10)* nofree nonnull align 8 dereferenceable(96) %1) unnamed_addr #127 !dbg !6644 {
top:
  %2 = call {}*** @julia.get_pgcstack() #128
  %ptls_field8 = getelementptr inbounds {}**, {}*** %2, i64 2, !dbg !6645
  %3 = bitcast {}*** %ptls_field8 to i32**, !dbg !6645
  %ptls_load910 = load i32*, i32** %3, align 8, !dbg !6645, !tbaa !2355
  %"ptls_load910'il_phi" = phi i32* , !dbg !6645
  %4 = getelementptr inbounds i32, i32* %ptls_load910, i64 8, !dbg !6645
  %5 = load i32, i32* %4, align 4, !dbg !6645
  %"'il_phi" = phi i32 , !dbg !6645
  %6 = add i32 %5, 1, !dbg !6645
  store i32 %6, i32* %4, align 4, !dbg !6645
  %7 = bitcast {} addrspace(10)* %0 to i8 addrspace(10)*, !dbg !6647
  %8 = addrspacecast i8 addrspace(10)* %7 to i8 addrspace(11)*, !dbg !6647
  %9 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 12, !dbg !6647
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !6647, !tbaa !266
  %11 = extractvalue { i8, i1 } %10, 1, !dbg !6647
  br i1 %11, label %L6, label %L9, !dbg !6648

common.ret:                                       ; preds = %L16, %L9, %L6
  %common.ret.op = phi i8 [ 1, %L6 ], [ 0, %L9 ], [ 0, %L16 ]
  %12 = insertvalue { {} addrspace(10)*, i8 } undef, i8 %common.ret.op, 1, !dbg !6649
  ret { {} addrspace(10)*, i8 } %12, !dbg !6649

L6:                                               ; preds = %top
  %13 = getelementptr inbounds i8, i8 addrspace(11)* %8, i64 8, !dbg !6650
  %14 = bitcast i8 addrspace(11)* %13 to i32 addrspace(11)*, !dbg !6650
  store i32 1, i32 addrspace(11)* %14, align 8, !dbg !6650, !tbaa !266
  %15 = bitcast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(10)*, !dbg !6652
  store atomic {} addrspace(10)* %1, {} addrspace(10)* addrspace(10)* %15 release, align 8, !dbg !6652, !tbaa !266
  call void ({} addrspace(10)*, ...) @julia.write_barrier({} addrspace(10)* nofree noundef nonnull %0, {} addrspace(10)* nofree nonnull %1) #129, !dbg !6652
  br label %common.ret

L9:                                               ; preds = %top
  %ptls_load41314 = load i32*, i32** %3, align 8, !dbg !6654, !tbaa !2355
  %"ptls_load41314'il_phi" = phi i32* , !dbg !6654
  %16 = getelementptr inbounds i32, i32* %ptls_load41314, i64 8, !dbg !6654
  %17 = load i32, i32* %16, align 4, !dbg !6654
  %"'il_phi1" = phi i32 , !dbg !6654
  %18 = add i32 %17, -1, !dbg !6654
  %19 = icmp eq i32 %17, 0, !dbg !6654
  %20 = select i1 %19, i32 0, i32 %18, !dbg !6654
  store i32 %20, i32* %16, align 4, !dbg !6654
  %21 = load atomic i32, i32* inttoptr (i64 140730518437480 to i32*) monotonic, align 8, !dbg !6656, !tbaa !600
  %"'il_phi2" = phi i32 , !dbg !6657
  %.not = icmp eq i32 %21, 0, !dbg !6657
  br i1 %.not, label %common.ret, label %L16, !dbg !6656

L16:                                              ; preds = %L9
  call void @jl_gc_run_pending_finalizers(i64 noundef 0) #128, !dbg !6660
  br label %common.ret, !dbg !6660

allocsForInversion:                               ; No predecessors!
}

in Mode: ReverseModePrimal
cannot handle unknown instruction
  %10 = cmpxchg i8 addrspace(11)* %9, i8 0, i8 1 acquire acquire, align 4, !dbg !143, !tbaa !147

Stacktrace:
 [1] replaceproperty!
   @ .\Base.jl:58
 [2] _trylock
   @ .\lock.jl:82

Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:4735
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Duplicated{Array{Float32, 4}}, Const{Array{Float32, 4}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::Tuple{Bool, Bool, Bool}, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:6195
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7446
  [5] _thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7958 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Array{Float32, 4}, Array{Float32, 4}}}})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7952
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:7996
  [8] #s451#163
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8056 [inlined]
  [9] var"#s451#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ShadowInit::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:582
 [11] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8089 [inlined]
 [12] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\compiler.jl:8082 [inlined]
 [13] autodiff
    @ C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\Enzyme.jl:197 [inlined]
 [14] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Duplicated{Array{Float32, 4}}, ::Const{Array{Float32, 4}})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\OWDZG\src\Enzyme.jl:223
 [15] top-level scope
    @ c:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv.jl:18

And LLVM is:

julia> @code_llvm conv(x, w)
;  @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:50 within `conv`
; Function Attrs: uwtable
define nonnull {}* @julia_conv_12575({}* nonnull align 16 dereferenceable(40) %0, {}* nonnull align 16 dereferenceable(40) %1) #0 {
top:
  %2 = alloca <4 x i64>, align 8
  %tmpcast = bitcast <4 x i64>* %2 to [4 x i64]*
  %3 = alloca <4 x i64>, align 8
  %tmpcast9 = bitcast <4 x i64>* %3 to [4 x i64]*
  %4 = alloca { [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }, align 8
; ┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:54 within `#conv#231`
; │┌ @ array.jl:153 within `size`
; ││┌ @ ntuple.jl:69 within `ntuple`
; │││┌ @ ntuple.jl:74 within `macro expansion`
; ││││┌ @ array.jl:153 within `#108`
; │││││┌ @ array.jl:150 within `size`
        %5 = bitcast {}* %0 to {}**
        %6 = getelementptr inbounds {}*, {}** %5, i64 3
        %7 = bitcast {}** %6 to <4 x i64>*
        %8 = load <4 x i64>, <4 x i64>* %7, align 8
; ││││└└
; ││││ @ ntuple.jl:75 within `macro expansion`
      store <4 x i64> %8, <4 x i64>* %2, align 8
; ││││ @ ntuple.jl:74 within `macro expansion`
; ││││┌ @ array.jl:153 within `#108`
; │││││┌ @ array.jl:150 within `size`
        %9 = bitcast {}* %1 to {}**
        %10 = getelementptr inbounds {}*, {}** %9, i64 3
        %11 = bitcast {}** %10 to <4 x i64>*
        %12 = load <4 x i64>, <4 x i64>* %11, align 8
; ││││└└
; ││││ @ ntuple.jl:75 within `macro expansion`
      store <4 x i64> %12, <4 x i64>* %3, align 8
; │└└└
; │┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\dim_helpers\DenseConvDims.jl:20 within `Type##kw`
    call void @"j_#DenseConvDims#8_12577"({ [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }* noalias nocapture nonnull sret({ [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }) %4, [2 x i64]* nocapture readonly @_j_const1, [2 x i64]* nocapture readonly @_j_const2, [2 x i64]* nocapture readonly @_j_const1, i64 signext 1, i8 zeroext 0, {}* readonly inttoptr (i64 2862949472816 to {}*), [4 x i64]* nocapture readonly %tmpcast, [4 x i64]* nocapture readonly %tmpcast9) #0
; │└
; │ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:56 within `#conv#231`
; │┌ @ C:\Users\jerem\.julia\packages\NNlib\ydqxJ\src\conv.jl:83 within `conv`
    %13 = call nonnull {}* @"j_#conv#233_12578"({}* nonnull %0, {}* nonnull %1, { [2 x i64], [2 x i64], i64, i64, i64, [2 x i64], [4 x i64], [2 x i64], i8 }* nocapture readonly %4) #0
; └└
  ret {}* %13
}

I've looking to breakdown NNlib.conv calls in https://github.com/jeremiedb/ADTests.jl/blob/main/experiments/enzyme/conv-debug.jl

I'm yet unclear what instructions is the source of the issue. Two potential candidates could be:

@wsmoses
Copy link
Member

wsmoses commented Mar 7, 2023

I don't think its the GC.@preserve. but the Threads.@sync might be it? @vtjnash @vchuravy ?

@vchuravy
Copy link
Member

vchuravy commented Mar 7, 2023

mutable struct Atomic{T}
    @atomic x::T
end

function f(x, y)
    @atomic x.x max y
    val = @atomic x.x
    val^2
end

@show f(Atomic(0.0), 1.0)

using Enzyme

x = Atomic(2.0)
dx = Atomic(0.0)

autodiff(Reverse, f, ACtive, Duplicated(x, dx), y)

@jeremiedb
Copy link
Author

For info, the above f on Atomic results in a different error, which segfault:

PS C:\Users\jerem\OneDrive\github\ADTests.jl> julia --project=@. --threads=1 .\experiments\enzyme\conv-debug.jl
f(Atomic(0.0), 3.0) = 9.0
module: ; ModuleID = 'text'
source_filename = "text"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-w64-mingw32"

; Function Attrs: nofree nosync readnone uwtable
define internal fastcc double @julia_max_1932(double %0, double %1) unnamed_addr #0 !dbg !7 {
top:
  %2 = call {}*** @julia.get_pgcstack()
  %3 = fcmp olt double %0, %1, !dbg !9
  %4 = bitcast double %1 to i64, !dbg !16
  %5 = bitcast double %0 to i64, !dbg !16
  %.not = icmp sgt i64 %4, -1, !dbg !19
  %6 = icmp slt i64 %5, 0, !dbg !24
  %7 = and i1 %6, %.not, !dbg !24
  %8 = or i1 %3, %7, !dbg !26
  %9 = fcmp ord double %0, 0.000000e+00, !dbg !28
  %10 = select i1 %9, double %1, double %0, !dbg !32
  %11 = fcmp ord double %1, 0.000000e+00, !dbg !28
  %12 = select i1 %11, double %0, double %1, !dbg !32
  %13 = select i1 %8, double %10, double %12, !dbg !32
  ret double %13, !dbg !15
}
...
eval at .\boot.jl:368 [inlined]
include_string at .\loading.jl:1428
_include at .\loading.jl:1488
include at .\Base.jl:419
jfptr_include_51599.clone_1 at C:\Users\jerem\AppData\Local\Programs\Julia-1.8.5\lib\julia\sys.dll (unknown line)
exec_options at .\client.jl:303
_start at .\client.jl:522
jfptr__start_54247.clone_1 at C:\Users\jerem\AppData\Local\Programs\Julia-1.8.5\lib\julia\sys.dll (unknown line)
jl_apply at C:/workdir/src\julia.h:1843 [inlined]
true_main at C:/workdir/src\jlapi.c:575
jl_repl_entrypoint at C:/workdir/src\jlapi.c:719
mainCRTStartup at C:/workdir/cli\loader_exe.c:59
BaseThreadInitThunk at C:\WINDOWS\System32\KERNEL32.DLL (unknown line)
RtlUserThreadStart at C:\WINDOWS\SYSTEM32\ntdll.dll (unknown line)
Allocations: 28609313 (Pool: 28566019; Big: 43294); GC: 30

@jeremiedb
Copy link
Author

jeremiedb commented Mar 8, 2023

Trying to isolate further the issue with NNlib.conv, the call to NNlib.gemm! seems to be problematic, although it also results in a different error message. As gemm! has Val as inputs, could it be related to #654 ?

using Enzyme
using NNlib

function my_gemm!(y, x, w)
    x_ptr = pointer(x)
    w_ptr = pointer(w)
    y_ptr = pointer(y)
    NNlib.gemm!(
        Val(false),
        Val(false),
        size(x, 1),
        size(w, 2),
        size(x, 2),
        1.0,
        x_ptr,
        w_ptr,
        0.0,
        y_ptr,
    )
    return y
end

x = rand(2, 3)
w = rand(3, 5)
y = zeros(2, 5)

dx = zeros(2, 3)
dw = zeros(3, 5)
dy = zeros(2, 5)

my_gemm!(y, x, w)
loss(y, x, w) = sum(my_gemm!(y, x, w))
loss(y, x, w)

autodiff(Reverse, loss, Duplicated(y, dy), Const(x), Duplicated(w, dw))
...
!509 = !DILocation(line: 150, scope: !30, inlinedAt: !510)
!510 = !DILocation(line: 14, scope: !499)
!511 = !DILocation(line: 8, scope: !35, inlinedAt: !512)
!512 = !DILocation(line: 104, scope: !38, inlinedAt: !513)
!513 = !DILocation(line: 412, scope: !41, inlinedAt: !514)
!514 = !DILocation(line: 48, scope: !44, inlinedAt: !510)
!515 = !DILocation(line: 26, scope: !499)

No augmented forward pass found for .text
declare void @.text(i8*, i8*, i8*, i8*, i8*, i8*, i64, i8*, i64, i8*, i8*, i64, i8*) local_unnamed_addr #6



Stacktrace:
  [1] julia_error(cstr::Cstring, val::Ptr{LLVM.API.LLVMOpaqueValue}, errtype::Enzyme.API.ErrorType, data::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:4735
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{Enzyme.API.CDIFFE_TYPE}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{Nothing}, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{Bool}, augmented::Ptr{Nothing}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\api.jl:123
  [3] enzyme!(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}}, mod::LLVM.Module, primalf::LLVM.Function, adjoint::GPUCompiler.FunctionSpec{typeof(loss), Tuple{Duplicated{Matrix{Float64}}, Const{Matrix{Float64}}, Duplicated{Matrix{Float64}}}}, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, dupClosure::Bool, wrap::Bool, modifiedBetween::NTuple{4, Bool}, returnPrimal::Bool, jlrules::Vector{String})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:6195
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, ctx::LLVM.Context, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7446
  [5] _thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7958 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams, GPUCompiler.FunctionSpec{typeof(loss), Tuple{Matrix{Float64}, Matrix{Float64}, Matrix{Float64}}}})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7952
  [7] cached_compilation(job::GPUCompiler.CompilerJob, key::UInt64, specid::UInt64)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:7996
  [8] #s451#163
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8056 [inlined]
  [9] var"#s451#163"(F::Any, Fn::Any, DF::Any, A::Any, TT::Any, Mode::Any, ModifiedBetween::Any, width::Any, specid::Any, ReturnPrimal::Any, ShadowInit::Any, ::Any, #unused#::Type, f::Any, df::Any, #unused#::Type, tt::Any, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Type, #unused#::Any)
    @ Enzyme.Compiler .\none:0
 [10] (::Core.GeneratedFunctionStub)(::Any, ::Vararg{Any})
    @ Core .\boot.jl:582
 [11] thunk
    @ C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8089 [inlined]
 [12] thunk(f::typeof(loss), df::Nothing, ::Type{Active{Float64}}, tt::Type{Tuple{Duplicated{Matrix{Float64}}, Const{Matrix{Float64}}, Duplicated{Matrix{Float64}}}}, ::Val{Enzyme.API.DEM_ReverseModeCombined}, ::Val{1}, ::Val{(false, false, false, false)}, ::Val{false})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\compiler.jl:8082
 [13] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Type{Active{Float64}}, ::Duplicated{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\Enzyme.jl:197
 [14] autodiff(::EnzymeCore.ReverseMode{false, false}, ::typeof(loss), ::Duplicated{Matrix{Float64}}, ::Const{Matrix{Float64}}, ::Vararg{Any})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\M5Bxx\src\Enzyme.jl:223
 [15] top-level scope
    @ C:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:41
in expression starting at C:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:41

@jakubMitura14
Copy link

Is it possible now to autodifferentiate convolutions in enzyme?

@wsmoses
Copy link
Member

wsmoses commented Nov 21, 2023

It should yes

@jakubMitura14
Copy link

fantastic, thanks !

@jeremiedb
Copy link
Author

@jakubMitura14 Did you manage to differentiate a convolution?
On CPU, it still errors on my end on Enzyme#main. When calling the above my_gemm! which is the core operator called within NNLib's conv, it now results in the following stacktrace:

julia> autodiff(Reverse, loss, Duplicated(y, dy), Const(x), Duplicated(w, dw))
ERROR: Enzyme execution failed.
Enzyme compilation failed.
Current scope:
; Function Attrs: mustprogress uwtable willreturn
define internal fastcc void @preprocess_julia_my_gemm__12471({} addrspace(10)* nonnull align 16 dereferenceable(40) %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1, {} addrspace(10)* nonnull align 16 dereferenceable(40) %2) unnamed_addr #12 !dbg !477 {
top:
  %3 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !478
  %4 = call noalias nonnull dereferenceable(1) dereferenceable_or_null(1) i8* @malloc(i64 1), !enzyme_fromstack !478
  %5 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %6 = bitcast i8* %5 to i64*, !enzyme_caststack !4
  %7 = bitcast i64* %6 to i8*
  %8 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %9 = bitcast i8* %8 to i64*, !enzyme_caststack !4
  %10 = bitcast i64* %9 to i8*
  %11 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %12 = bitcast i8* %11 to i64*, !enzyme_caststack !4
  %13 = bitcast i64* %12 to i8*
  %14 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %15 = bitcast i8* %14 to i64*, !enzyme_caststack !4
  %16 = bitcast i64* %15 to i8*
  %17 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %18 = bitcast i8* %17 to i64*, !enzyme_caststack !4
  %19 = bitcast i64* %18 to i8*
  %20 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %21 = bitcast i8* %20 to i64*, !enzyme_caststack !4
  %22 = bitcast i64* %21 to i8*
  %23 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %24 = bitcast i8* %23 to i64*, !enzyme_caststack !4
  %25 = bitcast i64* %24 to i8*
  %26 = call noalias nonnull dereferenceable(8) dereferenceable_or_null(8) i8* @malloc(i64 8), !enzyme_fromstack !479
  %27 = bitcast i8* %26 to i64*, !enzyme_caststack !4
  %28 = bitcast i64* %27 to i8*
  %29 = call {}*** @julia.get_pgcstack() #13
  %30 = addrspacecast {} addrspace(10)* %1 to {} addrspace(11)*, !dbg !480
  %31 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %30) #14, !dbg !480
  %32 = bitcast {}* %31 to i8**, !dbg !480
  %33 = load i8*, i8** %32, align 8, !dbg !480, !tbaa !42, !invariant.load !4, !nonnull !4
  %34 = ptrtoint i8* %33 to i64, !dbg !480
  %35 = addrspacecast {} addrspace(10)* %2 to {} addrspace(11)*, !dbg !483
  %36 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %35) #14, !dbg !483
  %37 = bitcast {}* %36 to i8**, !dbg !483
  %38 = load i8*, i8** %37, align 8, !dbg !483, !tbaa !42, !invariant.load !4, !nonnull !4
  %39 = ptrtoint i8* %38 to i64, !dbg !483
  %40 = addrspacecast {} addrspace(10)* %0 to {} addrspace(11)*, !dbg !486
  %41 = call nonnull {}* @julia.pointer_from_objref({} addrspace(11)* %40) #14, !dbg !486
  %42 = bitcast {}* %41 to i8**, !dbg !486
  %43 = load i8*, i8** %42, align 8, !dbg !486, !tbaa !42, !invariant.load !4, !nonnull !4
  %44 = ptrtoint i8* %43 to i64, !dbg !486
  %45 = bitcast {} addrspace(10)* %1 to {} addrspace(10)* addrspace(10)*, !dbg !489
  %46 = addrspacecast {} addrspace(10)* addrspace(10)* %45 to {} addrspace(10)* addrspace(11)*, !dbg !489
  %47 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %46, i64 3, !dbg !489
  %48 = bitcast {} addrspace(10)* addrspace(11)* %47 to i64 addrspace(11)*, !dbg !489
  %49 = load i64, i64 addrspace(11)* %48, align 8, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  %50 = bitcast {} addrspace(10)* %2 to {} addrspace(10)* addrspace(10)*, !dbg !489
  %51 = addrspacecast {} addrspace(10)* addrspace(10)* %50 to {} addrspace(10)* addrspace(11)*, !dbg !489
  %52 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %51, i64 4, !dbg !489
  %53 = bitcast {} addrspace(10)* addrspace(11)* %52 to i64 addrspace(11)*, !dbg !489
  %54 = load i64, i64 addrspace(11)* %53, align 16, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  %55 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %46, i64 4, !dbg !489
  %56 = bitcast {} addrspace(10)* addrspace(11)* %55 to i64 addrspace(11)*, !dbg !489
  %57 = load i64, i64 addrspace(11)* %56, align 16, !dbg !489, !tbaa !42, !range !46, !invariant.load !4
  call void @llvm.lifetime.start.p0i8(i64 noundef 1, i8* noundef nonnull %4) #13
  store i8 78, i8* %4, align 1, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 1, i8* noundef nonnull %3) #13
  store i8 78, i8* %3, align 1, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %7) #13
  store i64 %49, i64* %6, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %10) #13
  store i64 %54, i64* %9, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %13) #13
  store i64 %57, i64* %12, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %16) #13
  %58 = bitcast i64* %15 to double*, !dbg !491
  store double 1.000000e+00, double* %58, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %19) #13
  store i64 %49, i64* %18, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %22) #13
  store i64 %57, i64* %21, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %25) #13
  %59 = bitcast i64* %24 to double*, !dbg !491
  store double 0.000000e+00, double* %59, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @llvm.lifetime.start.p0i8(i64 noundef 8, i8* noundef nonnull %28) #13
  store i64 %49, i64* %27, align 16, !dbg !491, !tbaa !197, !noalias !495
  call void @.text(i8* noundef nonnull %4, i8* noundef nonnull %3, i8* noundef nonnull %7, i8* noundef nonnull %10, i8* noundef nonnull %13, i8* noundef nonnull %16, i64 %34, i8* noundef nonnull %19, i64 %39, i8* noundef nonnull %22, i8* noundef nonnull %25, i64 %44, i8* noundef nonnull %28) #13 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !494
  ret void, !dbg !498
}

No augmented forward pass found for .text
 at context:   call void @.text(i8* noundef nonnull %4, i8* noundef nonnull %3, i8* noundef nonnull %7, i8* noundef nonnull %10, i8* noundef nonnull %13, i8* noundef nonnull %16, i64 %34, i8* noundef nonnull %19, i64 %39, i8* noundef nonnull %22, i8* noundef nonnull %25, i64 %44, i8* noundef nonnull %28) #13 [ "jl_roots"({} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null, {} addrspace(10)* null) ], !dbg !45

Stacktrace:
 [1] gemm!
   @ C:\Users\jerem\.julia\packages\NNlib\Fg3DQ\src\gemm.jl:48
 [2] my_gemm!
   @ c:\Users\jerem\OneDrive\github\ADTests.jl\experiments\enzyme\conv-debug.jl:14


Stacktrace:
 [1] throwerr(cstr::Cstring)
   @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\tZYHp\src\compiler.jl:1251

@jakubMitura14
Copy link

No I did not yet tried it, I needed to work on sth different for some time

@wsmoses
Copy link
Member

wsmoses commented Nov 22, 2023

@jeremiedb you should be able to directly differentiate NNlib.conv (it has an EnzymeRule in NNlib)

@jeremiedb
Copy link
Author

Oh I missed the newly added EnzymeRules in NNlib, thanks for pointing that out!

I just hit new issue however with Enzyme v0.11.10, NNlib v0.9.8, Julia 1.10-rc1, Windows.

using Enzyme
using NNlib

loss(w, x) = sum(conv(x, w))
w = randn(Float32, 3, 3, 5, 7);
dw = zero(w);
x = randn(Float32, (3, 3, 5, 8));
loss(w, x);
grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x));

The call to loss works fine, but then the autodiff results in the following:

julia> grads = Enzyme.autodiff(Reverse, loss, Duplicated(w, dw), Const(x))
ERROR: AssertionError: legal
Stacktrace:
  [1] array_shadow_handler(B::Ptr{LLVM.API.LLVMOpaqueBuilder}, OrigCI::Ptr{LLVM.API.LLVMOpaqueValue}, numArgs::UInt64, Args::Ptr{Ptr{…}}, gutils::Ptr{Nothing})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:982
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo,
 uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\api.jl:141
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:7726
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9278
  [5] codegen
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:8886 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9830
  [7] cached_compilation
    @ C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9864 [inlined]
  [8] (::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{}, Int64, Bool, Bool, UInt64, DataType})(ctx::LLVM.Context)
    @ Enzyme.Compiler C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9921
  [9] JuliaContext(f::Enzyme.Compiler.var"#474#475"{DataType, DataType, DataType, Enzyme.API.CDerivativeMode, Tuple{}, Int64, Bool, Bool, UInt64, DataType})
    @ GPUCompiler C:\Users\jerem\.julia\packages\GPUCompiler\U36Ed\src\driver.jl:47
 [10] #s325#473
    @ C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\compiler.jl:9882 [inlined]
 [11]
    @ Enzyme.Compiler .\none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core .\boot.jl:600
 [13] autodiff
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:207 [inlined]
 [14] autodiff
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:236 [inlined]
 [15] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss), ::Duplicated{Array{Float32, 4}}, ::Const{Array{Float32, 4}})
    @ Enzyme C:\Users\jerem\.julia\packages\Enzyme\rbuCz\src\Enzyme.jl:222
 [16] top-level scope
    @ REPL[16]:1
Some type information was truncated. Use `show(err)` to see complete types.

I shall be able to test on a Linux machne tomorrow, in case it's a Windows specific issue.

@jeremiedb
Copy link
Author

jeremiedb commented Nov 24, 2023

@wsmoses I've isolated MWE of the above "legal" error that arise (both on Windows/Ubuntu). For illustration, the following is a 2-level loop that works fine:

using Enzyme

function my_conv_1(x, w)
    y = zero(x)
    for b in axes(y, 3)
        for wi in axes(y, 2)
            y[:, wi, b] .= w .* x[:, wi, b]
        end
    end
    return y
end
x = rand(Float32, 3, 5, 8);
w = rand(Float32, 3);
y = my_conv_1(x, w);
loss1(x, w) = sum(my_conv_1(x, w))
dw = zero(w);
loss1(x, w)
grads = Enzyme.autodiff(Reverse, loss1, Const(x), Duplicated(w, dw));

However, when adding another dimension to the data, it errors:

function my_conv_2(x, w)
    y = zero(x)
    for b in axes(y, 4)
        for hi in axes(y, 3)
            for wi in axes(y, 2)
                y[:, wi, hi, b] .= w .* x[:, wi, hi, b]
            end
        end
    end
    return y
end
x = rand(Float32, 3, 5, 5, 8);
w = rand(Float32, 3);
y = my_conv_2(x, w);
loss2(x, w) = sum(my_conv_2(x, w))
dw = zero(w);
loss2(x, w)
grads = Enzyme.autodiff(Reverse, loss2, Const(x), Duplicated(w, dw));

Stacktrace:

ERROR: AssertionError: legal
Stacktrace:
  [1] array_shadow_handler(B::Ptr{…}, OrigCI::Ptr{…}, numArgs::UInt64, Args::Ptr{…}, gutils::Ptr{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:982
  [2] EnzymeCreatePrimalAndGradient(logic::Enzyme.Logic, todiff::LLVM.Function, retType::Enzyme.API.CDIFFE_TYPE, constant_args::Vector{…}, TA::Enzyme.TypeAnalysis, returnValue::Bool, dretUsed::Bool, mode::Enzyme.API.CDerivativeMode, width::Int64, additionalArg::Ptr{…}, forceAnonymousTape::Bool, typeInfo::Enzyme.FnTypeInfo, uncacheable_args::Vector{…}, augmented::Ptr{…}, atomicAdd::Bool)
    @ Enzyme.API ~/.julia/packages/Enzyme/rbuCz/src/api.jl:141
  [3] enzyme!(job::GPUCompiler.CompilerJob{…}, mod::LLVM.Module, primalf::LLVM.Function, TT::Type, mode::Enzyme.API.CDerivativeMode, width::Int64, parallel::Bool, actualRetType::Type, wrap::Bool, modifiedBetween::Tuple{…}, returnPrimal::Bool, jlrules::Vector{…}, expectedTapeType::Type, loweredArgs::Set{…}, boxedArgs::Set{…})
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:7726
  [4] codegen(output::Symbol, job::GPUCompiler.CompilerJob{…}; libraries::Bool, deferred_codegen::Bool, optimize::Bool, toplevel::Bool, strip::Bool, validate::Bool, only_entry::Bool, parent_job::Nothing)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9278
  [5] codegen
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:8886 [inlined]
  [6] _thunk(job::GPUCompiler.CompilerJob{Enzyme.Compiler.EnzymeTarget, Enzyme.Compiler.EnzymeCompilerParams}, postopt::Bool) (repeats 2 times)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9830
  [7] cached_compilation
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9864 [inlined]
  [8] (::Enzyme.Compiler.var"#474#475"{})(ctx::LLVM.Context)
    @ Enzyme.Compiler ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9921
  [9] JuliaContext(f::Enzyme.Compiler.var"#474#475"{})
    @ GPUCompiler ~/.julia/packages/GPUCompiler/U36Ed/src/driver.jl:47
 [10] #s325#473
    @ ~/.julia/packages/Enzyme/rbuCz/src/compiler.jl:9882 [inlined]
 [11] 
    @ Enzyme.Compiler ./none:0
 [12] (::Core.GeneratedFunctionStub)(::UInt64, ::LineNumberNode, ::Any, ::Vararg{Any})
    @ Core ./boot.jl:600
 [13] autodiff
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:207 [inlined]
 [14] autodiff
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:236 [inlined]
 [15] autodiff(::ReverseMode{false, FFIABI}, ::typeof(loss2), ::Const{Array{Float32, 4}}, ::Duplicated{Vector{Float32}})
    @ Enzyme ~/.julia/packages/Enzyme/rbuCz/src/Enzyme.jl:222
 [16] top-level scope
    @ ~/github/ADTests.jl/experiments/conv-v3.jl:1
Some type information was truncated. Use `show(err)` to see complete types.

Curiously, it also errors if the above only performs a single loop:

function my_conv_3(x, w)
    y = zero(x)
    for hi in axes(y, 3)
        y[1] += w[1] * x[1]
    end
    return y
end
x = rand(Float32, 3, 5, 5, 8);
w = rand(Float32, 3);
y = my_conv_3(x, w);
loss3(x, w) = sum(my_conv_3(x, w))
dw = zero(w);
loss3(x, w)
grads = Enzyme.autodiff(Reverse, loss3, Const(x), Duplicated(w, dw));

From the above, it looks like there's something happening in the handling of Arrays of 4 or more dimensions. Did I miss something obvious?

Status `~/github/ADTests.jl/Project.toml`
  [052768ef] CUDA v5.1.1
  [d360d2e6] ChainRulesCore v1.18.0
  [7da242da] Enzyme v0.11.10
  [587475ba] Flux v0.14.6
  [bdcacae8] LoopVectorization v0.12.166
  [872c559c] NNlib v0.9.8
  [3bd65402] Optimisers v0.3.1
  [37e2e3b7] ReverseDiff v1.15.1
  [bc48ee85] Tullio v0.3.7
  [cd998857] Yota v0.8.5
  [e88e6eb3] Zygote v0.6.67

@wsmoses
Copy link
Member

wsmoses commented Nov 26, 2023

@jeremiedb issue wasn't anything to do with nnlib, just julia's special case handling for arrays of size 1, 2, 3. Should be fixed in #1157

@wsmoses wsmoses closed this as completed Nov 26, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants