Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

error/crash when differentiating member function #1388

Closed
samuelpmishLLNL opened this issue Aug 25, 2023 · 5 comments
Closed

error/crash when differentiating member function #1388

samuelpmishLLNL opened this issue Aug 25, 2023 · 5 comments

Comments

@samuelpmishLLNL
Copy link
Collaborator

see: https://fwd.gymni.ch/L6n0Wk

the presence or absence of a member variable in the struct "Material" emits a cryptic error message and dumps a stacktrace. I also can't seem to figure out how to call enzyme_fwddiff on the operator() without getting an error message like:

error: Enzyme: Cannot cast __enzyme_autodiff primal argument 3, found   %0 = load i32, ptr @enzyme_const, align 4, !tbaa !6, type i32 - to arg 1 ptr
@wsmoses
Copy link
Member

wsmoses commented Sep 19, 2023

copying from slack. So that error you pasted above, no longer happens (yay).

Instead the remaining issue is a type analysis problem.

in that minimal case, ironically the problem is that its so small it doesn't have information that the input type is a double
basically the entirety of that code is differentiating a memcpy
and the correct memcpy derivative (in reverse mode) is different for float/double/etc
enzyme runs a type analysis to deduce these types
[eg if it sees a value is loaded and fadded, or exp'd it knows the memory containing it is double]
here since there's no other code it can't tell
we actually made some significant progress recently to make better errors for that, which you can see in the explorer now [if you force it to not use the cache]

<source>:17:12: error: Enzyme: Cannot deduce type of copy   tail call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 8 dereferenceable(72) %0, ptr noundef nonnull align 8 dereferenceable(72) %2, i64 72, i1 false) #12, !dbg !67, !tbaa.struct !68
<analysis>
ptr %0: {[-1]:Pointer}, intvals: {}
double %1: {[-1]:Float@double}, intvals: {}
ptr %2: {[-1]:Pointer}, intvals: {}
ptr %3: {[-1]:Pointer}, intvals: {}
</analysis>
    return du_dX;
           ^
1 error generated.
Compiler returned: 1

since Enzyme cannot prove the type from the lack of info, it throws an error rather than risking incorrectness. you can tell enzyme to make its best guess, or even tell it the type explicitly
here for example, https://fwd.gymni.ch/GimdSi

@wsmoses
Copy link
Member

wsmoses commented Sep 19, 2023

Looks like the TBAA emitted was kind of worthless (See below), which we could've used to deduce otherwise on the memcpy.

https://fwd.gymni.ch/xr0I7u

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.mat3 = type { [3 x [3 x double]] }
%struct.vec3 = type { [3 x double] }

$_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_ = comdat any

$_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_ = comdat any

@enzyme_dup = dso_local local_unnamed_addr global i32 0, align 4
@enzyme_dupnoneed = dso_local local_unnamed_addr global i32 0, align 4
@enzyme_out = dso_local local_unnamed_addr global i32 0, align 4
@enzyme_const = dso_local local_unnamed_addr global i32 0, align 4
@__const.main.du_dX = private unnamed_addr constant %struct.mat3 { [3 x [3 x double]] [[3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], [3 x double] [double 7.000000e+00, double 8.000000e+00, double 9.000000e+00]] }, align 8
@__const.main.k = private unnamed_addr constant %struct.vec3 { [3 x double] [double 0.000000e+00, double 1.000000e+00, double 2.000000e+00] }, align 8
@__const.main.dk = private unnamed_addr constant %struct.vec3 { [3 x double] [double 0.000000e+00, double 1.000000e+00, double 0.000000e+00] }, align 8

; Function Attrs: norecurse uwtable mustprogress
define dso_local i32 @main() local_unnamed_addr #0 {
  %1 = alloca %struct.mat3, align 8
  %2 = alloca %struct.vec3, align 8
  %3 = alloca %struct.vec3, align 8
  %4 = alloca %struct.mat3, align 8
  %5 = alloca %struct.mat3, align 8
  %6 = bitcast %struct.mat3* %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* nonnull %6) #5
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(72) %6, i8* nonnull align 8 dereferenceable(72) bitcast (%struct.mat3* @__const.main.du_dX to i8*), i64 72, i1 false)
  %7 = bitcast %struct.vec3* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %7) #5
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(24) %7, i8* nonnull align 8 dereferenceable(24) bitcast (%struct.vec3* @__const.main.k to i8*), i64 24, i1 false)
  %8 = bitcast %struct.vec3* %3 to i8*
  call void @llvm.lifetime.start.p0i8(i64 24, i8* nonnull %8) #5
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(24) %8, i8* nonnull align 8 dereferenceable(24) bitcast (%struct.vec3* @__const.main.dk to i8*), i64 24, i1 false)
  %9 = bitcast %struct.mat3* %4 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* nonnull %9) #5
  %10 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %11 = load i32, i32* @enzyme_dup, align 4, !tbaa !3
  call void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiS0_i4vec3S2_EET_PvDpT0_(%struct.mat3* nonnull sret(%struct.mat3) align 8 %4, i8* bitcast (void (%struct.mat3*, double, %struct.mat3*, %struct.vec3*)* @_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_ to i8*), i32 %10, double 0.000000e+00, i32 %10, %struct.mat3* nonnull byval(%struct.mat3) align 8 @__const.main.du_dX, i32 %11, %struct.vec3* nonnull byval(%struct.vec3) align 8 @__const.main.k, %struct.vec3* nonnull byval(%struct.vec3) align 8 @__const.main.dk)
  %12 = bitcast %struct.mat3* %5 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* nonnull %12) #5
  %13 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %14 = load i32, i32* @enzyme_dup, align 4, !tbaa !3
  call void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiPS0_iP4vec3S4_EET_PvDpT0_(%struct.mat3* nonnull sret(%struct.mat3) align 8 %5, i8* bitcast (void (%struct.mat3*, double, %struct.mat3*, %struct.vec3*)* @_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_ to i8*), i32 %13, double 0.000000e+00, i32 %13, %struct.mat3* nonnull %1, i32 %14, %struct.vec3* nonnull %2, %struct.vec3* nonnull %3)
  call void @llvm.lifetime.end.p0i8(i64 72, i8* nonnull %12) #5
  call void @llvm.lifetime.end.p0i8(i64 72, i8* nonnull %9) #5
  call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %8) #5
  call void @llvm.lifetime.end.p0i8(i64 24, i8* nonnull %7) #5
  call void @llvm.lifetime.end.p0i8(i64 72, i8* nonnull %6) #5
  ret i32 0
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #1

declare dso_local void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiS0_i4vec3S2_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8, i8*, i32, double, i32, %struct.mat3* byval(%struct.mat3) align 8, i32, %struct.vec3* byval(%struct.vec3) align 8, %struct.vec3* byval(%struct.vec3) align 8) local_unnamed_addr #2

; Function Attrs: uwtable willreturn mustprogress
define linkonce_odr dso_local void @_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_(%struct.mat3* noalias sret(%struct.mat3) align 8 %0, double %1, %struct.mat3* byval(%struct.mat3) align 8 %2, %struct.vec3* byval(%struct.vec3) align 8 %3) #3 comdat {
  %5 = bitcast %struct.mat3* %0 to i8*
  %6 = bitcast %struct.mat3* %2 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(72) %5, i8* nonnull align 8 dereferenceable(72) %6, i64 72, i1 false) #5, !tbaa.struct !7
  ret void
}

declare dso_local void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiPS0_iP4vec3S4_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8, i8*, i32, double, i32, %struct.mat3*, i32, %struct.vec3*, %struct.vec3*) local_unnamed_addr #2

; Function Attrs: nounwind uwtable willreturn mustprogress
define linkonce_odr dso_local void @_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_(%struct.mat3* noalias sret(%struct.mat3) align 8 %0, double %1, %struct.mat3* nonnull align 8 dereferenceable(72) %2, %struct.vec3* nonnull align 8 dereferenceable(24) %3) #4 comdat {
  %5 = bitcast %struct.mat3* %0 to i8*
  %6 = bitcast %struct.mat3* %2 to i8*
  tail call void @llvm.memcpy.p0i8.p0i8.i64(i8* nonnull align 8 dereferenceable(72) %5, i8* nonnull align 8 dereferenceable(72) %6, i64 72, i1 false) #5, !tbaa.struct !7
  ret void
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1

attributes #0 = { norecurse uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { argmemonly nofree nosync nounwind willreturn }
attributes #2 = { "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #3 = { uwtable willreturn mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #4 = { nounwind uwtable willreturn mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #5 = { nounwind }

!llvm.linker.options = !{}
!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
!3 = !{!4, !4, i64 0}
!4 = !{!"int", !5, i64 0}
!5 = !{!"omnipotent char", !6, i64 0}
!6 = !{!"Simple C++ TBAA"}
!7 = !{i64 0, i64 72, !8}
!8 = !{!5, !5, i64 0}

@wsmoses
Copy link
Member

wsmoses commented Sep 19, 2023

The memcpy tbaa coming right out of clang looks worthless already:

; ModuleID = '<source>'
source_filename = "<source>"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
target triple = "x86_64-unknown-linux-gnu"

%struct.mat3 = type { [3 x [3 x double]] }
%struct.vec3 = type { [3 x double] }
%struct.Material = type { double }

$_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_ = comdat any

$_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_ = comdat any

$_ZN8MaterialclERK4mat3RK4vec3 = comdat any

@enzyme_dup = dso_local global i32 0, align 4
@enzyme_dupnoneed = dso_local global i32 0, align 4
@enzyme_out = dso_local global i32 0, align 4
@enzyme_const = dso_local global i32 0, align 4
@__const.main.du_dX = private unnamed_addr constant %struct.mat3 { [3 x [3 x double]] [[3 x double] [double 1.000000e+00, double 2.000000e+00, double 3.000000e+00], [3 x double] [double 4.000000e+00, double 5.000000e+00, double 6.000000e+00], [3 x double] [double 7.000000e+00, double 8.000000e+00, double 9.000000e+00]] }, align 8
@__const.main.k = private unnamed_addr constant %struct.vec3 { [3 x double] [double 0.000000e+00, double 1.000000e+00, double 2.000000e+00] }, align 8
@__const.main.dk = private unnamed_addr constant %struct.vec3 { [3 x double] [double 0.000000e+00, double 1.000000e+00, double 0.000000e+00] }, align 8

; Function Attrs: norecurse uwtable mustprogress
define dso_local i32 @main() #0 {
  %1 = alloca %struct.Material, align 8
  %2 = alloca %struct.mat3, align 8
  %3 = alloca %struct.vec3, align 8
  %4 = alloca %struct.vec3, align 8
  %5 = alloca %struct.mat3, align 8
  %6 = alloca %struct.Material, align 8
  %7 = alloca %struct.mat3, align 8
  %8 = alloca %struct.vec3, align 8
  %9 = alloca %struct.vec3, align 8
  %10 = alloca %struct.mat3, align 8
  %11 = alloca %struct.Material, align 8
  %12 = bitcast %struct.Material* %1 to i8*
  call void @llvm.lifetime.start.p0i8(i64 8, i8* %12) #6
  %13 = bitcast %struct.Material* %1 to i8*
  call void @llvm.memset.p0i8.i64(i8* align 8 %13, i8 0, i64 8, i1 false)
  %14 = bitcast %struct.mat3* %2 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* %14) #6
  %15 = bitcast %struct.mat3* %2 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %15, i8* align 8 bitcast (%struct.mat3* @__const.main.du_dX to i8*), i64 72, i1 false)
  %16 = bitcast %struct.vec3* %3 to i8*
  call void @llvm.lifetime.start.p0i8(i64 24, i8* %16) #6
  %17 = bitcast %struct.vec3* %3 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %17, i8* align 8 bitcast (%struct.vec3* @__const.main.k to i8*), i64 24, i1 false)
  %18 = bitcast %struct.vec3* %4 to i8*
  call void @llvm.lifetime.start.p0i8(i64 24, i8* %18) #6
  %19 = bitcast %struct.vec3* %4 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %19, i8* align 8 bitcast (%struct.vec3* @__const.main.dk to i8*), i64 24, i1 false)
  %20 = bitcast %struct.mat3* %5 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* %20) #6
  %21 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %22 = bitcast %struct.Material* %6 to i8*
  %23 = bitcast %struct.Material* %1 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %22, i8* align 8 %23, i64 8, i1 false), !tbaa.struct !7
  %24 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %25 = bitcast %struct.mat3* %7 to i8*
  %26 = bitcast %struct.mat3* %2 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %25, i8* align 8 %26, i64 72, i1 false), !tbaa.struct !10
  %27 = load i32, i32* @enzyme_dup, align 4, !tbaa !3
  %28 = bitcast %struct.vec3* %8 to i8*
  %29 = bitcast %struct.vec3* %3 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %28, i8* align 8 %29, i64 24, i1 false), !tbaa.struct !12
  %30 = bitcast %struct.vec3* %9 to i8*
  %31 = bitcast %struct.vec3* %4 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %30, i8* align 8 %31, i64 24, i1 false), !tbaa.struct !12
  %32 = getelementptr inbounds %struct.Material, %struct.Material* %6, i32 0, i32 0
  %33 = load double, double* %32, align 8
  call void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiS0_i4vec3S2_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8 %5, i8* bitcast (void (%struct.mat3*, double, %struct.mat3*, %struct.vec3*)* @_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_ to i8*), i32 %21, double %33, i32 %24, %struct.mat3* byval(%struct.mat3) align 8 %7, i32 %27, %struct.vec3* byval(%struct.vec3) align 8 %8, %struct.vec3* byval(%struct.vec3) align 8 %9)
  %34 = bitcast %struct.mat3* %10 to i8*
  call void @llvm.lifetime.start.p0i8(i64 72, i8* %34) #6
  %35 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %36 = bitcast %struct.Material* %11 to i8*
  %37 = bitcast %struct.Material* %1 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %36, i8* align 8 %37, i64 8, i1 false), !tbaa.struct !7
  %38 = load i32, i32* @enzyme_const, align 4, !tbaa !3
  %39 = load i32, i32* @enzyme_dup, align 4, !tbaa !3
  %40 = getelementptr inbounds %struct.Material, %struct.Material* %11, i32 0, i32 0
  %41 = load double, double* %40, align 8
  call void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiPS0_iP4vec3S4_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8 %10, i8* bitcast (void (%struct.mat3*, double, %struct.mat3*, %struct.vec3*)* @_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_ to i8*), i32 %35, double %41, i32 %38, %struct.mat3* %2, i32 %39, %struct.vec3* %3, %struct.vec3* %4)
  %42 = bitcast %struct.mat3* %10 to i8*
  call void @llvm.lifetime.end.p0i8(i64 72, i8* %42) #6
  %43 = bitcast %struct.mat3* %5 to i8*
  call void @llvm.lifetime.end.p0i8(i64 72, i8* %43) #6
  %44 = bitcast %struct.vec3* %4 to i8*
  call void @llvm.lifetime.end.p0i8(i64 24, i8* %44) #6
  %45 = bitcast %struct.vec3* %3 to i8*
  call void @llvm.lifetime.end.p0i8(i64 24, i8* %45) #6
  %46 = bitcast %struct.mat3* %2 to i8*
  call void @llvm.lifetime.end.p0i8(i64 72, i8* %46) #6
  %47 = bitcast %struct.Material* %1 to i8*
  call void @llvm.lifetime.end.p0i8(i64 8, i8* %47) #6
  ret i32 0
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.start.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: argmemonly nofree nosync nounwind willreturn writeonly
declare void @llvm.memset.p0i8.i64(i8* nocapture writeonly, i8, i64, i1 immarg) #2

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.memcpy.p0i8.p0i8.i64(i8* noalias nocapture writeonly, i8* noalias nocapture readonly, i64, i1 immarg) #1

declare dso_local void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiS0_i4vec3S2_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8, i8*, i32, double, i32, %struct.mat3* byval(%struct.mat3) align 8, i32, %struct.vec3* byval(%struct.vec3) align 8, %struct.vec3* byval(%struct.vec3) align 8) #3

; Function Attrs: uwtable mustprogress
define linkonce_odr dso_local void @_Z8wrapper1I8MaterialJ4mat34vec3EEDaT_DpT0_(%struct.mat3* noalias sret(%struct.mat3) align 8 %0, double %1, %struct.mat3* byval(%struct.mat3) align 8 %2, %struct.vec3* byval(%struct.vec3) align 8 %3) #4 comdat {
  %5 = alloca %struct.Material, align 8
  %6 = getelementptr inbounds %struct.Material, %struct.Material* %5, i32 0, i32 0
  store double %1, double* %6, align 8
  call void @_ZN8MaterialclERK4mat3RK4vec3(%struct.mat3* sret(%struct.mat3) align 8 %0, %struct.Material* nonnull dereferenceable(8) %5, %struct.mat3* nonnull align 8 dereferenceable(72) %2, %struct.vec3* nonnull align 8 dereferenceable(24) %3)
  ret void
}

declare dso_local void @_Z16__enzyme_fwddiffI4mat3Ji8MaterialiPS0_iP4vec3S4_EET_PvDpT0_(%struct.mat3* sret(%struct.mat3) align 8, i8*, i32, double, i32, %struct.mat3*, i32, %struct.vec3*, %struct.vec3*) #3

; Function Attrs: nounwind uwtable mustprogress
define linkonce_odr dso_local void @_Z8wrapper2I8MaterialJ4mat34vec3EEDaT_DpRKT0_(%struct.mat3* noalias sret(%struct.mat3) align 8 %0, double %1, %struct.mat3* nonnull align 8 dereferenceable(72) %2, %struct.vec3* nonnull align 8 dereferenceable(24) %3) #5 comdat {
  %5 = alloca %struct.Material, align 8
  %6 = alloca %struct.mat3*, align 8
  %7 = alloca %struct.vec3*, align 8
  %8 = getelementptr inbounds %struct.Material, %struct.Material* %5, i32 0, i32 0
  store double %1, double* %8, align 8
  store %struct.mat3* %2, %struct.mat3** %6, align 8, !tbaa !13
  store %struct.vec3* %3, %struct.vec3** %7, align 8, !tbaa !13
  %9 = load %struct.mat3*, %struct.mat3** %6, align 8, !tbaa !13
  %10 = load %struct.vec3*, %struct.vec3** %7, align 8, !tbaa !13
  call void @_ZN8MaterialclERK4mat3RK4vec3(%struct.mat3* sret(%struct.mat3) align 8 %0, %struct.Material* nonnull dereferenceable(8) %5, %struct.mat3* nonnull align 8 dereferenceable(72) %9, %struct.vec3* nonnull align 8 dereferenceable(24) %10)
  ret void
}

; Function Attrs: argmemonly nofree nosync nounwind willreturn
declare void @llvm.lifetime.end.p0i8(i64 immarg, i8* nocapture) #1

; Function Attrs: nounwind uwtable mustprogress
define linkonce_odr dso_local void @_ZN8MaterialclERK4mat3RK4vec3(%struct.mat3* noalias sret(%struct.mat3) align 8 %0, %struct.Material* nonnull dereferenceable(8) %1, %struct.mat3* nonnull align 8 dereferenceable(72) %2, %struct.vec3* nonnull align 8 dereferenceable(24) %3) #5 comdat align 2 {
  %5 = alloca %struct.Material*, align 8
  %6 = alloca %struct.mat3*, align 8
  %7 = alloca %struct.vec3*, align 8
  store %struct.Material* %1, %struct.Material** %5, align 8, !tbaa !13
  store %struct.mat3* %2, %struct.mat3** %6, align 8, !tbaa !13
  store %struct.vec3* %3, %struct.vec3** %7, align 8, !tbaa !13
  %8 = load %struct.Material*, %struct.Material** %5, align 8
  %9 = load %struct.mat3*, %struct.mat3** %6, align 8, !tbaa !13
  %10 = bitcast %struct.mat3* %0 to i8*
  %11 = bitcast %struct.mat3* %9 to i8*
  call void @llvm.memcpy.p0i8.p0i8.i64(i8* align 8 %10, i8* align 8 %11, i64 72, i1 false), !tbaa.struct !10
  ret void
}

attributes #0 = { norecurse uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #1 = { argmemonly nofree nosync nounwind willreturn }
attributes #2 = { argmemonly nofree nosync nounwind willreturn writeonly }
attributes #3 = { "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #4 = { uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #5 = { nounwind uwtable mustprogress "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
attributes #6 = { nounwind }

!llvm.linker.options = !{}
!llvm.module.flags = !{!0, !1}
!llvm.ident = !{!2}

!0 = !{i32 7, !"Dwarf Version", i32 4}
!1 = !{i32 1, !"wchar_size", i32 4}
!2 = !{!"clang version 12.0.1 (https://github.com/llvm/llvm-project.git fed41342a82f5a3a9201819a82bf7a48313e296b)"}
!3 = !{!4, !4, i64 0}
!4 = !{!"int", !5, i64 0}
!5 = !{!"omnipotent char", !6, i64 0}
!6 = !{!"Simple C++ TBAA"}
!7 = !{i64 0, i64 8, !8}
!8 = !{!9, !9, i64 0}
!9 = !{!"double", !5, i64 0}
!10 = !{i64 0, i64 72, !11}
!11 = !{!5, !5, i64 0}
!12 = !{i64 0, i64 24, !11}
!13 = !{!14, !14, i64 0}
!14 = !{!"any pointer", !5, i64 0}

My guess is that the clang sret memcpy lowering doesn't use the full type info in its !tbaa.struct !10 that it should.

cc @jdoerfert if you know where that is offhand for fixing in upstream LLVM

@samuelpmishLLNL
Copy link
Collaborator Author

Thanks for the update! The original compiler explorer example was too simple, here's a more representative one:
https://fwd.gymni.ch/BsL8dM

@samuelpmishLLNL
Copy link
Collaborator Author

@wsmoses explained to me that the underlying cause here is related to a function object with no member variables being passed to __enzyme_XXXdiff(...). In this case

    mat3 output = __enzyme_fwddiff<mat3>((void*)wrapper<Material, mat3, vec3>, 
        enzyme_const, mat, // mat has no member variables
        enzyme_const, &du_dX, 
        enzyme_dup, &k, &dk);

the argument mat is either being optimized away or omitted as a result of it having no member variables, so enzyme effectively sees

    mat3 output = __enzyme_fwddiff<mat3>((void*)wrapper<Material, mat3, vec3>, 
        enzyme_const, // nothing here any more
        enzyme_const, &du_dX, 
        enzyme_dup, &k, &dk);

which is a syntax error, because the first enzyme_const annotation has no argument to apply to. So, the fix is to omit both the function object and the associated enzyme annotation:

    mat3 output = __enzyme_fwddiff<mat3>((void*)wrapper<Material, mat3, vec3>, 
        // enzyme_const, mat
        enzyme_const, &du_dX, 
        enzyme_dup, &k, &dk);

That way, there won't accidentally be two enzyme annotations back-to-back. See: https://fwd.gymni.ch/lFgkQd

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants