Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 62 additions & 11 deletions deps/build_local.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,46 @@ Reactant_jll = Base.UUID("0192cb87-2b54-54ad-80e0-3be72ad8a3c0")

using ArgParse

using Libdl

# adapted from `cudaRuntimeGetVersion` in CUDA_Runtime_jll
function cuDriverGetVersion(library_handle)
function_handle = Libdl.dlsym(library_handle, "cuDriverGetVersion"; throw_error=false)
if function_handle === nothing
@debug "CUDA Driver library seems invalid (does not contain 'cuDriverGetVersion')"
return nothing
end
version_ref = Ref{Cint}()
status = ccall(function_handle, Cint, (Ptr{Cint},), version_ref)
if status != 0
@debug "Call to 'cuDriverGetVersion' failed with status $(status)"
return nothing
end
major, ver = divrem(version_ref[], 1000)
minor, patch = divrem(ver, 10)
version = VersionNumber(major, minor, patch)
@debug "Detected CUDA Driver version $(version)"
return version
end

function get_cuda_version()
cuname = if Sys.iswindows()
Libdl.find_library("nvcuda")
else
Libdl.find_library(["libcuda.so.1", "libcuda.so"])
end

if cuname == ""
return nothing
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
return nothing
return nothing

end

handle = Libdl.dlopen(cuname)
current_cuda_version = cuDriverGetVersion(handle)
path = Libdl.dlpath(handle)
Libdl.dlclose(handle)
return current_cuda_version
end

s = ArgParseSettings()
#! format: off
@add_arg_table! s begin
Expand All @@ -21,7 +61,7 @@ s = ArgParseSettings()
default = something(Sys.which("gcc"), "/usr/bin/gcc")
arg_type = String
"--cc"
default = something(Sys.which("cc"), Sys.which("gcc"), Sys.which("clang"), "/usr/bin/cc")
default = something(Sys.which("clang"), Sys.which("cc"), Sys.which("gcc"), "/usr/bin/cc")
arg_type = String
"--hermetic_python_version"
help = "Hermetic Python version."
Expand Down Expand Up @@ -78,21 +118,32 @@ source_dir = joinpath(@__DIR__, "ReactantExtra")
build_kind = parsed_args["debug"] ? "dbg" : "opt"

build_backend = parsed_args["backend"]
@assert build_backend in ("auto", "cpu", "cuda")

if build_backend == "auto"
build_backend = try
run(Cmd(`nvidia-smi`))
"cuda"
catch
"cpu"

if build_backend == "auto" || build_backend == "cuda"
cuda_ver = get_cuda_version()
@show cuda_ver
if cuda_ver === nothing
if build_backend == "cuda"
throw(AssertionError("Could not detect cuda version, but requested cuda with auto version build"))
end
build_backend = "cpu"
Comment on lines +126 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if build_backend == "cuda"
throw(AssertionError("Could not detect cuda version, but requested cuda with auto version build"))
end
build_backend = "cpu"
if build_backend == "cuda"
throw(
AssertionError(
"Could not detect cuda version, but requested cuda with auto version build",
),
)
end
build_backend = "cpu"

else
if Int(get_cuda_version().major) == 13
build_backend = "cuda13"
else
build_backend = "cuda12"
end
Comment on lines +131 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[JuliaFormatter] reported by reviewdog 🐶

Suggested change
if Int(get_cuda_version().major) == 13
build_backend = "cuda13"
else
build_backend = "cuda12"
end
if Int(get_cuda_version().major) == 13
build_backend = "cuda13"
else
build_backend = "cuda12"
end

end
end

arg = if build_backend == "cuda"
"--config=cuda"
arg = if build_backend == "cuda12"
"--config=cuda12"
elseif build_backend == "cuda13"
"--config=cuda13"
elseif build_backend == "cpu"
""
else
throw(AssertionError("Unknown backend `$build_backend`"))
end

bazel_cmd = if !isnothing(Sys.which("bazelisk"))
Expand Down
Loading