@@ -13,20 +13,20 @@ let demo () =
1313 let seed = 3 in
1414 Rand. init seed;
1515 Utils. settings.fixed_state_for_init < - Some seed;
16- Utils. enable_runtime_debug () ;
17- Utils. set_log_level 0 ;
18- (* Utils.settings.debug_log_from_routines <- true; *)
16+ Utils. settings.output_debug_files_in_build_directory < - true ;
17+ (* Utils.enable_runtime_debug (); *)
1918 let hid_dim = 16 in
2019 let len = 512 in
2120 let batch_size = 32 in
22- let n_batches = 2 * len / batch_size in
2321 let epochs = 75 in
22+ (* Utils.settings.debug_log_from_routines <- true; *)
23+ (* TINY for debugging: *)
24+ (* let hid_dim = 2 in let len = 16 in let batch_size = 2 in let epochs = 2 in *)
25+ let n_batches = 2 * len / batch_size in
2426 let steps = epochs * n_batches in
2527 let weight_decay = 0.0002 in
26- Utils. settings.fixed_state_for_init < - Some 4 ;
2728
28- let % op mlp x =
29- " b3" + (" w3" * ?/ (" b2" hid_dim + (" w2" * ?/ (" b1" hid_dim + (" w1" * x))))) in
29+ let % op mlp x = " b3" + (" w3" * ?/ (" b2" hid_dim + (" w2" * ?/ (" b1" hid_dim + (" w1" * x))))) in
3030
3131 let noise () = Rand. float_range (- 0.1 ) 0.1 in
3232 let moons_flat =
@@ -47,7 +47,7 @@ let demo () =
4747 let step_n, bindings = IDX. get_static_symbol bindings in
4848 let % op moons_input = moons_flat @| batch_n in
4949 let % op moons_class = moons_classes @| batch_n in
50-
50+
5151 let % op margin_loss = ?/ (1 - (moons_class *. mlp moons_input)) in
5252 let % op scalar_loss = (margin_loss ++ " ...|... => 0" ) /. ! ..batch_size in
5353
@@ -86,6 +86,7 @@ let demo () =
8686 for epoch = 0 to epochs - 1 do
8787 for batch = 0 to n_batches - 1 do
8888 batch_ref := batch;
89+ Utils. capture_stdout_logs @@ fun () ->
8990 Train. run routine;
9091 assert (Backend. to_host routine.context learning_rate.value);
9192 assert (Backend. to_host routine.context scalar_loss.value);
@@ -106,6 +107,7 @@ let demo () =
106107 in
107108 let callback (x , y ) =
108109 Tensor. set_values point [| x; y |];
110+ Utils. capture_stdout_logs @@ fun () ->
109111 assert (Backend. from_host result_routine.context point.value);
110112 Train. run result_routine;
111113 assert (Backend. to_host result_routine.context mlp_result.value);
0 commit comments