Enzyme crashes with MeasureTheory.logdensity #126

sethaxen opened this issue Sep 3, 2021 · 2 comments

sethaxen commented Sep 3, 2021

Calling autodiff on MeasureTheory.logdensity for a multivariate distribution causes Enzyme (and Julia) to crash. Here is a MWE (open details to see the huge stacktrace):

julia> using MeasureTheory, Enzyme

julia> d = Normal()^5
Normal() ^ 5

julia> x = rand(d)
5-element Vector{Float64}:

julia> logdensity(d, x)

julia> ∂x = zero(x);

julia> Enzyme.autodiff(logdensity, Active, Const(d), Duplicated(x, ∂x))
┌ Warning: ("primal differentiating jl_invoke call without split mode", Base.var"#mapreduce#673", MethodInstance for var"#mapreduce#673"(::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(mapreduce), ::Function, ::Function, ::MappedArrays.ReadonlyMappedArray{Normal{(), Tuple{}}, 1, FillArrays.Fill{Normal{(), Tuple{}}, 1, Tuple{Base.OneTo{Int64}}}, typeof(identity)}, ::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}, N} where N), Any[Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), mapreduce, MeasureBase.logdensity, +, [Normal(), Normal(), Normal(), Normal(), Normal()], [-0.714839309237649, -0.11854994590939029, 0.4096756581544794, -0.8549528606231284, -1.8390113461805115]])
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/afnXq/src/compiler.jl:353
┌ Warning: ("done primal differentiating jl_invoke call without split mode", Base.var"#mapreduce#673", MethodInstance for var"#mapreduce#673"(::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(mapreduce), ::Function, ::Function, ::MappedArrays.ReadonlyMappedArray{Normal{(), Tuple{}}, 1, FillArrays.Fill{Normal{(), Tuple{}}, 1, Tuple{Base.OneTo{Int64}}}, typeof(identity)}, ::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}, N} where N), Any[Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}(), mapreduce, MeasureBase.logdensity, +, [Normal(), Normal(), Normal(), Normal(), Normal()], [-0.714839309237649, -0.11854994590939029, 0.4096756581544794, -0.8549528606231284, -1.8390113461805115]], -2.402895298929556)
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/afnXq/src/compiler.jl:375
┌ Warning: ("reverse differentiating jl_invoke call without split mode", Base.var"#mapreduce#673", MethodInstance for var"#mapreduce#673"(::Base.Iterators.Pairs{Union{}, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}, ::typeof(mapreduce), ::Function, ::Function, ::MappedArrays.ReadonlyMappedArray{Normal{(), Tuple{}}, 1, FillArrays.Fill{Normal{(), Tuple{}}, 1, Tuple{Base.OneTo{Int64}}}, typeof(identity)}, ::Vararg{Union{Base.AbstractBroadcasted, AbstractArray}, N} where N))
└ @ Enzyme.Compiler ~/.julia/packages/Enzyme/afnXq/src/compiler.jl:404
define double @preprocess_julia__mapreduce_673_3591({ { [1 x [1 x i64]] } } %0, {} addrspace(10)* nonnull align 16 dereferenceable(40) %1) local_unnamed_addr !dbg !375 {
  %2 = alloca { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, align 8
  %.fca. = extractvalue { { [1 x [1 x i64]] } } %0, 0, 0, 0, 0
  %3 = bitcast { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 16, i8* nonnull align 8 dereferenceable(16) %3)
  %4 = call {}*** @julia.ptls_states() #10
  %.fca. = getelementptr inbounds { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2, i64 0, i32 0, i64 0, i32 0, i32 0, i32 0, i64 0, i64 0, !dbg !376
  store i64 %.fca., i64* %.fca., align 8, !dbg !376
  %.fca.0.0.1.gep = getelementptr inbounds { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2, i64 0, i32 0, i64 0, i32 1, !dbg !376
  store {} addrspace(10)* %1, {} addrspace(10)** %.fca.0.0.1.gep, align 8, !dbg !376
  %5 = addrspacecast { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2 to { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] } addrspace(11)*, !dbg !376
  %6 = call fastcc nonnull {} addrspace(10)* @julia_collect_3604({ [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(16) %5), !dbg !376
  %7 = bitcast {} addrspace(10)* %6 to {} addrspace(10)* addrspace(10)*, !dbg !378
  %8 = addrspacecast {} addrspace(10)* addrspace(10)* %7 to {} addrspace(10)* addrspace(11)*, !dbg !378
  %9 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %8, i64 3, !dbg !378
  %10 = bitcast {} addrspace(10)* addrspace(11)* %9 to i64 addrspace(11)*, !dbg !378
  %11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !378, !tbaa !210, !range !128
  switch i64 %11, label %L19.i [
    i64 0, label %julia__mapreduce_673_3591.inner.exit
    i64 1, label %L17.i
  ], !dbg !387

L17.i:                                            ; preds = %entry
  %12 = bitcast {} addrspace(10)* %6 to double addrspace(13)* addrspace(10)*, !dbg !388
  %13 = addrspacecast double addrspace(13)* addrspace(10)* %12 to double addrspace(13)* addrspace(11)*, !dbg !388
  %14 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %13, align 8, !dbg !388, !tbaa !133, !nonnull !4
  %15 = load double, double addrspace(13)* %14, align 8, !dbg !388, !tbaa !135
  br label %julia__mapreduce_673_3591.inner.exit, !dbg !390

L19.i:                                            ; preds = %entry
  %16 = icmp ugt i64 %11, 15, !dbg !391
  br i1 %16, label %L35.i, label %L21.i, !dbg !392

L21.i:                                            ; preds = %L19.i
  %17 = bitcast {} addrspace(10)* %6 to double addrspace(13)* addrspace(10)*, !dbg !393
  %18 = addrspacecast double addrspace(13)* addrspace(10)* %17 to double addrspace(13)* addrspace(11)*, !dbg !393
  %19 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %18, align 8, !dbg !393, !tbaa !133, !nonnull !4
  %20 = load double, double addrspace(13)* %19, align 8, !dbg !393, !tbaa !135
  %21 = getelementptr inbounds double, double addrspace(13)* %19, i64 1, !dbg !395
  %22 = load double, double addrspace(13)* %21, align 8, !dbg !395, !tbaa !135
  %23 = fadd double %20, %22, !dbg !397
  %.not67 = icmp ugt i64 %11, 2, !dbg !399
  br i1 %.not67, label %L30.i.preheader, label %julia__mapreduce_673_3591.inner.exit, !dbg !400

L30.i.preheader:                                  ; preds = %L21.i
  br label %L30.i, !dbg !400

L30.i:                                            ; preds = %L30.i.preheader, %L30.i
  %tiv = phi i64 [ 0, %L30.i.preheader ], [, %L30.i ]
  %value_phi1.i8 = phi double [ %28, %L30.i ], [ %23, %L30.i.preheader ]
  %24 = add i64 %tiv, 2, !dbg !401 = add nuw nsw i64 %tiv, 1, !dbg !401
  %25 = add nuw nsw i64 %24, 1, !dbg !401
  %26 = getelementptr inbounds double, double addrspace(13)* %19, i64 %24, !dbg !403
  %27 = load double, double addrspace(13)* %26, align 8, !dbg !403, !tbaa !135
  %28 = fadd double %value_phi1.i8, %27, !dbg !404
  %exitcond.not = icmp eq i64 %25, %11, !dbg !399
  br i1 %exitcond.not, label %julia__mapreduce_673_3591.inner.exit.loopexit, label %L30.i, !dbg !400

L35.i:                                            ; preds = %L19.i
  %29 = call fastcc double @julia_mapreduce_impl_3595({} addrspace(10)* nocapture nonnull readonly align 16 dereferenceable(40) %6, i64 signext 1, i64 signext %11), !dbg !406
  br label %julia__mapreduce_673_3591.inner.exit, !dbg !407

julia__mapreduce_673_3591.inner.exit.loopexit:    ; preds = %L30.i
  br label %julia__mapreduce_673_3591.inner.exit, !dbg !377

julia__mapreduce_673_3591.inner.exit:             ; preds = %julia__mapreduce_673_3591.inner.exit.loopexit, %L35.i, %L21.i, %L17.i, %entry
  %value_phi.i = phi double [ %15, %L17.i ], [ %29, %L35.i ], [ 0.000000e+00, %entry ], [ %23, %L21.i ], [ %28, %julia__mapreduce_673_3591.inner.exit.loopexit ]
  call void @llvm.lifetime.end.p0i8(i64 16, i8* nonnull %3), !dbg !377
  ret double %value_phi.i

  %24 = add i64 %tiv, 2, !dbg !100: {[-1]:Integer}, intvals: {2,}
i64 0: {[-1]:Anything}, intvals: {0,}
i64 1: {[-1]:Integer}, intvals: {1,}
i64 15: {[-1]:Integer}, intvals: {15,}
i64 2: {[-1]:Integer}, intvals: {2,}
  %21 = getelementptr inbounds double, double addrspace(13)* %19, i64 1, !dbg !92: {[-1]:Pointer}, intvals: {}
  %22 = load double, double addrspace(13)* %21, align 8, !dbg !92, !tbaa !82: {[-1]:Float@double}, intvals: {} = add nuw nsw i64 %tiv, 1, !dbg !100: {[-1]:Integer}, intvals: {1,}
  %.fca. = extractvalue { { [1 x [1 x i64]] } } %0, 0, 0, 0, 0: {[-1]:Integer}, intvals: {}
  %value_phi1.i8 = phi double [ %28, %L30.i ], [ %23, %L30.i.preheader ]: {[-1]:Float@double}, intvals: {}
  %25 = add nuw nsw i64 %24, 1, !dbg !100: {[-1]:Integer}, intvals: {3,}
  %26 = getelementptr inbounds double, double addrspace(13)* %19, i64 %24, !dbg !103: {[-1]:Pointer}, intvals: {}
  %27 = load double, double addrspace(13)* %26, align 8, !dbg !103, !tbaa !82: {[-1]:Float@double}, intvals: {}
  %28 = fadd double %value_phi1.i8, %27, !dbg !104: {[-1]:Float@double}, intvals: {}
  %exitcond.not = icmp eq i64 %25, %11, !dbg !98: {[-1]:Integer}, intvals: {}
  %.fca.0.0.1.gep = getelementptr inbounds { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2, i64 0, i32 0, i64 0, i32 1, !dbg !48: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,0]:Pointer, [-1,0,0,-1]:Float@double, [-1,0,8]:Integer, [-1,0,9]:Integer, [-1,0,10]:Integer, [-1,0,11]:Integer, [-1,0,12]:Integer, [-1,0,13]:Integer, [-1,0,14]:Integer, [-1,0,15]:Integer, [-1,0,16]:Integer, [-1,0,17]:Integer, [-1,0,18]:Integer, [-1,0,19]:Integer, [-1,0,20]:Integer, [-1,0,21]:Integer, [-1,0,22]:Integer, [-1,0,23]:Integer, [-1,0,24]:Integer, [-1,0,25]:Integer, [-1,0,26]:Integer, [-1,0,27]:Integer, [-1,0,28]:Integer, [-1,0,29]:Integer, [-1,0,30]:Integer, [-1,0,31]:Integer, [-1,0,32]:Integer, [-1,0,33]:Integer, [-1,0,34]:Integer, [-1,0,35]:Integer, [-1,0,36]:Integer, [-1,0,37]:Integer, [-1,0,38]:Integer, [-1,0,39]:Integer, [-1,0,40]:Integer}, intvals: {}
  %10 = bitcast {} addrspace(10)* addrspace(11)* %9 to i64 addrspace(11)*, !dbg !51: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %11 = load i64, i64 addrspace(11)* %10, align 8, !dbg !51, !tbaa !70, !range !75: {[-1]:Integer}, intvals: {}
  %6 = call fastcc nonnull {} addrspace(10)* @julia_collect_3604({ [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(16) %5), !dbg !48: {}, intvals: {}
  %29 = call fastcc double @julia_mapreduce_impl_3595({} addrspace(10)* nocapture nonnull readonly align 16 dereferenceable(40) %6, i64 signext 1, i64 signext %11), !dbg !106: {[-1]:Pointer}, intvals: {}
  %16 = icmp ugt i64 %11, 15, !dbg !86: {[-1]:Integer}, intvals: {}
  %9 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %8, i64 3, !dbg !51: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
  %12 = bitcast {} addrspace(10)* %6 to double addrspace(13)* addrspace(10)*, !dbg !77: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %13 = addrspacecast double addrspace(13)* addrspace(10)* %12 to double addrspace(13)* addrspace(11)*, !dbg !77: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %value_phi.i = phi double [ %15, %L17.i ], [ %29, %L35.i ], [ 0.000000e+00, %entry ], [ %23, %L21.i ], [ %28, %julia__mapreduce_673_3591.inner.exit.loopexit ]: {[-1]:Pointer}, intvals: {}
  %7 = bitcast {} addrspace(10)* %6 to {} addrspace(10)* addrspace(10)*, !dbg !51: {}, intvals: {}
  %8 = addrspacecast {} addrspace(10)* addrspace(10)* %7 to {} addrspace(10)* addrspace(11)*, !dbg !51: {}, intvals: {}
  %14 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %13, align 8, !dbg !77, !tbaa !80, !nonnull !4: {[-1]:Pointer}, intvals: {}
  %15 = load double, double addrspace(13)* %14, align 8, !dbg !77, !tbaa !82: {[-1]:Pointer}, intvals: {}
  %17 = bitcast {} addrspace(10)* %6 to double addrspace(13)* addrspace(10)*, !dbg !90: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %18 = addrspacecast double addrspace(13)* addrspace(10)* %17 to double addrspace(13)* addrspace(11)*, !dbg !90: {[-1]:Pointer, [-1,0]:Pointer}, intvals: {}
  %5 = addrspacecast { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2 to { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] } addrspace(11)*, !dbg !48: {[-1]:Pointer}, intvals: {}
{ { [1 x [1 x i64]] } } %0: {[0]:Integer, [1]:Integer, [2]:Integer, [3]:Integer, [4]:Integer, [5]:Integer, [6]:Integer, [7]:Integer}, intvals: {}
{} addrspace(10)* %1: {[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@double, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer, [-1,40]:Integer}, intvals: {}
  %2 = alloca { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, align 8: {[-1]:Pointer}, intvals: {}
  %3 = bitcast { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2 to i8*: {[-1]:Pointer}, intvals: {}
  %23 = fadd double %20, %22, !dbg !94: {[-1]:Float@double}, intvals: {}
  %.not67 = icmp ugt i64 %11, 2, !dbg !98: {[-1]:Integer}, intvals: {}
  %19 = load double addrspace(13)*, double addrspace(13)* addrspace(11)* %18, align 8, !dbg !90, !tbaa !80, !nonnull !4: {[-1]:Pointer}, intvals: {}
  %20 = load double, double addrspace(13)* %19, align 8, !dbg !90, !tbaa !82: {[-1]:Float@double}, intvals: {}
  %tiv = phi i64 [ 0, %L30.i.preheader ], [, %L30.i ]: {[-1]:Integer}, intvals: {0,}
  %.fca. = getelementptr inbounds { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }, { [1 x { { { [1 x [1 x i64]] } }, {} addrspace(10)* }] }* %2, i64 0, i32 0, i64 0, i32 0, i32 0, i32 0, i64 0, i64 0, !dbg !48: {[-1]:Pointer, [-1,0]:Integer, [-1,1]:Integer, [-1,2]:Integer, [-1,3]:Integer, [-1,4]:Integer, [-1,5]:Integer, [-1,6]:Integer, [-1,7]:Integer}, intvals: {}
Illegal updateAnalysis prev:{[-1]:Float@double} new: {[-1]:Pointer}
val:   %23 = fadd double %20, %22, !dbg !94 origin=  %value_phi.i = phi double [ %15, %L17.i ], [ %29, %L35.i ], [ 0.000000e+00, %entry ], [ %23, %L21.i ], [ %28, %julia__mapreduce_673_3591.inner.exit.loopexit ]
Assertion failed: (0 && "Performed illegal updateAnalysis"), function updateAnalysis, file /workspace/srcdir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp, line 586.

signal (6): Abort trap: 6
in expression starting at REPL[8]:1
__pthread_kill at /usr/lib/system/libsystem_kernel.dylib (unknown line)
Allocations: 54941224 (Pool: 54921697; Big: 19527); GC: 59
zsh: abort      julia
@wsmoses wsmoses self-assigned this Sep 4, 2021
wsmoses commented Sep 4, 2021

The compilation error is remedied by the latest commit pushed here: #115

This may require a new cut of Enzyme_jll [and I haven't validated correctness as MeasureTheory requires support for the in progress jl_invoke handling under the hood].

wsmoses commented Sep 28, 2021

Succeeds on my dev branch and I think main as well, closing.

@wsmoses wsmoses closed this as completed Sep 28, 2021
