@@ -404,7 +404,7 @@ let einsum1 ?(label = []) ?(capture_dims = []) spec =
404404 ~op_asn ~grad_asn ~label: (" =>" :: label)
405405
406406module NDO_before_einmax1 = struct
407- let (+ ) ?label t1 t2 = add ?label ~grad_spec: Prohibit_grad t1 t2 ()
407+ let ( + ) ?label t1 t2 = add ?label ~grad_spec: Prohibit_grad t1 t2 ()
408408 let where ?label t1 t2 t3 = where ?label ~grad_spec: Prohibit_grad t1 t2 t3 ()
409409 let not ?label t = not ?label ~grad_spec: Prohibit_grad t ()
410410 let ( < ) ?label t1 t2 = lt ?label ~grad_spec: Prohibit_grad t1 t2 ()
@@ -437,6 +437,9 @@ let tropical ?(label = []) ?(capture_dims = []) spec =
437437 ~compose_op: (Shape. Einsum (spec, capture_dims))
438438 ~op_asn ~grad_asn ~label: (" @^=>+" :: label)
439439
440+ (* * A fully-shape-inferred tensor that is initialized with the offset of each cell. *)
441+ let offsets = Tensor. term ~fetch_op: Range_over_offsets ?init_data:None
442+
440443(* * [range] is a 1D tensor of shape [upto], spans [0] inclusive, [upto] exclusive. *)
441444let range ?(label = [] ) ?(grad_spec = Tensor. Prohibit_grad ) ?axis_label upto =
442445 let result =
@@ -599,6 +602,7 @@ struct
599602 let einsum1 = einsum1 ~grad_spec: Grad_spec. grad_spec
600603 let einmax1 = einmax1 ~grad_spec: Grad_spec. grad_spec
601604 let tropical = tropical ~grad_spec: Grad_spec. grad_spec
605+ let offsets = offsets ~grad_spec: Grad_spec. grad_spec
602606 let range = range ~grad_spec: Grad_spec. grad_spec
603607 let range_of_shape = range_of_shape ~grad_spec: Grad_spec. grad_spec
604608 let stop_gradient = stop_gradient
@@ -692,10 +696,11 @@ struct
692696 let ( <> ) ?label t1 t2 = ne ?label t1 t2 ()
693697 let embed_self_id = embed_self_id
694698 let einsum ?label ?capture_dims spec t1 t2 = einsum ?label ?capture_dims spec t1 t2 ()
699+ let outer_sum ?label ?capture_dims spec t1 t2 = outer_sum ?label ?capture_dims spec t1 t2 ()
695700 let einsum1 ?label ?capture_dims spec t1 = einsum1 ?label ?capture_dims spec t1 ()
696701 let einmax1 ?label ?capture_dims spec t1 = einmax1 ?label ?capture_dims spec t1 ()
697702 let tropical ?label ?capture_dims spec t1 t2 = tropical ?label ?capture_dims spec t1 t2 ()
698- let ndarray = ndarray
703+ let offsets ? label () = offsets ?label ()
699704 let uniform ?label () = uniform () ?label ()
700705 let uniform_at ?label counter = uniform_at ?label counter ()
701706 let uniform1 ?label () = uniform1 () ?label ()
0 commit comments