@@ -990,6 +990,24 @@ let translate ?ident_label (expr : expression) : result =
990990 | [% expr [% e? expr1] **. [% e? expr2]] ->
991991 let res1 = loop ~proj_in_scope expr1 in
992992 { res1 with typ = Tensor ; expr = [% expr NTDSL.O. ( **. ) [% e res1.expr] [% e expr2]] }
993+ | [% expr
994+ [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
995+ [% e? expr1]
996+ ([% e? { pexp_desc = Pexp_ident _; _ } as spec] [% e? expr2])]
997+ when Hashtbl. mem einsum_binary_ops op_ident ->
998+ let res1 = loop ~proj_in_scope expr1 in
999+ let res2 = loop ~proj_in_scope expr2 in
1000+ let slot = List. hd_exn @@ List. sort [ res1.slot; res2.slot ] ~compare: compare_slots in
1001+ {
1002+ vbs = reduce_vbss [ res1.vbs; res2.vbs ];
1003+ typ = Tensor ;
1004+ slot;
1005+ expr =
1006+ [% expr
1007+ [% e Hashtbl. find_exn einsum_binary_ops op_ident loc]
1008+ [% e spec] [% e res1.expr] [% e res2.expr]];
1009+ array_opt_of_code = None ;
1010+ }
9931011 | [% expr
9941012 [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
9951013 [% e? expr1]
@@ -1031,7 +1049,19 @@ let translate ?ident_label (expr : expression) : result =
10311049 ~capture_dims: [% e capture_dims_expr] [% e spec] [% e res1.expr] [% e res2.expr]];
10321050 array_opt_of_code = None ;
10331051 }
1034- | [% expr
1052+ | [% expr
1053+ [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
1054+ [% e? expr1]
1055+ [% e? { pexp_desc = Pexp_ident _; _ } as spec]]
1056+ when Hashtbl. mem einsum_unary_ops op_ident ->
1057+ let res1 = loop ~proj_in_scope expr1 in
1058+ {
1059+ res1 with
1060+ typ = Tensor ;
1061+ expr =
1062+ [% expr [% e Hashtbl. find_exn einsum_unary_ops op_ident loc] [% e spec] [% e res1.expr]];
1063+ }
1064+ | [% expr
10351065 [% e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
10361066 [% e? expr1]
10371067 [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]]
0 commit comments