@@ -11,32 +11,31 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1111module Tn = Tnode
1212
1313module C_syntax (B : sig
14- type buffer_ptr
15-
16- val procs : (Low_level .optimized * buffer_ptr ctx_arrays option ) array
14+ val procs : Low_level .optimized array
1715 (* * The low-level prcedure to compile, and the arrays of the context it will be linked to if not
1816 shared and already known. *)
1917
20- val hardcoded_context_ptr : (buffer_ptr -> Ops .prec -> string ) option
2118 val use_host_memory : bool
2219 val logs_to_stdout : bool
2320 val main_kernel_prefix : string
2421 val kernel_prep_line : string
25- val include_lines : string list
22+ val includes : string list
2623 val typ_of_prec : Ops .prec -> string
2724 val binop_syntax : Ops .prec -> Ops .binop -> string * string * string
2825 val unop_syntax : Ops .prec -> Ops .unop -> string * string
2926 val convert_precision : from :Ops .prec -> to_ :Ops .prec -> string * string
3027end ) =
3128struct
3229 let get_ident =
33- Low_level. get_ident_within_code ~no_dots: true @@ Array. map B. procs ~f: (fun ( l , _ ) -> l.llc)
30+ Low_level. get_ident_within_code ~no_dots: true @@ Array. map B. procs ~f: (fun l -> l.llc)
3431
3532 let in_ctx tn = B. (Tn. is_in_context ~use_host_memory tn)
3633
3734 let pp_zero_out ppf tn =
3835 Stdlib.Format. fprintf ppf " @[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn. size_in_bytes tn
3936
37+ let pp_include ppf s = Stdlib.Format. fprintf ppf " #include %s" s
38+
4039 open Indexing.Pp_helpers
4140
4241 let pp_array_offset ppf (idcs , dims ) =
@@ -61,33 +60,8 @@ struct
6160
6261 (* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
6362 -> idx + (offset * dim)) *)
64- let % debug3_sexp compile_globals ppf : Tn. t Hash_set. t =
65- let open Stdlib.Format in
66- let is_global = Hash_set. create (module Tn ) in
67- fprintf ppf {|@ [<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
68- B.include_lines;
69- Array.iter B.procs ~f:(fun (l, ctx_arrays) ->
70- Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
71- let tn = node.tn in
72- if not @@ Hash_set.mem is_global tn then
73- let ctx_ptr = B.hardcoded_context_ptr in
74- let mem : (Tn.memory_mode * int) option = tn.memory_mode in
75- match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with
76- | Some true, Some get_ptr, Some ctx_arrays, _ ->
77- let ident = get_ident tn in
78- let ctx_array =
79- Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn
80- in
81- fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec);
82- Hash_set.add is_global tn
83- | Some false, _, _, Some (Hosted _, _)
84- when B.(Tn.known_shared_with_host ~use_host_memory tn) ->
85- let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in
86- fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd);
87- Hash_set.add is_global tn
88- | _ -> ()));
89- fprintf ppf "@,@]";
90- is_global
63+ let print_includes ppf =
64+ Stdlib.Format. (fprintf ppf {|@ [<v 0>%a@,|} (pp_print_list pp_include) B.includes)
9165
9266 let compile_main ~traced_store ppf llc : unit =
9367 let open Stdlib.Format in
@@ -285,18 +259,16 @@ struct
285259 in
286260 pp_ll ppf llc
287261
288- let%track3_sexp compile_proc ~name ppf idx_params ~is_global
289- Low_level.{ traced_store; llc; merge_node } =
262+ let%track3_sexp compile_proc ~name ppf idx_params Low_level.{ traced_store; llc; merge_node } =
290263 let open Stdlib.Format in
291264 let params : (string * param_source) list =
292- (* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals . *)
265+ (* Preserve the order in the hashtable. *)
293266 List.rev
294267 @@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
295268 (* A rough approximation to the type Gccjit_backend.mem_properties. *)
296269 let backend_info =
297270 Sexp.Atom
298- (if Hash_set.mem is_global tn then " Host "
299- else if Tn.is_virtual_force tn 334 then " Virt "
271+ (if Tn.is_virtual_force tn 334 then " Virt "
300272 else
301273 match in_ctx tn with
302274 | Some true -> " Ctx "
@@ -307,7 +279,7 @@ struct
307279 tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
308280 (* We often don't know ahead of linking with relevant contexts what the stream sharing
309281 mode of the node will become. Conservatively, use passing as argument. *)
310- if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then
282+ if Option.value ~default:true (in_ctx tn) then
311283 (B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " * " ^ get_ident tn, Param_ptr tn) :: params
312284 else params)
313285 in
@@ -373,12 +345,7 @@ struct
373345 params);
374346 fprintf ppf "/* Local declarations and initialization. */@ ";
375347 Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
376- if
377- not
378- (Tn.is_virtual_force tn 333
379- || Option.value ~default:true (in_ctx tn)
380- || Hash_set.mem is_global tn)
381- then
348+ if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then
382349 fprintf ppf "%s %s[%d]%s;@ "
383350 (B.typ_of_prec @@ Lazy.force tn.prec)
384351 (get_ident tn) (Tn.num_elems tn)
0 commit comments