Skip to content

Commit fb31801

Browse files
committed
Make local heap allocation opt-in; rename pp_float -> pp_scalar
1 parent fcfd6da commit fb31801

File tree

4 files changed

+73
-35
lines changed

4 files changed

+73
-35
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 61 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,16 @@ module type C_syntax_config = sig
7171
implementation should handle quoting [base_message_literal], choosing the log function
7272
(printf, fprintf, os_log), and prepending any necessary prefixes (like a log_id or
7373
captured_log_prefix) to the format string and arguments. *)
74+
75+
val local_heap_alloc :
76+
(zero_initialized:bool ->
77+
num_elems:int ->
78+
typ_doc:PPrint.document ->
79+
ident_doc:PPrint.document ->
80+
PPrint.document)
81+
option
82+
83+
val local_heap_dealloc : (ident_doc:PPrint.document -> PPrint.document) option
7484
end
7585

7686
module Pure_C_config (Input : sig
@@ -455,6 +465,30 @@ struct
455465
^^ (if List.is_empty args_docs then empty else comma ^^ space)
456466
^^ separate (comma ^^ space) args_docs
457467
^^ rparen ^^ semi
468+
469+
let local_heap_alloc ~zero_initialized ~num_elems ~typ_doc ~ident_doc =
470+
let open PPrint in
471+
let alloc_expr =
472+
if zero_initialized then
473+
string "calloc"
474+
^^ parens (OCaml.int num_elems ^^ comma ^^ space ^^ string "sizeof" ^^ parens typ_doc)
475+
else
476+
string "malloc"
477+
^^ parens
478+
(OCaml.int num_elems ^^ space ^^ string "*" ^^ space ^^ string "sizeof"
479+
^^ parens typ_doc)
480+
in
481+
typ_doc ^^ space ^^ string "*" ^^ ident_doc ^^ space ^^ equals ^^ space
482+
^^ parens (typ_doc ^^ string "*")
483+
^^ alloc_expr
484+
485+
let local_heap_alloc = Some local_heap_alloc
486+
487+
let local_heap_dealloc ~ident_doc =
488+
let open PPrint in
489+
string "free(" ^^ ident_doc ^^ string ")"
490+
491+
let local_heap_dealloc = Some local_heap_dealloc
458492
end
459493

460494
module C_syntax (B : C_syntax_config) = struct
@@ -529,7 +563,7 @@ module C_syntax (B : C_syntax_config) = struct
529563
let ident_doc = string (get_ident tn) in
530564
let dims = Lazy.force tn.dims in
531565
let prec = Lazy.force tn.prec in
532-
let local_defs, val_doc = pp_float prec llsc in
566+
let local_defs, val_doc = pp_scalar prec llsc in
533567
let offset_doc = pp_array_offset (idcs, dims) in
534568
let assignment =
535569
group
@@ -595,7 +629,7 @@ module C_syntax (B : C_syntax_config) = struct
595629
let dims = Lazy.force tn.dims in
596630
let prec = Lazy.force tn.prec in
597631
let arg_prec = Ops.uint4x32 in
598-
let local_defs, arg_doc = pp_float arg_prec arg in
632+
let local_defs, arg_doc = pp_scalar arg_prec arg in
599633
(* Generate the function call *)
600634
let result_doc = B.vec_unop_syntax prec vec_unop arg_doc in
601635
(* Generate assignments for each output element *)
@@ -671,13 +705,13 @@ module C_syntax (B : C_syntax_config) = struct
671705
else if PPrint.is_empty local_defs then assignments
672706
else local_defs ^^ hardline ^^ assignments
673707
| Set_local ({ scope_id; tn = { prec; _ } }, value) ->
674-
let local_defs, value_doc = pp_float (Lazy.force prec) value in
708+
let local_defs, value_doc = pp_scalar (Lazy.force prec) value in
675709
let assignment =
676710
string ("v" ^ Int.to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi
677711
in
678712
if PPrint.is_empty local_defs then assignment else local_defs ^^ hardline ^^ assignment
679713

680-
and pp_float (prec : Ops.prec) (vcomp : Low_level.scalar_t) : PPrint.document * PPrint.document =
714+
and pp_scalar (prec : Ops.prec) (vcomp : Low_level.scalar_t) : PPrint.document * PPrint.document =
681715
(* Returns (local definitions, value expression) *)
682716
let open PPrint in
683717
match vcomp with
@@ -735,29 +769,29 @@ module C_syntax (B : C_syntax_config) = struct
735769
let idx_doc = if PPrint.is_empty idx_doc then string "0" else idx_doc in
736770
let expr = string prefix ^^ idx_doc ^^ string postfix in
737771
(empty, expr)
738-
| Binop (Arg1, v1, _v2) -> pp_float prec v1
739-
| Binop (Arg2, _v1, v2) -> pp_float prec v2
772+
| Binop (Arg1, v1, _v2) -> pp_scalar prec v1
773+
| Binop (Arg2, _v1, v2) -> pp_scalar prec v2
740774
| Ternop (op, v1, v2, v3) ->
741-
let d1, e1 = pp_float prec v1 in
742-
let d2, e2 = pp_float prec v2 in
743-
let d3, e3 = pp_float prec v3 in
775+
let d1, e1 = pp_scalar prec v1 in
776+
let d2, e2 = pp_scalar prec v2 in
777+
let d3, e3 = pp_scalar prec v3 in
744778
let defs =
745779
List.filter_map [ d1; d2; d3 ] ~f:(fun d -> if PPrint.is_empty d then None else Some d)
746780
|> separate hardline
747781
in
748782
let expr = group (B.ternop_syntax prec op e1 e2 e3) in
749783
(defs, expr)
750784
| Binop (op, v1, v2) ->
751-
let d1, e1 = pp_float prec v1 in
752-
let d2, e2 = pp_float prec v2 in
785+
let d1, e1 = pp_scalar prec v1 in
786+
let d2, e2 = pp_scalar prec v2 in
753787
let defs =
754788
List.filter_map [ d1; d2 ] ~f:(fun d -> if PPrint.is_empty d then None else Some d)
755789
|> separate hardline
756790
in
757791
let expr = group (B.binop_syntax prec op e1 e2) in
758792
(defs, expr)
759793
| Unop (op, v) ->
760-
let defs, expr_v = pp_float prec v in
794+
let defs, expr_v = pp_scalar prec v in
761795
let expr = group (B.unop_syntax prec op expr_v) in
762796
(defs, expr)
763797

@@ -937,6 +971,9 @@ module C_syntax (B : C_syntax_config) = struct
937971
body := !body ^^ debug_init_doc ^^ hardline);
938972

939973
let heap_allocated = ref [] in
974+
let stack_threshold_in_bytes =
975+
Int.of_string @@ Utils.get_global_arg ~default:"16384" ~arg_name:"stack_threshold_in_bytes"
976+
in
940977
let local_decls =
941978
string "/* Local declarations and initialization. */"
942979
^^ hardline
@@ -947,27 +984,18 @@ module C_syntax (B : C_syntax_config) = struct
947984
let ident_doc = string (get_ident tn) in
948985
let num_elems = Tn.num_elems tn in
949986
let size_doc = OCaml.int num_elems in
950-
(* Use heap allocation for arrays larger than 16KB to avoid stack overflow in Domain
951-
threads *)
952-
let stack_threshold = 16384 / (Ops.prec_in_bytes @@ Lazy.force tn.prec) in
953-
if num_elems > stack_threshold then (
987+
(* Use heap allocation for arrays larger than stack_threshold_in_bytes to avoid stack
988+
overflow in Domain threads *)
989+
990+
if
991+
Option.is_some B.local_heap_alloc && stack_threshold_in_bytes > 0
992+
&& num_elems > stack_threshold_in_bytes / (Ops.prec_in_bytes @@ Lazy.force tn.prec)
993+
then (
954994
(* Heap allocation for large arrays *)
955995
heap_allocated := get_ident tn :: !heap_allocated;
956-
let alloc_expr =
957-
if node.Low_level.zero_initialized then
958-
string "calloc"
959-
^^ parens
960-
(OCaml.int num_elems ^^ comma ^^ space ^^ string "sizeof"
961-
^^ parens typ_doc)
962-
else
963-
string "malloc"
964-
^^ parens
965-
(OCaml.int num_elems ^^ space ^^ string "*" ^^ space ^^ string "sizeof"
966-
^^ parens typ_doc)
967-
in
968-
typ_doc ^^ space ^^ string "*" ^^ ident_doc ^^ space ^^ equals ^^ space
969-
^^ parens (typ_doc ^^ string "*")
970-
^^ alloc_expr ^^ semi ^^ hardline)
996+
Option.value_exn B.local_heap_alloc
997+
~zero_initialized:node.Low_level.zero_initialized ~num_elems ~typ_doc ~ident_doc
998+
^^ semi ^^ hardline)
971999
else
9721000
(* Stack allocation for small arrays *)
9731001
let init_doc =
@@ -983,13 +1011,13 @@ module C_syntax (B : C_syntax_config) = struct
9831011
body := !body ^^ main_logic;
9841012

9851013
(* Free heap-allocated arrays *)
986-
if not (List.is_empty !heap_allocated) then
1014+
if Option.is_some B.local_heap_dealloc && not (List.is_empty !heap_allocated) then
9871015
body :=
9881016
!body ^^ hardline
9891017
^^ string "/* Cleanup heap-allocated arrays. */"
9901018
^^ hardline
9911019
^^ separate_map hardline
992-
(fun ident -> string "free" ^^ parens (string ident) ^^ semi)
1020+
(fun ident -> Option.value_exn B.local_heap_dealloc ~ident_doc:(string ident) ^^ semi)
9931021
!heap_allocated
9941022
^^ hardline;
9951023

arrayjit/lib/cuda_backend.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
705705
^^ comma ^^ space
706706
^^ separate (comma ^^ space) all_args
707707
^^ rparen ^^ semi
708+
709+
let local_heap_alloc = None
710+
let local_heap_dealloc = None
708711
end
709712

710713
let builtins_large_header =

arrayjit/lib/metal_backend.ml

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
620620
string metal_log_object_name ^^ string ".log_debug(" ^^ base_doc ^^ comma ^^ space
621621
^^ separate (comma ^^ space) args_docs
622622
^^ rparen ^^ semi
623+
624+
let local_heap_alloc = None
625+
let local_heap_dealloc = None
623626
end
624627

625628
let%diagn_sexp compile_metal_source ~name ~source ~device =
@@ -700,8 +703,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
700703
traced_stores;
701704
}
702705

703-
let%diagn2_sexp link_proc ~prior_context ~library ~func_name ~params ~lowered_bindings ~ctx_arrays
704-
=
706+
let%debug4_sexp link_proc ~prior_context ~library ~func_name
707+
~(params : (string * param_source) list) ~lowered_bindings ~(ctx_arrays : buffer_ptr Tn.t_map)
708+
: Task.t =
705709
let stream = prior_context.stream in
706710
let device = stream.device.dev in
707711
let queue = stream.runner.queue in

ocannl_config.example

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,9 @@ cc_backend_compiler_command=
8787
# are used for the dynamic libraries (default behavior).
8888
output_dlls_in_build_directory=false
8989

90+
# Per-tensor-node cutoff in bytes for considering off-stack allocation, opt-in for backends.
91+
# Decrease it if you see crashes like `bus error`. Values smaller-or-equal 0 mean "no limits".
92+
stack_threshold_in_bytes=16384
9093

9194
# Only tensor nodes with up to this many visits per array cell (in a dedicated interpreter)
9295
# can be inlined. Values worth considering: 0 (disables inlining) to 3.

0 commit comments

Comments
 (0)