Skip to content

Commit 6389d4f

Browse files
authored
Merge pull request #366 from ahrefs/uniform1-variants
feat: Add uniform1 variants for non-vectorized random number generation
2 parents 7c81b8c + 1884a72 commit 6389d4f

20 files changed

+338
-86
lines changed

CLAUDE.md

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,10 @@ opam install cudajit # for CUDA backend
7777

7878
### Testing
7979

80-
- Tests are implemented either as inline expectations using `ppx_expect`; or as cram-style tests where an `.ml` file is compiled, executed, and its output compared against an `.expected` file
81-
- Tutorial files in `test/` serve as both documentation and integration tests
80+
- Tests are implemented either as inline expectations using `ppx_expect`; or as cram-style tests using Dune's `test` stanza where an `.ml` file is compiled, executed, and its output compared against an `.expected` file
81+
- The two approaches are exclusive: a test using using `.expected` file target cannot also use `%expect` inline expectations
82+
- `.expected` tests are easier to debug, `%expect` tests should only be used when the outputs are illustrative
83+
- Tutorial files, i.e. `%expect` tests, in `test/` serve as both documentation and integration tests
8284
- Use `dune promote` to accept test output changes
8385
- **Test Placement Guidelines**:
8486
* Always add tests under one of the test subdirectories

arrayjit/lib/assignments.ml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,6 @@ let%track4_sexp to_low_level code =
284284
let rhs_idcs = Array.map projections.project_rhs.(0) ~f:subst_index in
285285
let open Low_level in
286286
let rhs_ll = get rhs rhs_idcs in
287-
(* For now, we know the only vec_unop is Uint4x32_to_prec_uniform *)
288287
let length =
289288
match op with
290289
| Ops.Uint4x32_to_prec_uniform -> (

arrayjit/lib/c_syntax.ml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -403,6 +403,7 @@ module C_syntax (B : C_syntax_config) = struct
403403
let ident_doc = string (get_ident tn) in
404404
let dims = Lazy.force tn.dims in
405405
let prec = Lazy.force tn.prec in
406+
(* FIXME: this precision is hardcoded, bad, bad practice. *)
406407
let arg_prec = Ops.uint4x32 in
407408
let local_defs, arg_doc = pp_scalar arg_prec arg in
408409
let local_defs = pp_local_defs local_defs in
@@ -427,9 +428,15 @@ module C_syntax (B : C_syntax_config) = struct
427428
(* For non-Fixed_idx (Iterator, etc), add i to the computed offset *)
428429
pp_array_offset (idcs, dims) ^^ string (" + " ^ Int.to_string i)
429430
in
430-
ident_doc ^^ brackets offset_doc ^^ string " = " ^^ vec_var
431-
^^ string (".v[" ^ Int.to_string i ^ "]")
432-
^^ semi)
431+
let value_doc =
432+
if length = 1 then
433+
(* When length=1, vec_typ_of_prec returns a scalar type, so no .v[] access *)
434+
vec_var
435+
else
436+
(* When length>1, access the vector element *)
437+
vec_var ^^ string (".v[" ^ Int.to_string i ^ "]")
438+
in
439+
ident_doc ^^ brackets offset_doc ^^ string " = " ^^ value_doc ^^ semi)
433440
in
434441
separate hardline elem_assigns
435442
in
@@ -574,7 +581,16 @@ module C_syntax (B : C_syntax_config) = struct
574581
let expr = group (B.binop_syntax prec op e1 e2) in
575582
(defs, expr)
576583
| Unop (op, v) ->
577-
let defs, expr_v = pp_scalar prec v in
584+
let arg_prec =
585+
match op with
586+
| Ops.Uint4x32_to_prec_uniform1 ->
587+
(* The argument to Uint4x32_to_prec_uniform1 must be evaluated with uint4x32 precision,
588+
regardless of the target precision. This handles the case where the operation is
589+
inlined as part of a scalar expression. *)
590+
Ops.uint4x32
591+
| _ -> prec
592+
in
593+
let defs, expr_v = pp_scalar arg_prec v in
578594
let expr = group (B.unop_syntax prec op expr_v) in
579595
(defs, expr)
580596

arrayjit/lib/metal_backend.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -580,6 +580,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
580580
| Recip_sqrt, _ -> func_doc "rsqrt"
581581
| Tanh_approx, _ -> func_doc "tanh"
582582
| Not, _ -> fun v -> string "!" ^^ v
583+
| Uint4x32_to_prec_uniform1, _ ->
584+
fun v -> func_doc "uint4x32_to_prec_uniform1" v
583585
(* Logical not *)
584586

585587
(* Keep vec_unop_syntax same as in pure C syntax. *)

arrayjit/lib/ops.ml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -350,6 +350,11 @@ type unop =
350350
| Neg
351351
| Tanh_approx
352352
| Not (** 0. -> 1. | _ -> 0. *)
353+
| Uint4x32_to_prec_uniform1
354+
(** Non-vectorized variant of [Uint4x32_to_prec_uniform] that converts the given Uint4x32 to a
355+
single value of the output precision. Less bit-efficient but operates poitwise. For random
356+
bits, the result is uniform over the range of the precision for integer precisions, and
357+
over the range \[0.0, 1.0) for floating point precisions. *)
353358
[@@deriving sexp, compare, equal]
354359

355360
type vec_unop =
@@ -431,6 +436,8 @@ let interpret_unop op v =
431436
| Neg -> ~-.v
432437
| Tanh_approx -> tanh v
433438
| Not -> if v = 0. then 1. else 0.
439+
| Uint4x32_to_prec_uniform1 ->
440+
invalid_arg "Ops.interpret_unop: Uint4x32_to_prec_uniform1 argument outside the domain of float"
434441

435442
let interpret_ternop op v1 v2 v3 =
436443
let open Float in
@@ -580,6 +587,7 @@ let unop_cd_syntax = function
580587
| Neg -> "neg"
581588
| Tanh_approx -> "tanh"
582589
| Not -> "not"
590+
| Uint4x32_to_prec_uniform1 -> "uint4x32_to_prec_uniform1"
583591

584592
let vec_unop_cd_syntax = function Uint4x32_to_prec_uniform -> "uint4x32_to_prec_uniform"
585593

@@ -627,6 +635,9 @@ let unop_c_syntax prec op =
627635
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for integer precisions"
628636
| Tanh_approx, _ -> ("tanhf(", ")")
629637
| Not, _ -> ("(", " == 0.0 ? 1.0 : 0.0)")
638+
| Uint4x32_to_prec_uniform1, Uint4x32_prec _ ->
639+
invalid_arg "Ops.vec_unop_c_syntax: Uint4x32_to_prec_uniform1 not supported for Uint4x32"
640+
| Uint4x32_to_prec_uniform1, _ -> ("uint4x32_to_" ^ prec_string prec ^ "_uniform(", ")")
630641

631642
let vec_unop_c_syntax prec op =
632643
match (op, prec) with

lib/operation.ml

Lines changed: 48 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,19 @@ let uint4x32_to_prec_uniform ?grad_spec =
289289
~op_asn ~grad_asn ?grad_spec (* Modifying the label would cause identifier pollution. *)
290290
?label ~top_down_prec:true t1
291291

292+
let uint4x32_to_prec_uniform1 ?grad_spec =
293+
let module NTDSL = Initial_NTDSL in
294+
let%cd op_asn ~v ~t1 ~projections = v =: uint4x32_to_prec_uniform1 v1 in
295+
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~projections:_ = Asgns.empty_comp in
296+
fun t1 ?label ?top_down_prec ->
297+
(* Ignore what the caller says, since we must learn the precision from the outside. *)
298+
ignore (top_down_prec : bool option);
299+
Tn.update_prec t1.Tensor.value Ir.Ops.uint4x32;
300+
Tensor.unop (* A placeholder that will be replaced by the actual precision by Tensor.op. *)
301+
~transpose_op:Pointwise_un ~op_asn ~grad_asn
302+
?grad_spec (* Modifying the label would cause identifier pollution. *)
303+
?label ~top_down_prec:true t1
304+
292305
let lt ?(label = []) =
293306
let module NTDSL = Initial_NTDSL in
294307
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 < v2) in
@@ -355,15 +368,16 @@ let where ?(label = []) ~grad_spec t1 t2 t3 =
355368
Tensor.ternop ~label:("where" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
356369
t2 t3
357370

371+
(** [range] is a 1D tensor of shape [upto], spans [[0, upto)]. *)
358372
let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
359373
let result =
360374
Tensor.term ~fetch_op:Range_over_offsets ~grad_spec ~batch_dims:[]
361-
~label:(("0" ^ "..." ^ Int.to_string upto) :: label)
375+
~label:(("0" ^ "..." ^ Int.to_string (upto - 1)) :: label)
362376
~input_dims:[]
363377
in
364378
match axis_label with
365-
| None -> result ~output_dims:[ upto + 1 ] ()
366-
| Some l -> result ~output_axes:[ (l, upto + 1) ] ()
379+
| None -> result ~output_dims:[ upto ] ()
380+
| Some l -> result ~output_axes:[ (l, upto) ] ()
367381

368382
let range_of_shape ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?batch_dims ?input_dims
369383
?output_dims ?batch_axes ?input_axes ?output_axes () =
@@ -433,6 +447,24 @@ let uniform_at ?grad_spec counter =
433447
~label:[ "range_over_offsets" ] ())
434448
())
435449

450+
(** A wasteful variant of {!uniform} that produces a single value from each 4x32 random bits. *)
451+
let uniform1 ?grad_spec () =
452+
uint4x32_to_prec_uniform1 ?grad_spec
453+
(threefry4x32
454+
(threefry4x32 (embed_self_id ()) (Tensor.get_random_seed ()) ())
455+
(Tensor.term ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad
456+
~label:[ "range_over_offsets" ] ())
457+
())
458+
459+
(** A wasteful variant of {!uniform_at} that produces a single value from each 4x32 random bits. *)
460+
let uniform_at1 ?grad_spec counter =
461+
uint4x32_to_prec_uniform1 ?grad_spec
462+
(threefry4x32
463+
(threefry4x32 (threefry4x32 (embed_self_id ()) (Tensor.get_random_seed ()) ()) counter ())
464+
(Tensor.term ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad
465+
~label:[ "range_over_offsets" ] ())
466+
())
467+
436468
module DO = struct
437469
let ( * ) ?label t1 t2 = matmul ~grad_spec:If_needed ?label t1 t2 ()
438470
let ( *. ) ?label t1 t2 = pointmul ~grad_spec:If_needed ?label t1 t2 ()
@@ -442,6 +474,9 @@ module DO = struct
442474
let uint4x32_to_prec_uniform ?label t1 =
443475
uint4x32_to_prec_uniform ~grad_spec:If_needed t1 ?label ()
444476

477+
let uint4x32_to_prec_uniform1 ?label t1 =
478+
uint4x32_to_prec_uniform1 ~grad_spec:If_needed t1 ?label ()
479+
445480
let ( **. ) ?label base exp = pointpow ?label exp base ~grad_spec:If_needed ()
446481
let relu ?label t = relu ~grad_spec:If_needed ?label t ()
447482
let sat01 ?label t = sat01 ~grad_spec:If_needed ?label t ()
@@ -478,6 +513,8 @@ module DO = struct
478513
let ndarray = Tensor.ndarray ~grad_spec:If_needed
479514
let uniform ?label () = uniform ~grad_spec:Require_grad () ?label ()
480515
let uniform_at ?label counter = uniform_at ~grad_spec:Require_grad ?label counter ()
516+
let uniform1 ?label () = uniform1 ~grad_spec:Require_grad () ?label ()
517+
let uniform_at1 ?label counter = uniform_at1 ~grad_spec:Require_grad ?label counter ()
481518
end
482519

483520
module NDO = struct
@@ -502,6 +539,9 @@ module NDO = struct
502539
let uint4x32_to_prec_uniform ?label t1 =
503540
uint4x32_to_prec_uniform ~grad_spec:Prohibit_grad ?label t1 ()
504541

542+
let uint4x32_to_prec_uniform1 ?label t1 =
543+
uint4x32_to_prec_uniform1 ~grad_spec:Prohibit_grad ?label t1 ()
544+
505545
let recip ?label t = recip ~grad_spec:Prohibit_grad ?label t ()
506546
let recip_sqrt ?label t = recip_sqrt ~grad_spec:Prohibit_grad ?label t ()
507547
let tanh ?label t = tanh ~grad_spec:Prohibit_grad ?label t ()
@@ -515,6 +555,8 @@ module NDO = struct
515555
let ndarray = Tensor.ndarray ~grad_spec:Prohibit_grad
516556
let uniform ?label () = uniform ~grad_spec:Prohibit_grad () ?label ()
517557
let uniform_at ?label counter = uniform_at ~grad_spec:Prohibit_grad ?label counter ()
558+
let uniform1 ?label () = uniform1 ~grad_spec:Prohibit_grad () ?label ()
559+
let uniform_at1 ?label counter = uniform_at1 ~grad_spec:Prohibit_grad ?label counter ()
518560
end
519561

520562
(** The input [i] dimensions default to empty. The batch and output dimensions will be inferred if
@@ -555,6 +597,7 @@ module TDSL = struct
555597
let ndarray = Tensor.ndarray ~grad_spec:If_needed
556598
let threefry4x32 = threefry4x32 ~grad_spec:If_needed
557599
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform ~grad_spec:If_needed
600+
let uint4x32_to_prec_uniform1 = uint4x32_to_prec_uniform1 ~grad_spec:If_needed
558601
let embed_self_id = embed_self_id
559602

560603
(** The default initialization operation for {!param} calls. *)
@@ -615,6 +658,8 @@ module NTDSL = struct
615658
let embed_self_id = embed_self_id
616659
let uniform = uniform ~grad_spec:Prohibit_grad
617660
let uniform_at = uniform_at ~grad_spec:Prohibit_grad
661+
let uniform1 = uniform1 ~grad_spec:Prohibit_grad
662+
let uniform_at1 = uniform_at1 ~grad_spec:Prohibit_grad
618663

619664
let counter ?(label = []) =
620665
let module NTDSL = Initial_NTDSL in

lib/ppx_cd.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -454,7 +454,7 @@ let translate ?ident_label (expr : expression) : result =
454454
@@ Location.error_extensionf ~loc
455455
"ppx_ocannl %%cd: expected a unary operator, one of: %s"
456456
"id, relu, sat01, exp, log, exp2, log2, sin, cos, sqrt, recip, recip_sqrt, \
457-
neg, tanh" ))
457+
neg, tanh, uint4x32_to_prec_uniform1" ))
458458
in
459459
let vec_unary_op vec_un_op =
460460
loc

lib/ppx_shared.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ let unary_ops =
190190
("neg", fun loc -> ([%expr Shape.Pointwise_un], [%expr Ir.Ops.Neg]));
191191
("tanh", fun loc -> ([%expr Shape.Pointwise_un], [%expr Ir.Ops.Tanh_approx]));
192192
("not", fun loc -> ([%expr Shape.Pointwise_un], [%expr Ir.Ops.Not]));
193+
( "uint4x32_to_prec_uniform1",
194+
fun loc -> ([%expr Shape.Pointwise_un], [%expr Ir.Ops.Uint4x32_to_prec_uniform1]) );
193195
]
194196

195197
(** Vector unary primitive ops. *)

lib/tensor.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,9 @@ val set_random_seed : ?seed:int -> unit -> unit
245245

246246
val get_random_seed : unit -> t
247247
(** Returns a tensor with the current random seed. Lazily initialized using {!set_random_seed} and
248-
reset when {!unsafe_reinitialize} is called. *)
248+
reset when {!unsafe_reinitialize} is called. IMPORTANT: all sites using the same global random
249+
seed, e.g. using [get_random_seed ()] not separated by a call to {!unsafe_reinitialize}, must
250+
descend from the first caller's optimization context. *)
249251

250252
(** {2 Printing.} *)
251253

test/einsum/einsum_trivia.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -961,13 +961,13 @@ let%expect_test "outer_sum simulating axis concatenation" =
961961
and type optimize_ctx = Backend.optimize_ctx)
962962
in
963963

964-
let ri = TDSL.range 3 in
964+
let ri = TDSL.range 4 in
965965
let%op ti = ri ++ "i=>i0" in
966966
(* Write position 2 of ti, otherwise shape inference concludes it's dim-1 and broadcasted. *)
967967
let%cd _ = ti =: 0 ++ "i=>i2" in
968-
let rj = TDSL.range 4 in
968+
let rj = TDSL.range 5 in
969969
let%op tj = rj ++ "j=>j1" in
970-
let rk = TDSL.range 5 in
970+
let rk = TDSL.range 6 in
971971
let%op tk = rk ++ "k=>k2" in
972972
let positions = TDSL.outer_sum "ijl;kl=>ijkl" (TDSL.outer_sum "il;jl=>ijl" ti tj ()) tk () in
973973
Train.set_hosted tk.value;

0 commit comments

Comments
 (0)