@@ -228,12 +228,20 @@ type example_train_result = {
228228(* * [run_once] is a wrapper around {!init_params} that additionally runs code of [f t] and returns
229229 the context. If [skip_init] is true (false by default), no initialization is performmed. If
230230 [reinit_all] is true (false by default), all parameters are reinitialized, otherwise only the
231- parameters that are not in [ctx.ctx_arrays] are initialized. *)
232- let % track3_sexp run_once ?(hosted = true ) ?(skip_init = false ) ?reinit_all ?(bindings = IDX. empty)
233- ~f ctx t =
231+ parameters that are not in [ctx.ctx_arrays] are initialized. If [output_cd_file] is true, the
232+ update code is output to a file before shape inference potentially crashes at [init_params]. *)
233+ let % track3_sexp run_once ?(output_cd_file = false ) ?(hosted = true ) ?(skip_init = false ) ?reinit_all
234+ ?(bindings = IDX. empty) ~f ctx t =
234235 if hosted then set_hosted t.Tensor. value;
235236 (* Compute the update early, to ensure the shape inference is done. *)
236237 let update = f t in
238+ (if Utils. settings.output_debug_files_in_build_directory || output_cd_file then
239+ let name = Asgns. get_name_exn update.Asgns. asgns in
240+ let cd_source = Utils. output_to_build_file ~fname: (name ^ " -debug.cd" ) in
241+ let static_indices = Idx. bound_symbols bindings in
242+ match cd_source with
243+ | None -> ()
244+ | Some callback -> callback (Asgns. to_doc ~name ~static_indices () update.Asgns. asgns));
237245 let ctx =
238246 if skip_init || Set. is_empty t.params then ctx
239247 else init_params ?reinit_all ~hosted ctx bindings t
@@ -244,16 +252,18 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all ?(bin
244252(* * Context-based versions of training functions for the new simplified API *)
245253
246254(* * [forward_once] is a wrapper around {!run_once} that runs the forward code of [t]. *)
247- let forward_once ?(hosted = true ) ?(skip_init = false ) ?reinit_all ?(bindings = IDX. empty) ctx t =
248- let ctx = run_once ~hosted ~skip_init ?reinit_all ~bindings ~f: forward ctx t in
255+ let forward_once ?output_cd_file ?(hosted = true ) ?(skip_init = false ) ?reinit_all
256+ ?(bindings = IDX. empty) ctx t =
257+ let ctx = run_once ?output_cd_file ~hosted ~skip_init ?reinit_all ~bindings ~f: forward ctx t in
249258 (* FIXME: this is going away soon. *)
250259 Tensor. remove_bprop_root t;
251260 ctx
252261
253262(* * [update_once] is a wrapper around {!run_once} that runs the gradient update code of [t]: both
254263 forward and backprop. *)
255- let update_once ?(hosted = true ) ?(skip_init = false ) ?reinit_all ?(bindings = IDX. empty) ctx t =
256- run_once ~hosted ~skip_init ?reinit_all ~bindings ~f: grad_update ctx t
264+ let update_once ?output_cd_file ?(hosted = true ) ?(skip_init = false ) ?reinit_all
265+ ?(bindings = IDX. empty) ctx t =
266+ run_once ?output_cd_file ~hosted ~skip_init ?reinit_all ~bindings ~f: grad_update ctx t
257267
258268(* * [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
259269 [~with_code:false], [~with_grad:true], and [~style:`Default]. *)
0 commit comments