Skip to content

Commit 0f0336b

Browse files
committed
Rename backend_utils -> c_syntax, uniformly validate merge nodes
1 parent d54b5e0 commit 0f0336b

File tree

9 files changed

+57
-31
lines changed

9 files changed

+57
-31
lines changed

CHANGES.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,15 @@
99
- Migrated to cudajit 0.5.
1010
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
1111
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
12-
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; TODO: renamed `backend_utils.ml` to `c_syntax.ml`.
12+
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
1313
- TODO: Moved the multicore backend from a `device = stream` model to a single device model.
1414
- TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
1515
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
1616

17+
### Fixed
18+
19+
- Validating merge nodes for the CUDA backend.
20+
1721
## [0.4.1] -- 2024-09-17
1822

1923
### Added

arrayjit/lib/backend_types.ml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,7 @@ module type Backend = sig
165165
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
166166

167167
type device
168-
type stream
168+
type stream [@@deriving sexp_of]
169169

170170
val init : stream -> context
171171
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
@@ -180,7 +180,6 @@ module type Backend = sig
180180
val is_idle : stream -> bool
181181
(** Whether the stream is currently waiting for work. *)
182182

183-
val sexp_of_stream : stream -> Sexp.t
184183
val get_device : ordinal:int -> device
185184
val num_devices : unit -> int
186185

@@ -298,14 +297,18 @@ module type Lowered_backend = sig
298297
val get_buffer : Tnode.t -> context -> buffer_ptr option
299298

300299
type device
301-
type stream
300+
type stream [@@deriving sexp_of]
302301

303302
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
304303
val init : stream -> context
305304
val await : stream -> unit
306305
val is_idle : stream -> bool
307306
val all_work : stream -> event
308-
val sexp_of_stream : stream -> Sexplib.Sexp.t
307+
308+
val scheduled_merge_node : stream -> Tnode.t option
309+
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge
310+
buffer right after [await stream]. *)
311+
309312
val num_devices : unit -> int
310313
val suggested_num_streams : device -> int
311314
val get_device : ordinal:int -> device

arrayjit/lib/backends.ml

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,17 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77
[%%global_debug_log_level 9]
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

10+
let check_merge_buffer ~scheduled_node ~code_node =
11+
let name = function Some tn -> Tnode.debug_name tn | None -> "none" in
12+
match (scheduled_node, code_node) with
13+
| _, None -> ()
14+
| Some actual, Some expected when Tnode.equal actual expected -> ()
15+
| _ ->
16+
raise
17+
@@ Utils.User_error
18+
("Merge buffer mismatch, on stream: " ^ name scheduled_node ^ ", expected by code: "
19+
^ name code_node)
20+
1021
module Multicore_backend (Backend : Backend_types.No_device_backend) : Backend_types.Backend =
1122
struct
1223
module Domain = Domain [@warning "-3"]
@@ -690,6 +701,11 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
690701
verify from_prior_context;
691702
link_compiled ~merge_buffer prior_context proc
692703
in
704+
let schedule =
705+
Task.prepend schedule ~work:(fun () ->
706+
check_merge_buffer ~scheduled_node:(Option.map !merge_buffer ~f:snd)
707+
~code_node:(expected_merge_node code))
708+
in
693709
{ context; schedule; bindings; name }
694710

695711
let link_batch ~merge_buffer (prior_context : context) (code_batch : code_batch) =
@@ -711,9 +727,15 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
711727
verify from_prior_context;
712728
procs
713729
in
714-
Array.fold_map procs ~init:prior_context ~f:(fun context -> function
730+
let code_nodes = expected_merge_nodes code_batch in
731+
Array.fold_mapi procs ~init:prior_context ~f:(fun i context -> function
715732
| Some proc ->
716733
let context, bindings, schedule, name = link_compiled ~merge_buffer context proc in
734+
let schedule =
735+
Task.prepend schedule ~work:(fun () ->
736+
check_merge_buffer ~scheduled_node:(Option.map !merge_buffer ~f:snd)
737+
~code_node:code_nodes.(i))
738+
in
717739
(context, Some { context; schedule; bindings; name })
718740
| None -> (context, None))
719741

@@ -800,6 +822,12 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
800822
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context.ctx
801823
~from_prior_context:code.from_prior_context [| code.traced_store |];
802824
let ctx, bindings, schedule = link context.ctx code.code in
825+
let schedule =
826+
Task.prepend schedule ~work:(fun () ->
827+
check_merge_buffer
828+
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
829+
~code_node:(expected_merge_node code))
830+
in
803831
{ context = { ctx; expected_merge_node = code.expected_merge_node }; schedule; bindings; name }
804832

805833
let link_batch context code_batch =
@@ -809,12 +837,14 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
809837
( { ctx; expected_merge_node = context.expected_merge_node },
810838
Array.mapi schedules ~f:(fun i ->
811839
Option.map ~f:(fun schedule ->
812-
{
813-
context = { ctx; expected_merge_node = code_batch.expected_merge_nodes.(i) };
814-
schedule;
815-
bindings;
816-
name;
817-
})) )
840+
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
841+
let schedule =
842+
Task.prepend schedule ~work:(fun () ->
843+
check_merge_buffer
844+
~scheduled_node:(scheduled_merge_node @@ get_ctx_stream context.ctx)
845+
~code_node:expected_merge_node)
846+
in
847+
{ context = { ctx; expected_merge_node }; schedule; bindings; name })) )
818848

819849
let init stream = { ctx = init stream; expected_merge_node = None }
820850
let get_ctx_stream context = get_ctx_stream context.ctx

arrayjit/lib/backend_utils.ml renamed to arrayjit/lib/c_syntax.ml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -390,15 +390,3 @@ struct
390390
fprintf ppf "@;<0 -2>}@]@.";
391391
params
392392
end
393-
394-
let check_merge_buffer ~merge_buffer ~code_node =
395-
let stream_node = Option.map !merge_buffer ~f:snd in
396-
let name = function Some tn -> Tn.debug_name tn | None -> "none" in
397-
match (stream_node, code_node) with
398-
| _, None -> ()
399-
| Some actual, Some expected when Tn.equal actual expected -> ()
400-
| _ ->
401-
raise
402-
@@ Utils.User_error
403-
("Merge buffer mismatch, on stream: " ^ name stream_node ^ ", expected by code: "
404-
^ name code_node)

arrayjit/lib/cc_backend.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -151,7 +151,7 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
151151
else ctx_arrays
152152
| Some _ -> ctx_arrays))
153153
in
154-
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
154+
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
155155
let for_lowereds = [| lowered |]
156156
let opt_ctx_arrays = opt_ctx_arrays
157157
end)) in
@@ -185,7 +185,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
185185
else ctx_arrays
186186
| Some _ -> ctx_arrays)))
187187
in
188-
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
188+
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
189189
let for_lowereds = for_lowereds
190190
let opt_ctx_arrays = opt_ctx_arrays
191191
end)) in

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,6 +97,8 @@ let will_wait_for context event = Cu.Delimited_event.wait context.stream.cu_stre
9797
let sync event = Cu.Delimited_event.synchronize event
9898
let all_work stream = Cu.Delimited_event.record stream.cu_stream
9999

100+
let scheduled_merge_node stream = Option.map ~f:snd stream.merge_buffer
101+
100102
let is_initialized, initialize =
101103
let initialized = ref false in
102104
let init (config : config) : unit =
@@ -462,7 +464,7 @@ end
462464
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
463465
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
464466
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
465-
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
467+
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
466468
let for_lowereds = [| lowered |]
467469
end)) in
468470
let idx_params = Indexing.bound_symbols bindings in
@@ -477,7 +479,7 @@ let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
477479

478480
let compile_batch ~names bindings lowereds =
479481
let for_lowereds = Array.filter_map ~f:Fn.id lowereds in
480-
let module Syntax = Backend_utils.C_syntax (C_syntax_config (struct
482+
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
481483
let for_lowereds = for_lowereds
482484
end)) in
483485
let idx_params = Indexing.bound_symbols bindings in

arrayjit/lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
assignments
3737
task
3838
backend_types
39-
backend_utils
39+
c_syntax
4040
cc_backend
4141
gcc_backend
4242
cuda_backend

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -858,7 +858,6 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
858858
in
859859
let%diagn_l_sexp work () : unit =
860860
[%log_result name];
861-
Backend_utils.check_merge_buffer ~merge_buffer ~code_node:code.expected_merge_node;
862861
Indexing.apply run_variadic ();
863862
if Utils.debug_log_from_routines () then (
864863
Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name);

arrayjit/lib/writing_a_backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ type mem_properties =
138138
| Constant_from_host (** The array is read directly from the host. *)
139139
```
140140

141-
while the CC and CUDA backends do it implicitly via the input to the `Backend_utils.C_syntax` functor:
141+
while the CC and CUDA backends do it implicitly via the input to the `C_syntax.C_syntax` functor:
142142

143143
```ocaml
144144
module C_syntax (B : sig

0 commit comments

Comments
 (0)