Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Slow Broadcasting (compared to Zygote) #1434

Open
avik-pal opened this issue May 12, 2024 · 4 comments
Open

Slow Broadcasting (compared to Zygote) #1434

avik-pal opened this issue May 12, 2024 · 4 comments

Comments

@avik-pal
Copy link
Contributor

avik-pal commented May 12, 2024

using NNlib, Enzyme, Zygote

gelu_act(x) = sum(abs2, gelu.(x))

x = randn(Float32, 32, 32, 1, 32)

@btime Enzyme.gradient(Reverse, $gelu_act, $x); # 1.298 ms (65 allocations: 386.95 KiB)

@btime Zygote.gradient($gelu_act, $x); # 735.499 μs (26 allocations: 384.75 KiB)

This might be somewhat unfair because gelu has rrule defined.

@wsmoses
Copy link
Member

wsmoses commented May 13, 2024

using BenchmarkTools, Enzyme, Zygote

gelu_act(x) = sum(abs2, sin.(x))

x = randn(Float32, 32, 32, 1, 32); 

Enzyme.gradient(Reverse, gelu_act, x); 

@btime Enzyme.gradient(Reverse, $gelu_act, $x); 

@btime Zygote.gradient($gelu_act, $x)

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

So the code generated here by the broadacst, before any Enzyme AD is actually quite awful. I wrote some primitive optimization passes to do a bit of cleanup (which may fix the runtiem activity), but still the indexing pattern is really bad.

@vchuravy any ideas what's happening here (besides it presumably now being > 3 dims so no specialization by Julia)

after simplification :
; Function Attrs: mustprogress willreturn
define "enzyme_type"="{[-1]:Float@float}" float @preprocess_julia_gelu_act_1468({} addrspace(10)* nocapture noundef nonnull readonly align 16 dereferenceable(40) "enzyme_type"="{[-1]:Pointer, [-1,0]:Pointer, [-1,0,-1]:Float@float, [-1,8]:Integer, [-1,9]:Integer, [-1,10]:Integer, [-1,11]:Integer, [-1,12]:Integer, [-1,13]:Integer, [-1,14]:Integer, [-1,15]:Integer, [-1,16]:Integer, [-1,17]:Integer, [-1,18]:Integer, [-1,19]:Integer, [-1,20]:Integer, [-1,21]:Integer, [-1,22]:Integer, [-1,23]:Integer, [-1,24]:Integer, [-1,25]:Integer, [-1,26]:Integer, [-1,27]:Integer, [-1,28]:Integer, [-1,29]:Integer, [-1,30]:Integer, [-1,31]:Integer, [-1,32]:Integer, [-1,33]:Integer, [-1,34]:Integer, [-1,35]:Integer, [-1,36]:Integer, [-1,37]:Integer, [-1,38]:Integer, [-1,39]:Integer}" "enzymejl_parmtype"="131622087272080" "enzymejl_parmtype_ref"="2" %0) local_unnamed_addr #21 !dbg !1226 {
top:
  %newstruct14 = alloca [4 x [1 x i64]], align 8
  %newstruct133 = alloca [4 x [1 x i64]], align 8
  %1 = call {}*** @julia.get_pgcstack() #22
  %current_task1161 = getelementptr inbounds {}**, {}*** %1, i64 -14
  %current_task1 = bitcast {}*** %current_task1161 to {}**
  %ptls_field162 = getelementptr inbounds {}**, {}*** %1, i64 2
  %2 = bitcast {}*** %ptls_field162 to i64***
  %ptls_load163164 = load i64**, i64*** %2, align 8, !tbaa !19
  %3 = getelementptr inbounds i64*, i64** %ptls_load163164, i64 2
  %safepoint = load i64*, i64** %3, align 8, !tbaa !23
  fence syncscope("singlethread") seq_cst
  call void @julia.safepoint(i64* %safepoint) #22, !dbg !1227
  fence syncscope("singlethread") seq_cst
  %4 = addrspacecast {} addrspace(10)* %0 to {} addrspace(10)* addrspace(11)*, !dbg !1228
  %arraysize_ptr = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 3, !dbg !1228
  %5 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr to i64 addrspace(11)*, !dbg !1228
  %arraysize = load i64, i64 addrspace(11)* %5, align 8, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr2 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 4, !dbg !1228
  %6 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr2 to i64 addrspace(11)*, !dbg !1228
  %arraysize3 = load i64, i64 addrspace(11)* %6, align 16, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr4 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 5, !dbg !1228
  %7 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr4 to i64 addrspace(11)*, !dbg !1228
  %arraysize5 = load i64, i64 addrspace(11)* %7, align 8, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr6 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %4, i64 6, !dbg !1228
  %8 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr6 to i64 addrspace(11)*, !dbg !1228
  %arraysize7 = load i64, i64 addrspace(11)* %8, align 16, !dbg !1228, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %9 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 1, i64 0, !dbg !1237
  store i64 %arraysize3, i64* %9, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %10 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 2, i64 0, !dbg !1237
  store i64 %arraysize5, i64* %10, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %11 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 3, i64 0, !dbg !1237
  store i64 %arraysize7, i64* %11, align 8, !dbg !1237, !tbaa !149, !alias.scope !151, !noalias !1242
  %memcpy_refined_dst = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct14, i64 0, i64 0, i64 0, !dbg !1241
  store i64 %arraysize, i64* %memcpy_refined_dst, align 8, !dbg !1241, !tbaa !149, !alias.scope !151, !noalias !1242
  %box = call noalias nonnull dereferenceable(32) "enzyme_inactive" {} addrspace(10)* @julia.gc_alloc_obj({}** nonnull %current_task1, i64 noundef 32, {} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 131621933207328 to {}*) to {} addrspace(10)*)) #23, !dbg !1245
  %12 = bitcast {} addrspace(10)* %box to i8 addrspace(10)*, !dbg !1245
  %newstruct15.sroa.0.0..sroa_cast = bitcast {} addrspace(10)* %box to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize, i64 addrspace(10)* %newstruct15.sroa.0.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.2.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 8, !dbg !1245
  %newstruct15.sroa.2.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.2.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize3, i64 addrspace(10)* %newstruct15.sroa.2.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.3.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 16, !dbg !1245
  %newstruct15.sroa.3.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.3.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize5, i64 addrspace(10)* %newstruct15.sroa.3.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %newstruct15.sroa.4.0..sroa_idx = getelementptr inbounds i8, i8 addrspace(10)* %12, i64 24, !dbg !1245
  %newstruct15.sroa.4.0..sroa_cast = bitcast i8 addrspace(10)* %newstruct15.sroa.4.0..sroa_idx to i64 addrspace(10)*, !dbg !1245
  store i64 %arraysize7, i64 addrspace(10)* %newstruct15.sroa.4.0..sroa_cast, align 8, !dbg !1245, !tbaa !165, !alias.scope !166, !noalias !1252
  %13 = call noalias nonnull {} addrspace(10)* @ijl_new_array({} addrspace(10)* noundef addrspacecast ({}* inttoptr (i64 131622087272080 to {}*) to {} addrspace(10)*), {} addrspace(10)* noundef nonnull %box) #24, !dbg !1245
  %14 = addrspacecast {} addrspace(10)* %13 to {} addrspace(10)* addrspace(11)*, !dbg !1253
  %arraysize_ptr17 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 3, !dbg !1253
  %15 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr17 to i64 addrspace(11)*, !dbg !1253
  %arraysize18 = load i64, i64 addrspace(11)* %15, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr19 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 4, !dbg !1253
  %16 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr19 to i64 addrspace(11)*, !dbg !1253
  %arraysize20 = load i64, i64 addrspace(11)* %16, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr21 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 5, !dbg !1253
  %17 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr21 to i64 addrspace(11)*, !dbg !1253
  %arraysize22 = load i64, i64 addrspace(11)* %17, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %arraysize_ptr23 = getelementptr inbounds {} addrspace(10)*, {} addrspace(10)* addrspace(11)* %14, i64 6, !dbg !1253
  %18 = bitcast {} addrspace(10)* addrspace(11)* %arraysize_ptr23 to i64 addrspace(11)*, !dbg !1253
  %arraysize24 = load i64, i64 addrspace(11)* %18, align 8, !dbg !1253, !tbaa !23, !range !44, !alias.scope !45, !noalias !48
  %.not = icmp ne i64 %arraysize18, %arraysize, !dbg !1261
  %.not171 = icmp ne i64 %arraysize20, %arraysize3
  %or.cond = select i1 %.not, i1 true, i1 %.not171, !dbg !1265
  %.not172 = icmp ne i64 %arraysize22, %arraysize5
  %or.cond173 = select i1 %or.cond, i1 true, i1 %.not172, !dbg !1265
  %19 = icmp ne i64 %arraysize24, %arraysize7
  %or.cond176 = select i1 %or.cond173, i1 true, i1 %19, !dbg !1265
  br i1 %or.cond176, label %L250, label %L58, !dbg !1265

L58:                                              ; preds = %top
  %20 = icmp eq i64 %arraysize7, 1, !dbg !1266
  %21 = icmp eq i64 %arraysize5, 1, !dbg !1279
  %22 = icmp eq i64 %arraysize3, 1, !dbg !1282
  %23 = icmp eq i64 %arraysize, 1, !dbg !1285
  %24 = icmp ne i64 %arraysize3, 0, !dbg !1288
  %25 = icmp ne i64 %arraysize5, 0, !dbg !1294
  %26 = icmp ne i64 %arraysize7, 0, !dbg !1294
  %27 = and i1 %24, %25, !dbg !1297
  %28 = and i1 %27, %26, !dbg !1297
  br i1 %28, label %L126.preheader, label %L272, !dbg !1292

L126.preheader:                                   ; preds = %L58
  %.not166 = icmp eq i64 %arraysize, 0
  %29 = addrspacecast {} addrspace(10)* %0 to float addrspace(13)* addrspace(11)*
  %30 = addrspacecast {} addrspace(10)* %13 to float addrspace(13)* addrspace(11)*
  br label %L126.outer, !dbg !1299

L126.outer:                                       ; preds = %L226, %L126.preheader
  %iv = phi i64 [ %iv.next, %L226 ], [ 0, %L126.preheader ]
  %iv.next = add nuw nsw i64 %iv, 1
  %value_phi56.op = add nsw i64 %iv.next, -1
  %31 = select i1 %20, i64 0, i64 %value_phi56.op
  %32 = mul i64 %31, %arraysize5
  %arrayptr168 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %29, align 16
  %33 = mul i64 %value_phi56.op, %arraysize5
  br label %L126, !dbg !1299

L126:                                             ; preds = %.thread, %L126.outer
  %iv1 = phi i64 [ %iv.next2, %.thread ], [ 0, %L126.outer ]
  %value_phi54 = phi i64 [ %value_phi92.ph, %.thread ], [ 1, %L126.outer ]
  %value_phi55 = phi i64 [ %value_phi93.ph, %.thread ], [ 1, %L126.outer ]
  %iv.next2 = add nuw nsw i64 %iv1, 1, !dbg !1299
  br i1 %.not166, label %L192, label %L141.lr.ph, !dbg !1299

L141.lr.ph:                                       ; preds = %L126
  %value_phi54.op = add nsw i64 %value_phi54, -1
  %34 = select i1 %22, i64 0, i64 %value_phi54.op
  %value_phi55.op = add i64 %value_phi55, -1
  %35 = select i1 %21, i64 0, i64 %value_phi55.op
  %reass.add = add i64 %35, %32
  %reass.mul = mul i64 %reass.add, %arraysize3
  %reass.add201 = add i64 %reass.mul, %34
  %reass.mul202 = mul i64 %reass.add201, %arraysize
  %reass.add199 = add i64 %value_phi55.op, %33
  %reass.mul200 = mul i64 %reass.add199, %arraysize3
  %reass.add203 = add i64 %reass.mul200, %value_phi54.op
  %reass.mul204 = mul i64 %reass.add203, %arraysize
  br label %L141, !dbg !1300

L141:                                             ; preds = %L141, %L141.lr.ph
  %iv3 = phi i64 [ %iv.next4, %L141 ], [ 0, %L141.lr.ph ]
  %iv.next4 = add nuw nsw i64 %iv3, 1, !dbg !1301
  %36 = select i1 %23, i64 0, i64 %iv3, !dbg !1304
  %37 = add i64 %36, %reass.mul202, !dbg !1304
  %38 = getelementptr inbounds float, float addrspace(13)* %arrayptr168, i64 %37, !dbg !1304
  %arrayref = load float, float addrspace(13)* %38, align 4, !dbg !1304, !tbaa !77, !alias.scope !80, !noalias !81
  %39 = call float @julia_sin_1489(float %arrayref) #22, !dbg !1312
  %40 = add i64 %iv3, %reass.mul204, !dbg !1314
  %arrayptr88169 = load float addrspace(13)*, float addrspace(13)* addrspace(11)* %30, align 8, !dbg !1314, !tbaa !23, !alias.scope !1316, !noalias !48, !nonnull !18
  %41 = getelementptr inbounds float, float addrspace(13)* %arrayptr88169, i64 %40, !dbg !1314
  store float %39, float addrspace(13)* %41, align 4, !dbg !1314, !tbaa !77, !alias.scope !80, !noalias !1317
  %exitcond.not = icmp eq i64 %iv.next4, %arraysize, !dbg !1318
  br i1 %exitcond.not, label %L192.loopexit, label %L141, !dbg !1300, !llvm.loop !1319

L192.loopexit:                                    ; preds = %L141
  br label %L192, !dbg !1320

L192:                                             ; preds = %L192.loopexit, %L126
  %42 = add i64 %value_phi54, 1, !dbg !1320
  %43 = icmp ugt i64 %value_phi54, 9223372036854775806, !dbg !1324
  %44 = icmp sgt i64 %42, %arraysize3, !dbg !1324
  %45 = or i1 %43, %44, !dbg !1327
  %46 = icmp eq i64 %value_phi54, %arraysize3
  %or.cond174 = or i1 %46, %45, !dbg !1327
  br i1 %or.cond174, label %L201, label %.thread, !dbg !1327

L201:                                             ; preds = %L192
  %47 = add i64 %value_phi55, 1, !dbg !1328
  %48 = icmp ugt i64 %value_phi55, 9223372036854775806, !dbg !1331
  %49 = icmp sgt i64 %47, %arraysize5, !dbg !1331
  %50 = or i1 %48, %49, !dbg !1334
  %51 = icmp eq i64 %value_phi55, %arraysize5
  %or.cond175 = or i1 %51, %50, !dbg !1334
  br i1 %or.cond175, label %L226, label %.thread, !dbg !1334

.thread:                                          ; preds = %L201, %L192
  %value_phi92.ph = phi i64 [ 1, %L201 ], [ %42, %L192 ]
  %value_phi93.ph = phi i64 [ %47, %L201 ], [ %value_phi55, %L192 ]
  br label %L126, !dbg !1323

L226:                                             ; preds = %L201
  %52 = add nuw nsw i64 %iv.next, 1, !dbg !1335
  %exitcond207.not = icmp eq i64 %iv.next, %arraysize7, !dbg !1338
  br i1 %exitcond207.not, label %L272.loopexit, label %L126.outer, !dbg !1323

L250:                                             ; preds = %top
  %53 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 0, i64 0, !dbg !1339
  store i64 %arraysize18, i64* %53, align 8, !dbg !1339, !tbaa !149, !alias.scope !151, !noalias !1242
  %54 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 1, i64 0, !dbg !1343
  store i64 %arraysize20, i64* %54, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %55 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 2, i64 0, !dbg !1343
  store i64 %arraysize22, i64* %55, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %56 = getelementptr inbounds [4 x [1 x i64]], [4 x [1 x i64]]* %newstruct133, i64 0, i64 3, i64 0, !dbg !1343
  store i64 %arraysize24, i64* %56, align 8, !dbg !1343, !tbaa !149, !alias.scope !151, !noalias !1242
  %57 = addrspacecast [4 x [1 x i64]]* %newstruct133 to [4 x [1 x i64]] addrspace(11)*, !dbg !1259
  %58 = addrspacecast [4 x [1 x i64]]* %newstruct14 to [4 x [1 x i64]] addrspace(11)*, !dbg !1259
  call fastcc void @julia_throwdm_1475([4 x [1 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %57, [4 x [1 x i64]] addrspace(11)* nocapture nofree noundef nonnull readonly align 8 dereferenceable(32) %58) #25, !dbg !1259
  unreachable, !dbg !1259

L272.loopexit:                                    ; preds = %L226
  br label %L272, !dbg !1347

L272:                                             ; preds = %L272.loopexit, %L58
  %59 = call fastcc float @julia__mapreduce_1477({} addrspace(10)* noalias nocapture nofree noundef nonnull readonly align 16 dereferenceable(40) %13) #22, !dbg !1347
  ret float %59, !dbg !1347
}

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

This is post my fix btw^

Post fix timings:


julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  927.718 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  386.687 μs (38 allocations: 641.19 KiB)

The bigger issue rn imo is the fact that loop bounds aren't statically inferrable due to whatever that awful index math is. So as a result inner loops are caching the true iteration count, inside other loops, doing a bunch of unnecessary caching/etc. That's still not fixed.

Pre my fix timings are slower so minor fix does ~something for perf, but again index math is likely root cause. At least others won't have to deal with runtime activity though.

julia> Enzyme.gradient(Reverse, gelu_act, x);

julia> @btime Enzyme.gradient(Reverse, $gelu_act, $x);
  967.698 μs (8 allocations: 384.28 KiB)

julia> @btime Zygote.gradient($gelu_act, $x)
  377.638 μs (38 allocations: 641.19 KiB)

@wsmoses
Copy link
Member

wsmoses commented May 14, 2024

Module pre optimization:
preopt.ll.txt

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants