Skip to content

Commit 9682b4a

Browse files
committed
Untested: infrastructure for ternary operations
1 parent 842daaa commit 9682b4a

File tree

6 files changed

+230
-14
lines changed

6 files changed

+230
-14
lines changed

lib/ppx_cd.ml

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ type expr_type =
3333

3434
let is_unknown = function Unknown -> true | _ -> false
3535

36-
type projections_slot = LHS | RHS1 | RHS2 | Nonslot | Undet [@@deriving equal, sexp]
36+
type projections_slot = LHS | RHS1 | RHS2 | RHS3 | Nonslot | Undet [@@deriving equal, sexp]
3737

3838
let assignment_op expr =
3939
(* This should stay in sync with Arrayjit.Ops.assign_op_cd_syntax. *)
@@ -72,6 +72,7 @@ let assignment_op expr =
7272

7373
let binary_op expr =
7474
(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
75+
(* FIXME: get rid of this and use binary_ops table instead. *)
7576
let loc = expr.pexp_loc in
7677
match expr with
7778
| [%expr ( + )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Add])
@@ -106,6 +107,18 @@ let binary_op expr =
106107
"+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
107108
(Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
108109

110+
let ternary_op expr =
111+
(* FIXME: get rid of this and use ternary_ops table instead. *)
112+
let loc = expr.pexp_loc in
113+
match expr with
114+
| [%expr where] -> ([%expr Shape.Pointwise_tern], [%expr Arrayjit.Ops.Where])
115+
| [%expr fma] -> ([%expr Shape.Compose_accumulate], [%expr Arrayjit.Ops.FMA])
116+
| _ ->
117+
( [%expr Shape.Pointwise_bin],
118+
Ast_builder.Default.pexp_extension ~loc
119+
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a ternary operator, one of: %s"
120+
"where, fma" )
121+
109122
type result = {
110123
vbs : value_binding Map.M(String).t;
111124
(** [vbs] are the bindings introduced by inline tensor declarations (aka. punning). These
@@ -206,6 +219,7 @@ let project_p_slot debug loc slot =
206219
| LHS -> [%expr p.project_lhs]
207220
| RHS1 -> [%expr p.project_rhs.(0)]
208221
| RHS2 -> [%expr p.project_rhs.(1)]
222+
| RHS3 -> [%expr p.project_rhs.(2)]
209223
| Nonslot ->
210224
Ast_builder.Default.pexp_extension ~loc
211225
@@ Location.error_extensionf ~loc
@@ -221,6 +235,7 @@ let project_p_dims debug loc slot =
221235
| LHS -> [%expr p.lhs_dims]
222236
| RHS1 -> [%expr p.rhs_dims.(0)]
223237
| RHS2 -> [%expr p.rhs_dims.(1)]
238+
| RHS3 -> [%expr p.rhs_dims.(2)]
224239
| Nonslot ->
225240
Ast_builder.Default.pexp_extension ~loc
226241
@@ Location.error_extensionf ~loc
@@ -344,6 +359,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
344359
| LHS -> [%pat? nondiff__lhs]
345360
| RHS1 -> [%pat? nondiff__rhs1]
346361
| RHS2 -> [%pat? nondiff__rhs2]
362+
| RHS3 -> [%pat? nondiff__rhs3]
347363
| Nonslot | Undet -> [%pat? nondiff__tensor]
348364
in
349365
let t = pat2expr v in
@@ -444,6 +460,74 @@ let translate (expr : expression) : result =
444460
{ vbs = no_vbs; typ = Tensor; slot = Undet; expr; array_opt_of_code = None }
445461
in
446462
let loop = transl ~bad_pun_hints in
463+
(* FIXME: collapse these (code reuse) *)
464+
let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
465+
() =
466+
let initialize_neutral, accu_op = assignment_op accu_op in
467+
let setup_l =
468+
setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope:true lhs
469+
in
470+
let _, tern_op = ternary_op tern_op in
471+
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
472+
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
473+
let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs3 in
474+
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
475+
let projections =
476+
match projections with
477+
| Some prjs -> prjs
478+
| None ->
479+
let lhs_dims = project_p_dims "LHS" lhs.pexp_loc setup_l.slot in
480+
let rhs1_dims = project_p_dims "RHS1" lhs.pexp_loc setup_r1.slot in
481+
let rhs2_dims = project_p_dims "RHS2" lhs.pexp_loc setup_r2.slot in
482+
let rhs3_dims = project_p_dims "RHS3" lhs.pexp_loc setup_r3.slot in
483+
let project_lhs = project_p_slot "LHS" lhs.pexp_loc setup_l.slot in
484+
let project_rhs1 = project_p_slot "RHS1" rhs1.pexp_loc setup_r1.slot in
485+
let project_rhs2 = project_p_slot "RHS2" rhs2.pexp_loc setup_r2.slot in
486+
let project_rhs3 = project_p_slot "RHS3" rhs3.pexp_loc setup_r3.slot in
487+
[%expr
488+
lazy
489+
(let p = Lazy.force projections in
490+
Arrayjit.Indexing.
491+
{
492+
product_space = p.product_space;
493+
product_iterators = p.product_iterators;
494+
lhs_dims = [%e lhs_dims];
495+
rhs_dims = [| [%e rhs1_dims]; [%e rhs2_dims]; [%e rhs3_dims] |];
496+
project_lhs = [%e project_lhs];
497+
project_rhs = [| [%e project_rhs1]; [%e project_rhs2]; [%e project_rhs3] |];
498+
debug_info =
499+
{
500+
p.debug_info with
501+
trace =
502+
( "ppx_cd " ^ [%e expr2string_or_empty accu_op] ^ " "
503+
^ [%e expr2string_or_empty tern_op],
504+
Arrayjit.Indexing.unique_debug_id () )
505+
:: p.debug_info.trace;
506+
};
507+
})]
508+
in
509+
(* TODO: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
510+
eliding the code. *)
511+
let body =
512+
[%expr
513+
Option.value ~default:Arrayjit.Assignments.Noop
514+
@@ Option.map [%e setup_l.array_opt] ~f:(fun lhs ->
515+
Option.map3 [%e setup_r1.array_opt] [%e setup_r2.array_opt] [%e setup_r2.array_opt]
516+
~f:(fun rhs1 rhs2 rhs3 ->
517+
Arrayjit.Assignments.Accum_ternop
518+
{
519+
initialize_neutral = [%e initialize_neutral];
520+
accum = [%e accu_op];
521+
lhs;
522+
op = [%e tern_op];
523+
rhs1;
524+
rhs2;
525+
rhs3;
526+
projections = [%e projections];
527+
}))]
528+
in
529+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
530+
in
447531
let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
448532
let initialize_neutral, accu_op = assignment_op accu_op in
449533
let setup_l =
@@ -561,6 +645,27 @@ let translate (expr : expression) : result =
561645
in
562646
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] body
563647
in
648+
let process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic =
649+
let initialize_neutral, accu_op = assignment_op accu_op in
650+
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
651+
let setup_r1 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs1 in
652+
let setup_r2 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs2 in
653+
let setup_r3 = setup_array ~punned ~bad_pun_hints ~is_lhs:false @@ loop ~proj_in_scope rhs3 in
654+
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
655+
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
656+
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
657+
let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
658+
let t3_expr, rhs3_is_grad, rhs3_is_merge = args_for ~loc setup_r3 in
659+
let body =
660+
[%expr
661+
Tensor.raw_ternop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op]
662+
~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e tern_op] ~t1:[%e t1_expr]
663+
~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr]
664+
~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~t3:[%e t3_expr]
665+
~rhs3_is_grad:[%e rhs3_is_grad] ~rhs3_is_merge:[%e rhs3_is_merge] ~logic:[%e logic]]
666+
in
667+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
668+
in
564669
let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
565670
let initialize_neutral, accu_op = assignment_op accu_op in
566671
let setup_l = setup_array ~punned ~bad_pun_hints ~is_lhs:true @@ loop ~proj_in_scope lhs in
@@ -655,6 +760,19 @@ let translate (expr : expression) : result =
655760
slot = RHS2;
656761
expr = [%expr Option.map t2.Tensor.diff ~f:(fun d -> d.Tensor.grad)];
657762
}
763+
| { pexp_desc = Pexp_ident { txt = Lident "rhs3"; _ }; _ } ->
764+
{ default_result with typ = Array; slot = RHS3 }
765+
| { pexp_desc = Pexp_ident { txt = Lident "t3"; _ }; _ } ->
766+
{ default_result with typ = Tensor; slot = RHS3 }
767+
| { pexp_desc = Pexp_ident { txt = Lident "v3"; _ }; _ } ->
768+
{ default_result with typ = Array; slot = RHS3; expr = [%expr t3.Tensor.value] }
769+
| { pexp_desc = Pexp_ident { txt = Lident "g3"; _ }; _ } ->
770+
{
771+
default_result with
772+
typ = Grad_of_tensor [%expr t3];
773+
slot = RHS3;
774+
expr = [%expr Option.map t3.Tensor.diff ~f:(fun d -> d.Tensor.grad)];
775+
}
658776
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_primitive_op op_ident ->
659777
default_result
660778
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -811,7 +929,15 @@ let translate (expr : expression) : result =
811929
[%e? accu_op]
812930
[%e? lhs]
813931
([%e? bin_op] [%e? rhs1] ([%e? rhs2] ~projections:[%e? projections]))] ->
932+
(* Note: when clause not needed here and below, it's an error if bin_op is not a primitive
933+
binary op. *)
814934
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~projections ~proj_in_scope:true ()
935+
| [%expr
936+
[%e? accu_op]
937+
[%e? lhs]
938+
([%e? tern_op] ([%e? rhs1], [%e? rhs2], [%e? rhs3]) ~projections:[%e? projections])] ->
939+
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~projections
940+
~proj_in_scope:true ()
815941
| [%expr
816942
[%e? accu_op]
817943
[%e? lhs]
@@ -852,6 +978,25 @@ let translate (expr : expression) : result =
852978
in
853979
let _, bin_op = binary_op bin_op in
854980
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
981+
| [%expr
982+
[%e? accu_op]
983+
[%e? lhs]
984+
([%e? tern_op]
985+
([%e? rhs1], [%e? rhs2], [%e? rhs3])
986+
~logic:[%e? { pexp_desc = Pexp_constant (Pconst_string (spec, s_loc, _)); _ }])] ->
987+
let logic =
988+
let loc = s_loc in
989+
if String.equal spec "." then [%expr Shape.Pointwise_bin]
990+
else if String.equal spec "@" then [%expr Shape.Compose]
991+
else
992+
Ast_builder.Default.pexp_extension ~loc
993+
@@ Location.error_extensionf ~loc
994+
"ppx_ocannl %%cd: expected <.> or <@>, found <%s> -- einsum notation for ternary \
995+
operators not supported yet, see issue #305"
996+
spec
997+
in
998+
let _, tern_op = binary_op tern_op in
999+
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
8551000
| [%expr
8561001
[%e? accu_op]
8571002
[%e? lhs]
@@ -882,6 +1027,13 @@ let translate (expr : expression) : result =
8821027
[%e? rhs2])]
8831028
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident && proj_in_scope ->
8841029
process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~proj_in_scope ()
1030+
| [%expr
1031+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1032+
[%e? lhs]
1033+
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
1034+
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
1035+
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident && proj_in_scope ->
1036+
process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~proj_in_scope ()
8851037
| [%expr
8861038
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
8871039
[%e? lhs]
@@ -905,6 +1057,14 @@ let translate (expr : expression) : result =
9051057
when is_assignment accu_ident && Hashtbl.mem binary_ops binop_ident ->
9061058
let logic, bin_op = binary_op bin_op in
9071059
process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic
1060+
| [%expr
1061+
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
1062+
[%e? lhs]
1063+
([%e? { pexp_desc = Pexp_ident { txt = Lident ternop_ident; _ }; _ } as tern_op]
1064+
([%e? rhs1], [%e? rhs2], [%e? rhs3]))]
1065+
when is_assignment accu_ident && Hashtbl.mem ternary_ops ternop_ident ->
1066+
let logic, tern_op = ternary_op tern_op in
1067+
process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic
9081068
| [%expr
9091069
[%e? { pexp_desc = Pexp_ident { txt = Lident accu_ident; _ }; _ } as accu_op]
9101070
[%e? lhs]

lib/shape.ml

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,8 @@ type transpose_type =
8585
| Batch_slice of Arrayjit.Indexing.static_symbol
8686
[@@deriving equal, sexp]
8787

88+
type ternary_type = Pointwise_tern | Compose_accumulate [@@deriving sexp, equal]
89+
8890
let identifier_multichar = Angstrom.take_while1 Char.is_alphanum
8991

9092
let opt_separators : _ Angstrom.t =
@@ -203,26 +205,19 @@ let einsum_of_spec spec =
203205
| Error msg ->
204206
raise @@ Utils.User_error ("Shape.einsum_of_spec: while parsing: " ^ spec ^ " error: " ^ msg)
205207

206-
(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the
207-
tensor. Axes are broadcast-expanded on a bottom-up update to fit the incoming shape. *)
208208
type logic =
209209
| Broadcast of compose_type * t * t
210-
(** Matches the shapes for a binary operation.
211-
212-
For [Broadcast (Einsum (ls1, ls2, ls3), s1, s2)], the labels of [s1] and [s2] must match
213-
according to the [ls1], [ls2] lineup, and the resulting shape inherits the labels
214-
according to the [ls3] lineup. *)
215210
| Transpose of transpose_type * t
216-
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of
217-
[s1], hence the name. *)
211+
| Broadcast_tern of ternary_type * t * t * t
218212
| Terminal of Arrayjit.Ops.init_op
219-
(** Extracts any available shape information from the initialization. E.g. for
220-
[File_mapped fn], opens the file [fn] to check its length. *)
221213
[@@deriving equal, sexp]
222214

223215
let logic_to_spec = function
224-
| Broadcast (Pointwise_bin, _, _) | Transpose (Pointwise_un, _) -> "."
225-
| Broadcast (Compose, _, _) -> "@"
216+
| Broadcast (Pointwise_bin, _, _)
217+
| Transpose (Pointwise_un, _)
218+
| Broadcast_tern (Pointwise_tern, _, _, _) ->
219+
"."
220+
| Broadcast (Compose, _, _) | Broadcast_tern (Compose_accumulate, _, _, _) -> "@"
226221
| Broadcast (Einsum spec, _, _) | Transpose (Permute spec, _) -> spec
227222
| Transpose (Transpose, _) -> "T"
228223
| Transpose (Batch_slice _, _) -> "@|"
@@ -430,6 +425,31 @@ let get_inequalities ({ shape = cur_sh; logic; id = _ } as _upd : update_step) :
430425
Row_ineq { cur = cur_sh.output; subr = sh1.output };
431426
Row_ineq { cur = cur_sh.output; subr = sh2.output };
432427
] )
428+
| Broadcast_tern (Compose_accumulate, sh1, sh2, sh3) ->
429+
( Row.dim_map_empty,
430+
[
431+
Row_ineq { cur = sh1.input; subr = sh2.output };
432+
Row_ineq { cur = cur_sh.batch; subr = sh1.batch };
433+
Row_ineq { cur = cur_sh.batch; subr = sh2.batch };
434+
Row_ineq { cur = cur_sh.input; subr = sh2.input };
435+
Row_ineq { cur = cur_sh.output; subr = sh1.output };
436+
Row_ineq { cur = cur_sh.batch; subr = sh3.batch };
437+
Row_ineq { cur = cur_sh.input; subr = sh3.input };
438+
Row_ineq { cur = cur_sh.output; subr = sh3.output };
439+
] )
440+
| Broadcast_tern (Pointwise_tern, sh1, sh2, sh3) ->
441+
( Row.dim_map_empty,
442+
[
443+
Row_ineq { cur = cur_sh.batch; subr = sh1.batch };
444+
Row_ineq { cur = cur_sh.batch; subr = sh2.batch };
445+
Row_ineq { cur = cur_sh.batch; subr = sh3.batch };
446+
Row_ineq { cur = cur_sh.input; subr = sh1.input };
447+
Row_ineq { cur = cur_sh.input; subr = sh2.input };
448+
Row_ineq { cur = cur_sh.input; subr = sh3.input };
449+
Row_ineq { cur = cur_sh.output; subr = sh1.output };
450+
Row_ineq { cur = cur_sh.output; subr = sh2.output };
451+
Row_ineq { cur = cur_sh.output; subr = sh3.output };
452+
] )
433453
| Transpose (Batch_slice { static_range; static_symbol }, sh) ->
434454
let slice_v = get_var () in
435455
let slice_var = Var slice_v in
@@ -553,6 +573,10 @@ let iter_shapes update_step ~f =
553573
| Broadcast (_, sh1, sh2) ->
554574
f sh1;
555575
f sh2
576+
| Broadcast_tern (_, sh1, sh2, sh3) ->
577+
f sh1;
578+
f sh2;
579+
f sh3
556580

557581
let all_rows update_step =
558582
let rows_sh sh = [ sh.batch; sh.input; sh.output ] in
@@ -562,6 +586,7 @@ let all_rows update_step =
562586
| Terminal _ -> []
563587
| Transpose (_, sh1) -> rows_sh sh1
564588
| Broadcast (_, sh1, sh2) -> rows_sh sh1 @ rows_sh sh2
589+
| Broadcast_tern (_, sh1, sh2, sh3) -> rows_sh sh1 @ rows_sh sh2 @ rows_sh sh3
565590

566591
let apply_env_t env sh =
567592
sh.batch <- Row.subst_row env sh.batch;
@@ -661,6 +686,10 @@ let fresh_proj_ids update =
661686
| Broadcast (_, sh1, sh2) ->
662687
fresh_shape sh1;
663688
fresh_shape sh2
689+
| Broadcast_tern (_, sh1, sh2, sh3) ->
690+
fresh_shape sh1;
691+
fresh_shape sh2;
692+
fresh_shape sh3
664693

665694
(** Computes the indexing into subtensors given the shape information of a tensor.
666695
[derive_projections] should only be invoked when the shapes are fully inferred already! *)
@@ -692,6 +721,7 @@ let derive_projections (update_step : update_step) : Idx.projections =
692721
| Terminal _ -> []
693722
| Transpose (_, sh) -> [ sh ]
694723
| Broadcast (_, sh1, sh2) -> [ sh1; sh2 ]
724+
| Broadcast_tern (_, sh1, sh2, sh3) -> [ sh1; sh2; sh3 ]
695725
in
696726
let lhs_dims = to_dims lhs in
697727
let rhs_dims = Array.of_list_map ~f:to_dims rhs in

lib/shape.mli

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,12 @@ type transpose_type =
8383
| Batch_slice of Arrayjit.Indexing.static_symbol (** Removes the leftmost batch axis. *)
8484
[@@deriving equal, sexp]
8585

86+
(** If you miss expressivity here, leave a note on {!{https://github.com/ahrefs/ocannl/issues/305}issue 305}. *)
87+
type ternary_type =
88+
| Pointwise_tern (** As in the operation [Where]. *)
89+
| Compose_accumulate (** As in the operation [FMA]. *)
90+
[@@deriving equal, sexp]
91+
8692
val make :
8793
?batch_dims:int list ->
8894
?input_dims:int list ->
@@ -123,6 +129,7 @@ type logic =
123129
| Transpose of transpose_type * t
124130
(** Permutes the axes of a shape. One case of [Transpose] is to swap inputs with outputs of
125131
[s1], hence the name. *)
132+
| Broadcast_tern of ternary_type * t * t * t (** Matches the shapes for a ternary operation. *)
126133
| Terminal of Arrayjit.Ops.init_op
127134
(** Extracts any available shape information from the initialization. E.g. for
128135
[File_mapped fn], opens the file [fn] to check its length. *)

lib/tensor.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,11 @@ let binop ~label ?compose_op ~op_asn ~grad_asn ?grad_spec t1 t2 =
264264
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~t2 ~projections in
265265
op ~label ?compose_op ?transpose_op:None ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1; t2 ]
266266

267+
let ternop ~label ?compose_op ~op_asn ~grad_asn ?grad_spec t1 t2 t3 =
268+
let op_asn ~v ~projections = op_asn ~v ~t1 ~t2 ~t3 ~projections in
269+
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~t2 ~t3 ~projections in
270+
op ~label ?compose_op ?transpose_op:None ~op_asn ~grad_asn ?grad_spec (Shape.make ()) [ t1; t2; t3 ]
271+
267272
let unop ~label ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 =
268273
let op_asn ~v ~projections = op_asn ~v ~t1 ~projections in
269274
let grad_asn ~v ~g ~projections = grad_asn ~v ~g ~t1 ~projections in

0 commit comments

Comments
 (0)