Skip to content

Commit 606f3d2

Browse files
committed
In progress: get rid of hard-coded pointers, and of opt_ctx_arrays
1 parent b5d6104 commit 606f3d2

File tree

8 files changed

+48
-104
lines changed

8 files changed

+48
-104
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
- Huge refactoring of backend internal interfaces and API (not repeating same code).
2323
- Built per-tensor-node stream-to-stream synchronization into copying functions.
2424
- Re-introduced whole-device blocking synchronization, which now is just a slight optimization as it also cleans up event book-keeping.
25+
- Simplifications: no more explicit compilation postponing; no more hard-coded pointers (all non-local arrays are passed by parameter).
2526

2627
### Fixed
2728

arrayjit/lib/backend_impl.ml

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -158,23 +158,13 @@ module type Lowered_no_device_backend = sig
158158

159159
type procedure [@@deriving sexp_of]
160160

161-
val compile :
162-
name:string ->
163-
opt_ctx_arrays:ctx_arrays option ->
164-
Indexing.unit_bindings ->
165-
Low_level.optimized ->
166-
procedure
167-
(** [opt_ctx_arrays], if any, already contain the arrays of the context that will result from
168-
linking the code. *)
161+
val compile : name:string -> Indexing.unit_bindings -> Low_level.optimized -> procedure
169162

170163
val compile_batch :
171164
names:string option array ->
172-
opt_ctx_arrays:ctx_arrays option array option ->
173165
Indexing.unit_bindings ->
174166
Low_level.optimized option array ->
175167
procedure option array
176-
(** [opt_ctx_arrays], if any, already contain the arrays of the contexts that will result from
177-
linking the code. *)
178168

179169
val link_compiled :
180170
merge_buffer:buffer option ref ->

arrayjit/lib/backends.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,11 @@ module Add_device
218218
[@@deriving sexp_of]
219219

220220
let compile ~name bindings lowered : code =
221-
let proc = compile ~name ~opt_ctx_arrays:None bindings lowered in
221+
let proc = compile ~name bindings lowered in
222222
{ lowered; proc }
223223

224224
let compile_batch ~names bindings lowereds : code_batch =
225-
let procs = compile_batch ~names ~opt_ctx_arrays:None bindings lowereds in
225+
let procs = compile_batch ~names bindings lowereds in
226226
{ lowereds; procs }
227227

228228
include Add_scheduler (Backend)

arrayjit/lib/c_syntax.ml

Lines changed: 12 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,32 +11,31 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1111
module Tn = Tnode
1212

1313
module C_syntax (B : sig
14-
type buffer_ptr
15-
16-
val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array
14+
val procs : Low_level.optimized array
1715
(** The low-level prcedure to compile, and the arrays of the context it will be linked to if not
1816
shared and already known. *)
1917

20-
val hardcoded_context_ptr : (buffer_ptr -> Ops.prec -> string) option
2118
val use_host_memory : bool
2219
val logs_to_stdout : bool
2320
val main_kernel_prefix : string
2421
val kernel_prep_line : string
25-
val include_lines : string list
22+
val includes : string list
2623
val typ_of_prec : Ops.prec -> string
2724
val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
2825
val unop_syntax : Ops.prec -> Ops.unop -> string * string
2926
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
3027
end) =
3128
struct
3229
let get_ident =
33-
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun (l, _) -> l.llc)
30+
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc)
3431

3532
let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn)
3633

3734
let pp_zero_out ppf tn =
3835
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
3936

37+
let pp_include ppf s = Stdlib.Format.fprintf ppf "#include %s" s
38+
4039
open Indexing.Pp_helpers
4140

4241
let pp_array_offset ppf (idcs, dims) =
@@ -61,33 +60,8 @@ struct
6160

6261
(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
6362
-> idx + (offset * dim)) *)
64-
let%debug3_sexp compile_globals ppf : Tn.t Hash_set.t =
65-
let open Stdlib.Format in
66-
let is_global = Hash_set.create (module Tn) in
67-
fprintf ppf {|@[<v 0>%a@,/* Global declarations. */@,|} (pp_print_list pp_print_string)
68-
B.include_lines;
69-
Array.iter B.procs ~f:(fun (l, ctx_arrays) ->
70-
Hashtbl.iter l.Low_level.traced_store ~f:(fun (node : Low_level.traced_array) ->
71-
let tn = node.tn in
72-
if not @@ Hash_set.mem is_global tn then
73-
let ctx_ptr = B.hardcoded_context_ptr in
74-
let mem : (Tn.memory_mode * int) option = tn.memory_mode in
75-
match (in_ctx tn, ctx_ptr, ctx_arrays, mem) with
76-
| Some true, Some get_ptr, Some ctx_arrays, _ ->
77-
let ident = get_ident tn in
78-
let ctx_array =
79-
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays tn
80-
in
81-
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array (Lazy.force tn.prec);
82-
Hash_set.add is_global tn
83-
| Some false, _, _, Some (Hosted _, _)
84-
when B.(Tn.known_shared_with_host ~use_host_memory tn) ->
85-
let nd = Option.value_exn ~here:[%here] @@ Lazy.force tn.array in
86-
fprintf ppf "#define %s (%s)@," (get_ident tn) (Ndarray.c_ptr_to_string nd);
87-
Hash_set.add is_global tn
88-
| _ -> ()));
89-
fprintf ppf "@,@]";
90-
is_global
63+
let print_includes ppf =
64+
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,|} (pp_print_list pp_include) B.includes)
9165
9266
let compile_main ~traced_store ppf llc : unit =
9367
let open Stdlib.Format in
@@ -285,18 +259,16 @@ struct
285259
in
286260
pp_ll ppf llc
287261
288-
let%track3_sexp compile_proc ~name ppf idx_params ~is_global
289-
Low_level.{ traced_store; llc; merge_node } =
262+
let%track3_sexp compile_proc ~name ppf idx_params Low_level.{ traced_store; llc; merge_node } =
290263
let open Stdlib.Format in
291264
let params : (string * param_source) list =
292-
(* Preserve the order in the hashtable, so it's the same as e.g. in compile_globals. *)
265+
(* Preserve the order in the hashtable. *)
293266
List.rev
294267
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
295268
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
296269
let backend_info =
297270
Sexp.Atom
298-
(if Hash_set.mem is_global tn then "Host"
299-
else if Tn.is_virtual_force tn 334 then "Virt"
271+
(if Tn.is_virtual_force tn 334 then "Virt"
300272
else
301273
match in_ctx tn with
302274
| Some true -> "Ctx"
@@ -307,7 +279,7 @@ struct
307279
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
308280
(* We often don't know ahead of linking with relevant contexts what the stream sharing
309281
mode of the node will become. Conservatively, use passing as argument. *)
310-
if Option.value ~default:true (in_ctx tn) && not (Hash_set.mem is_global tn) then
282+
if Option.value ~default:true (in_ctx tn) then
311283
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
312284
else params)
313285
in
@@ -373,12 +345,7 @@ struct
373345
params);
374346
fprintf ppf "/* Local declarations and initialization. */@ ";
375347
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
376-
if
377-
not
378-
(Tn.is_virtual_force tn 333
379-
|| Option.value ~default:true (in_ctx tn)
380-
|| Hash_set.mem is_global tn)
381-
then
348+
if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then
382349
fprintf ppf "%s %s[%d]%s;@ "
383350
(B.typ_of_prec @@ Lazy.force tn.prec)
384351
(get_ident tn) (Tn.num_elems tn)

arrayjit/lib/cc_backend.ml

Lines changed: 10 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -72,47 +72,37 @@ let c_compile_and_load ~f_name =
7272
result
7373

7474
module C_syntax_config (Input : sig
75-
val procs : (Low_level.optimized * buffer_ptr ctx_arrays option) array
75+
val procs : Low_level.optimized array
7676
end) =
7777
struct
78-
type nonrec buffer_ptr = buffer_ptr
79-
8078
let procs = Input.procs
81-
let hardcoded_context_ptr = c_ptr_to_string
8279
let use_host_memory = use_host_memory
8380
let logs_to_stdout = false
8481
let main_kernel_prefix = ""
8582
let kernel_prep_line = ""
86-
87-
let include_lines =
88-
[ "#include <stdio.h>"; "#include <stdlib.h>"; "#include <string.h>"; "#include <math.h>" ]
89-
83+
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
9084
let typ_of_prec = Ops.c_typ_of_prec
9185
let binop_syntax = Ops.binop_c_syntax
9286
let unop_syntax = Ops.unop_c_syntax
9387
let convert_precision = Ops.c_convert_precision
9488
end
9589

96-
let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) =
90+
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) =
9791
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
98-
let procs = [| (lowered, opt_ctx_arrays) |]
92+
let procs = [| lowered |]
9993
end)) in
10094
(* FIXME: do we really want all of them, or only the used ones? *)
10195
let idx_params = Indexing.bound_symbols bindings in
10296
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
103-
let is_global = Syntax.compile_globals pp_file.ppf in
104-
let params = Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered in
97+
Syntax.print_includes pp_file.ppf;
98+
let params = Syntax.compile_proc ~name pp_file.ppf idx_params lowered in
10599
pp_file.finalize ();
106100
let result = c_compile_and_load ~f_name:pp_file.f_name in
107101
{ result; params; bindings; name }
108102

109-
let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
110-
(lowereds : Low_level.optimized option array) =
103+
let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) =
111104
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
112-
let procs =
113-
Array.filter_mapi lowereds ~f:(fun i ->
114-
Option.map ~f:(fun lowereds ->
115-
(lowereds, Option.(map opt_ctx_arrays ~f:(fun ctx_arrays -> value_exn ctx_arrays.(i))))))
105+
let procs = Array.filter_opt lowereds
116106
end)) in
117107
(* FIXME: do we really want all of them, or only the used ones? *)
118108
let idx_params = Indexing.bound_symbols bindings in
@@ -122,11 +112,11 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
122112
@@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names))
123113
in
124114
let pp_file = Utils.pp_file ~base_name ~extension:".c" in
125-
let is_global = Syntax.compile_globals pp_file.ppf in
115+
Syntax.print_includes pp_file.ppf;
126116
let params =
127117
Array.mapi lowereds ~f:(fun i lowered ->
128118
Option.map2 names.(i) lowered ~f:(fun name lowered ->
129-
Syntax.compile_proc ~name pp_file.ppf idx_params ~is_global lowered))
119+
Syntax.compile_proc ~name pp_file.ppf idx_params lowered))
130120
in
131121
pp_file.finalize ();
132122
let result = c_compile_and_load ~f_name:pp_file.f_name in

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -267,21 +267,20 @@ let%diagn2_sexp cuda_to_ptx ~name cu_src =
267267
ptx
268268

269269
module C_syntax_config (Input : sig
270-
val procs : (Low_level.optimized * ctx_arrays option) array
270+
val procs : Low_level.optimized array
271271
end) =
272272
struct
273273
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
274274

275275
let procs = Input.procs
276-
let hardcoded_context_ptr = None
277276
let use_host_memory = use_host_memory
278277
let logs_to_stdout = true
279278
let main_kernel_prefix = "extern \"C\" __global__"
280279

281280
let kernel_prep_line =
282281
"/* FIXME: single-threaded for now. */if (threadIdx.x != 0 || blockIdx.x != 0) { return; }"
283282

284-
let include_lines = [ "#include <cuda_fp16.h>" ]
283+
let includes = [ "<cuda_fp16.h>" ]
285284

286285
let typ_of_prec = function
287286
| Ops.Byte_prec _ -> "unsigned char"
@@ -341,31 +340,31 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
341340
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
342341
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
343342
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
344-
let procs = [| (lowered, None) |]
343+
let procs = [| lowered |]
345344
end)) in
346345
let idx_params = Indexing.bound_symbols bindings in
347346
let b = Buffer.create 4096 in
348347
let ppf = Stdlib.Format.formatter_of_buffer b in
349348
if Utils.debug_log_from_routines () then
350349
Stdlib.Format.fprintf ppf "@,__device__ int printf (const char * format, ... );@,";
351-
let is_global = Syntax.compile_globals ppf in
352-
let params = Syntax.compile_proc ~name ~is_global ppf idx_params lowered in
350+
Syntax.print_includes ppf;
351+
let params = Syntax.compile_proc ~name ppf idx_params lowered in
353352
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
354353
{ traced_store; ptx; params; bindings; name }
355354

356355
let compile_batch ~names bindings lowereds =
357356
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
358-
let procs = Array.filter_map lowereds ~f:(Option.map ~f:(fun lowereds -> (lowereds, None)))
357+
let procs = Array.filter_opt lowereds
359358
end)) in
360359
let idx_params = Indexing.bound_symbols bindings in
361360
let b = Buffer.create 4096 in
362361
let ppf = Stdlib.Format.formatter_of_buffer b in
363-
let is_global = Syntax.compile_globals ppf in
362+
Syntax.print_includes ppf;
364363
let params_and_names =
365364
Array.map2_exn names lowereds
366365
~f:
367366
(Option.map2 ~f:(fun name lowered ->
368-
(Syntax.compile_proc ~name ~is_global ppf idx_params lowered, name)))
367+
(Syntax.compile_proc ~name ppf idx_params lowered, name)))
369368
in
370369
let name : string =
371370
String.(

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,8 @@ let zero_out ctx block node =
9898

9999
let get_c_ptr ctx num_typ ptr = Gccjit.(RValue.ptr ctx (Type.pointer num_typ) ptr)
100100

101-
let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs
102-
initializations (tn : Tn.t) =
101+
let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs initializations
102+
(tn : Tn.t) =
103103
let open Gccjit in
104104
let traced = Low_level.(get_node traced_store tn) in
105105
let dims = Lazy.force tn.dims in
@@ -116,14 +116,12 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
116116
let hosted = Tn.is_hosted_force tn 344 in
117117
let in_ctx = Tn.is_in_context ~use_host_memory tn in
118118
let ptr =
119-
match (in_ctx, opt_ctx_arrays, hosted) with
120-
| Some true, Some ctx_arrays, _ ->
121-
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Map.find_exn ctx_arrays tn
122-
| (Some true | None), None, _ ->
119+
match (in_ctx, hosted) with
120+
| Some true, _ ->
123121
let p = Param.create ctx ptr_typ ident in
124122
param_ptrs := (p, Param_ptr tn) :: !param_ptrs;
125123
Lazy.from_val (RValue.param p)
126-
| (Some false | None), _, true -> (
124+
| (Some false | None), true -> (
127125
let addr arr =
128126
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr
129127
in
@@ -133,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~opt_ctx_arrays
133131
| Some (Single_nd arr) -> addr arr
134132
| Some (Double_nd arr) -> addr arr
135133
| None -> assert false)
136-
| (Some false | None), _, false ->
134+
| (Some false | None), false ->
137135
let arr_typ = Type.array ctx num_typ size_in_elems in
138136
let v = ref None in
139137
let initialize _init_block func = v := Some (Function.local func arr_typ ident) in
@@ -500,7 +498,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
500498
loop_proc ~toplevel:true ~name ~env body;
501499
!current_block
502500

503-
let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
501+
let%diagn_sexp compile_proc ~name ctx bindings ~get_ident
504502
Low_level.{ traced_store; llc = proc; merge_node } =
505503
let open Gccjit in
506504
let c_index = Type.get ctx Type.Int in
@@ -536,7 +534,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
536534
let data =
537535
prepare_node
538536
~debug_log_zero_out:(debug_log_zero_out ctx log_functions get_ident)
539-
~get_ident ctx traced_store ~opt_ctx_arrays ~param_ptrs initializations tn
537+
~get_ident ctx traced_store ~param_ptrs initializations tn
540538
in
541539
Hashtbl.add_exn nodes ~key:tn ~data);
542540
let params : (gccjit_param * param_source) list = !param_ptrs in
@@ -590,7 +588,7 @@ let%diagn_sexp compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident
590588
Block.return_void after_proc;
591589
(ctx_info, params)
592590

593-
let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optimized) =
591+
let compile ~(name : string) bindings (lowered : Low_level.optimized) =
594592
let get_ident = Low_level.get_ident_within_code ~no_dots:true [| lowered.llc |] in
595593
let open Gccjit in
596594
if Option.is_none !root_ctx then initialize ();
@@ -599,7 +597,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim
599597
(* if Utils.settings.with_debug && Utils.settings.output_debug_files_in_build_directory then (
600598
Context.set_option ctx Context.Keep_intermediates true; Context.set_option ctx
601599
Context.Dump_everything true); *)
602-
let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in
600+
let info, params = compile_proc ~name ctx bindings ~get_ident lowered in
603601
(if Utils.settings.output_debug_files_in_build_directory then
604602
let f_name = Utils.build_file @@ name ^ "-gccjit-debug.c" in
605603
Context.dump_to_file ctx ~update_locs:true f_name);
@@ -608,7 +606,7 @@ let compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_level.optim
608606
Context.release ctx;
609607
{ info; result; bindings; name; params = List.map ~f:snd params }
610608

611-
let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bindings
609+
let%diagn_sexp compile_batch ~(names : string option array) bindings
612610
(lowereds : Low_level.optimized option array) =
613611
let get_ident =
614612
Low_level.get_ident_within_code ~no_dots:true
@@ -623,10 +621,9 @@ let%diagn_sexp compile_batch ~(names : string option array) ~opt_ctx_arrays bind
623621
Context.Dump_everything true); *)
624622
let funcs =
625623
Array.mapi lowereds ~f:(fun i lowered ->
626-
let opt_ctx_arrays = Option.(join @@ map opt_ctx_arrays ~f:(fun arrs -> arrs.(i))) in
627624
match (names.(i), lowered) with
628625
| Some name, Some lowered ->
629-
let info, params = compile_proc ~name ~opt_ctx_arrays ctx bindings ~get_ident lowered in
626+
let info, params = compile_proc ~name ctx bindings ~get_ident lowered in
630627
Some (info, params)
631628
| _ -> None)
632629
in

0 commit comments

Comments
 (0)