diff --git a/deps/build_local.jl b/deps/build_local.jl index e62a2624ca..6f2eab3f46 100644 --- a/deps/build_local.jl +++ b/deps/build_local.jl @@ -6,6 +6,46 @@ Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0") using ArgParse +using Libdl + +# adapted from `cudaRuntimeGetVersion` in CUDA_Runtime_jll +function cuDriverGetVersion(library_handle) + function_handle = Libdl.dlsym(library_handle, "cuDriverGetVersion"; throw_error=false) + if function_handle === nothing + @debug "CUDA Driver library seems invalid (does not contain 'cuDriverGetVersion')" + return nothing + end + version_ref = Ref{Cint}() + status = ccall(function_handle, Cint, (Ptr{Cint},), version_ref) + if status != 0 + @debug "Call to 'cuDriverGetVersion' failed with status $(status)" + return nothing + end + major, ver = divrem(version_ref[], 1000) + minor, patch = divrem(ver, 10) + version = VersionNumber(major, minor, patch) + @debug "Detected CUDA Driver version $(version)" + return version +end + +function get_cuda_version() + cuname = if Sys.iswindows() + Libdl.find_library("nvcuda") + else + Libdl.find_library(["libcuda.so.1", "libcuda.so"]) + end + + if cuname == "" + return nothing + end + + handle = Libdl.dlopen(cuname) + current_cuda_version = cuDriverGetVersion(handle) + path = Libdl.dlpath(handle) + Libdl.dlclose(handle) + return current_cuda_version +end + s = ArgParseSettings() #! format: off @add_arg_table! s begin @@ -21,7 +61,7 @@ s = ArgParseSettings() default = something(Sys.which("gcc"), "/usr/bin/gcc") arg_type = String "--cc" - default = something(Sys.which("cc"), Sys.which("gcc"), Sys.which("clang"), "/usr/bin/cc") + default = something(Sys.which("clang"), Sys.which("cc"), Sys.which("gcc"), "/usr/bin/cc") arg_type = String "--hermetic_python_version" help = "Hermetic Python version." @@ -78,21 +118,32 @@ source_dir = joinpath(@__DIR__, "ReactantExtra") build_kind = parsed_args["debug"] ? "dbg" : "opt" build_backend = parsed_args["backend"] -@assert build_backend in ("auto", "cpu", "cuda") - -if build_backend == "auto" - build_backend = try - run(Cmd(`nvidia-smi`)) - "cuda" - catch - "cpu" + +if build_backend == "auto" || build_backend == "cuda" + cuda_ver = get_cuda_version() + @show cuda_ver + if cuda_ver === nothing + if build_backend == "cuda" + throw(AssertionError("Could not detect cuda version, but requested cuda with auto version build")) + end + build_backend = "cpu" + else + if Int(get_cuda_version().major) == 13 + build_backend = "cuda13" + else + build_backend = "cuda12" + end end end -arg = if build_backend == "cuda" - "--config=cuda" +arg = if build_backend == "cuda12" + "--config=cuda12" +elseif build_backend == "cuda13" + "--config=cuda13" elseif build_backend == "cpu" "" +else + throw(AssertionError("Unknown backend `$build_backend`")) end bazel_cmd = if !isnothing(Sys.which("bazelisk"))