@@ -285,17 +285,6 @@ let uint4x32_to_prec_uniform ?grad_spec =
285285 ~transpose_op: (Uint4x32_to_prec (lazy (assert false )))
286286 ~op_asn ~grad_asn ?grad_spec t1
287287
288- (* * A wasteful variant of {!uint4x32_to_prec_uniform} that produces a single value from each 4x32
289- random bits. *)
290- let uint4x32_to_prec_uniform1 ?grad_spec =
291- (* FIXME: Must use non-vectorized version of uint4x32_to_prec_uniform. *)
292- let module NTDSL = Initial_NTDSL in
293- let % cd op_asn ~v ~t1 ~projections = v =: uint4x32_to_prec_uniform v1 in
294- let % cd grad_asn ~t: _ ~g: _ ~t1: _ ~projections: _ = Asgns. empty_comp in
295- fun t1 ->
296- Tn. update_prec t1.Tensor. value Ir.Ops. uint4x32;
297- Tensor. unop ~transpose_op: Pointwise_un ~op_asn ~grad_asn ?grad_spec t1
298-
299288let lt ?(label = [] ) =
300289 let module NTDSL = Initial_NTDSL in
301290 let % cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 < v2) in
@@ -424,22 +413,19 @@ let embed_self_id ?(label = []) () : Tensor.t =
424413 Tensor. term ~fetch_op: Embed_self_id ~grad_spec: Prohibit_grad ~label: (" !@self_id" :: label)
425414 ~batch_dims: [] ~input_dims: [] ~output_dims: [ 1 ] ()
426415
427- (* FIXME: this should work, but it's a shape mismatch error *)
428- (* let uniform_plus_one ?grad_spec () ?label = add ?label (Tensor.number ~grad_spec:Prohibit_grad 1.)
429- (uint4x32_to_prec_uniform ?grad_spec (threefry4x32 (embed_self_id ()) (Tensor.term
430- ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad ~label:[ "range_over_offsets" ] ()) ())
431- ()) *)
432416let uniform ?grad_spec () =
433417 uint4x32_to_prec_uniform ?grad_spec
434418 (threefry4x32 (embed_self_id () )
435419 (Tensor. term ~fetch_op: Range_over_offsets ~grad_spec: Prohibit_grad
436420 ~label: [ " range_over_offsets" ] () )
437421 () )
438422
439- (* * A wasteful variant of {!uniform} that produces a single value from each 4x32 random bits. *)
440- let uniform1 ?grad_spec () =
441- uint4x32_to_prec_uniform1 ?grad_spec
442- (threefry4x32 (embed_self_id () )
423+ (* * Generates a single uniform random number using a counter symbol for PRNG state. This is useful
424+ for sequential sampling in recurrent contexts. *)
425+ let uniform_at ?grad_spec counter =
426+ uint4x32_to_prec_uniform ?grad_spec
427+ (threefry4x32
428+ (threefry4x32 (embed_self_id () ) counter () )
443429 (Tensor. term ~fetch_op: Range_over_offsets ~grad_spec: Prohibit_grad
444430 ~label: [ " range_over_offsets" ] () )
445431 () )
@@ -487,7 +473,7 @@ module DO = struct
487473 let einsum1 ?label spec t1 = einsum1 ?label spec t1 ~grad_spec: If_needed ()
488474 let ndarray = Tensor. ndarray ~grad_spec: If_needed
489475 let uniform ?label () = uniform ~grad_spec: Require_grad () ?label ()
490- let uniform1 ?label () = uniform1 ~grad_spec: Require_grad () ?label ()
476+ let uniform_at ?label counter = uniform_at ~grad_spec: Require_grad ?label counter
491477end
492478
493479module NDO = struct
@@ -523,6 +509,8 @@ module NDO = struct
523509 let einsum ?label spec t1 t2 = einsum spec t1 t2 ~grad_spec: Prohibit_grad ?label ()
524510 let einsum1 ?label spec t1 = einsum1 spec t1 ~grad_spec: Prohibit_grad ?label ()
525511 let ndarray = Tensor. ndarray ~grad_spec: Prohibit_grad
512+ let uniform ?label () = uniform ~grad_spec: Prohibit_grad () ?label ()
513+ let uniform_at ?label counter = uniform_at ~grad_spec: Prohibit_grad ?label counter
526514end
527515
528516(* * The input [i] dimensions default to empty. The batch and output dimensions will be inferred if
@@ -623,6 +611,7 @@ module NTDSL = struct
623611 let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform ~grad_spec: Prohibit_grad
624612 let embed_self_id = embed_self_id
625613 let uniform = uniform ~grad_spec: Prohibit_grad
614+ let uniform_at = uniform_at ~grad_spec: Prohibit_grad
626615
627616 let counter ?(label = [] ) =
628617 let module NTDSL = Initial_NTDSL in
0 commit comments