Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid cartesian iteration where possible. #454

Merged
merged 3 commits into from Feb 20, 2023
Merged

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Feb 14, 2023

@vchuravy Does KA.jl handle iterator selection better?

@maleadt
Copy link
Member Author

maleadt commented Feb 14, 2023

This doesn't seem to help much, as most broadcast objects require cartesian indexing:

julia> IndexStyle(Broadcast.instantiate(Broadcast.broadcasted(+, zeros(5), 5*ones(1, 4))))
IndexCartesian()

An alternative is to provide the CartesianIndices as a const, which makes the div denominators constant, but LLVM isn't actually able to optimize them away:

@inline function _copyto!(dest::AbstractArray, bc::Broadcasted)
    axes(dest) == axes(bc) || Broadcast.throwdm(axes(dest), axes(bc))
    isempty(dest) && return dest
    bc′ = Broadcast.preprocess(dest, bc)

    # grid-stride kernel
    function broadcast_kernel(ctx, dest, ::Val{idx}, bc′, nelem) where {idx}
        i = 0
        while i < nelem
            i += 1

            # the CartesianIndices are passed as a constant value,
            # to prevent expensive integer divisions on non-constant values
            j = @linearidx(dest, i)
            J = @inbounds idx[j]
            @inbounds dest[j] = bc′[J]
        end
        return
    end
    elements = length(dest)
    elements_per_thread = typemax(Int)
    idx = CartesianIndices(dest)
    heuristic = launch_heuristic(backend(dest), broadcast_kernel, dest, Val(idx), bc′, 1;
                                 elements, elements_per_thread)
    config = launch_configuration(backend(dest), heuristic;
                                  elements, elements_per_thread)
    gpu_call(broadcast_kernel, dest, Val(idx), bc′, config.elements_per_thread;
             threads=config.threads, blocks=config.blocks)

    return dest
end
pass:                                             ; preds = %L5
  %51 = add i64 %50, -1
  %52 = sdiv i64 %51, 2
  %.neg = mul i64 %52, -2
  %53 = add i64 %.neg, %50
  %54 = mul i64 %12, %52
  %55 = add i64 %53, -1
  %56 = add i64 %55, %54
  %57 = getelementptr inbounds float, float addrspace(1)* %13, i64 %56
  %58 = load float, float addrspace(1)* %57, align 4
  %59 = getelementptr inbounds float, float addrspace(1)* %.unpack9, i64 %51
  store float %58, float addrspace(1)* %59, align 4
  %.not = icmp slt i64 %48, %4
  br i1 %.not, label %L5, label %common.ret
}

... instead of

pass.us:                                          ; preds = %L5.us
  %32 = add i64 %31, -1
  %33 = sdiv i64 %32, %.unpack5.unpack
  %34 = mul i64 %33, %.unpack5.unpack
  %35 = sub i64 %32, %34
  %36 = add i64 %33, 1
  %37 = select i1 %.not11, i64 %.fca.0.0.2.1.extract, i64 %36
  %38 = add i64 %37, -1
  %39 = mul i64 %13, %38
  %40 = add i64 %35, %39
  %41 = getelementptr inbounds float, float addrspace(1)* %14, i64 %40
  %42 = load float, float addrspace(1)* %41, align 4
  %43 = getelementptr inbounds float, float addrspace(1)* %.unpack9, i64 %32
  store float %42, float addrspace(1)* %43, align 4
  %.not.us = icmp slt i64 %29, %4
  br i1 %.not.us, label %L5.us, label %common.ret

common.ret:                                       ; preds = %pass.us, %L5.us, %pass.us.us, %L5.us.us, %L5.lr.ph.L5.lr.ph.split_crit_edge, %conversion
  ret void

fail:                                             ; preds = %L5.lr.ph.L5.lr.ph.split_crit_edge
  call fastcc void @gpu_report_exception() #2
  call fastcc void @gpu_signal_exception() #2
  call void @llvm.trap()
  unreachable
}

FWIW, a driver script to test this:

using Metal

function kernel_copy!(a, b)
    (i,j) = thread_position_in_grid_2d()
    @inbounds a[i,j] = b[i,j]
    return
end

function benchmark(n=2^14, nsample=10)
    test(n)

    function measure(f, name)
        a = MtlArray(rand(Float32, n,n))
        b = similar(a)

        ts = zeros(nsample)
        for i  1:nsample
            ts[i] = @elapsed Metal.@sync begin
                f(a, b)
            end
        end

        tmin = minimum(ts)

        size_in_bytes = 2*length(a)*sizeof(Float32) #1R+1W
        byte_per_ns = size_in_bytes / (tmin*1.e9)

        println("$name performance: $(round(byte_per_ns; digits=3)) GB/s")
    end

    threads = (32,32)
    grid_size = cld.(n, threads)
    measure("kernel") do a, b
        @metal threads=threads grid=grid_size kernel_copy!(a, b)
    end

    measure("broadcast") do a, b
        a .= b
    end
end

function test(n=2^14)
    a = MtlArray(rand(Float32, n,n))
    b = similar(a)

    threads = (32,32)
    grid_size = cld.(n, threads)
    @metal threads=threads grid=grid_size kernel_copy!(a, b)
    @assert Array(a) == Array(b)

    b = similar(a)
    a .= b
    @assert Array(a) == Array(b)
end

function codegen()
    a = MtlArray(rand(Float32, 2, 2))
    b = MtlArray(rand(Float32, 2, 2))

    #@device_code_llvm debuginfo=:none @metal kernel_copy!(a, b)

    @device_code_llvm debuginfo=:none a .= b
end

On my M1 Pro, the kernel gives about 180 GB/s, current broadcast does 10GB/s, the optimized one here does 24GB/s.

@maleadt
Copy link
Member Author

maleadt commented Feb 14, 2023

... and (incorrectly) forcing the broadcast to use linear indexing all the way brings performance to 180GB/s. So the problem still is the cartesian indexing.

EDIT: but passing a volatile CartesianIndex to the kernel is fast, so the problem is squarly with the indexing of the CartesianIndices and not with its use in Broadcast. That's good news. I might have an idea for a fix.

@maleadt
Copy link
Member Author

maleadt commented Feb 15, 2023

EDIT: but passing a volatile CartesianIndex to the kernel is fast, so the problem is squarly with the indexing of the CartesianIndices and not with its use in Broadcast. That's good news. I might have an idea for a fix.

Turns out that wasn't true. I implemented my idea for a fix at https://github.com/maleadt/StaticCartesian.jl, essentially, this not only puts the CartesianIndices iterator in the type domain (exposing the constant divisors to LLVM), but also implements the bit twiddling optimizations I was talking about in Julia. This results in getindex on an iterator only performing simple bit operations, no divisions. Since my experiment yesterday showed that avoiding this getindex (but still using Cartesian indices) was fast, I assumed that this would fix the problem. Alas, performance only improved by 2x, from 10GB/s to 20GB/s, while the hand-written kernel gets to 180GB/s.

In an attempt to improve this, I used the AIR intrinsic for mulhi:

mulhi(x::Int32, y::Int32) = ccall("extern air.mul_hi.s.i32", llvmcall, Int32, (Int32, Int32), x, y)
mulhi(x::UInt32, y::UInt32) = ccall("extern air.mul_hi.u.i32", llvmcall, UInt32, (UInt32, UInt32), x, y)
mulhi(x::Int64, y::Int64) = ccall("extern air.mul_hi.s.i64", llvmcall, Int64, (Int64, Int64), x, y)
mulhi(x::UInt64, y::UInt64) = ccall("extern air.mul_hi.u.i64", llvmcall, UInt64, (UInt64, UInt64), x, y)
@device_override StaticCartesian.mulhi(x::T, y::T) where T <: Union{Int32,UInt32,Int64,UInt64} = mulhi(x, y)

That only brought performance to 30GB/s, still way to low.

Disappointingly, going back and "just" putting the original CartesianIndices iterator in the type domain (thus emitting div instructions by constant divisors) also puts us at around 25GB/s. So maybe Apple's back-end compiler is already doing the division-by-constant to bit twiddling optimization (LLVM definitely isn't)?

However, in trying all this, I noticed something very weird: removing the @inbounds from the CartesianIndices getindex makes the broadcast kernel run at 120GB/s (again, out of aorund 175)... I'm not sure how to explain this. Maybe the checkbounds branches provide some additional static information that the back-end can use? Quickly trying to add that info through LLVM.Interop.assume didn't help.

Or maybe the measurements here are off; I wish we had a decent profiler...

@maleadt
Copy link
Member Author

maleadt commented Feb 15, 2023

Looks like with @inbounds we emit 2 sdiv instructions, while without we get 3 udivs, which presumably the back-end knows how to handle. It's too bad StaticCartesian.jl doesn't fix this then, as this completely gets rid of the divisions.

slow (no bounds-check, sdiv):

;  @ /Users/tim/Julia/pkg/GPUArrays/src/host/broadcast.jl:74 within `broadcast_kernel`
; ┌ @ abstractarray.jl:1241 within `getindex`
; │┌ @ abstractarray.jl:1286 within `_getindex`
; ││┌ @ abstractarray.jl:1293 within `_to_subscript_indices`
; │││┌ @ abstractarray.jl:1315 within `_unsafe_ind2sub`
; ││││┌ @ abstractarray.jl:2639 within `_ind2sub` @ abstractarray.jl:2677
; │││││┌ @ int.jl:86 within `-`
        %31 = add nsw i64 %29, -1
; │││││└
; │││││┌ @ abstractarray.jl:2690 within `_ind2sub_recurse`
; ││││││┌ @ abstractarray.jl:2697 within `_div`
; │││││││┌ @ int.jl:288 within `div`
          %32 = sdiv i64 %31, 16383
; └└└└└└└└

fast (bounds-check, udiv):

;  @ /Users/tim/Julia/pkg/GPUArrays/src/host/broadcast.jl:74 within `broadcast_kernel`
; ┌ @ abstractarray.jl:1241 within `getindex`
; │┌ @ abstractarray.jl:1285 within `_getindex`
; ││┌ @ abstractarray.jl:668 within `checkbounds` @ abstractarray.jl:653
; │││┌ @ abstractarray.jl:727 within `checkindex`
; ││││┌ @ bool.jl:38 within `&`
       %.off.us = add nsw i64 %29, -1
       %31 = icmp ugt i64 %.off.us, 268402688
; │││└└
; │││ @ abstractarray.jl:668 within `checkbounds`
     br i1 %31, label %L52, label %pass.us

pass.us:                                          ; preds = %L25.us
; ││└
; ││ @ abstractarray.jl:1286 within `_getindex`
; ││┌ @ abstractarray.jl:1293 within `_to_subscript_indices`
; │││┌ @ abstractarray.jl:1315 within `_unsafe_ind2sub`
; ││││┌ @ abstractarray.jl:2639 within `_ind2sub` @ abstractarray.jl:2677
; │││││┌ @ abstractarray.jl:2690 within `_ind2sub_recurse`
; ││││││┌ @ abstractarray.jl:2697 within `_div`
; │││││││┌ @ int.jl:288 within `div`
          %.lhs.trunc.us = trunc i64 %.off.us to i32
          %32 = udiv i32 %.lhs.trunc.us, 16383
          %.zext.us = zext i32 %32 to i64
; └└└└└└└└

@maleadt
Copy link
Member Author

maleadt commented Feb 15, 2023

Seems like the div is a red herring, forcing udiv emission using assume(i >= 1) doesn't yield the speed-up I see with @inbounds.

EDIT: ah, adding the upper bound too makes it fast 🚀

@maxwindiff
Copy link
Contributor

That's really cool. Indeed the backend generates bit-twiddling code (obtained using https://github.com/dougallj/applegpu) for these kernels: https://gist.github.com/maxwindiff/a1850531f72c20ff5c922ac3743f2093

kernel void div2(device uint *a)  { *a /= 2; }
kernel void div3(device uint *a)  { *a /= 3; }
kernel void div5(device uint *a)  { *a /= 5; }
kernel void div7(device uint *a)  { *a /= 7; }
kernel void div11(device uint *a) { *a /= 11; }
kernel void div13(device uint *a) { *a /= 13; }

@maleadt
Copy link
Member Author

maleadt commented Feb 16, 2023

Cool, I didn't know about that disassembler! We should integrate that with Metal.jl.

Did you use it with Metal C code, or how did you get a binary dump of the code generated by Julia? I'm working on wrapping MtlBinaryArchive, but that's a fair bit of work, so am wondering if I missed something.

@maxwindiff
Copy link
Contributor

I used Metal C + Xcode. I was thinking of trying to add it to Metal.jl but you are too fast 😂

@maleadt
Copy link
Member Author

maleadt commented Feb 17, 2023

I used Metal C + Xcode.

No need for that anymore: JuliaGPU/Metal.jl#96 🎉

@ToucheSir
Copy link

This may have broken Flux, FluxML/Flux.jl#2214. Will try to create a MWE next week if nobody gets to it first.

@chengchingwen
Copy link
Contributor

It also seems to cause a huge performance issue with Transformers. Haven't have a MWE, but it looks like the kernel is being recompiled over and over again. Most of the time is on cpu, the gpu is barely runned.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants