@@ -53,7 +53,8 @@ let%expect_test "Half-moons data parallel" =
5353 epoch_loss
5454 in
5555 let module Backend = (val backend) in
56- let inputs, outputs, _model_result, infer_callback, batch_losses, epoch_losses, learning_rates =
56+ let inputs, outputs, _model_result, infer_callback, _batch_losses, _epoch_losses, _learning_rates
57+ =
5758 Train. example_train_loop ~seed ~batch_size ~max_num_devices: (batch_size / 2 ) ~init_lr
5859 ~data_len: len ~epochs ~inputs: moons_flat ~outputs: moons_classes ~model: mlp ~loss_fn
5960 ~weight_decay ~per_batch_callback ~per_epoch_callback
@@ -66,60 +67,36 @@ let%expect_test "Half-moons data parallel" =
6667 let callback (x , y ) = Float. ((infer_callback [| x; y |]).(0 ) > = 0. ) in
6768 let plot_moons =
6869 let open PrintBox_utils in
69- plot ~size: (120 , 40 ) ~x_label: " ixes " ~y_label: " ygreks "
70+ plot ~no_axes: true ~ size: (120 , 40 )
7071 [
7172 Scatterplot { points = points1; pixel = " #" };
7273 Scatterplot { points = points2; pixel = " %" };
7374 Boundary_map { pixel_false = " ." ; pixel_true = " *" ; callback };
7475 ]
7576 in
76- Stdio. printf " \n Half-moons scatterplot and decision boundary:\n %! " ;
77+ Stdio. printf " \n Half-moons scatterplot and decision boundary:\n " ;
7778 PrintBox_text. output Stdio. stdout plot_moons;
78- Stdio. printf " \n Batch Loss:\n %!" ;
79- let plot_loss =
80- let open PrintBox_utils in
81- plot ~size: (120 , 30 ) ~x_label: " step" ~y_label: " batch loss"
82- [ Line_plot { points = Array. of_list_rev batch_losses; pixel = " -" } ]
83- in
84- PrintBox_text. output Stdio. stdout plot_loss;
85- Stdio. printf " \n Epoch Loss:\n %!" ;
86- let plot_loss =
87- let open PrintBox_utils in
88- plot ~size: (120 , 30 ) ~x_label: " step" ~y_label: " epoch loss"
89- [ Line_plot { points = Array. of_list_rev epoch_losses; pixel = " -" } ]
79+ (* NOTE: as of OCANNL 0.4, moons_demo_parallel, while deterministic on a single machine, gives
80+ slightly different results on machines with a different hardware, e.g. arm64, ppc. Here we list
81+ the results from the various CI targets. The first result is the one typically observed, the
82+ second comes from targets debian-arm64 and debian-s390x, the third one from debian-ppc. *)
83+ let result = [% expect.output] in
84+ let typical_target =
85+ {| | }
9086 in
91- PrintBox_text. output Stdio. stdout plot_loss;
92- Stdio. printf " \n Batch Log-loss:\n %!" ;
93- let plot_loss =
94- let open PrintBox_utils in
95- plot ~size: (120 , 30 ) ~x_label: " step" ~y_label: " batch log loss"
96- [
97- Line_plot
98- {
99- points =
100- Array. of_list_rev_map batch_losses ~f: Float. (fun x -> max (log 0.00003 ) (log x));
101- pixel = " -" ;
102- };
103- ]
87+ let arm64_and_s390x_target =
88+ {| | }
10489 in
105- PrintBox_text. output Stdio. stdout plot_loss;
106- Stdio. printf " \n Epoch Log-loss:\n %!" ;
107- let plot_loss =
108- let open PrintBox_utils in
109- plot ~size: (120 , 30 ) ~x_label: " step" ~y_label: " epoch log loss"
110- [ Line_plot { points = Array. of_list_rev_map epoch_losses ~f: Float. log; pixel = " -" } ]
90+ let ppc64_target =
91+ {| | }
11192 in
112- PrintBox_text. output Stdio. stdout plot_loss;
113- Stdio. printf " \n Learning rate:\n %!" ;
114- let plot_lr =
115- let open PrintBox_utils in
116- plot ~size: (120 , 30 ) ~x_label: " step" ~y_label: " learning rate"
117- [ Line_plot { points = Array. of_list_rev learning_rates; pixel = " -" } ]
93+ let result_as_expected =
94+ List. mem
95+ [ typical_target; arm64_and_s390x_target; ppc64_target ]
96+ result ~equal: String. equal
11897 in
119- PrintBox_text. output Stdio. stdout plot_lr;
120- (* NOTE: as of OCANNL 0.4, moons_demo_parallel, while deterministic on a single machine, gives
121- slightly different results on machines with a different hardware, e.g. arm64, ppc. Here we list
122- the results from the various CI targets. The first result is the one typically observed, the
123- second comes from targets debian-arm64 and debian-s390x, the third one from debian-ppc. *)
124- let result = [% expect.output] in
125- Stdio. printf " \n R:%s\n END\n %!" result
98+ if result_as_expected then Stdio. print_string " moons_demo_parallel result is as expected"
99+ else (
100+ Stdio. print_endline " Unexpected result:" ;
101+ Stdio. print_string result);
102+ [% expect " moons_demo_parallel result is as expected" ]
0 commit comments