Skip to content

Commit 82ea915

Browse files
committed
A new convenience operation offsets and fix to ndarray (it's not composable so doesn't belong to O)
1 parent 9e75d1a commit 82ea915

File tree

5 files changed

+23
-14
lines changed

5 files changed

+23
-14
lines changed

lib/operation.ml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,7 +404,7 @@ let einsum1 ?(label = []) ?(capture_dims = []) spec =
404404
~op_asn ~grad_asn ~label:("=>" :: label)
405405

406406
module NDO_before_einmax1 = struct
407-
let (+) ?label t1 t2 = add ?label ~grad_spec:Prohibit_grad t1 t2 ()
407+
let ( + ) ?label t1 t2 = add ?label ~grad_spec:Prohibit_grad t1 t2 ()
408408
let where ?label t1 t2 t3 = where ?label ~grad_spec:Prohibit_grad t1 t2 t3 ()
409409
let not ?label t = not ?label ~grad_spec:Prohibit_grad t ()
410410
let ( < ) ?label t1 t2 = lt ?label ~grad_spec:Prohibit_grad t1 t2 ()
@@ -437,6 +437,9 @@ let tropical ?(label = []) ?(capture_dims = []) spec =
437437
~compose_op:(Shape.Einsum (spec, capture_dims))
438438
~op_asn ~grad_asn ~label:("@^=>+" :: label)
439439

440+
(** A fully-shape-inferred tensor that is initialized with the offset of each cell. *)
441+
let offsets = Tensor.term ~fetch_op:Range_over_offsets ?init_data:None
442+
440443
(** [range] is a 1D tensor of shape [upto], spans [0] inclusive, [upto] exclusive. *)
441444
let range ?(label = []) ?(grad_spec = Tensor.Prohibit_grad) ?axis_label upto =
442445
let result =
@@ -599,6 +602,7 @@ struct
599602
let einsum1 = einsum1 ~grad_spec:Grad_spec.grad_spec
600603
let einmax1 = einmax1 ~grad_spec:Grad_spec.grad_spec
601604
let tropical = tropical ~grad_spec:Grad_spec.grad_spec
605+
let offsets = offsets ~grad_spec:Grad_spec.grad_spec
602606
let range = range ~grad_spec:Grad_spec.grad_spec
603607
let range_of_shape = range_of_shape ~grad_spec:Grad_spec.grad_spec
604608
let stop_gradient = stop_gradient
@@ -692,10 +696,11 @@ struct
692696
let ( <> ) ?label t1 t2 = ne ?label t1 t2 ()
693697
let embed_self_id = embed_self_id
694698
let einsum ?label ?capture_dims spec t1 t2 = einsum ?label ?capture_dims spec t1 t2 ()
699+
let outer_sum ?label ?capture_dims spec t1 t2 = outer_sum ?label ?capture_dims spec t1 t2 ()
695700
let einsum1 ?label ?capture_dims spec t1 = einsum1 ?label ?capture_dims spec t1 ()
696701
let einmax1 ?label ?capture_dims spec t1 = einmax1 ?label ?capture_dims spec t1 ()
697702
let tropical ?label ?capture_dims spec t1 t2 = tropical ?label ?capture_dims spec t1 t2 ()
698-
let ndarray = ndarray
703+
let offsets ?label () = offsets ?label ()
699704
let uniform ?label () = uniform () ?label ()
700705
let uniform_at ?label counter = uniform_at ?label counter ()
701706
let uniform1 ?label () = uniform1 () ?label ()

lib/ppx_cd.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -914,7 +914,7 @@ let translate ?ident_label (expr : expression) : result =
914914
})
915915
| { pexp_desc = Pexp_array _; _ }
916916
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
917-
{ default_result with expr = ndarray_op expr }
917+
{ default_result with expr = ndarray_op ~ndarray_fn:[%expr NTDSL.ndarray] expr }
918918
| { pexp_desc = Pexp_ident { txt = Lident ("v" | "lhs"); _ }; _ } ->
919919
{ default_result with typ = Array; slot = LHS }
920920
| { pexp_desc = Pexp_ident { txt = Lident "g"; _ }; _ } ->

lib/ppx_op.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ let operators =
4545
("embed_self_id", "embed_self_id");
4646
("einsum", "einsum");
4747
("einsum1", "einsum1");
48-
("ndarray", "ndarray");
48+
("offsets", "offsets");
4949
("uniform", "uniform");
5050
("uniform_at", "uniform_at");
5151
("uniform1", "uniform1");
@@ -325,7 +325,7 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
325325
"ppx_ocannl %%op: record field label must be a simple identifier" ))
326326
| { pexp_desc = Pexp_array _; _ }
327327
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
328-
(no_vbs, ndarray_op ?label expr)
328+
(no_vbs, ndarray_op ?label ~ndarray_fn:[%expr TDSL.ndarray] expr)
329329
| [%expr !.[%e? expr1]] ->
330330
(* Hardcoding the patterns for (!.), (!..), and ( **. ) to avoid treating the constants as
331331
already tensors. *)

lib/ppx_shared.ml

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -256,13 +256,17 @@ let assignment_ops =
256256
let einsum_binary_ops =
257257
Hashtbl.of_alist_exn
258258
(module String)
259-
[ ("+*", fun loc -> [%expr einsum]); ("@^+", fun loc -> [%expr tropical]) ]
259+
[
260+
("+*", fun loc -> [%expr einsum]);
261+
("@^+", fun loc -> [%expr tropical]);
262+
("+++", fun loc -> [%expr outer_sum]);
263+
]
260264

261265
let einsum_unary_ops =
262-
Hashtbl.of_alist_exn
263-
(module String)
264-
[ ("++", fun loc -> [%expr einsum1]); ("@^^", fun loc -> [%expr einmax1]) ]
265-
266+
Hashtbl.of_alist_exn
267+
(module String)
268+
[ ("++", fun loc -> [%expr einsum1]); ("@^^", fun loc -> [%expr einmax1]) ]
269+
266270
let is_primitive_op op_ident =
267271
List.exists ~f:(Fn.flip Hashtbl.mem op_ident) [ ternary_ops; unary_ops; binary_ops ]
268272

@@ -313,11 +317,11 @@ let translate_str translate ({ pstr_desc; pstr_loc = loc; _ } as str) =
313317
let str_expander_with_punning translate ~loc ~path (payload : structure_item list) =
314318
flatten_str ~loc ~path @@ List.map payload ~f:(translate_str translate)
315319

316-
let ndarray_op ?axis_labels ?label expr =
320+
let ndarray_op ?axis_labels ?label ~ndarray_fn expr =
317321
let loc = expr.pexp_loc in
318322
let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
319323
let edims dims = Ast_builder.Default.elist ~loc dims in
320-
let w_val = [%expr ndarray [%e values]] in
324+
let w_val = [%expr [%e ndarray_fn] [%e values]] in
321325
let op =
322326
match (axis_labels, label) with
323327
| None, None -> w_val

test/ppx/test_ppx_op_expected.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,13 +32,13 @@ let y2 =
3232
((( *. ) ?label:None) x1 hey3) x2
3333
let a =
3434
let open! TDSL.O in
35-
((ndarray
35+
((TDSL.ndarray
3636
[|(Float.of_int 1);(Float.of_int 2);(Float.of_int 3);(Float.of_int 4);(
3737
Float.of_int 5);(Float.of_int 6)|]) ~label:["a"]) ~batch_dims:[]
3838
~input_dims:[3] ~output_dims:[2] ()
3939
let b =
4040
let open! TDSL.O in
41-
((ndarray
41+
((TDSL.ndarray
4242
[|(Float.of_int 7);(Float.of_int 8);(Float.of_int 9);(Float.of_int 10)|])
4343
~label:["b"]) ~batch_dims:[2] ~input_dims:[] ~output_dims:[2] ()
4444
let y =

0 commit comments

Comments
 (0)