Skip to content

Commit 48aecf5

Browse files
committed
Untested: allow non-literal specification strings for einsum-like operators (%cd and %op)
1 parent 82ea915 commit 48aecf5

File tree

2 files changed

+52
-1
lines changed

2 files changed

+52
-1
lines changed

lib/ppx_cd.ml

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -990,6 +990,24 @@ let translate ?ident_label (expr : expression) : result =
990990
| [%expr [%e? expr1] **. [%e? expr2]] ->
991991
let res1 = loop ~proj_in_scope expr1 in
992992
{ res1 with typ = Tensor; expr = [%expr NTDSL.O.( **. ) [%e res1.expr] [%e expr2]] }
993+
| [%expr
994+
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
995+
[%e? expr1]
996+
([%e? { pexp_desc = Pexp_ident _; _ } as spec] [%e? expr2])]
997+
when Hashtbl.mem einsum_binary_ops op_ident ->
998+
let res1 = loop ~proj_in_scope expr1 in
999+
let res2 = loop ~proj_in_scope expr2 in
1000+
let slot = List.hd_exn @@ List.sort [ res1.slot; res2.slot ] ~compare:compare_slots in
1001+
{
1002+
vbs = reduce_vbss [ res1.vbs; res2.vbs ];
1003+
typ = Tensor;
1004+
slot;
1005+
expr =
1006+
[%expr
1007+
[%e Hashtbl.find_exn einsum_binary_ops op_ident loc]
1008+
[%e spec] [%e res1.expr] [%e res2.expr]];
1009+
array_opt_of_code = None;
1010+
}
9931011
| [%expr
9941012
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
9951013
[%e? expr1]
@@ -1031,7 +1049,19 @@ let translate ?ident_label (expr : expression) : result =
10311049
~capture_dims:[%e capture_dims_expr] [%e spec] [%e res1.expr] [%e res2.expr]];
10321050
array_opt_of_code = None;
10331051
}
1034-
| [%expr
1052+
| [%expr
1053+
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
1054+
[%e? expr1]
1055+
[%e? { pexp_desc = Pexp_ident _; _ } as spec]]
1056+
when Hashtbl.mem einsum_unary_ops op_ident ->
1057+
let res1 = loop ~proj_in_scope expr1 in
1058+
{
1059+
res1 with
1060+
typ = Tensor;
1061+
expr =
1062+
[%expr [%e Hashtbl.find_exn einsum_unary_ops op_ident loc] [%e spec] [%e res1.expr]];
1063+
}
1064+
| [%expr
10351065
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
10361066
[%e? expr1]
10371067
[%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ }]]

lib/ppx_op.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -183,6 +183,17 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
183183
[%expr
184184
TDSL.number ?label:[%e opt_expr ~loc label] ~axis_label:[%e axis] (Float.of_int [%e i])]
185185
)
186+
| [%expr
187+
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
188+
[%e? expr1]
189+
([%e? { pexp_desc = Pexp_ident _; _ } as spec] [%e? expr2])]
190+
when Hashtbl.mem einsum_binary_ops op_ident ->
191+
let vbs1, e1 = loop expr1 in
192+
let vbs2, e2 = loop expr2 in
193+
( reduce_vbss [ vbs1; vbs2 ],
194+
[%expr
195+
[%e Hashtbl.find_exn einsum_binary_ops op_ident loc]
196+
?label:[%e opt_expr ~loc label] [%e spec] [%e e1] [%e e2]] )
186197
| [%expr
187198
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
188199
[%e? expr1]
@@ -195,6 +206,16 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
195206
[%expr
196207
[%e Hashtbl.find_exn einsum_binary_ops op_ident loc]
197208
?label:[%e opt_expr ~loc label] [%e spec] [%e e1] [%e e2]] )
209+
| [%expr
210+
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
211+
[%e? expr1]
212+
[%e? { pexp_desc = Pexp_ident _; _ } as spec]]
213+
when Hashtbl.mem einsum_unary_ops op_ident ->
214+
let vbs1, e1 = loop expr1 in
215+
( vbs1,
216+
[%expr
217+
[%e Hashtbl.find_exn einsum_unary_ops op_ident loc]
218+
?label:[%e opt_expr ~loc label] [%e spec] [%e e1]] )
198219
| [%expr
199220
[%e? { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ }]
200221
[%e? expr1]

0 commit comments

Comments
 (0)