@@ -71,6 +71,16 @@ module type C_syntax_config = sig
7171 implementation should handle quoting [base_message_literal], choosing the log function
7272 (printf, fprintf, os_log), and prepending any necessary prefixes (like a log_id or
7373 captured_log_prefix) to the format string and arguments. *)
74+
75+ val local_heap_alloc :
76+ (zero_initialized :bool ->
77+ num_elems :int ->
78+ typ_doc :PPrint .document ->
79+ ident_doc :PPrint .document ->
80+ PPrint .document )
81+ option
82+
83+ val local_heap_dealloc : (ident_doc :PPrint .document -> PPrint .document ) option
7484end
7585
7686module Pure_C_config (Input : sig
@@ -455,6 +465,30 @@ struct
455465 ^^ (if List. is_empty args_docs then empty else comma ^^ space)
456466 ^^ separate (comma ^^ space) args_docs
457467 ^^ rparen ^^ semi
468+
469+ let local_heap_alloc ~zero_initialized ~num_elems ~typ_doc ~ident_doc =
470+ let open PPrint in
471+ let alloc_expr =
472+ if zero_initialized then
473+ string " calloc"
474+ ^^ parens (OCaml. int num_elems ^^ comma ^^ space ^^ string " sizeof" ^^ parens typ_doc)
475+ else
476+ string " malloc"
477+ ^^ parens
478+ (OCaml. int num_elems ^^ space ^^ string " *" ^^ space ^^ string " sizeof"
479+ ^^ parens typ_doc)
480+ in
481+ typ_doc ^^ space ^^ string " *" ^^ ident_doc ^^ space ^^ equals ^^ space
482+ ^^ parens (typ_doc ^^ string " *" )
483+ ^^ alloc_expr
484+
485+ let local_heap_alloc = Some local_heap_alloc
486+
487+ let local_heap_dealloc ~ident_doc =
488+ let open PPrint in
489+ string " free(" ^^ ident_doc ^^ string " )"
490+
491+ let local_heap_dealloc = Some local_heap_dealloc
458492end
459493
460494module C_syntax (B : C_syntax_config ) = struct
@@ -529,7 +563,7 @@ module C_syntax (B : C_syntax_config) = struct
529563 let ident_doc = string (get_ident tn) in
530564 let dims = Lazy. force tn.dims in
531565 let prec = Lazy. force tn.prec in
532- let local_defs, val_doc = pp_float prec llsc in
566+ let local_defs, val_doc = pp_scalar prec llsc in
533567 let offset_doc = pp_array_offset (idcs, dims) in
534568 let assignment =
535569 group
@@ -595,7 +629,7 @@ module C_syntax (B : C_syntax_config) = struct
595629 let dims = Lazy. force tn.dims in
596630 let prec = Lazy. force tn.prec in
597631 let arg_prec = Ops. uint4x32 in
598- let local_defs, arg_doc = pp_float arg_prec arg in
632+ let local_defs, arg_doc = pp_scalar arg_prec arg in
599633 (* Generate the function call *)
600634 let result_doc = B. vec_unop_syntax prec vec_unop arg_doc in
601635 (* Generate assignments for each output element *)
@@ -671,13 +705,13 @@ module C_syntax (B : C_syntax_config) = struct
671705 else if PPrint. is_empty local_defs then assignments
672706 else local_defs ^^ hardline ^^ assignments
673707 | Set_local ({ scope_id; tn = { prec; _ } } , value ) ->
674- let local_defs, value_doc = pp_float (Lazy. force prec) value in
708+ let local_defs, value_doc = pp_scalar (Lazy. force prec) value in
675709 let assignment =
676710 string (" v" ^ Int. to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi
677711 in
678712 if PPrint. is_empty local_defs then assignment else local_defs ^^ hardline ^^ assignment
679713
680- and pp_float (prec : Ops.prec ) (vcomp : Low_level.scalar_t ) : PPrint.document * PPrint.document =
714+ and pp_scalar (prec : Ops.prec ) (vcomp : Low_level.scalar_t ) : PPrint.document * PPrint.document =
681715 (* Returns (local definitions, value expression) *)
682716 let open PPrint in
683717 match vcomp with
@@ -735,29 +769,29 @@ module C_syntax (B : C_syntax_config) = struct
735769 let idx_doc = if PPrint. is_empty idx_doc then string " 0" else idx_doc in
736770 let expr = string prefix ^^ idx_doc ^^ string postfix in
737771 (empty, expr)
738- | Binop (Arg1, v1 , _v2 ) -> pp_float prec v1
739- | Binop (Arg2, _v1 , v2 ) -> pp_float prec v2
772+ | Binop (Arg1, v1 , _v2 ) -> pp_scalar prec v1
773+ | Binop (Arg2, _v1 , v2 ) -> pp_scalar prec v2
740774 | Ternop (op , v1 , v2 , v3 ) ->
741- let d1, e1 = pp_float prec v1 in
742- let d2, e2 = pp_float prec v2 in
743- let d3, e3 = pp_float prec v3 in
775+ let d1, e1 = pp_scalar prec v1 in
776+ let d2, e2 = pp_scalar prec v2 in
777+ let d3, e3 = pp_scalar prec v3 in
744778 let defs =
745779 List. filter_map [ d1; d2; d3 ] ~f: (fun d -> if PPrint. is_empty d then None else Some d)
746780 |> separate hardline
747781 in
748782 let expr = group (B. ternop_syntax prec op e1 e2 e3) in
749783 (defs, expr)
750784 | Binop (op , v1 , v2 ) ->
751- let d1, e1 = pp_float prec v1 in
752- let d2, e2 = pp_float prec v2 in
785+ let d1, e1 = pp_scalar prec v1 in
786+ let d2, e2 = pp_scalar prec v2 in
753787 let defs =
754788 List. filter_map [ d1; d2 ] ~f: (fun d -> if PPrint. is_empty d then None else Some d)
755789 |> separate hardline
756790 in
757791 let expr = group (B. binop_syntax prec op e1 e2) in
758792 (defs, expr)
759793 | Unop (op , v ) ->
760- let defs, expr_v = pp_float prec v in
794+ let defs, expr_v = pp_scalar prec v in
761795 let expr = group (B. unop_syntax prec op expr_v) in
762796 (defs, expr)
763797
@@ -937,6 +971,9 @@ module C_syntax (B : C_syntax_config) = struct
937971 body := ! body ^^ debug_init_doc ^^ hardline);
938972
939973 let heap_allocated = ref [] in
974+ let stack_threshold_in_bytes =
975+ Int. of_string @@ Utils. get_global_arg ~default: " 16384" ~arg_name: " stack_threshold_in_bytes"
976+ in
940977 let local_decls =
941978 string " /* Local declarations and initialization. */"
942979 ^^ hardline
@@ -947,27 +984,18 @@ module C_syntax (B : C_syntax_config) = struct
947984 let ident_doc = string (get_ident tn) in
948985 let num_elems = Tn. num_elems tn in
949986 let size_doc = OCaml. int num_elems in
950- (* Use heap allocation for arrays larger than 16KB to avoid stack overflow in Domain
951- threads *)
952- let stack_threshold = 16384 / (Ops. prec_in_bytes @@ Lazy. force tn.prec) in
953- if num_elems > stack_threshold then (
987+ (* Use heap allocation for arrays larger than stack_threshold_in_bytes to avoid stack
988+ overflow in Domain threads *)
989+
990+ if
991+ Option. is_some B. local_heap_alloc && stack_threshold_in_bytes > 0
992+ && num_elems > stack_threshold_in_bytes / (Ops. prec_in_bytes @@ Lazy. force tn.prec)
993+ then (
954994 (* Heap allocation for large arrays *)
955995 heap_allocated := get_ident tn :: ! heap_allocated;
956- let alloc_expr =
957- if node.Low_level. zero_initialized then
958- string " calloc"
959- ^^ parens
960- (OCaml. int num_elems ^^ comma ^^ space ^^ string " sizeof"
961- ^^ parens typ_doc)
962- else
963- string " malloc"
964- ^^ parens
965- (OCaml. int num_elems ^^ space ^^ string " *" ^^ space ^^ string " sizeof"
966- ^^ parens typ_doc)
967- in
968- typ_doc ^^ space ^^ string " *" ^^ ident_doc ^^ space ^^ equals ^^ space
969- ^^ parens (typ_doc ^^ string " *" )
970- ^^ alloc_expr ^^ semi ^^ hardline)
996+ Option. value_exn B. local_heap_alloc
997+ ~zero_initialized: node.Low_level. zero_initialized ~num_elems ~typ_doc ~ident_doc
998+ ^^ semi ^^ hardline)
971999 else
9721000 (* Stack allocation for small arrays *)
9731001 let init_doc =
@@ -983,13 +1011,13 @@ module C_syntax (B : C_syntax_config) = struct
9831011 body := ! body ^^ main_logic;
9841012
9851013 (* Free heap-allocated arrays *)
986- if not (List. is_empty ! heap_allocated) then
1014+ if Option. is_some B. local_heap_dealloc && not (List. is_empty ! heap_allocated) then
9871015 body :=
9881016 ! body ^^ hardline
9891017 ^^ string " /* Cleanup heap-allocated arrays. */"
9901018 ^^ hardline
9911019 ^^ separate_map hardline
992- (fun ident -> string " free " ^^ parens (string ident) ^^ semi)
1020+ (fun ident -> Option. value_exn B. local_heap_dealloc ~ident_doc: (string ident) ^^ semi)
9931021 ! heap_allocated
9941022 ^^ hardline;
9951023
0 commit comments