Skip to content

Commit 930bd0c

Browse files
committed
In progress: ppx_op pass the label for the primary tensor directly as a string list
In preparation for including a parameter tensor's label in the primary tensor label.
1 parent d1a2868 commit 930bd0c

File tree

3 files changed

+255
-205
lines changed

3 files changed

+255
-205
lines changed

lib/ppx_op.ml

Lines changed: 36 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ open Ppxlib
33
open Ppx_arrayjit.Ppx_helper
44
open Ppx_shared
55

6-
let ndarray_op ?ident_label ?axis_labels expr =
6+
let ndarray_op ?label ?axis_labels expr =
77
let loc = expr.pexp_loc in
88
let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
99
let edims dims = Ast_builder.Default.elist ~loc dims in
@@ -13,10 +13,8 @@ let ndarray_op ?ident_label ?axis_labels expr =
1313
| Some axis_labels -> [%expr TDSL.ndarray ~axis_labels:[%e axis_labels]]
1414
in
1515
[%expr
16-
[%e op]
17-
~label:[%e opt_pat2string_list ~loc ident_label]
18-
~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims]
19-
~output_dims:[%e edims output_dims] [%e values]]
16+
[%e op] ?label:[%e opt_expr ~loc label] ~batch_dims:[%e edims batch_dims]
17+
~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims] [%e values]]
2018

2119
let make_p ~has_config ~loc =
2220
if has_config then [%expr TDSL.param ~more_label:config.label] else [%expr TDSL.param]
@@ -60,12 +58,12 @@ let make_vb_nd ~has_config ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
6058
let vb = Ast_helper.Vb.mk ~loc pat v in
6159
(pat, vb)
6260

63-
let rec translate ~has_config ?ident_label expr =
61+
let rec translate ~has_config ?label expr =
6462
let loc = expr.pexp_loc in
6563
let loop = translate ~has_config in
6664
match expr with
6765
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
68-
(no_vbs, [%expr TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]])
66+
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] [%e expr]])
6967
| { pexp_desc = Pexp_constant (Pconst_integer _); _ } ->
7068
(no_vbs, [%expr TDSL.number (Float.of_int [%e expr])])
7169
| [%expr
@@ -74,10 +72,7 @@ let rec translate ~has_config ?ident_label expr =
7472
let axis =
7573
Ast_helper.Exp.constant ~loc:pexp_loc (Pconst_string (String.of_char ch, pexp_loc, None))
7674
in
77-
( no_vbs,
78-
[%expr
79-
TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] ~axis_label:[%e axis] [%e f]]
80-
)
75+
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] ~axis_label:[%e axis] [%e f]])
8176
| [%expr
8277
[%e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
8378
[%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -86,24 +81,21 @@ let rec translate ~has_config ?ident_label expr =
8681
in
8782
( no_vbs,
8883
[%expr
89-
TDSL.number
90-
~label:[%e opt_pat2string_list ~loc ident_label]
91-
~axis_label:[%e axis]
92-
(Float.of_int [%e i])] )
84+
TDSL.number ?label:[%e opt_expr ~loc label] ~axis_label:[%e axis] (Float.of_int [%e i])]
85+
)
9386
| [%expr
9487
[%e? expr1]
9588
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [%e? expr2]]
9689
when String.contains spec_str '>' ->
9790
let vbs1, e1 = loop expr1 in
9891
let vbs2, e2 = loop expr2 in
9992
( reduce_vbss [ vbs1; vbs2 ],
100-
[%expr
101-
TDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1] [%e e2]] )
93+
[%expr TDSL.einsum ?label:[%e opt_expr ~loc label] [%e spec] [%e e1] [%e e2]] )
10294
| [%expr
10395
[%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
10496
when String.contains spec_str '>' ->
10597
let vbs1, e1 = loop expr1 in
106-
(vbs1, [%expr TDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1]])
98+
(vbs1, [%expr TDSL.einsum1 ?label:[%e opt_expr ~loc label] [%e spec] [%e e1]])
10799
| [%expr
108100
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
109101
[%e?
@@ -129,45 +121,38 @@ let rec translate ~has_config ?ident_label expr =
129121
(Map.singleton (module String) ident vb, pat2expr pat)
130122
| { pexp_desc = Pexp_array _; _ }
131123
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
132-
(no_vbs, ndarray_op ?ident_label expr)
124+
(no_vbs, ndarray_op ?label expr)
133125
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
134126
(* We need to hardcode these two patterns to prevent the numbers from being converted to
135127
tensors. *)
136128
let vbs, e1 = loop expr1 in
137-
( vbs,
138-
[%expr
139-
TDSL.O.( **. )
140-
~label:[%e opt_pat2string_list ~loc ident_label]
141-
[%e e1]
142-
(Float.of_int [%e i])] )
129+
(vbs, [%expr TDSL.O.( **. ) ?label:[%e opt_expr ~loc label] [%e e1] (Float.of_int [%e i])])
143130
| [%expr [%e? expr1] **. [%e? expr2]] ->
144131
let vbs, e1 = loop expr1 in
145-
( vbs,
146-
[%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]]
147-
)
132+
(vbs, [%expr TDSL.O.( **. ) ?label:[%e opt_expr ~loc label] [%e e1] [%e expr2]])
148133
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
149-
let vbs1, e1 = loop ?ident_label expr1 in
134+
let vbs1, e1 = loop ?label expr1 in
150135
let vbs2, e2 = loop expr2 in
151136
let vbs3, e3 = loop expr3 in
152137
(reduce_vbss [ vbs1; vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
153138
| [%expr [%e? expr1] [%e? expr2]] ->
154-
let vbs1, e1 = loop ?ident_label expr1 in
139+
let vbs1, e1 = loop ?label expr1 in
155140
let vbs2, e2 = loop expr2 in
156141
(Map.merge_skewed vbs1 vbs2 ~combine:(fun ~key:_ _v1 v2 -> v2), [%expr [%e e1] [%e e2]])
157142
| [%expr fun ~config -> [%e? body]] ->
158-
let vbs, body = translate ~has_config:true ?ident_label body in
143+
let vbs, body = translate ~has_config:true ?label body in
159144
(no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs body]])
160145
| [%expr fun ~(config : [%typ? config_ty]) -> [%e? body]] ->
161-
let vbs, body = translate ~has_config:true ?ident_label body in
146+
let vbs, body = translate ~has_config:true ?label body in
162147
(no_vbs, [%expr fun ~(config : [%typ ty]) -> [%e let_opt ~loc vbs body]])
163148
| [%expr fun [%p? pat] -> [%e? body]] ->
164-
let vbs, body = loop ?ident_label body in
149+
let vbs, body = loop ?label body in
165150
(vbs, [%expr fun [%p pat] -> [%e body]])
166151
| [%expr
167152
while [%e? test_expr] do
168153
[%e? body_expr]
169154
done] ->
170-
let vbs, body = loop ?ident_label body_expr in
155+
let vbs, body = loop ?label body_expr in
171156
( vbs,
172157
[%expr
173158
while [%e test_expr] do
@@ -177,7 +162,7 @@ let rec translate ~has_config ?ident_label expr =
177162
for [%p? pat] = [%e? init] to [%e? final] do
178163
[%e? body_expr]
179164
done] ->
180-
let vbs, body = loop ?ident_label body_expr in
165+
let vbs, body = loop ?label body_expr in
181166
( vbs,
182167
[%expr
183168
for [%p pat] = [%e init] to [%e final] do
@@ -187,7 +172,7 @@ let rec translate ~has_config ?ident_label expr =
187172
for [%p? pat] = [%e? init] downto [%e? final] do
188173
[%e? body_expr]
189174
done] ->
190-
let vbs, body = loop ?ident_label body_expr in
175+
let vbs, body = loop ?label body_expr in
191176
( vbs,
192177
[%expr
193178
for [%p pat] = [%e init] downto [%e final] do
@@ -197,48 +182,52 @@ let rec translate ~has_config ?ident_label expr =
197182
[%e? expr1];
198183
[%e? expr2]] ->
199184
let vbs1, e1 = loop expr1 in
200-
let vbs2, e2 = loop ?ident_label expr2 in
185+
let vbs2, e2 = loop ?label expr2 in
201186
( reduce_vbss [ vbs1; vbs2 ],
202187
[%expr
203188
[%e e1];
204189
[%e e2]] )
205190
| [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
206-
let vbs2, e2 = loop ?ident_label expr2 in
207-
let vbs3, e3 = loop ?ident_label expr3 in
191+
let vbs2, e2 = loop ?label expr2 in
192+
let vbs3, e3 = loop ?label expr3 in
208193
(reduce_vbss [ vbs2; vbs3 ], [%expr if [%e expr1] then [%e e2] else [%e e3]])
209194
| [%expr if [%e? expr1] then [%e? expr2]] ->
210-
let vbs2, e2 = loop ?ident_label expr2 in
195+
let vbs2, e2 = loop ?label expr2 in
211196
(vbs2, [%expr if [%e expr1] then [%e e2]])
212197
| { pexp_desc = Pexp_match (expr1, cases); _ } ->
213198
let vbss, cases =
214199
List.unzip
215200
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
216-
let vbs, pc_rhs = loop ?ident_label pc_rhs in
201+
let vbs, pc_rhs = loop ?label pc_rhs in
217202
(vbs, { c with pc_rhs }))
218203
in
219204
(reduce_vbss vbss, { expr with pexp_desc = Pexp_match (expr1, cases) })
220205
| { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
221206
let vbss1, bindings =
222207
List.unzip
223208
@@ List.map bindings ~f:(fun binding ->
224-
let vbs, pvb_expr = loop ~ident_label:binding.pvb_pat binding.pvb_expr in
209+
let vbs, pvb_expr =
210+
loop ~label:[%expr [ [%e pat2string binding.pvb_pat] ]] binding.pvb_expr
211+
in
225212
(vbs, { binding with pvb_expr }))
226213
in
227-
let vbs2, body = loop ?ident_label body in
214+
let vbs2, body = loop ?label body in
228215
let all_bindings = (Map.data @@ reduce_vbss vbss1) @ bindings @ Map.data vbs2 in
229216
(no_vbs, { expr with pexp_desc = Pexp_let (recflag, all_bindings, body) })
230217
| { pexp_desc = Pexp_open (decl, body); _ } ->
231-
let vbs, body = loop ?ident_label body in
218+
let vbs, body = loop ?label body in
232219
(vbs, { expr with pexp_desc = Pexp_open (decl, body) })
233220
| { pexp_desc = Pexp_letmodule (name, module_expr, body); _ } ->
234-
let vbs, body = loop ?ident_label body in
221+
let vbs, body = loop ?label body in
235222
(vbs, { expr with pexp_desc = Pexp_letmodule (name, module_expr, body) })
236223
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
237-
(no_vbs, [%expr [%e expr] ~label:[%e opt_pat2string_list ~loc ident_label]])
224+
(no_vbs, [%expr [%e expr] ?label:[%e opt_expr ~loc label]])
238225
| expr -> (no_vbs, expr)
239226

240227
let translate ?ident_label expr =
241-
let vbs, expr = translate ~has_config:false ?ident_label expr in
228+
let vbs, expr =
229+
translate ~has_config:false ~label:(opt_pat2string_list ~loc:expr.pexp_loc ident_label) expr
230+
in
242231
let loc = expr.pexp_loc in
243232
( vbs,
244233
match ident_label with

0 commit comments

Comments
 (0)