Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specialize broadcast to avoid integer divisions. #304

Merged
merged 5 commits into from
Mar 8, 2024
Merged

Conversation

maleadt
Copy link
Member

@maleadt maleadt commented Mar 6, 2024

By using hardware 2d/3d indices whenever possible.

Benchmark script:

using Metal
using Chairmarks

function memcopy(output_data::AbstractArray{T}, input_data::AbstractArray{T}) where T
    i = thread_position_in_grid_1d()
    if 1 <= i <= length(input_data)
        @inbounds output_data[i] = input_data[i]
    end
    return
end

function test(dims; T=Float32)
    cpu_in = rand(T, dims)
    gpu_in = MtlArray(cpu_in)
    gpu_out = similar(gpu_in)
    print(join(dims, "×"), "-element $T (", Base.format_bytes(sizeof(cpu_in)), "): ")

    # verify results
    gpu_out .= gpu_in
    @assert Array(gpu_in) == Array(gpu_out)

    # reference
    threads = 256
    groups = cld(length(cpu_in), threads)
    bench = @b Metal.@sync @metal threads=threads groups=groups memcopy(gpu_out, gpu_in)
    reference = Base.format_bytes(2*sizeof(gpu_in) / bench.time) * "/s"

    # broadcast
    bench = @b Metal.@sync (gpu_out .= gpu_in; nothing)
    speed = Base.format_bytes(2*sizeof(gpu_in) / bench.time) * "/s"

    println(speed, " / ", reference)

    return
end

function main()
    for dims in [(2^28,), (2^14, 2^14), (2^8, 2^8, 2^8),
                 (1, 2^28), (2^28,1),
                 (1, 2^14, 2^14), (2^14, 1, 2^14), (2^14, 2^14, 1),
                 (1, 1, 2^28), (1, 2^28, 1), (2^28, 1, 1)]
        test(dims)
    end
end

Before (on an M3 Pro with max 150GB/s bandwidth):

268435456-element Float32 (1024.000 MiB): 125.527 GiB/s / 126.282 GiB/s
16384×16384-element Float32 (1024.000 MiB): 31.089 GiB/s / 126.004 GiB/s
256×256×256-element Float32 (64.000 MiB): 17.313 GiB/s / 108.182 GiB/s
1×268435456-element Float32 (1024.000 MiB): 121.685 GiB/s / 125.155 GiB/s
268435456×1-element Float32 (1024.000 MiB): 121.661 GiB/s / 125.648 GiB/s
1×16384×16384-element Float32 (1024.000 MiB): 26.239 GiB/s / 125.158 GiB/s
16384×1×16384-element Float32 (1024.000 MiB): 26.102 GiB/s / 125.538 GiB/s
16384×16384×1-element Float32 (1024.000 MiB): 26.079 GiB/s / 125.631 GiB/s
1×1×268435456-element Float32 (1024.000 MiB): 72.273 GiB/s / 124.147 GiB/s
1×268435456×1-element Float32 (1024.000 MiB): 72.466 GiB/s / 125.649 GiB/s
268435456×1×1-element Float32 (1024.000 MiB): 72.353 GiB/s / 125.043 GiB/s

After:

268435456-element Float32 (1024.000 MiB): 122.613 GiB/s / 125.370 GiB/s
16384×16384-element Float32 (1024.000 MiB): 121.427 GiB/s / 125.990 GiB/s
256×256×256-element Float32 (64.000 MiB): 104.428 GiB/s / 107.949 GiB/s
1×268435456-element Float32 (1024.000 MiB): 125.704 GiB/s / 125.532 GiB/s
268435456×1-element Float32 (1024.000 MiB): 78.905 GiB/s / 125.298 GiB/s
1×16384×16384-element Float32 (1024.000 MiB): 123.495 GiB/s / 126.113 GiB/s
16384×1×16384-element Float32 (1024.000 MiB): 118.807 GiB/s / 125.526 GiB/s
16384×16384×1-element Float32 (1024.000 MiB): 120.582 GiB/s / 125.738 GiB/s
1×1×268435456-element Float32 (1024.000 MiB): 122.905 GiB/s / 125.729 GiB/s
1×268435456×1-element Float32 (1024.000 MiB): 80.628 GiB/s / 125.679 GiB/s
268435456×1×1-element Float32 (1024.000 MiB): 78.646 GiB/s / 125.736 GiB/s

Fixes #41

By using hardware 2d/3d indices whenever possible.
@maleadt maleadt added performance Gotta go fast. arrays Things about the array abstraction. labels Mar 6, 2024
@maleadt
Copy link
Member Author

maleadt commented Mar 6, 2024

Looks like a validator crash I'll need to reduce.

In addition to the changes in here, we could take the static CartesianIndex I had put in JuliaGPU/GPUArrays.jl#520 and use that here to specialize the cartesian fallback. However, as noted in JuliaGPU/GPUArrays.jl#520 (comment) some applications use broadcast operations with many different shapes. Maybe we should switch to a specialized version after the cartesian one has been executed for, say, 5 times?

@maleadt
Copy link
Member Author

maleadt commented Mar 7, 2024

Reduced the validator failure in #308

@maleadt maleadt merged commit c556832 into main Mar 8, 2024
1 check passed
@maleadt maleadt deleted the tb/broadcast_nd branch March 8, 2024 11:07
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
arrays Things about the array abstraction. performance Gotta go fast.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

slow broadcast copy in 2D
1 participant