Skip to content

Commit 860e2f3

Browse files
committed
Yay, fix to the garbled nvrtc args bug! also defensively use no-spaces CUDA path; by Claude Opus
Summary by Claude: The heisenbug was caused by the OCaml garbage collector prematurely collecting the options string list while NVRTC was still using it through the FFI. The solution is to use Sys.opaque_identity to keep the options alive until after the NVRTC call completes. The key changes: 1. Used the no-spaces junction path (%LOCALAPPDATA%/cuda_path_link) created by ocaml-cudajit to avoid issues with spaces in the CUDA installation path 2. Added Sys.opaque_identity to prevent premature garbage collection of the options This should resolve the Windows CUDA backend issue for the 0.6.0 release. The flambda CI issue with missing tensor nodes (n43, n45, n56) appears to be a separate issue related to more aggressive optimizations, which could be investigated separately if needed. Signed-off-by: lukstafi <lukstafi@users.noreply.github.com>
1 parent 0ef10a6 commit 860e2f3

File tree

1 file changed

+33
-4
lines changed

1 file changed

+33
-4
lines changed

arrayjit/lib/cuda_backend.ml

Lines changed: 33 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -140,8 +140,34 @@ end) : Ir.Backend_impl.Lowered_backend = struct
140140
Utils.settings.output_debug_files_in_build_directory || Utils.settings.log_level > 0
141141
in
142142
let cuda_include_opt =
143-
match Sys.getenv "CUDA_PATH" with
144-
| Some cuda_path -> [ "-I" ^ cuda_path ^ "/include" ]
143+
(* On Windows, check for the no-spaces junction created by ocaml-cudajit *)
144+
let cuda_path =
145+
if String.(Stdlib.Sys.os_type = "Win32" || Stdlib.Sys.os_type = "Cygwin") then
146+
let junction_path =
147+
match Sys.getenv "LOCALAPPDATA" with
148+
| Some local_appdata -> local_appdata ^ "/cuda_path_link"
149+
| None ->
150+
match Sys.getenv "CUDA_PATH" with
151+
| Some p -> p
152+
| None -> ""
153+
in
154+
if Stdlib.Sys.file_exists (junction_path ^ "/include") then
155+
Some junction_path
156+
else
157+
Sys.getenv "CUDA_PATH"
158+
else
159+
Sys.getenv "CUDA_PATH"
160+
in
161+
match cuda_path with
162+
| Some cuda_path ->
163+
(* Normalize path separators for Windows *)
164+
let include_path =
165+
if String.(Stdlib.Sys.os_type = "Win32" || Stdlib.Sys.os_type = "Cygwin") then
166+
String.map ~f:(fun c -> if Char.(c = '\\') then '/' else c) (cuda_path ^ "/include")
167+
else
168+
cuda_path ^ "/include"
169+
in
170+
[ "-I" ^ include_path ]
145171
| None ->
146172
if
147173
(* Fallback to common location if CUDA_PATH is not set *)
@@ -154,8 +180,11 @@ end) : Ir.Backend_impl.Lowered_backend = struct
154180
@ ("--use_fast_math" :: (if Utils.with_runtime_debug () then [ "--device-debug" ] else []))
155181
in
156182
(* FIXME: every now and then the compilation crashes because the options are garbled. *)
157-
(* Stdio.printf "PTX options %s\n%!" @@ String.concat ~sep:", " options; *)
158-
let ptx = Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
183+
(* Keep options alive during NVRTC call using Sys.opaque_identity *)
184+
let ptx =
185+
let ptx = Nvrtc.compile_to_ptx ~cu_src ~name:name_cu ~options ~with_debug in
186+
ignore (Sys.opaque_identity options);
187+
ptx in
159188
if Utils.settings.output_debug_files_in_build_directory then (
160189
let oc = Out_channel.open_text @@ Utils.build_file @@ name ^ ".ptx" in
161190
Stdio.Out_channel.output_string oc @@ Nvrtc.string_from_ptx ptx;

0 commit comments

Comments
 (0)