Skip to content

Commit e6f0e7e

Browse files
committed
Remove uniform1 for now, will be in 0.6.1
1 parent 9a2006a commit e6f0e7e

File tree

1 file changed

+10
-21
lines changed

1 file changed

+10
-21
lines changed

lib/operation.ml

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -285,17 +285,6 @@ let uint4x32_to_prec_uniform ?grad_spec =
285285
~transpose_op:(Uint4x32_to_prec (lazy (assert false)))
286286
~op_asn ~grad_asn ?grad_spec t1
287287

288-
(** A wasteful variant of {!uint4x32_to_prec_uniform} that produces a single value from each 4x32
289-
random bits. *)
290-
let uint4x32_to_prec_uniform1 ?grad_spec =
291-
(* FIXME: Must use non-vectorized version of uint4x32_to_prec_uniform. *)
292-
let module NTDSL = Initial_NTDSL in
293-
let%cd op_asn ~v ~t1 ~projections = v =: uint4x32_to_prec_uniform v1 in
294-
let%cd grad_asn ~t:_ ~g:_ ~t1:_ ~projections:_ = Asgns.empty_comp in
295-
fun t1 ->
296-
Tn.update_prec t1.Tensor.value Ir.Ops.uint4x32;
297-
Tensor.unop ~transpose_op:Pointwise_un ~op_asn ~grad_asn ?grad_spec t1
298-
299288
let lt ?(label = []) =
300289
let module NTDSL = Initial_NTDSL in
301290
let%cd op_asn ~v ~t1 ~t2 ~projections = v =: (v1 < v2) in
@@ -424,22 +413,19 @@ let embed_self_id ?(label = []) () : Tensor.t =
424413
Tensor.term ~fetch_op:Embed_self_id ~grad_spec:Prohibit_grad ~label:("!@self_id" :: label)
425414
~batch_dims:[] ~input_dims:[] ~output_dims:[ 1 ] ()
426415

427-
(* FIXME: this should work, but it's a shape mismatch error *)
428-
(* let uniform_plus_one ?grad_spec () ?label = add ?label (Tensor.number ~grad_spec:Prohibit_grad 1.)
429-
(uint4x32_to_prec_uniform ?grad_spec (threefry4x32 (embed_self_id ()) (Tensor.term
430-
~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad ~label:[ "range_over_offsets" ] ()) ())
431-
()) *)
432416
let uniform ?grad_spec () =
433417
uint4x32_to_prec_uniform ?grad_spec
434418
(threefry4x32 (embed_self_id ())
435419
(Tensor.term ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad
436420
~label:[ "range_over_offsets" ] ())
437421
())
438422

439-
(** A wasteful variant of {!uniform} that produces a single value from each 4x32 random bits. *)
440-
let uniform1 ?grad_spec () =
441-
uint4x32_to_prec_uniform1 ?grad_spec
442-
(threefry4x32 (embed_self_id ())
423+
(** Generates a single uniform random number using a counter symbol for PRNG state. This is useful
424+
for sequential sampling in recurrent contexts. *)
425+
let uniform_at ?grad_spec counter =
426+
uint4x32_to_prec_uniform ?grad_spec
427+
(threefry4x32
428+
(threefry4x32 (embed_self_id ()) counter ())
443429
(Tensor.term ~fetch_op:Range_over_offsets ~grad_spec:Prohibit_grad
444430
~label:[ "range_over_offsets" ] ())
445431
())
@@ -487,7 +473,7 @@ module DO = struct
487473
let einsum1 ?label spec t1 = einsum1 ?label spec t1 ~grad_spec:If_needed ()
488474
let ndarray = Tensor.ndarray ~grad_spec:If_needed
489475
let uniform ?label () = uniform ~grad_spec:Require_grad () ?label ()
490-
let uniform1 ?label () = uniform1 ~grad_spec:Require_grad () ?label ()
476+
let uniform_at ?label counter = uniform_at ~grad_spec:Require_grad ?label counter
491477
end
492478

493479
module NDO = struct
@@ -523,6 +509,8 @@ module NDO = struct
523509
let einsum ?label spec t1 t2 = einsum spec t1 t2 ~grad_spec:Prohibit_grad ?label ()
524510
let einsum1 ?label spec t1 = einsum1 spec t1 ~grad_spec:Prohibit_grad ?label ()
525511
let ndarray = Tensor.ndarray ~grad_spec:Prohibit_grad
512+
let uniform ?label () = uniform ~grad_spec:Prohibit_grad () ?label ()
513+
let uniform_at ?label counter = uniform_at ~grad_spec:Prohibit_grad ?label counter
526514
end
527515

528516
(** The input [i] dimensions default to empty. The batch and output dimensions will be inferred if
@@ -623,6 +611,7 @@ module NTDSL = struct
623611
let uint4x32_to_prec_uniform = uint4x32_to_prec_uniform ~grad_spec:Prohibit_grad
624612
let embed_self_id = embed_self_id
625613
let uniform = uniform ~grad_spec:Prohibit_grad
614+
let uniform_at = uniform_at ~grad_spec:Prohibit_grad
626615

627616
let counter ?(label = []) =
628617
let module NTDSL = Initial_NTDSL in

0 commit comments

Comments
 (0)