Skip to content

Commit a834849

Browse files
committed
Verifying context by tracking embedded_nodes;
big change, needs some cleanup but mostly done.
1 parent dfc1858 commit a834849

19 files changed

+375
-229
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
### Changed
88

99
- Migrated to cudajit 0.5.
10+
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
1011
- TODO: Built per-tensor-node device-to-device synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.
1112

1213
## [0.4.1] -- 2024-09-17

README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ OCANNL is sponsored by [Ahrefs](https://ocaml.org/success-stories/peta-byte-scal
2424
* Differentiable computations, centered around the [`%op`](lib/ppx_op.ml) syntax extension.
2525
* `%op` stands for "operation", it's meant to express tensors: `Tensor.t`, and tensor functions.
2626
* Plain computations, centered around the [`%cd`](lib/ppx_cd.ml) syntax extension. It integrates the `arrayjit` backend library with shape inference.
27-
* `%cd` stands for "code", it's meant to express assignments: `Assignments.t`.
27+
* `%cd` stands for "code", it's meant to express assignment computations: `Assignments.comp`.
2828
* The support for mixed-precision computations is upcoming.
2929
* E.g. higher-precision network components, or gradients at a higher precision than values.
3030
* Currently (v0.3), you can select the precision, and individual computation nodes track their precision, but mixing precisions might break things.

arrayjit/lib/assignments.ml

Lines changed: 55 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,18 @@ and t =
4444
| Fetch of { array : Tn.t; fetch_op : fetch_op; dims : int array Lazy.t }
4545
[@@deriving sexp_of]
4646

47-
type comp = {asgns: t; }
47+
type comp = {
48+
asgns : t;
49+
embedded_nodes : Set.M(Tn).t;
50+
(** The nodes in {!field-asgns} that are not in [embedded_nodes] need to already be in
51+
contexts linked with the {!comp}. *)
52+
}
53+
[@@deriving sexp_of]
54+
(** Computations based on assignments. Note: the [arrayjit] library makes use of, but does not
55+
produce nor verify the {!field-embedded_nodes} associated to some given {!field-asgns}. *)
56+
57+
let to_comp asgns = { asgns; embedded_nodes = Set.empty (module Tnode) }
58+
let empty_comp = to_comp Noop
4859

4960
let get_name_exn asgns =
5061
let punct_or_sp = Str.regexp "[-@*/:.;, ]" in
@@ -63,39 +74,54 @@ let get_name_exn asgns =
6374
let result = loop asgns in
6475
if String.is_empty result then invalid_arg "Assignments.get_name: no comments in code" else result
6576

66-
(** Returns nodes that are inputs to the computation in a narrow sense: nodes that were potentially
67-
computed by assignments executed before. *)
68-
let input_or_recurrent_nodes asgns =
77+
let is_total ~initialize_neutral ~projections =
78+
initialize_neutral && Indexing.is_bijective projections
79+
80+
(** Returns the left-hand-side nodes of total assignments. NOTE: [output_nodes] forces the
81+
computation of the assignments' projections, so should only be called after shape inference. *)
82+
let output_nodes asgns =
6983
let open Utils.Set_O in
7084
let empty = Set.empty (module Tn) in
71-
let single = function
72-
| Node tn ->
73-
if Tn.known_constant tn || Tn.known_volatile tn || Tn.known_not_materialized tn then
74-
Set.empty (module Tn)
75-
else Set.singleton (module Tn) tn
76-
| Merge_buffer _ -> Set.empty (module Tn)
77-
in
78-
let maybe have lhs = if have then Set.singleton (module Tn) lhs else empty in
7985
let rec loop = function
8086
| Noop -> empty
81-
| Seq (t1, t2) -> loop t1 + (loop t2 - assigned t1)
87+
| Seq (t1, t2) -> loop t1 + loop t2
8288
| Block_comment (_, t) -> loop t
83-
| Accum_binop { initialize_neutral; lhs; rhs1; rhs2; _ } ->
84-
maybe (not initialize_neutral) lhs + single rhs1 + single rhs2
85-
| Accum_unop { initialize_neutral; lhs; rhs; _ } ->
86-
maybe (not initialize_neutral) lhs + single rhs
89+
| Accum_unop { lhs; initialize_neutral; projections; _ }
90+
| Accum_binop { lhs; initialize_neutral; projections; _ } ->
91+
if is_total ~initialize_neutral ~projections:(Lazy.force projections) then
92+
Set.singleton (module Tn) lhs
93+
else empty
8794
| Fetch _ -> empty
88-
and assigned = function
89-
| Noop -> Set.empty (module Tn)
90-
| Seq (t1, t2) -> assigned t1 + assigned t2
91-
| Block_comment (_, t) -> assigned t
92-
| Accum_binop { initialize_neutral; lhs; _ } -> maybe initialize_neutral lhs
93-
| Accum_unop { initialize_neutral; lhs; _ } -> maybe initialize_neutral lhs
94-
| Fetch { array; _ } -> Set.singleton (module Tn) array
9595
in
9696
loop asgns
9797

98-
let sequential l = Option.value ~default:Noop @@ List.reduce l ~f:(fun st sts -> Seq (st, sts))
98+
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it ideally should
99+
be called after compilation. *)
100+
let context_nodes asgns =
101+
let open Utils.Set_O in
102+
let empty = Set.empty (module Tn) in
103+
let one tn = if Tnode.is_in_context_force tn 34 then Set.singleton (module Tn) tn else empty in
104+
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
105+
let rec loop = function
106+
| Noop -> empty
107+
| Seq (t1, t2) -> loop t1 + loop t2
108+
| Block_comment (_, t) -> loop t
109+
| Accum_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs)
110+
| Accum_binop { lhs; rhs1; rhs2; _ } ->
111+
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2 ]
112+
| Fetch { array; _ } -> one array
113+
in
114+
loop asgns
115+
116+
let sequential l =
117+
Option.value ~default:Noop @@ List.reduce l ~f:(fun sts another_st -> Seq (sts, another_st))
118+
119+
let sequence l =
120+
Option.value ~default:{ asgns = Noop; embedded_nodes = Set.empty (module Tn) }
121+
@@ List.reduce l
122+
~f:(fun
123+
{ asgns = sts; embedded_nodes = embs } { asgns = another_st; embedded_nodes = emb } ->
124+
{ asgns = Seq (sts, another_st); embedded_nodes = Set.union embs emb })
99125

100126
let%diagn1_sexp to_low_level code =
101127
let open Indexing in
@@ -145,7 +171,6 @@ let%diagn1_sexp to_low_level code =
145171
derive_index ~product_syms:projections.product_iterators
146172
~projection:projections.project_rhs.(1)
147173
in
148-
let is_assignment = initialize_neutral && Indexing.is_bijective projections in
149174
let basecase rev_iters =
150175
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
151176
let rhs1_idcs = rhs1_idx ~product in
@@ -156,7 +181,7 @@ let%diagn1_sexp to_low_level code =
156181
let rhs1_ll = get rhs1 rhs1_idcs in
157182
let rhs2_ll = get rhs2 rhs2_idcs in
158183
let rhs2 = binop ~op ~rhs1:rhs1_ll ~rhs2:rhs2_ll in
159-
if is_assignment then set lhs lhs_idcs rhs2
184+
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
160185
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
161186
in
162187
let rec for_loop rev_iters = function
@@ -178,7 +203,7 @@ let%diagn1_sexp to_low_level code =
178203
[%log "projections=", (projections : projections)];
179204
raise e
180205
in
181-
if initialize_neutral && not is_assignment then
206+
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
182207
let dims = lazy projections.lhs_dims in
183208
let fetch_op = Constant (Ops.neutral_elem accum) in
184209
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
@@ -193,15 +218,14 @@ let%diagn1_sexp to_low_level code =
193218
derive_index ~product_syms:projections.product_iterators
194219
~projection:projections.project_rhs.(0)
195220
in
196-
let is_assignment = initialize_neutral && Indexing.is_bijective projections in
197221
let basecase rev_iters =
198222
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
199223
let lhs_idcs = lhs_idx ~product in
200224
let open Low_level in
201225
let lhs_ll = get (Node lhs) lhs_idcs in
202226
let rhs_ll = get rhs @@ rhs_idx ~product in
203227
let rhs2 = unop ~op ~rhs:rhs_ll in
204-
if is_assignment then set lhs lhs_idcs rhs2
228+
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
205229
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
206230
in
207231
let rec for_loop rev_iters = function
@@ -218,7 +242,7 @@ let%diagn1_sexp to_low_level code =
218242
}
219243
in
220244
let for_loops = for_loop [] (Array.to_list projections.product_space) in
221-
if initialize_neutral && not is_assignment then
245+
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
222246
let dims = lazy projections.lhs_dims in
223247
let fetch_op = Constant (Ops.neutral_elem accum) in
224248
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)

arrayjit/lib/backends.ml

Lines changed: 19 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ module type No_device_backend = sig
3535
val expected_merge_node : code -> Tnode.t option
3636
val expected_merge_nodes : code_batch -> Tnode.t option array
3737

38-
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.t -> code
38+
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
3939
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
4040
device-agnostic way. If [~shared:false], the backend can opt to postpone compiling altogether
4141
until [link] is called, to benefit from more optimizations. *)
@@ -45,7 +45,7 @@ module type No_device_backend = sig
4545
?names:string array ->
4646
?occupancy:(name:string -> src_n:int -> bool) ->
4747
Indexing.unit_bindings ->
48-
Assignments.t array ->
48+
Assignments.comp array ->
4949
code_batch
5050
(** Unlike the [~shared] parameter, [compile_batch] vs. [compile] is mostly about improving the
5151
compile time and debugging convenience by generating fewer files -- ideally does not affect
@@ -871,18 +871,20 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend =
871871
| Compiled (lowereds, _) ->
872872
Array.filter_map lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store))
873873

874-
let compile ?(shared = false) ?name bindings asgns : code =
875-
let name, lowered = lower_assignments ?name bindings asgns in
874+
let compile ?(shared = false) ?name bindings comp : code =
875+
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
876876
if shared then Compiled (lowered, Backend.compile ~name ~opt_ctx_arrays:None bindings lowered)
877877
else Postponed { lowered; bindings; name }
878878

879-
let compile_batch ?(shared = false) ?names ?occupancy bindings asgns_l : code_batch =
880-
let names, lowereds = lower_batch_assignments ?names ?occupancy bindings asgns_l in
879+
let compile_batch ?(shared = false) ?names ?occupancy bindings comp_l : code_batch =
880+
let names, lowereds =
881+
lower_batch_assignments ?names ?occupancy bindings
882+
@@ Array.map comp_l ~f:(fun c -> c.Assignments.asgns)
883+
in
881884
if shared then Compiled (lowereds, compile_batch ~names ~opt_ctx_arrays:None bindings lowereds)
882885
else Postponed { lowereds; bindings; names }
883886

884-
let link ~from_prior_context ~merge_buffer (prior_context : context)
885-
(code : code) =
887+
let link ~from_prior_context ~merge_buffer (prior_context : context) (code : code) =
886888
Backend.(
887889
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
888890
[| get_traced_store code |]);
@@ -897,8 +899,8 @@ module Simple_no_device_backend (Backend : Simple_backend) : No_device_backend =
897899
in
898900
{ context; schedule; bindings; name }
899901

900-
let link_batch ~from_prior_context ~merge_buffer
901-
(prior_context : context) (code_batch : code_batch) =
902+
let link_batch ~from_prior_context ~merge_buffer (prior_context : context)
903+
(code_batch : code_batch) =
902904
Backend.(
903905
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
904906
@@ get_traced_stores code_batch);
@@ -963,16 +965,19 @@ module Cuda_backend : Backend = struct
963965
let work_for context = work_for context.ctx
964966
let will_wait_for context = will_wait_for context.ctx
965967

966-
let compile ?shared:_ ?name bindings asgns : code =
967-
let name, lowered = lower_assignments ?name bindings asgns in
968+
let compile ?shared:_ ?name bindings comp : code =
969+
let name, lowered = lower_assignments ?name bindings comp.Assignments.asgns in
968970
{
969971
traced_store = lowered.traced_store;
970972
code = compile ~name bindings lowered;
971973
expected_merge_node = lowered.Low_level.merge_node;
972974
}
973975

974-
let compile_batch ?shared:_ ?names ?occupancy bindings asgns_l =
975-
let names, lowereds = lower_batch_assignments ?names ?occupancy bindings asgns_l in
976+
let compile_batch ?shared:_ ?names ?occupancy bindings comp_l =
977+
let names, lowereds =
978+
lower_batch_assignments ?names ?occupancy bindings
979+
@@ Array.map comp_l ~f:(fun c -> c.Assignments.asgns)
980+
in
976981
{
977982
traced_stores =
978983
Array.filter_map lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store));

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,7 @@ type procedure = {
6767
[@@deriving sexp_of]
6868

6969
let expected_merge_node proc = proc.lowered.merge_node
70-
71-
let is_in_context node =
72-
Tnode.default_to_most_local node.Low_level.tn 33;
73-
match node.tn.memory_mode with
74-
| Some (Hosted (Constant | Volatile), _) -> false
75-
| Some ((Virtual | Local), _) -> false
76-
| _ -> true
70+
let is_in_context node = Tnode.is_in_context_force node.Low_level.tn 33
7771

7872
let header_sep =
7973
let open Re in

arrayjit/lib/tnode.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,13 @@ let is_materialized_force tn provenance =
156156
| Some ((On_device | Hosted _ | Materialized), _) -> true
157157
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false
158158

159+
let is_in_context_force tn provenance =
160+
default_to_most_local tn provenance;
161+
match tn.memory_mode with
162+
| Some (Hosted (Constant | Volatile), _) -> false
163+
| Some ((Virtual | Local), _) -> false
164+
| _ -> true
165+
159166
let known_not_materialized tn =
160167
match tn.memory_mode with Some ((Virtual | Local), _) -> true | _ -> false
161168

arrayjit/lib/writing_a_backend.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ TODO: update regarding events and device-to-device synchronization.
2020

2121
## Design around compiling and running code, backend interfaces
2222

23-
Currently, OCANNL integrates new backends via code in [Backends](backends.ml), so it's the "sink" of backend module dependencies; [Backend_utils](backend_utils.ml) is the "source". `Backend_utils.Types` introduces the context-specific `routine` type, for code executable on a backend. The interface `Backends.No_device_backend` has `compile` functions that take `Assignments.t` as input, to allow full flexibility in backend implementations. There is a helper `Backends.lower_assignments` that wraps `Assignments.lower` and `Low_level.optimize_proc`, since currently all backends use the optimized C-like representation `Low_level.t`. The user-facing interface `Backends.Backend` builds on top of `No_device_backend` providing multi-device functionality. The functor `Multicore_backend` converts a `No_device_backend` targetting the CPU into a `Backend` whose devices are parallel threads (and ultimately the CPU cores).
23+
Currently, OCANNL integrates new backends via code in [Backends](backends.ml), so it's the "sink" of backend module dependencies; [Backend_utils](backend_utils.ml) is the "source". `Backend_utils.Types` introduces the context-specific `routine` type, for code executable on a backend. The interface `Backends.No_device_backend` has `compile` functions that take `Assignments.comp` as input, to allow full flexibility in backend implementations. There is a helper `Backends.lower_assignments` that wraps `Assignments.lower` and `Low_level.optimize_proc`, since currently all backends use the optimized C-like representation `Low_level.t`. The user-facing interface `Backends.Backend` builds on top of `No_device_backend` providing multi-device functionality. The functor `Multicore_backend` converts a `No_device_backend` targetting the CPU into a `Backend` whose devices are parallel threads (and ultimately the CPU cores).
2424

2525
```ocaml
2626
type lowered_bindings = (static_symbol, int ref) List.Assoc.t (* in indexing.ml *)

bin/micrograd_demo.ml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
open Base
22
open Ocannl
33
module Tn = Arrayjit.Tnode
4+
module Asgns = Arrayjit.Assignments
45
module IDX = Train.IDX
56
module TDSL = Operation.TDSL
67
module NTDSL = Operation.NTDSL
@@ -81,7 +82,9 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
8182
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
8283
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
8384
let ctx = Backend.init device in
84-
let routine = Train.to_routine (module Backend) ctx bindings (Seq (update.fwd_bprop, sgd)) in
85+
let routine =
86+
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
87+
in
8588
Train.all_host_to_device (module Backend) routine.context scalar_loss;
8689
Train.all_host_to_device (module Backend) routine.context learning_rate;
8790
(* Stdio.print_endline "\n******** scalar_loss **********"; Tensor.print_tree ~with_id:true
@@ -122,8 +125,12 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
122125
Train.set_on_host Changed_on_devices mlp_result.value;
123126
(* By using jitted.context here, we don't need to copy the parameters back to the host. *)
124127
let result_routine =
125-
Train.to_routine (module Backend) routine.context IDX.empty
126-
@@ Block_comment ("moons infer", mlp_result.forward)
128+
Train.to_routine
129+
(module Backend)
130+
routine.context IDX.empty
131+
[%cd
132+
~~("moons" "infer";
133+
mlp_result.forward)]
127134
in
128135
Stdio.print_endline "\n******** mlp_result **********";
129136
Tensor.print_tree ~with_id:true ~with_grad:false ~depth:9 mlp_result;

bin/moons_demo.ml

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ module TDSL = Operation.TDSL
66
module NTDSL = Operation.NTDSL
77
module CDSL = Train.CDSL
88
module Utils = Arrayjit.Utils
9+
module Asgns = Arrayjit.Assignments
910
module Rand = Arrayjit.Rand.Lib
1011
module Debug_runtime = Utils.Debug_runtime
1112

@@ -59,7 +60,9 @@ let demo () =
5960
let module Backend = (val Arrayjit.Backends.fresh_backend ~backend_name:"cuda" ()) in
6061
let device = Backend.(new_virtual_device @@ get_device ~ordinal:0) in
6162
let ctx = Backend.init device in
62-
let routine = Train.to_routine (module Backend) ctx bindings (Seq (update.fwd_bprop, sgd)) in
63+
let routine =
64+
Train.to_routine (module Backend) ctx bindings (Asgns.sequence [ update.fwd_bprop; sgd ])
65+
in
6366

6467
let points = Tensor.value_2d_points ~xdim:0 ~ydim:1 moons_flat in
6568
let classes = Tensor.value_1d_points ~xdim:0 moons_classes in
@@ -102,8 +105,12 @@ let demo () =
102105
let%op mlp_result = mlp "point" in
103106
Train.set_on_host Changed_on_devices mlp_result.value;
104107
let result_routine =
105-
Train.to_routine (module Backend) routine.context IDX.empty
106-
@@ Block_comment ("moons infer", mlp_result.forward)
108+
Train.to_routine
109+
(module Backend)
110+
routine.context IDX.empty
111+
[%cd
112+
~~("moons" "infer";
113+
mlp_result.forward)]
107114
in
108115
let callback (x, y) =
109116
Tensor.set_values point [| x; y |];

bin/zero2hero_1of7.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ let _suspended () =
2727
Stdio.printf "\n%!";
2828
Tensor.print_tree ~with_id:true ~with_grad:true ~depth:9 v;
2929
Stdlib.Format.printf "\nHigh-level code:\n%!";
30-
Stdlib.Format.printf "%a\n%!" (Arrayjit.Assignments.fprint_hum ()) code.fwd_bprop
30+
Stdlib.Format.printf "%a\n%!" (Arrayjit.Assignments.fprint_hum ()) code.fwd_bprop.asgns
3131

3232
let _suspended () =
3333
Rand.init 0;

0 commit comments

Comments
 (0)