Skip to content
Merged

Cuv2 #423

Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
5915756
Kernel-supporting jll
wsmoses Dec 17, 2024
1b35e8e
fix rulescc
wsmoses Dec 17, 2024
3f364ca
adapt to hedron dep
wsmoses Dec 17, 2024
2d745c4
init target
wsmoses Dec 17, 2024
2892212
fixup
wsmoses Dec 17, 2024
261b3c2
additional fixups
wsmoses Dec 17, 2024
7ef39a4
fixup
wsmoses Dec 17, 2024
e86af4f
fix
wsmoses Dec 17, 2024
f1d289c
registry utils
wsmoses Dec 17, 2024
802f445
callname
wsmoses Dec 17, 2024
9aefd5e
reg
wsmoses Dec 17, 2024
312ee5b
fix
wsmoses Dec 17, 2024
bd94773
fix bld
wsmoses Dec 17, 2024
ef143c3
cleanup
wsmoses Dec 17, 2024
1be6732
no pip
wsmoses Dec 17, 2024
8e553de
fix
wsmoses Dec 17, 2024
a2c664c
force rules python to older version before bug
wsmoses Dec 17, 2024
e41bb8f
fixup jll
wsmoses Dec 17, 2024
6f92d00
with proto
wsmoses Dec 18, 2024
1787ece
fix
wsmoses Dec 18, 2024
d38eac9
fix
wsmoses Dec 18, 2024
b50c8f1
Update WORKSPACE
wsmoses Dec 18, 2024
4002eff
more deps for apple
wsmoses Dec 18, 2024
65189e2
bump
wsmoses Dec 18, 2024
5c9ef9a
fix
wsmoses Dec 18, 2024
a149e0a
workspace bump
wsmoses Dec 18, 2024
a049cf2
workspace
wsmoses Dec 18, 2024
cc3e5e5
Update Compiler.jl
wsmoses Dec 19, 2024
1116598
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
ab5e575
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
54ae823
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
cc47802
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
ce05591
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
e5ca9dd
Update Project.toml
wsmoses Dec 19, 2024
0a7da97
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
a43ccf0
Update Project.toml
wsmoses Dec 19, 2024
cc411a3
Update Project.toml
wsmoses Dec 19, 2024
0d6260c
fix
Dec 19, 2024
17a2e12
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
31c20e4
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
eca6ee9
Update ReactantCUDAExt.jl
wsmoses Dec 19, 2024
54fcea8
Update cuda.jl
wsmoses Dec 19, 2024
f595178
Update cuda.jl
wsmoses Dec 19, 2024
ac98093
Update cuda.jl
wsmoses Dec 19, 2024
f8d8c95
Cuda kernel v2
Dec 23, 2024
43ace8b
Merge branch 'main' into cuv2
wsmoses Dec 23, 2024
1c71a3b
Update Project.toml
wsmoses Dec 23, 2024
c22c01a
Update API.cpp
wsmoses Dec 23, 2024
fa621fe
Apply suggestions from code review
wsmoses Dec 24, 2024
e9b769f
Merge branch 'main' into cuv2
wsmoses Dec 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,8 @@ function __init__()
end
Reactant.Compiler.cuLaunch[] = Base.reinterpret(UInt, ptr1)
Reactant.Compiler.cuModule[] = Base.reinterpret(UInt, ptr2)
return Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
Reactant.Compiler.cuFunc[] = Base.reinterpret(UInt, ptr3)
return nothing
end

end # module ReactantCUDAExt
2 changes: 1 addition & 1 deletion src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ function compile_mlir!(mod, f, args; optimize::Union{Bool,Symbol}=true)
if isdefined(Reactant_jll, :ptxas_path)
toolkit = Reactant_jll.ptxas_path[1:(end - length("/bin/ptxas"))]
end
kern = "lower-kernel{toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
kern = "lower-kernel{run_init=true toolkitPath=$toolkit cuLaunchKernelPtr=$(cuLaunch[]) cuModuleLoadDataPtr=$(cuModule[]) cuModuleGetFunctionPtr=$(cuFunc[])}"
if optimize === :all
run_pass_pipeline!(mod, join([opt_passes, "enzyme-batch", opt_passes], ","))
run_pass_pipeline!(mod, "enzyme,arith-raise{stablehlo=true}"; enable_verifier=false)
Expand Down
17 changes: 12 additions & 5 deletions test/cuda.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,16 @@ using Reactant
using Test
using CUDA

using Reactant_jll
@show Reactant_jll.libReactantExtra_path

function square_kernel!(x)
i = threadIdx().x
x[i] *= x[i]
#i = threadIdx().x
#x[i] *= x[i]
#@cuprintf("overwrote value of %f was thrown during kernel execution on thread (%d, %d, %d) in block (%d, %d, %d).\n",
# 0.0, threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z)
#x[i], threadIdx().x, threadIdx().y, threadIdx().z, blockIdx().x, blockIdx().y, blockIdx().z)

# sync_threads()
return nothing
end
Expand All @@ -18,9 +25,9 @@ end
@testset "Square Kernel" begin
oA = collect(1:1:64)
A = Reactant.to_rarray(oA)
@show @code_hlo optimize = false square!(A)
@show @code_hlo optimize = :before_kernel square!(A)
@show @code_hlo square!(A)
# @show @code_hlo optimize = false square!(A)
# @show @code_hlo optimize = :before_kernel square!(A)
# @show @code_hlo square!(A)
func! = @compile square!(A)
func!(A)
@show A
Expand Down
Loading