@@ -116,10 +116,24 @@ let iter_embedded ~f t =
116116 Set. iter ~f t.forward.embedded_nodes;
117117 Option. iter t.diff ~f: (fun diff -> Set. iter ~f diff.backprop.embedded_nodes)
118118
119- let init_params _t =
120- (* Based on the interface documentation, this should collect forward code of t.params *)
121- (* For now, return empty since the 'params' field is missing from the current implementation *)
122- Asgns. empty_comp
119+ let rec init_params t =
120+ let open Asgns in
121+ let rem_embedded = ref @@ Set. empty (module Tn ) in
122+ let asgns =
123+ Block_comment
124+ ( " init params for " ^ Tn. debug_name t.value,
125+ sequential
126+ @@ Set. fold t.params ~init: [] ~f: (fun acc param ->
127+ if Set. is_empty param.params then param.forward.asgns :: acc
128+ else
129+ let asgns = init_params param in
130+ rem_embedded := Set. union ! rem_embedded asgns.embedded_nodes;
131+ Seq (asgns.asgns, param.forward.asgns) :: acc) )
132+ in
133+ let embedded_nodes =
134+ Set. fold ~init: ! rem_embedded t.params ~f: (fun acc p -> Set. add acc p.value)
135+ in
136+ { asgns; embedded_nodes }
123137
124138let initial_default_prec =
125139 Ir.Ops. prec_of_string (Utils. get_global_arg ~default: " single" ~arg_name: " default_prec" )
@@ -299,7 +313,8 @@ let op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
299313 session_state.backprop_roots < - Map. remove session_state.backprop_roots ti.id);
300314 (* The order is not relevant, we keep the same order as in backprop for readability. *)
301315 let diff = Some { grad = g; zero_grads; backprop } in
302- let tensor = { params = Set. empty (module T ); forward; diff; id; value = v; shape; children } in
316+ let params = Set. union_list (module T ) @@ List. map ordered_ts ~f: (fun ti -> ti.params) in
317+ let tensor = { params; forward; diff; id; value = v; shape; children } in
303318 session_state.forward_roots < - Map. add_exn session_state.forward_roots ~key: id ~data: tensor;
304319 session_state.backprop_roots < - Map. add_exn session_state.backprop_roots ~key: id ~data: tensor;
305320 tensor
@@ -409,10 +424,10 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
409424 Tn. update_prec ~only_if: is_up_to_fp16 t.value single);
410425 t
411426
412- let param ?(more_label = [] ) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value ? values
413- label =
427+ let param ?(more_label = [] ) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?value
428+ ? values label =
414429 let fetch_op_fn ~v :_ =
415- match values, value with
430+ match ( values, value) with
416431 | Some values , None -> Asgns. Constant_fill values
417432 | None , Some value -> Asgns. Constant value
418433 | None , None -> Asgns. Range_over_offsets
@@ -429,7 +444,8 @@ let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?
429444 update computations. *)
430445 let g = (Option. value_exn ~here: [% here] t.diff).grad in
431446 Tn. update_memory_mode g Never_virtual 26 ;
432- t
447+ remove_fwd_root t;
448+ { t with params = Set. singleton (module T ) t }
433449
434450let debug_name t = Tn. debug_name t.value
435451let debug_grad t = Tn. debug_name (Option. value_exn t.diff).grad
0 commit comments