@@ -6,6 +6,7 @@ module Idx = Arrayjit.Indexing
66module Debug_runtime = Arrayjit.Utils. Debug_runtime
77
88type tn = Tn .t
9+ type tn_set = Set .M (Arrayjit .Tnode ).t
910type asgns = Asgns .t
1011type init_op = Arrayjit.Ops .init_op
1112type fetch_op = Asgns .fetch_op
@@ -23,9 +24,10 @@ type t = {
2324 forward : Asgns .t ;
2425 diff : diff option ;
2526 id : int ;
26- value : Tn .t ;
27+ value : tn ;
2728 shape : Shape .t ;
2829 children : subtensor list ;
30+ non_embedded : tn_set ;
2931}
3032
3133and subtensor = { subtensor : t ; embedded : bool }
@@ -147,12 +149,14 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
147149 ?(transpose_op = Shape. Pointwise_un ) ?(init_op = default_init_op) ~op_asn ~grad_asn
148150 ?(grad_spec = If_needed ) make_shape (orig_ts : t list ) : t =
149151 let ordered_ts = List. dedup_and_sort orig_ts ~compare: (fun t1 t2 -> Int. ascending t1.id t2.id) in
152+ let non_embedded = ref @@ Set. empty (module Tn ) in
150153 let children =
151154 List. folding_map orig_ts
152155 ~init: (Set. empty (module Int ))
153156 ~f: (fun used ti ->
154- ( Set. add used ti.id,
155- { subtensor = ti; embedded = is_fwd_root ti && not (Set. mem used ti.id) } ))
157+ let root = is_fwd_root ti in
158+ if not root then non_embedded := Set. add ! non_embedded ti.value;
159+ (Set. add used ti.id, { subtensor = ti; embedded = root && not (Set. mem used ti.id) }))
156160 in
157161 let id = session_state.next_id in
158162 session_state.next_id < - session_state.next_id + 1 ;
@@ -187,7 +191,9 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
187191 || Fn. non is_require_grad grad_spec
188192 && List. for_all orig_ts ~f: (fun ti -> Option. is_none ti.diff)
189193 then (
190- let tensor = { forward; diff = None ; id; value = v; shape; children } in
194+ let tensor =
195+ { forward; diff = None ; id; value = v; shape; children; non_embedded = ! non_embedded }
196+ in
191197 session_state.forward_roots < - Map. add_exn session_state.forward_roots ~key: id ~data: tensor;
192198 tensor)
193199 else
@@ -216,7 +222,11 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
216222 that all ancestors of a node are backpropagated before the node is backpropagated, even for
217223 non-tree DAGs. *)
218224 let backprop =
219- let bprop = dcode ~f: (fun diff -> diff.backprop) in
225+ let bprop =
226+ dcode ~f: (fun diff ->
227+ non_embedded := Set. add ! non_embedded diff.grad;
228+ diff.backprop)
229+ in
220230 let bcks =
221231 List. map ordered_ts ~f: (fun ti -> if is_bck_root ti then bprop ti else Asgns. Noop )
222232 in
@@ -226,7 +236,7 @@ let op ~(label : string list) ?(compose_op = Shape.Pointwise_bin)
226236 session_state.backprop_roots < - Map. remove session_state.backprop_roots ti.id);
227237 (* The order is not relevant, we keep the same order as in backprop for readability. *)
228238 let diff = Some { grad = g; zero_grads; backprop } in
229- let tensor = { forward; diff; id; value = v; shape; children } in
239+ let tensor = { forward; diff; id; value = v; shape; children; non_embedded = ! non_embedded } in
230240 session_state.forward_roots < - Map. add_exn session_state.forward_roots ~key: id ~data: tensor;
231241 session_state.backprop_roots < - Map. add_exn session_state.backprop_roots ~key: id ~data: tensor;
232242 tensor
@@ -350,30 +360,33 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
350360 Tn. update_memory_mode g Never_virtual 26 ;
351361 t
352362
353- let rec iter_embedded_arrays ~f t =
354- f t.value;
355- Option. iter t.diff ~f: (fun diff -> f diff.grad);
356- List. iter ~f: (fun ch -> if ch.embedded then iter_embedded_arrays ~f ch.subtensor) t.children
357-
358- let rec non_and_embedded_nodes t =
363+ let rec inputs_and_outputs t =
364+ (* TODO: consider either caching here, or as a field of t. *)
365+ let opt_grad t = Option. value_map ~default: [] ~f: (fun diff -> [ diff.grad ]) t.diff in
366+ let dir_outputs t =
367+ Set. of_list (module Tn )
368+ @@ List. filter ~f: (fun tn -> not @@ Set. mem t.non_embedded tn)
369+ @@ (t.value :: opt_grad t)
370+ in
371+ let open Arrayjit.Utils.Set_O in
359372 let non_embedded, embedded =
360373 List. fold t.children
361- ~init: (Set. empty (module Self ), Set. empty ( module Self ))
374+ ~init: (t.non_embedded, Set. of_list (module Tn ) (t.value :: opt_grad t ))
362375 ~f: (fun (non_embedded , embedded ) ch ->
363- if ch.embedded then (non_embedded, Set. add embedded ch.subtensor)
364- else (Set. add non_embedded ch.subtensor, embedded))
376+ (ch.subtensor.non_embedded + non_embedded, dir_outputs ch.subtensor + embedded))
365377 in
366- let open Arrayjit.Utils.Set_O in
367378 let non_embedded, embedded =
368379 List. fold t.children ~init: (non_embedded, embedded)
369380 ~f: (fun ((non_embedded , embedded ) as accu ) ch ->
370381 if ch.embedded then
371- let more_non, more = non_and_embedded_nodes ch.subtensor in
382+ let more_non, more = inputs_and_outputs ch.subtensor in
372383 (non_embedded + more_non, embedded + more)
373384 else accu)
374385 in
375386 (non_embedded - embedded, embedded)
376387
388+ let iter_outputs ~f t = Set. iter ~f @@ snd @@ inputs_and_outputs t
389+ let input_nodes t = fst @@ inputs_and_outputs t
377390let debug_name t = Tn. debug_name t.value
378391let debug_grad t = Tn. debug_name (Option. value_exn t.diff).grad
379392
0 commit comments