@@ -54,13 +54,14 @@ let%expect_test "uint4x32_to_prec_uniform different precisions" =
5454 let key = Ocannl.Tensor. get_random_seed () in
5555 let counter = TDSL. range 5 in
5656 let random_bits = O. threefry4x32 key counter in
57+ let ctx = ref None in
5758
5859 (* Test different target precisions *)
5960 let test_precision prec prec_name =
6061 let uniform = O. uint4x32_to_prec_uniform random_bits in
6162 Ir.Tnode. update_prec uniform.value prec;
6263 Ocannl.Train. set_hosted uniform.value;
63- ignore (Ocannl.Train. forward_once (module Backend ) uniform);
64+ ctx : = Some (Ocannl.Train .forward_once (module Backend ) ?ctx: ! ctx uniform );
6465 let result = Ir.Tnode. get_values uniform.value in
6566 Stdio. printf " %s precision - first value: %f, second value: %f\n " prec_name result.(0 )
6667 result.(1 );
@@ -69,15 +70,14 @@ let%expect_test "uint4x32_to_prec_uniform different precisions" =
6970 in
7071
7172 test_precision Ir.Ops. single " Single" ;
72- test_precision Ir.Ops. double " Double" ;
73+ (* Metal backend doesn't support double precision. *)
74+ (* test_precision Ir.Ops.double "Double"; *)
7375 test_precision Ir.Ops. half " Half" ;
7476
7577 [% expect
7678 {|
7779 Single precision - first value : 0.756113 , second value : 0.758716
7880 All values in [0 , 1 ) range : true
79- Double precision - first value : 0.756113 , second value : 0.758716
80- All values in [0 , 1 ) range : true
8181 Half precision - first value : 0.756113 , second value : 0.758716
8282 All values in [0 , 1 ) range : true
8383 |}]
0 commit comments