Skip to content

Commit e650a3e

Browse files
committed
Experiment: observe more concise output for moons_demo_parallel
1 parent 9978033 commit e650a3e

File tree

1 file changed

+24
-47
lines changed

1 file changed

+24
-47
lines changed

test/moons_demo_parallel.ml

Lines changed: 24 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,8 @@ let%expect_test "Half-moons data parallel" =
5353
epoch_loss
5454
in
5555
let module Backend = (val backend) in
56-
let inputs, outputs, _model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
56+
let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
57+
=
5758
Train.example_train_loop ~seed ~batch_size ~max_num_devices:(batch_size / 2) ~init_lr
5859
~data_len:len ~epochs ~inputs:moons_flat ~outputs:moons_classes ~model:mlp ~loss_fn
5960
~weight_decay ~per_batch_callback ~per_epoch_callback
@@ -66,60 +67,36 @@ let%expect_test "Half-moons data parallel" =
6667
let callback (x, y) = Float.((infer_callback [| x; y |]).(0) >= 0.) in
6768
let plot_moons =
6869
let open PrintBox_utils in
69-
plot ~size:(120, 40) ~x_label:"ixes" ~y_label:"ygreks"
70+
plot ~no_axes:true ~size:(120, 40)
7071
[
7172
Scatterplot { points = points1; pixel = "#" };
7273
Scatterplot { points = points2; pixel = "%" };
7374
Boundary_map { pixel_false = "."; pixel_true = "*"; callback };
7475
]
7576
in
76-
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n%!";
77+
Stdio.printf "\nHalf-moons scatterplot and decision boundary:\n";
7778
PrintBox_text.output Stdio.stdout plot_moons;
78-
Stdio.printf "\nBatch Loss:\n%!";
79-
let plot_loss =
80-
let open PrintBox_utils in
81-
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch loss"
82-
[ Line_plot { points = Array.of_list_rev batch_losses; pixel = "-" } ]
83-
in
84-
PrintBox_text.output Stdio.stdout plot_loss;
85-
Stdio.printf "\nEpoch Loss:\n%!";
86-
let plot_loss =
87-
let open PrintBox_utils in
88-
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch loss"
89-
[ Line_plot { points = Array.of_list_rev epoch_losses; pixel = "-" } ]
79+
(* NOTE: as of OCANNL 0.4, moons_demo_parallel, while deterministic on a single machine, gives
80+
slightly different results on machines with a different hardware, e.g. arm64, ppc. Here we list
81+
the results from the various CI targets. The first result is the one typically observed, the
82+
second comes from targets debian-arm64 and debian-s390x, the third one from debian-ppc. *)
83+
let result = [%expect.output] in
84+
let typical_target =
85+
{| |}
9086
in
91-
PrintBox_text.output Stdio.stdout plot_loss;
92-
Stdio.printf "\nBatch Log-loss:\n%!";
93-
let plot_loss =
94-
let open PrintBox_utils in
95-
plot ~size:(120, 30) ~x_label:"step" ~y_label:"batch log loss"
96-
[
97-
Line_plot
98-
{
99-
points =
100-
Array.of_list_rev_map batch_losses ~f:Float.(fun x -> max (log 0.00003) (log x));
101-
pixel = "-";
102-
};
103-
]
87+
let arm64_and_s390x_target =
88+
{| |}
10489
in
105-
PrintBox_text.output Stdio.stdout plot_loss;
106-
Stdio.printf "\nEpoch Log-loss:\n%!";
107-
let plot_loss =
108-
let open PrintBox_utils in
109-
plot ~size:(120, 30) ~x_label:"step" ~y_label:"epoch log loss"
110-
[ Line_plot { points = Array.of_list_rev_map epoch_losses ~f:Float.log; pixel = "-" } ]
90+
let ppc64_target =
91+
{| |}
11192
in
112-
PrintBox_text.output Stdio.stdout plot_loss;
113-
Stdio.printf "\nLearning rate:\n%!";
114-
let plot_lr =
115-
let open PrintBox_utils in
116-
plot ~size:(120, 30) ~x_label:"step" ~y_label:"learning rate"
117-
[ Line_plot { points = Array.of_list_rev learning_rates; pixel = "-" } ]
93+
let result_as_expected =
94+
List.mem
95+
[ typical_target; arm64_and_s390x_target; ppc64_target ]
96+
result ~equal:String.equal
11897
in
119-
PrintBox_text.output Stdio.stdout plot_lr;
120-
(* NOTE: as of OCANNL 0.4, moons_demo_parallel, while deterministic on a single machine, gives
121-
slightly different results on machines with a different hardware, e.g. arm64, ppc. Here we list
122-
the results from the various CI targets. The first result is the one typically observed, the
123-
second comes from targets debian-arm64 and debian-s390x, the third one from debian-ppc. *)
124-
let result = [%expect.output] in
125-
Stdio.printf "\nR:%s\nEND\n%!" result
98+
if result_as_expected then Stdio.print_string "moons_demo_parallel result is as expected"
99+
else (
100+
Stdio.print_endline "Unexpected result:";
101+
Stdio.print_string result);
102+
[%expect "moons_demo_parallel result is as expected"]

0 commit comments

Comments
 (0)