Skip to content

Commit 6d0c60f

Browse files
committed
moons_demo_parallel even more lenient expectation; CLAUDE.md typo; experiment with arrayjit/bin build
The experiment will probably fail... About https://ocaml.ci.dev/github/ahrefs/ocannl/commit/f29d8d7b2361a0bdf7145e58275309f54caf0ac3/variant/%28lint-fmt%29 Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent f29d8d7 commit 6d0c60f

File tree

4 files changed

+27
-22
lines changed

4 files changed

+27
-22
lines changed

CLAUDE.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ opam install cudajit # for CUDA backend
102102
**Important Debug Settings**:
103103
- `output_debug_files_in_build_directory=true` - enables `build_files/` generation
104104
- `debug_log_from_routines=true` - enables runtime logging
105-
- `debug_log_to_stream_files=true` - writes logs to `log_files/<backend>-<stream>-<stream>.log`
105+
- `debug_log_to_stream_files=true` - writes logs to `log_files/<backend>-<deviceF>-<stream>.log`
106106
- `clean_up_artifacts_on_startup=false` - preserves debug files between runs
107107

108108
### Backend Development

arrayjit/bin/dune

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
(executable
2+
(package arrayjit)
3+
(public_name arrayjit_read_config)
24
(name read_config)
35
(modules read_config)
46
(libraries utils)

test/training/moons_demo_parallel.expected

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,17 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
3-
Epoch=0, lr=0.199, loss=86.008
4-
Epoch=1, lr=0.198, loss=30.702
5-
Epoch=2, lr=0.198, loss=27.25
6-
Epoch=3, lr=0.197, loss=24.172
7-
Epoch=4, lr=0.196, loss=21.705
8-
Epoch=5, lr=0.195, loss=18.794
9-
Epoch=6, lr=0.194, loss=17.389
10-
Epoch=7, lr=0.193, loss=15.646
11-
Epoch=8, lr=0.193, loss=15.238
12-
Epoch=9, lr=0.192, loss=13.168
3+
Epoch=0, loss under target 87: true
4+
Epoch=1, loss under target 32: true
5+
Epoch=2, loss under target 29: true
6+
Epoch=3, loss under target 26: true
7+
Epoch=4, loss under target 23: true
8+
Epoch=5, loss under target 20: true
9+
Epoch=6, loss under target 19: true
10+
Epoch=7, loss under target 17: true
11+
Epoch=8, loss under target 16: true
12+
Epoch=9, loss under target 15: true
1313
..........
14-
Epoch loss: 0.000
14+
Final epoch loss under target 0.002: true
1515

1616
Epoch loss:
1717
┌────────┬─────────────────────────┐
@@ -23,8 +23,8 @@ Epoch loss:
2323
│ │ │
2424
│l │- │
2525
│o │- │
26-
│s │ --- -
27-
│s0.00 │ ----------------------│
26+
│s │ ----
27+
│s0.00 │ -----------------------│
2828
├────────┼─────────────────────────┤
2929
│ │0.00 1.19e+2│
3030
│ │ step │

test/training/moons_demo_parallel.ml

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -40,10 +40,12 @@ let main () =
4040
let module Backend = (val Backends.fresh_backend ()) in
4141
let per_batch_callback ~at_batch:_ ~at_step:_ ~learning_rate:_ ~batch_loss:_ ~epoch_loss:_ = () in
4242
(* Tn.print_accessible_headers (); *)
43-
let per_epoch_callback ~at_step:_ ~at_epoch ~learning_rate ~epoch_loss =
44-
if at_epoch = epochs - 5 then Stdio.printf "\n%!";
43+
let epoch_loss_target_limits = [| 87.; 32.; 29.; 26.; 23.; 20.; 19.; 17.; 16.; 15. |] in
44+
let per_epoch_callback ~at_step:_ ~at_epoch ~learning_rate:_ ~epoch_loss =
4545
if at_epoch < 10 then
46-
Stdio.printf "Epoch=%d, lr=%.3g, loss=%.5g\n%!" at_epoch learning_rate epoch_loss;
46+
Stdio.printf "Epoch=%d, loss under target %g: %b\n%!" at_epoch
47+
epoch_loss_target_limits.(at_epoch)
48+
Float.(epoch_loss_target_limits.(at_epoch) > epoch_loss);
4749
if at_epoch > 10 && at_epoch % 10 = 0 then Stdio.printf ".%!"
4850
in
4951
let {
@@ -63,7 +65,7 @@ let main () =
6365
()
6466
in
6567
let epoch_loss = List.hd_exn rev_epoch_losses in
66-
Stdio.printf "Epoch loss: %.3f\n%!" epoch_loss;
68+
Stdio.printf "\nFinal epoch loss under target 0.002: %b\n%!" Float.(0.002 > epoch_loss);
6769
(* if Float.(epoch_loss < 1.5) then Stdio.printf "Success\n" else *)
6870
let points = Tn.points_2d ~xdim:0 ~ydim:1 inputs.value in
6971
let classes = Tn.points_1d ~xdim:0 outputs.value in
@@ -93,11 +95,12 @@ let main () =
9395
]
9496
in
9597
(* PrintBox_text.output Stdio.stdout plot_loss; *)
96-
Stdio.printf "\nEpoch loss:\n%!";
97-
let plot_loss =
98-
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch loss" ~small:true
98+
(* Stdio.printf "\nEpoch loss:\n%!"; *)
99+
let _plot_loss =
100+
PrintBox_utils.plot ~x_label:"step" ~y_label:"epoch loss"
99101
[ Line_plot { points = Array.of_list_rev rev_epoch_losses; content = PrintBox.line "-" } ]
100102
in
101-
PrintBox_text.output Stdio.stdout plot_loss
103+
(* PrintBox_text.output Stdio.stdout plot_loss *)
104+
()
102105

103106
let () = main ()

0 commit comments

Comments
 (0)