Skip to content

Commit fb4b363

Browse files
committed
First pass on a slew of bugs uncovered by the Metal backend
Very strange to be getting session-level bugs in one of the backends but not the other.
1 parent 5919417 commit fb4b363

File tree

10 files changed

+105
-43
lines changed

10 files changed

+105
-43
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -274,6 +274,12 @@ struct
274274

275275
let binop_syntax prec op v1 v2 =
276276
match op with
277+
| Ops.Threefry4x32 -> (
278+
match prec with
279+
| Ops.Uint4x32_prec _ ->
280+
let open PPrint in
281+
group (string "arrayjit_threefry4x32(" ^^ v1 ^^ string ", " ^^ v2 ^^ string ")")
282+
| _ -> invalid_arg "Pure_C_config.binop_syntax: Threefry4x32 on non-uint4x32 precision")
277283
| Ops.Satur01_gate -> (
278284
match prec with
279285
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ | Ops.Uint4x32_prec _ ->

arrayjit/lib/cuda_backend.ml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -588,7 +588,13 @@ end) : Ir.Backend_impl.Lowered_backend = struct
588588
| Cmpeq, _ -> f "=="
589589
| Or, _ -> f "||"
590590
| And, _ -> f "&&"
591-
| Threefry4x32, _ -> func "arrayjit_threefry4x32"
591+
| Threefry4x32, _ ->
592+
(* Threefry4x32 must output to uint4x32 precision *)
593+
(match prec with
594+
| Ops.Uint4x32_prec _ -> func "arrayjit_threefry4x32"
595+
| _ -> raise @@ Utils.User_error
596+
(Printf.sprintf "CUDA backend: Threefry4x32 requires target precision to be uint4x32, but got %s"
597+
(Ops.prec_string prec)))
592598

593599
let unop_syntax prec v =
594600
let open PPrint in

arrayjit/lib/low_level.ml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -531,11 +531,17 @@ let inline_computation ~id computations_table traced static_indices call_args =
531531
@@ List.map ~f:(fun s -> s.Indexing.static_symbol) static_indices
532532
in
533533
let make_subst i lhs_ind =
534-
let rhs_ind = call_args.(i) in
535-
match lhs_ind with
536-
| Indexing.Iterator lhs_s when not (Set.mem static_indices lhs_s) -> Some (lhs_s, rhs_ind)
537-
| _ when Indexing.equal_axis_index lhs_ind rhs_ind -> None
538-
| _ -> raise @@ Non_virtual 13
534+
if i >= Array.length call_args then
535+
failwith
536+
[%string
537+
"make_subst: call_args too short, maybe stale optimization context? Tnode: \
538+
%{Tn.debug_name traced.tn} #%{traced.tn.Tn.id#Int} i: %{i#Int}"]
539+
else
540+
let rhs_ind = call_args.(i) in
541+
match lhs_ind with
542+
| Indexing.Iterator lhs_s when not (Set.mem static_indices lhs_s) -> Some (lhs_s, rhs_ind)
543+
| _ when Indexing.equal_axis_index lhs_ind rhs_ind -> None
544+
| _ -> raise @@ Non_virtual 13
539545
in
540546
(* In the order of computation. *)
541547
let loop_proc (def_args, def) : t option =

arrayjit/lib/metal_backend.ml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -448,13 +448,13 @@ end) : Ir.Backend_impl.Lowered_backend = struct
448448
| Ops.Bfloat16_prec _ -> "bfloat" (* Metal supports bfloat16 natively *)
449449
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
450450
| Ops.Single_prec _ -> "float"
451-
| Ops.Double_prec _ -> "double"
451+
| Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
452452
| Ops.Void_prec -> "void"
453453

454454
let vec_typ_of_prec ~length prec =
455455
match (prec, length) with
456456
| Ops.Single_prec _, 4 -> "float4_t"
457-
| Ops.Double_prec _, 2 -> "float2_t" (* Metal uses float2 since it lacks double *)
457+
| Ops.Double_prec _, 2 -> raise @@ Utils.User_error "Metal backend does not support double precision"
458458
| Ops.Int32_prec _, 4 -> "int32x4_t"
459459
| (Ops.Byte_prec _ | Ops.Fp8_prec _), 16 -> "int8x16_t"
460460
| (Ops.Uint16_prec _ | Ops.Bfloat16_prec _), 8 -> "uint16x8_t"
@@ -472,7 +472,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
472472
| Ops.Bfloat16_prec _ -> "bf" (* TODO: Verify actual Metal suffix for bfloat16 *)
473473
| Ops.Fp8_prec _ -> invalid_arg "Metal backend does not support FP8 precision"
474474
| Ops.Single_prec _ -> "f"
475-
| Ops.Double_prec _ -> ""
475+
| Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
476476
| Ops.Void_prec -> ""
477477

478478
let ternop_syntax _prec op =
@@ -514,12 +514,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
514514
^^ space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
515515
^^ string "0.0f"))
516516
| Relu_gate, Ops.Double_prec _ ->
517-
fun v1 v2 ->
518-
group
519-
(parens
520-
(group (parens (v1 ^^ string " > 0.0"))
521-
^^ space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
522-
^^ string "0.0"))
517+
raise @@ Utils.User_error "Metal backend does not support double precision"
523518
| Relu_gate, _ (* Byte_prec, Void_prec *) ->
524519
fun v1 v2 ->
525520
group
@@ -537,7 +532,13 @@ end) : Ir.Backend_impl.Lowered_backend = struct
537532
^^ space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
538533
^^ string ("0.0" ^ s)))
539534
| ToPowOf, _ -> func "pow"
540-
| Threefry4x32, _ -> func "arrayjit_threefry4x32"
535+
| Threefry4x32, _ ->
536+
(* Threefry4x32 must output to uint4x32 precision *)
537+
(match prec with
538+
| Ops.Uint4x32_prec _ -> func "arrayjit_threefry4x32"
539+
| _ -> raise @@ Utils.User_error
540+
(Printf.sprintf "Metal backend: Threefry4x32 requires target precision to be uint4x32, but got %s"
541+
(Ops.prec_string prec)))
541542
| Arg1, _ | Arg2, _ -> invalid_arg "Metal C_syntax_config: Arg1/Arg2 not operators"
542543

543544
let unop_syntax prec op =
@@ -554,7 +555,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
554555
| Sqrt, _ -> func_doc "sqrt"
555556
| Relu, Ops.Half_prec _ -> fun v -> func_doc "max" (separate comma_sep [ string "0.0h"; v ])
556557
| Relu, Ops.Single_prec _ -> fun v -> func_doc "max" (separate comma_sep [ string "0.0f"; v ])
557-
| Relu, Ops.Double_prec _ -> fun v -> func_doc "max" (separate comma_sep [ string "0.0"; v ])
558+
| Relu, Ops.Double_prec _ -> raise @@ Utils.User_error "Metal backend does not support double precision"
558559
| Relu, _ (* Byte_prec, Void_prec *) ->
559560
fun v -> func_doc "max" (separate comma_sep [ string "0"; v ])
560561
| Satur01, p ->

arrayjit/lib/ops.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ let binop_c_syntax prec v =
485485
| Or, _ -> ("(", " ||", ")")
486486
| And, _ -> ("(", " &&", ")")
487487
| Threefry4x32, _ ->
488-
(* This corresponds to the pure C implementation in arrayjit_stubs.c. *)
488+
(* This corresponds to the pure C implementation in builtins.c. *)
489489
("arrayjit_threefry4x32(", ",", ")")
490490

491491
let is_assign_op = function

test/einsum/moons_demo_variant.expected

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
33
Tnode: collecting accessible arrays...
4-
n0 moons_flat as moons_flat: Host-const/37; double prec 40x10x2; mem in bytes: 6_400
5-
n1 moons_classes as moons_classes: Host-const/37; double prec 40x10x1; mem in bytes: 3_200
4+
n0 moons_flat as moons_flat: Host-const/37; single prec 40x10x2; mem in bytes: 3_200
5+
n1 moons_classes as moons_classes: Host-const/37; single prec 40x10x1; mem in bytes: 1_600
66
n2 range_over_offsets as range_over_offsets: Virt/15; single prec 4; mem in bytes: <not-in-yet>
77
n3 !@self_id as n3: Virt/40; single prec 1; mem in bytes: <not-in-yet>
88
n4 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
@@ -28,42 +28,42 @@ n23 !@self_id as n23: Virt/40; single prec 1; mem in bytes: <not-in-yet>
2828
n24 threefry4x32 as threefry4x32: Virt/15; single prec 4; mem in bytes: <not-in-yet>
2929
n25 w3 as w3: Host&stream/412410; single prec 1x16; mem in bytes: <not-in-yet>
3030
n26 grad_w3 as w3.grad: Local/26046; single prec 1x16; mem in bytes: <not-in-yet>
31-
n27 @|_moons_input as moons_input: Local/1046; double prec 10x2; mem in bytes: <not-in-yet>
32-
n30 @|_moons_class as moons_class: Local/1046; double prec 10x1; mem in bytes: <not-in-yet>
33-
n33 * as n33: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
31+
n27 @|_moons_input as moons_input: Local/1046; single prec 10x2; mem in bytes: <not-in-yet>
32+
n30 @|_moons_class as moons_class: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
33+
n33 * as n33: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
3434
n34 grad_* as n33.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
35-
n35 + as n35: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
35+
n35 + as n35: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
3636
n36 grad_+ as n35.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
37-
n37 relu as relu: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
37+
n37 relu as relu: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
3838
n38 grad_relu as relu.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
39-
n39 * as n39: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
39+
n39 * as n39: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
4040
n40 grad_* as n39.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
41-
n41 + as n41: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
41+
n41 + as n41: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
4242
n42 grad_+ as n41.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
43-
n43 relu as relu: Local/1046; double prec 10x16; mem in bytes: <not-in-yet>
43+
n43 relu as relu: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
4444
n44 grad_relu as relu.grad: Local/1046; single prec 10x16; mem in bytes: <not-in-yet>
45-
n45 * as n45: Local/1046; double prec 10x1; mem in bytes: <not-in-yet>
45+
n45 * as n45: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
4646
n46 grad_* as n45.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
4747
n47 0.5 as n47: Virt/40; single prec 1; mem in bytes: <not-in-yet>
48-
n48 +_mlp_@|_moons_input as mlp_moons_input: Virt/15; double prec 10x1; mem in bytes: <not-in-yet>
48+
n48 +_mlp_@|_moons_input as mlp_moons_input: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
4949
n49 grad_+_mlp_@|_moons_input as mlp_moons_input.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
50-
n50 *. as n50: Virt/15; double prec 10x1; mem in bytes: <not-in-yet>
50+
n50 *. as n50: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
5151
n51 grad_*. as n50.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
5252
n52 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
53-
n53 - as n53: Local/1046; double prec 10x1; mem in bytes: <not-in-yet>
53+
n53 - as n53: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
5454
n54 grad_- as n53.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
55-
n55 relu_margin_loss as relu_margin_loss: Virt/15; double prec 10x1; mem in bytes: <not-in-yet>
55+
n55 relu_margin_loss as relu_margin_loss: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
5656
n56 grad_relu_margin_loss as relu_margin_loss.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
5757
n57 10 as 10: Virt/40; single prec 1; mem in bytes: <not-in-yet>
58-
n58 => as n58: Local/1046; double prec 1; mem in bytes: <not-in-yet>
58+
n58 => as n58: Local/1046; single prec 1; mem in bytes: <not-in-yet>
5959
n59 grad_=> as n58.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
60-
n60 /._scalar_loss as scalar_loss: Host&stream/412410; double prec 1; mem in bytes: 8
60+
n60 /._scalar_loss as scalar_loss: Host&stream/412410; single prec 1; mem in bytes: 4
6161
n61 grad_/._scalar_loss as scalar_loss.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6262
n62 2 as 2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6363
n63 **. as n63: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6464
n64 -1 as n64: Virt/40; single prec 1; mem in bytes: <not-in-yet>
65-
n65 *. as n65: Virt/152; double prec 1; mem in bytes: <not-in-yet>
66-
n66 /. as n66: Host&stream/412410; double prec 1; mem in bytes: <not-in-yet>
65+
n65 *. as n65: Virt/152; single prec 1; mem in bytes: <not-in-yet>
66+
n66 /. as n66: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
6767
n67 1 as 1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6868
n68 80 as 80: Virt/40; single prec 1; mem in bytes: <not-in-yet>
6969
n69 !@ as n69: Virt/152; single prec 1; mem in bytes: <not-in-yet>

test/einsum/moons_demo_variant.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,11 @@ let () =
2020
let epochs = 2 in
2121
let steps = epochs * 2 * len / batch_size in
2222
let moons_config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some 5 } in
23-
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
24-
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
25-
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
23+
let moons_coordinates, moons_labels =
24+
Datasets.Half_moons.generate_single_prec ~config:moons_config ~len ()
25+
in
26+
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Single moons_coordinates in
27+
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Single moons_labels in
2628
let batch_n, bindings = IDX.get_static_symbol ~static_range:n_batches IDX.empty in
2729
let step_n, bindings = IDX.get_static_symbol bindings in
2830
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in

test/operations/dune

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,13 @@
8080
(preprocess
8181
(pps ppx_here ppx_ocannl)))
8282

83+
(test
84+
(name test_threefry_precision)
85+
(modules test_threefry_precision)
86+
(libraries base ocannl)
87+
(preprocess
88+
(pps ppx_here ppx_ocannl)))
89+
8390
(library
8491
(name operations_tutorials)
8592
(package neural_nets_lib)
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
open Base
2+
open Ocannl
3+
4+
let () =
5+
Utils.settings.output_debug_files_in_build_directory <- true;
6+
Utils.settings.log_level <- 1;
7+
let module TDSL = Operation.TDSL in
8+
9+
(* Create a simple Threefry4x32 operation *)
10+
let key = TDSL.number ~label:["key"] 42.0 in
11+
let counter = TDSL.number ~label:["counter"] 1.0 in
12+
let rng_result = TDSL.threefry4x32 ~label:["rng_result"] key counter () in
13+
14+
(* Print the precision of the result *)
15+
Stdlib.Printf.printf "Threefry4x32 result precision: %s\n"
16+
(Ir.Ops.prec_string (Lazy.force rng_result.value.prec));
17+
18+
(* Try to use it in a computation - this should trigger the error *)
19+
let uniform_result = TDSL.uint4x32_to_prec_uniform ~label:["uniform"] rng_result () in
20+
Stdlib.Printf.printf "Uniform result precision: %s\n"
21+
(Ir.Ops.prec_string (Lazy.force uniform_result.value.prec));
22+
let module Backend = (val Backends.fresh_backend ()) in
23+
try
24+
let _ctx = Train.forward_once (module Backend) uniform_result in
25+
Stdlib.Printf.printf "Compilation successful!\n";
26+
(* Also check the actual value precision in the context *)
27+
let tn = rng_result.value in
28+
Stdlib.Printf.printf "Actual tensor precision in context: %s\n"
29+
(Ir.Ops.prec_string (Lazy.force tn.prec))
30+
with
31+
| Utils.User_error msg -> Stdlib.Printf.printf "Error: %s\n" msg
32+
| e -> Stdlib.Printf.printf "Unexpected error: %s\n" (Exn.to_string e)

test/training/moons_demo_parallel.ml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,11 @@ let main () =
2222
let epochs = 60 in
2323
(* let epochs = 1 in *)
2424
let moons_config = Datasets.Half_moons.Config.{ noise_range = 0.1; seed = Some seed } in
25-
let moons_coordinates, moons_labels = Datasets.Half_moons.generate ~config:moons_config ~len () in
26-
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_coordinates in
27-
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Double moons_labels in
25+
let moons_coordinates, moons_labels =
26+
Datasets.Half_moons.generate_single_prec ~config:moons_config ~len ()
27+
in
28+
let moons_flat_ndarray = Ir.Ndarray.as_array Ir.Ops.Single moons_coordinates in
29+
let moons_classes_ndarray = Ir.Ndarray.as_array Ir.Ops.Single moons_labels in
2830
let moons_flat = TDSL.rebatch ~l:"moons_flat" moons_flat_ndarray () in
2931
let moons_classes = TDSL.rebatch ~l:"moons_classes" moons_classes_ndarray () in
3032
let%op mlp x = "w3" * relu ("b2" hid_dim + ("w2" * relu ("b1" hid_dim + ("w1" * x)))) in

0 commit comments

Comments
 (0)