@@ -606,8 +606,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
606606 | Tanh_approx , Single_prec _ -> func " __tanhf"
607607 | Tanh_approx , _ -> func " tanh"
608608 | Not , _ -> f " (" " == 0.0 ? 1.0 : 0.0)"
609- | Uint4x32_to_prec_uniform , _ ->
610- func (" uint4x32_to_" ^ Ops. prec_string prec ^ " _uniform" )
609+ | Uint4x32_to_prec_uniform , _ -> func (" uint4x32_to_" ^ Ops. prec_string prec ^ " _uniform" )
611610
612611 let ternop_syntax prec v =
613612 let open PPrint in
@@ -657,6 +656,24 @@ end) : Ir.Backend_impl.Lowered_backend = struct
657656 ^^ rparen ^^ semi
658657 end
659658
659+ let builtins_large_header =
660+ {|
661+ __device__ uint4x32_t ( * arrayjit_threefry4x32)(uint4x32_t key, uint4x32_t counter) = nullptr;
662+ | }
663+
664+ let prepend_builtins b =
665+ if Utils. debug_log_from_routines () then
666+ Buffer. add_string b " __device__ int printf (const char * format, ... );\n " ;
667+ Buffer. add_string b " \n\n " ;
668+ let builtins_path =
669+ Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " builtins_small.cu"
670+ in
671+ let builtins_content = Stdio.In_channel. read_all builtins_path in
672+ Buffer. add_string b builtins_content;
673+ (* Needs to be after the small builtins, because uses uint4x32_t. *)
674+ Buffer. add_string b builtins_large_header;
675+ Buffer. add_string b " \n\n "
676+
660677 let % diagn2_sexp compile ~name bindings ({ Low_level. traced_store; _ } as lowered) =
661678 (* TODO: The following link seems to claim it's better to expand into loops than use memset.
662679 https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
@@ -665,8 +682,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
665682 end )) in
666683 let idx_params = Indexing. bound_symbols bindings in
667684 let b = Buffer. create 4096 in
668- if Utils. debug_log_from_routines () then
669- Buffer. add_string b " __device__ int printf (const char * format, ... );\n " ;
685+ prepend_builtins b;
670686 let declarations_doc = Syntax. print_declarations () in
671687 let params, proc_doc = Syntax. compile_proc ~name idx_params lowered in
672688 let final_doc = PPrint. (declarations_doc ^^ proc_doc) in
@@ -680,16 +696,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
680696 end )) in
681697 let idx_params = Indexing. bound_symbols bindings in
682698 let b = Buffer. create 4096 in
683- (* Read and prepend the CUDA builtins file *)
684- let builtins_path =
685- Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " arrayjit_builtins.cu"
686- in
687- (try
688- let builtins_content = Stdio.In_channel. read_all builtins_path in
689- Buffer. add_string b builtins_content;
690- Buffer. add_string b " \n\n "
691- with _ -> () );
692- (* Silently skip if file not found *)
699+ prepend_builtins b;
693700 let declarations_doc = Syntax. print_declarations () in
694701 let params_and_docs =
695702 Array. map2_exn names lowereds
@@ -787,10 +794,29 @@ end) : Ir.Backend_impl.Lowered_backend = struct
787794 Cu.Module. [ GENERATE_DEBUG_INFO true ; GENERATE_LINE_INFO true ]
788795 else []
789796
797+ let set_ptr_in_kernel kernel_module src name =
798+ let dst, _ = Cuda.Module. get_global kernel_module ~name in
799+ (* Copy the helper function address to the kernel's function pointer variable *)
800+ Cuda.Deviceptr. memcpy_D_to_D ~dst ~src ~size_in_bytes: 8 (* pointer size *) ()
801+
802+ let set_builtins_in_kernel =
803+ assert ! initialized;
804+ let builtins_path =
805+ Stdlib.Filename. concat (Stdlib.Filename. dirname Stdlib. __FILE__) " builtins_large.cu"
806+ in
807+ let cu_src = Stdio.In_channel. read_all builtins_path in
808+ let code = cuda_to_ptx ~name: " builtins_large" cu_src in
809+ (* set_ctx ctx; *)
810+ let run_module = Cu.Module. load_data_ex code (run_options () ) in
811+ let threefry4x32_ptr, _ = Cu.Module. get_global run_module ~name: " arrayjit_threefry4x32" in
812+ fun kernel_module ->
813+ set_ptr_in_kernel kernel_module threefry4x32_ptr " arrayjit_threefry4x32"
814+
790815 let % track3_sexp link prior_context (code : code ) ctx_arrays =
791816 let ctx = ctx_of prior_context in
792817 set_ctx ctx;
793818 let run_module = Cu.Module. load_data_ex code.ptx (run_options () ) in
819+ set_builtins_in_kernel run_module;
794820 let idx_params = Indexing. bound_symbols code.bindings in
795821 let lowered_bindings : Indexing.lowered_bindings =
796822 List. map idx_params ~f: (fun s -> (s, ref 0 ))
@@ -809,6 +835,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
809835 let ctx = ctx_of prior_context in
810836 set_ctx ctx;
811837 let run_module = Cu.Module. load_data_ex code_batch.ptx (run_options () ) in
838+ set_builtins_in_kernel run_module;
812839 let procs =
813840 Array. mapi code_batch.params_and_names ~f: (fun i pns ->
814841 Option. value ~default: None
0 commit comments