Skip to content

Commit 6f12ca8

Browse files
committed
Illustrate ppx_minidebug logging from the cuda backend
1 parent 53fec9b commit 6f12ca8

File tree

1 file changed

+10
-8
lines changed

1 file changed

+10
-8
lines changed

bin/moons_demo.ml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,20 +13,20 @@ let demo () =
1313
let seed = 3 in
1414
Rand.init seed;
1515
Utils.settings.fixed_state_for_init <- Some seed;
16-
Utils.enable_runtime_debug ();
17-
Utils.set_log_level 0;
18-
(* Utils.settings.debug_log_from_routines <- true; *)
16+
Utils.settings.output_debug_files_in_build_directory <- true;
17+
(* Utils.enable_runtime_debug (); *)
1918
let hid_dim = 16 in
2019
let len = 512 in
2120
let batch_size = 32 in
22-
let n_batches = 2 * len / batch_size in
2321
let epochs = 75 in
22+
(* Utils.settings.debug_log_from_routines <- true; *)
23+
(* TINY for debugging: *)
24+
(* let hid_dim = 2 in let len = 16 in let batch_size = 2 in let epochs = 2 in *)
25+
let n_batches = 2 * len / batch_size in
2426
let steps = epochs * n_batches in
2527
let weight_decay = 0.0002 in
26-
Utils.settings.fixed_state_for_init <- Some 4;
2728

28-
let%op mlp x =
29-
"b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in
29+
let%op mlp x = "b3" + ("w3" * ?/("b2" hid_dim + ("w2" * ?/("b1" hid_dim + ("w1" * x))))) in
3030

3131
let noise () = Rand.float_range (-0.1) 0.1 in
3232
let moons_flat =
@@ -47,7 +47,7 @@ let demo () =
4747
let step_n, bindings = IDX.get_static_symbol bindings in
4848
let%op moons_input = moons_flat @| batch_n in
4949
let%op moons_class = moons_classes @| batch_n in
50-
50+
5151
let%op margin_loss = ?/(1 - (moons_class *. mlp moons_input)) in
5252
let%op scalar_loss = (margin_loss ++ "...|... => 0") /. !..batch_size in
5353

@@ -86,6 +86,7 @@ let demo () =
8686
for epoch = 0 to epochs - 1 do
8787
for batch = 0 to n_batches - 1 do
8888
batch_ref := batch;
89+
Utils.capture_stdout_logs @@ fun () ->
8990
Train.run routine;
9091
assert (Backend.to_host routine.context learning_rate.value);
9192
assert (Backend.to_host routine.context scalar_loss.value);
@@ -106,6 +107,7 @@ let demo () =
106107
in
107108
let callback (x, y) =
108109
Tensor.set_values point [| x; y |];
110+
Utils.capture_stdout_logs @@ fun () ->
109111
assert (Backend.from_host result_routine.context point.value);
110112
Train.run result_routine;
111113
assert (Backend.to_host result_routine.context mlp_result.value);

0 commit comments

Comments
 (0)