Skip to content

Commit 42e108d

Browse files
committed
Additional opportunity to output the .cd file: from run_once before init_params
1 parent 5ba9a5c commit 42e108d

File tree

2 files changed

+18
-8
lines changed

2 files changed

+18
-8
lines changed

lib/train.ml

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -228,12 +228,20 @@ type example_train_result = {
228228
(** [run_once] is a wrapper around {!init_params} that additionally runs code of [f t] and returns
229229
the context. If [skip_init] is true (false by default), no initialization is performmed. If
230230
[reinit_all] is true (false by default), all parameters are reinitialized, otherwise only the
231-
parameters that are not in [ctx.ctx_arrays] are initialized. *)
232-
let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all ?(bindings = IDX.empty)
233-
~f ctx t =
231+
parameters that are not in [ctx.ctx_arrays] are initialized. If [output_cd_file] is true, the
232+
update code is output to a file before shape inference potentially crashes at [init_params]. *)
233+
let%track3_sexp run_once ?(output_cd_file = false) ?(hosted = true) ?(skip_init = false) ?reinit_all
234+
?(bindings = IDX.empty) ~f ctx t =
234235
if hosted then set_hosted t.Tensor.value;
235236
(* Compute the update early, to ensure the shape inference is done. *)
236237
let update = f t in
238+
(if Utils.settings.output_debug_files_in_build_directory || output_cd_file then
239+
let name = Asgns.get_name_exn update.Asgns.asgns in
240+
let cd_source = Utils.output_to_build_file ~fname:(name ^ "-debug.cd") in
241+
let static_indices = Idx.bound_symbols bindings in
242+
match cd_source with
243+
| None -> ()
244+
| Some callback -> callback (Asgns.to_doc ~name ~static_indices () update.Asgns.asgns));
237245
let ctx =
238246
if skip_init || Set.is_empty t.params then ctx
239247
else init_params ?reinit_all ~hosted ctx bindings t
@@ -244,16 +252,18 @@ let%track3_sexp run_once ?(hosted = true) ?(skip_init = false) ?reinit_all ?(bin
244252
(** Context-based versions of training functions for the new simplified API *)
245253

246254
(** [forward_once] is a wrapper around {!run_once} that runs the forward code of [t]. *)
247-
let forward_once ?(hosted = true) ?(skip_init = false) ?reinit_all ?(bindings = IDX.empty) ctx t =
248-
let ctx = run_once ~hosted ~skip_init ?reinit_all ~bindings ~f:forward ctx t in
255+
let forward_once ?output_cd_file ?(hosted = true) ?(skip_init = false) ?reinit_all
256+
?(bindings = IDX.empty) ctx t =
257+
let ctx = run_once ?output_cd_file ~hosted ~skip_init ?reinit_all ~bindings ~f:forward ctx t in
249258
(* FIXME: this is going away soon. *)
250259
Tensor.remove_bprop_root t;
251260
ctx
252261

253262
(** [update_once] is a wrapper around {!run_once} that runs the gradient update code of [t]: both
254263
forward and backprop. *)
255-
let update_once ?(hosted = true) ?(skip_init = false) ?reinit_all ?(bindings = IDX.empty) ctx t =
256-
run_once ~hosted ~skip_init ?reinit_all ~bindings ~f:grad_update ctx t
264+
let update_once ?output_cd_file ?(hosted = true) ?(skip_init = false) ?reinit_all
265+
?(bindings = IDX.empty) ctx t =
266+
run_once ?output_cd_file ~hosted ~skip_init ?reinit_all ~bindings ~f:grad_update ctx t
257267

258268
(** [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
259269
[~with_code:false], [~with_grad:true], and [~style:`Default]. *)

test/operations/transformer_test.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ let () =
5050
(* Forward pass *)
5151
let output = transformer_model ~train_step:None ~src ~tgt ~mask in
5252

53-
let _ctx = Ocannl.Train.forward_once ctx output in
53+
let _ctx = Ocannl.Train.forward_once ~output_cd_file:true ctx output in
5454

5555
(* Verify output shape *)
5656
Stdio.printf "Output shape:\n%s\n%!"

0 commit comments

Comments
 (0)