Skip to content

Commit 25c71e5

Browse files
committed
Get rid of hard-coded pointers: all materialized nodes are kernel parameters
1 parent 606f3d2 commit 25c71e5

File tree

4 files changed

+18
-17
lines changed

4 files changed

+18
-17
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
- [The Anatomy of an OCANNL Backend](#the-anatomy-of-an-ocannl-backend)
66
- [Design around compiling and running code, backend interfaces](#design-around-compiling-and-running-code-backend-interfaces)
7-
- [Shared relocatable compilation, batch compilation](#shared-relocatable-compilation-batch-compilation)
7+
- [Batch compilation; in the future: lazy and cached compilation artifacts](#batch-compilation-in-the-future-lazy-and-cached-compilation-artifacts)
88
- [Tensor nodes, arrays, memory properties](#tensor-nodes-arrays-memory-properties)
99
- [Typical details of a backend implementation](#typical-details-of-a-backend-implementation)
1010
- [Conditionally emitting the tracing debugger code](#conditionally-emitting-the-tracing-debugger-code)
@@ -125,7 +125,7 @@ Conventionally, the compilation implementation is split into three functions / l
125125
- On GPU-like backends, we cannot load the code at compile time. For example, the CUDA driver API function `cuModuleLoadDataEx` loads the module into _the current context_, which is device-specific, so it must be called from within `link` or `link_batch`.
126126
- GPU-like backends necessitate distinguishing between `link` and `link_batch`, to prevent the same code from being loaded as multiple modules.
127127

128-
The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends.
128+
The `C_syntax` functor returns the `compile_proc` function for use by `compile` and `compile_batch` of the backends. For simplicity, `C_syntax` passes all materialized nodes by parameters even for backends that use some nodes directly from the host rather than from the device / from context.
129129

130130
### Conditionally emitting the tracing debugger code
131131

arrayjit/lib/c_syntax.ml

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ struct
6161
(* let compute_array_offset ~idcs ~dims = Array.fold2_exn idcs dims ~init:0 ~f:(fun offset idx dim
6262
-> idx + (offset * dim)) *)
6363
let print_includes ppf =
64-
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,|} (pp_print_list pp_include) B.includes)
64+
Stdlib.Format.(fprintf ppf {|@[<v 0>%a@,@,|} (pp_print_list pp_include) B.includes)
6565
6666
let compile_main ~traced_store ppf llc : unit =
6767
let open Stdlib.Format in
@@ -266,20 +266,19 @@ struct
266266
List.rev
267267
@@ Hashtbl.fold traced_store ~init:[] ~f:(fun ~key:tn ~data:_ params ->
268268
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
269-
let backend_info =
270-
Sexp.Atom
271-
(if Tn.is_virtual_force tn 334 then "Virt"
272-
else
273-
match in_ctx tn with
274-
| Some true -> "Ctx"
275-
| Some false -> "Local"
276-
| None -> "Unk")
269+
let backend_info, is_param =
270+
if Tn.is_virtual_force tn 334 then ("Virt", false)
271+
else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true)
272+
else if Tn.is_materialized_force tn 335 then ("Global or ctx", true)
273+
else if Tn.known_not_materialized tn then ("Local", false)
274+
else assert false
277275
in
276+
let backend_info = Sexp.Atom backend_info in
278277
if not @@ Utils.sexp_mem ~elem:backend_info tn.backend_info then
279278
tn.backend_info <- Utils.sexp_append ~elem:backend_info tn.backend_info;
280279
(* We often don't know ahead of linking with relevant contexts what the stream sharing
281280
mode of the node will become. Conservatively, use passing as argument. *)
282-
if Option.value ~default:true (in_ctx tn) then
281+
if is_param then
283282
(B.typ_of_prec (Lazy.force tn.Tn.prec) ^ " *" ^ get_ident tn, Param_ptr tn) :: params
284283
else params)
285284
in
@@ -345,7 +344,7 @@ struct
345344
params);
346345
fprintf ppf "/* Local declarations and initialization. */@ ";
347346
Hashtbl.iteri traced_store ~f:(fun ~key:tn ~data:node ->
348-
if not (Tn.is_virtual_force tn 333 || Option.value ~default:true (in_ctx tn)) then
347+
if not (Tn.is_virtual_force tn 333 || Tn.is_materialized_force tn 336) then
349348
fprintf ppf "%s %s[%d]%s;@ "
350349
(B.typ_of_prec @@ Lazy.force tn.prec)
351350
(get_ident tn) (Tn.num_elems tn)

arrayjit/lib/tnode.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -478,21 +478,23 @@ let get_style ?(arg_name = "ll_ident_style") ?(no_dots = false) () =
478478
invalid_arg @@ "Wrong " ^ arg_name ^ ", must be one of: heuristic, name_and_label, name_only"
479479

480480
let header tn =
481+
let debug = Utils.settings.log_level > 0 in
481482
let mem_size =
482483
if Lazy.is_val tn.array then
483484
match tn.array with
484485
| (lazy None) -> "<not-hosted>"
485486
| (lazy (Some nd)) ->
486487
let size = Int.to_string_hum @@ Nd.size_in_bytes nd in
487-
if Utils.settings.log_level > 0 then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size
488+
if debug then size ^ " @ " ^ Nd.ptr_to_string_hum nd else size
488489
else "<not-in-yet>"
489490
in
490491
let repeating_nograd_idents = Hashtbl.create ~size:1 (module String) in
491492
let repeating_grad_idents = Hashtbl.create ~size:1 (module String) in
492493
[%string
493494
{|%{id tn} %{label tn} as %{
494495
styled_ident ~repeating_nograd_idents ~repeating_grad_idents (`Heuristic_ocannl `Dot_grad) tn
495-
}: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}|}]
496+
}: %{debug_memory_mode tn.memory_mode}; %{dims_to_string tn}; mem in bytes: %{mem_size}%{
497+
if debug then "; debug: " ^ Sexp.to_string_hum tn.backend_info else ""}|}]
496498

497499
module Registry = Stdlib.Weak.Make (struct
498500
type nonrec t = t

bin/micrograd_demo.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ module Debug_runtime = Utils.Debug_runtime
1212

1313
let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
1414
Rand.init 0;
15-
Utils.enable_runtime_debug ();
15+
(* Utils.enable_runtime_debug (); *)
1616
(* Utils.settings.debug_log_from_routines <- true; *)
1717
let hid_dim = 16 in
1818
let len = 300 in
@@ -183,7 +183,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
183183
PrintBox_text.output Stdio.stdout plot_lr
184184

185185
let () = experiment 4 ~no_batch_shape_inference:true ~use_builtin_weight_decay:true ()
186-
let () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false ()
186+
let _suspended () = experiment 4 ~no_batch_shape_inference:false ~use_builtin_weight_decay:false ()
187187

188188
let _suspended () =
189189
for seed = 0 to 19 do

0 commit comments

Comments
 (0)