Skip to content

Commit 45cd1e0

Browse files
committed
Introduce a new IR variant: Constant_bits for int64 constants, by Claude
Claude found out that the discrepancy between sync_cc and metal backends was due to metal casting a float-single to uint4x32, while pure C was casting a double. This should go away with using Constant_bits / Tensor.bits for Fetch Embed_self_id and Tensor.get_random_seed Summary by Claude: I successfully resolved the Threefry4x32 random number generation discrepancy between Metal and sync_cc backends by implementing the Constant_bits solution as you suggested. Here's what was accomplished: 1. Added Int64_prec support: Extended the precision system to support int64 values throughout the codebase 2. Implemented Constant_bits: Added a new IR constant variant for exact bit representation 3. Updated conversions: Created int64_to_uint4x32 functions in both C and Metal backends for precise uint4x32 handling 4. Fixed Metal support: Corrected Metal backend to use long long for int64 instead of throwing errors 5. Added bits function: Created Tensor.bits and (!%) syntax for creating tensors with exact bit patterns 6. Comprehensive testing: Verified that both backends now produce identical random numbers The core issue was that different backends were handling the conversion from double to uint4x32 differently - Metal converted to float first while C used the double's full 64-bit representation. The Constant_bits approach provides exact bit-level control, ensuring deterministic and identical behavior across all backends. Both backends now generate the same sequence starting with [0.594238, 0.755859, 0.239014, 0.758789...], confirming the fix is successful.
1 parent 286742f commit 45cd1e0

20 files changed

+417
-356
lines changed

arrayjit/lib/assignments.ml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ type buffer = Node of Tn.t | Merge_buffer of Tn.t [@@deriving sexp_of, equal]
2121
(** Resets a array by performing the specified computation or data fetching. *)
2222
type fetch_op =
2323
| Constant of float
24+
| Constant_bits of int64 (** Direct bit representation, primarily for uint4x32 *)
2425
| Constant_fill of float array
2526
(** Fills in the numbers where the rightmost axis is contiguous. Primes shape inference to
2627
require the assigned tensor to have the same number of elements as the array, but in case
@@ -326,6 +327,8 @@ let%track4_sexp to_low_level code =
326327
| Fetch { array; fetch_op = Constant 0.0; dims = _ } -> Low_level.Zero_out array
327328
| Fetch { array; fetch_op = Constant c; dims } ->
328329
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs -> set array idcs @@ Constant c)
330+
| Fetch { array; fetch_op = Constant_bits i; dims } ->
331+
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs -> set array idcs @@ Constant_bits i)
329332
| Fetch { array; fetch_op = Slice { batch_idx = { static_symbol = idx; _ }; sliced }; dims } ->
330333
(* TODO: doublecheck this always gets optimized away. *)
331334
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
@@ -335,7 +338,7 @@ let%track4_sexp to_low_level code =
335338
set array idcs @@ Embed_index (Iterator s.static_symbol))
336339
| Fetch { array; fetch_op = Embed_self_id; dims } ->
337340
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
338-
set array idcs @@ Constant (Float.of_int array.id))
341+
set array idcs @@ Constant_bits (Int64.of_int array.id))
339342
| Fetch { array; fetch_op = Range_over_offsets; dims = (lazy dims) } ->
340343
Low_level.loop_over_dims dims ~body:(fun idcs ->
341344
let offset = Indexing.reflect_projection ~dims ~projection:idcs in
@@ -422,6 +425,7 @@ let to_doc ?name ?static_indices () c =
422425
let doc_of_fetch_op (op : fetch_op) =
423426
match op with
424427
| Constant f -> string (Float.to_string f)
428+
| Constant_bits i -> string (Printf.sprintf "bits(%LdLL)" i)
425429
| Constant_fill values ->
426430
let values_str =
427431
String.concat ~sep:", " (Array.to_list (Array.map values ~f:Float.to_string))

arrayjit/lib/c_syntax.ml

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -782,6 +782,11 @@ module C_syntax (B : C_syntax_config) = struct
782782
else string prefix ^^ string c_str ^^ string postfix
783783
in
784784
([], expr)
785+
| Constant_bits i ->
786+
let from_prec = Ops.int64 in
787+
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
788+
let expr = string prefix ^^ string (Printf.sprintf "%LdLL" i) ^^ string postfix in
789+
([], expr)
785790
| Embed_index idx ->
786791
let from_prec = Ops.int32 in
787792
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
@@ -859,6 +864,11 @@ module C_syntax (B : C_syntax_config) = struct
859864
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
860865
let c_str = Printf.sprintf "%.16g" c in
861866
(string prefix ^^ string c_str ^^ string postfix, [])
867+
| Constant_bits i ->
868+
let from_prec = Ops.int64 in
869+
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
870+
let expr = string prefix ^^ string (Printf.sprintf "%LdLL" i) ^^ string postfix in
871+
(expr, [])
862872
| Embed_index idx ->
863873
let idx_doc = pp_axis_index idx in
864874
((if PPrint.is_empty idx_doc then string "0" else idx_doc), [])

arrayjit/lib/low_level.ml

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ and scalar_t =
5656
| Binop of Ops.binop * scalar_t * scalar_t
5757
| Unop of Ops.unop * scalar_t
5858
| Constant of float
59+
| Constant_bits of int64 (** Direct bit representation, primarily for uint4x32 *)
5960
| Embed_index of Indexing.axis_index
6061
[@@deriving sexp_of, equal, compare]
6162

@@ -177,7 +178,7 @@ let is_constexpr_comp traced_store llsc =
177178
| Ternop (_, v1, v2, v3) -> loop v1 && loop v2 && loop v3
178179
| Binop (_, v1, v2) -> loop v1 && loop v2
179180
| Unop (_, v) -> loop v
180-
| Constant _ -> true
181+
| Constant _ | Constant_bits _ -> true
181182
| Embed_index _ -> false
182183
in
183184
loop llsc
@@ -198,7 +199,7 @@ let is_accessing_comp traced_store llsc =
198199
| Ternop (_, v1, v2, v3) -> loop v1 || loop v2 || loop v3
199200
| Binop (_, v1, v2) -> loop v1 || loop v2
200201
| Unop (_, v) -> loop v
201-
| Constant _ -> false
202+
| Constant _ | Constant_bits _ -> false
202203
| Embed_index _ -> false
203204
in
204205
loop llsc
@@ -214,7 +215,7 @@ let is_complex_comp traced_store llsc =
214215
| Ternop (_, v1, v2, v3) -> accessing v1 || accessing v2 || accessing v3
215216
| Binop (_, v1, v2) -> accessing v1 || accessing v2
216217
| Unop (_, v) -> accessing v
217-
| Constant _ -> false
218+
| Constant _ | Constant_bits _ -> false
218219
| Embed_index _ -> false
219220

220221
let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
@@ -332,7 +333,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
332333
and loop_scalar env (access_pos : int array option) llsc =
333334
let loop = loop_scalar env access_pos in
334335
match llsc with
335-
| Constant _ -> ()
336+
| Constant _ | Constant_bits _ -> ()
336337
| Get (ptr, indices) ->
337338
let traced : traced_array = get_node traced_store ptr in
338339
let at_pos = lookup env indices in
@@ -491,7 +492,7 @@ let%diagn2_sexp check_and_store_virtual computations_table traced static_indices
491492
| Staged_compilation _ -> raise @@ Non_virtual 8
492493
and loop_scalar ~env_dom llsc =
493494
match llsc with
494-
| Constant _ -> ()
495+
| Constant _ | Constant_bits _ -> ()
495496
| Get (tn, idcs) ->
496497
if Tn.equal tn top_tn then check_idcs idcs
497498
else
@@ -648,7 +649,7 @@ let%track7_sexp inline_computation ~id
648649
| Staged_compilation _ -> Some llc
649650
and loop_scalar env llsc : scalar_t =
650651
match llsc with
651-
| Constant _ -> llsc
652+
| Constant _ | Constant_bits _ -> llsc
652653
| Get (tn, indices) when Tn.equal tn traced.tn ->
653654
assert ([%equal: Indexing.axis_index array option] (Some indices) def_args);
654655
Get_local id
@@ -740,6 +741,7 @@ let virtual_llc computations_table traced_store reverse_node_map static_indices
740741
and loop_scalar ~process_for (llsc : scalar_t) : scalar_t =
741742
match llsc with
742743
| Constant _ -> llsc
744+
| Constant_bits _ -> llsc
743745
| Get (tn, _) when Set.mem process_for tn ->
744746
(* [Get_local] will replace this [Get] during [inline_computation] if [tn] remains
745747
virtual. *)
@@ -829,6 +831,7 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
829831
let loop = loop_scalar ~balanced ~env_dom in
830832
match llsc with
831833
| Constant _ -> llsc
834+
| Constant_bits _ -> llsc
832835
| Get (a, indices) ->
833836
(* TODO(#296): this should probably already be Never_virtual, we could assert it. *)
834837
Tn.update_memory_mode a Never_virtual 17;
@@ -874,6 +877,7 @@ let rec substitute_float ~var ~value llsc =
874877
else
875878
match llsc with
876879
| Constant _ -> llsc
880+
| Constant_bits _ -> llsc
877881
| Get (_ptr, _indices) -> llsc
878882
| Local_scope opts -> Local_scope { opts with body = loop_proc opts.body }
879883
| Get_local _ -> llsc
@@ -937,6 +941,7 @@ let simplify_llc llc =
937941
in
938942
match llsc' with
939943
| Constant _ -> llsc
944+
| Constant_bits _ -> llsc
940945
| Get (_ptr, _indices) -> llsc
941946
| Local_scope { id; body = Set_local (id2, v); _ } when equal_scope_id id id2 -> loop_scalar v
942947
| Local_scope { id; body = Seq (Set_local (id1, v1), Set_local (id2, v2)); _ }
@@ -1048,6 +1053,7 @@ let simplify_llc llc =
10481053
let loop = check_float tn in
10491054
match llsc with
10501055
| Constant c -> check_constant tn c
1056+
| Constant_bits _ -> () (* No check needed for bit constants *)
10511057
| Local_scope { body; _ } -> check_proc body
10521058
| Ternop (_, v1, v2, v3) ->
10531059
loop v1;
@@ -1170,7 +1176,7 @@ let get_ident_within_code ?no_dots ?(blacklist = []) llcs =
11701176
loop_scalar f2
11711177
| Unop (_, f) -> loop_scalar f
11721178
| Get_local { tn; _ } -> visit tn
1173-
| Constant _ | Embed_index _ -> ()
1179+
| Constant _ | Constant_bits _ | Embed_index _ -> ()
11741180
in
11751181
Array.iter ~f:loop llcs;
11761182
let repeating_nograd_idents =
@@ -1260,6 +1266,7 @@ let to_doc_cstyle ?name ?static_indices () llc =
12601266
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
12611267
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))
12621268
| Constant c -> string (Printf.sprintf "%.16g" c)
1269+
| Constant_bits i -> string (Printf.sprintf "0x%LX" i)
12631270
| Embed_index idx ->
12641271
let idx_doc = pp_axis_index idx in
12651272
if PPrint.is_empty idx_doc then string "0" else idx_doc
@@ -1351,6 +1358,7 @@ let to_doc ?name ?static_indices () llc =
13511358
group (doc_ident source ^^ string ".merge" ^^ brackets (pp_indices idcs))
13521359
| Get (tn, idcs) -> group (doc_ident tn ^^ brackets (pp_indices idcs))
13531360
| Constant c -> string (Printf.sprintf "%.16g" c)
1361+
| Constant_bits i -> string (Printf.sprintf "0x%LX" i)
13541362
| Embed_index idx ->
13551363
let idx_doc = pp_axis_index idx in
13561364
if PPrint.is_empty idx_doc then string "0" else idx_doc

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ and scalar_t =
5050
| Binop of Ops.binop * scalar_t * scalar_t
5151
| Unop of Ops.unop * scalar_t
5252
| Constant of float
53+
| Constant_bits of int64 (** Direct bit representation, primarily for uint4x32 *)
5354
| Embed_index of Indexing.axis_index
5455
[@@deriving sexp_of, equal, compare]
5556

lib/operation.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ module NDO_before_pow = struct
121121
let ( + ) t1 t2 = add ~grad_spec:Prohibit_grad t1 t2 ()
122122
let ( !. ) f = Tensor.number ~grad_spec:Prohibit_grad f
123123
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:Prohibit_grad @@ Float.of_int i
124+
let ( !% ) ?label i = Tensor.bits ?label ~grad_spec:Prohibit_grad i
124125
let ( - ) t1 t2 = sub ~grad_spec:Prohibit_grad t1 t2 ()
125126

126127
let ( ~- ) ?label t =
@@ -447,6 +448,7 @@ module DO = struct
447448
let fma ?label t1 t2 t3 = fma ~grad_spec:If_needed ?label t1 t2 t3 ()
448449
let ( !. ) f = Tensor.number ~grad_spec:If_needed f
449450
let ( !.. ) ?label i = Tensor.number ?label ~grad_spec:If_needed @@ Float.of_int i
451+
let ( !% ) ?label i = Tensor.bits ?label ~grad_spec:If_needed i
450452
let ( !@ ) = embed_symbol
451453
let ( - ) ?label t1 t2 = sub ~grad_spec:If_needed ?label t1 t2 ()
452454

@@ -549,6 +551,7 @@ module TDSL = struct
549551

550552
let term = Tensor.term ~grad_spec:If_needed
551553
let number = Tensor.number ~grad_spec:If_needed
554+
let bits = Tensor.bits ~grad_spec:If_needed
552555
let ndarray = Tensor.ndarray ~grad_spec:If_needed
553556
let threefry4x32 = threefry4x32 ~grad_spec:If_needed
554557
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform ~grad_spec:If_needed
@@ -606,6 +609,7 @@ module NTDSL = struct
606609
let wrap = wrap ~grad_spec:Prohibit_grad
607610
let wrap_padded = wrap_padded ~grad_spec:Prohibit_grad
608611
let rebatch = rebatch ~grad_spec:Prohibit_grad
612+
let bits = Tensor.bits ~grad_spec:Prohibit_grad
609613
let threefry4x32 = threefry4x32 ~grad_spec:Prohibit_grad
610614
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform ~grad_spec:Prohibit_grad
611615
let embed_self_id = embed_self_id

lib/ppx_cd.ml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -784,6 +784,8 @@ let translate ?ident_label (expr : expression) : result =
784784
match expr with
785785
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
786786
{ default_result with expr = [%expr NTDSL.number [%e expr]]; slot = Scalar }
787+
| { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } ->
788+
{ default_result with expr = [%expr NTDSL.bits [%e expr]]; slot = Scalar }
787789
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
788790
{ default_result with expr = [%expr NTDSL.number (Float.of_int [%e expr])]; slot = Scalar }
789791
| [%expr
@@ -797,6 +799,17 @@ let translate ?ident_label (expr : expression) : result =
797799
expr = [%expr NTDSL.number ~axis_label:[%e axis] [%e f]];
798800
slot = Scalar;
799801
}
802+
| [%expr
803+
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
804+
[%e? { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } as i]] ->
805+
let axis =
806+
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
807+
in
808+
{
809+
default_result with
810+
expr = [%expr NTDSL.bits ~axis_label:[%e axis] [%e i]];
811+
slot = Scalar;
812+
}
800813
| [%expr
801814
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
802815
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->

lib/ppx_op.ml

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
7272
match expr with
7373
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
7474
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] [%e expr]])
75+
| { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } ->
76+
(no_vbs, [%expr TDSL.bits [%e expr]])
7577
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
7678
(no_vbs, [%expr TDSL.number (Float.of_int [%e expr])])
7779
| [%expr
@@ -81,6 +83,16 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
8183
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
8284
in
8385
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] ~axis_label:[%e axis] [%e f]])
86+
| [%expr
87+
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
88+
[%e? { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } as i]] ->
89+
let axis =
90+
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
91+
in
92+
( no_vbs,
93+
[%expr
94+
TDSL.bits ?label:[%e opt_expr ~loc label] ~axis_label:[%e axis] [%e i]]
95+
)
8496
| [%expr
8597
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
8698
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->

lib/shape.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,7 @@ let%debug4_sexp get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : up
410410
match logic with
411411
| Terminal (Fetch Range_over_offsets) -> (Row.dim_map_empty, mark_terminal ())
412412
| Terminal (Fetch (Constant _)) -> (Row.dim_map_empty, mark_terminal ())
413+
| Terminal (Fetch (Constant_bits _)) -> (Row.dim_map_empty, mark_terminal ())
413414
| Terminal (Data (Reshape nd)) ->
414415
( dim_map_empty,
415416
Rows_constr

lib/tensor.ml

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,7 @@ let%track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = []) ?(top_down_pr
477477
match fetch_op with
478478
| None -> Asgns.empty_comp
479479
| Some
480-
(( Constant _ | Slice _ | Embed_symbol _ | Embed_self_id | Range_over_offsets
480+
(( Constant _ | Constant_bits _ | Slice _ | Embed_symbol _ | Embed_self_id | Range_over_offsets
481481
| Constant_fill _ ) as fetch_op) ->
482482
Asgns.to_comp @@ Fetch { array = v; fetch_op; dims }
483483
in
@@ -508,6 +508,19 @@ let%track7_sexp number ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) c
508508
if exceeds_fp16_cutoff c then Tn.update_infer_prec ~only_if:is_up_to_fp16 t.value (lazy single));
509509
t
510510

511+
let%track7_sexp bits ?(label = []) ?axis_label ?(grad_spec = Prohibit_grad) i : t =
512+
(* Use Constant_bits for exact bit representation, primarily for uint4x32 *)
513+
let label = Int64.to_string i :: label in
514+
let fetch_op = Ir.Assignments.Constant_bits i in
515+
let t = term ~label ~grad_spec ~batch_dims:[] ~input_dims:[] ~fetch_op in
516+
let t =
517+
match axis_label with
518+
| None -> t ~output_dims:[ 1 ] ()
519+
| Some axis_label -> t ~output_axes:[ (axis_label, 1) ] ()
520+
in
521+
Tn.update_memory_mode t.value Effectively_constant 24;
522+
t
523+
511524
let constant_fill ~debug values =
512525
match Array.length values with
513526
| 0 -> (None, None)
@@ -632,7 +645,7 @@ let set_random_seed ?seed () =
632645
let seed =
633646
Option.value ~default:42 @@ Option.first_some seed Utils.settings.fixed_state_for_init
634647
in
635-
let res = number ~label:[ "random_seed" ] ~grad_spec:Prohibit_grad (Int.to_float seed) in
648+
let res = bits ~label:[ "random_seed" ] ~grad_spec:Prohibit_grad (Int64.of_int seed) in
636649
Tn.update_prec res.value Ir.Ops.uint4x32;
637650
random_seed := Some res
638651

lib/tensor.mli

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,11 @@ val number : ?label:string list -> ?axis_label:string -> ?grad_spec:grad_spec ->
190190
(** A number: a tensor with a single axis of one dimension, initialized to the given value.
191191
[grad_spec] is by default [Prohibit_grad]. *)
192192

193+
val bits : ?label:string list -> ?axis_label:string -> ?grad_spec:grad_spec -> int64 -> t
194+
(** A number with exact bit representation: a tensor with a single axis of one dimension,
195+
initialized to the given int64 value. Useful for initializing uint4x32 tensors where exact bit
196+
patterns matter. [grad_spec] is by default [Prohibit_grad]. *)
197+
193198
val ndarray : ?grad_spec:grad_spec -> float array -> op_fun
194199
(** A tensor with an explicit shape, initialized to the given values. Omitted shape rows default to
195200
no axes. [grad_spec] is by default [Prohibit_grad]. If [strict] is [true] (the default), the

0 commit comments

Comments
 (0)