@@ -3,7 +3,7 @@ open Ppxlib
33open Ppx_arrayjit.Ppx_helper
44open Ppx_shared
55
6- let ndarray_op ?ident_label ?axis_labels expr =
6+ let ndarray_op ?label ?axis_labels expr =
77 let loc = expr.pexp_loc in
88 let values, batch_dims, output_dims, input_dims = ndarray_constant expr in
99 let edims dims = Ast_builder.Default. elist ~loc dims in
@@ -13,10 +13,8 @@ let ndarray_op ?ident_label ?axis_labels expr =
1313 | Some axis_labels -> [% expr TDSL. ndarray ~axis_labels: [% e axis_labels]]
1414 in
1515 [% expr
16- [% e op]
17- ~label: [% e opt_pat2string_list ~loc ident_label]
18- ~batch_dims: [% e edims batch_dims] ~input_dims: [% e edims input_dims]
19- ~output_dims: [% e edims output_dims] [% e values]]
16+ [% e op] ?label:[% e opt_expr ~loc label] ~batch_dims: [% e edims batch_dims]
17+ ~input_dims: [% e edims input_dims] ~output_dims: [% e edims output_dims] [% e values]]
2018
2119let make_p ~has_config ~loc =
2220 if has_config then [% expr TDSL. param ~more_label: config.label] else [% expr TDSL. param]
@@ -60,12 +58,12 @@ let make_vb_nd ~has_config ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
6058 let vb = Ast_helper.Vb. mk ~loc pat v in
6159 (pat, vb)
6260
63- let rec translate ~has_config ?ident_label expr =
61+ let rec translate ~has_config ?label expr =
6462 let loc = expr.pexp_loc in
6563 let loop = translate ~has_config in
6664 match expr with
6765 | { pexp_desc = Pexp_constant (Pconst_float _ ); _ } ->
68- (no_vbs, [% expr TDSL. number ~ label: [% e opt_pat2string_list ~loc ident_label ] [% e expr]])
66+ (no_vbs, [% expr TDSL. number ? label:[% e opt_expr ~loc label ] [% e expr]])
6967 | { pexp_desc = Pexp_constant (Pconst_integer _ ); _ } ->
7068 (no_vbs, [% expr TDSL. number (Float. of_int [% e expr])])
7169 | [% expr
@@ -74,10 +72,7 @@ let rec translate ~has_config ?ident_label expr =
7472 let axis =
7573 Ast_helper.Exp. constant ~loc: pexp_loc (Pconst_string (String. of_char ch, pexp_loc, None ))
7674 in
77- ( no_vbs,
78- [% expr
79- TDSL. number ~label: [% e opt_pat2string_list ~loc ident_label] ~axis_label: [% e axis] [% e f]]
80- )
75+ (no_vbs, [% expr TDSL. number ?label:[% e opt_expr ~loc label] ~axis_label: [% e axis] [% e f]])
8176 | [% expr
8277 [% e? { pexp_desc = Pexp_constant (Pconst_char ch); pexp_loc; _ }]
8378 [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
@@ -86,24 +81,21 @@ let rec translate ~has_config ?ident_label expr =
8681 in
8782 ( no_vbs,
8883 [% expr
89- TDSL. number
90- ~label: [% e opt_pat2string_list ~loc ident_label]
91- ~axis_label: [% e axis]
92- (Float. of_int [% e i])] )
84+ TDSL. number ?label:[% e opt_expr ~loc label] ~axis_label: [% e axis] (Float. of_int [% e i])]
85+ )
9386 | [% expr
9487 [% e? expr1]
9588 *+ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [% e? expr2]]
9689 when String. contains spec_str '>' ->
9790 let vbs1, e1 = loop expr1 in
9891 let vbs2, e2 = loop expr2 in
9992 ( reduce_vbss [ vbs1; vbs2 ],
100- [% expr
101- TDSL. einsum ~label: [% e opt_pat2string_list ~loc ident_label] [% e spec] [% e e1] [% e e2]] )
93+ [% expr TDSL. einsum ?label:[% e opt_expr ~loc label] [% e spec] [% e e1] [% e e2]] )
10294 | [% expr
10395 [% e? expr1] ++ [% e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
10496 when String. contains spec_str '>' ->
10597 let vbs1, e1 = loop expr1 in
106- (vbs1, [% expr TDSL. einsum1 ~ label: [% e opt_pat2string_list ~loc ident_label ] [% e spec] [% e e1]])
98+ (vbs1, [% expr TDSL. einsum1 ? label:[% e opt_expr ~loc label ] [% e spec] [% e e1]])
10799 | [% expr
108100 [% e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
109101 [% e?
@@ -129,45 +121,38 @@ let rec translate ~has_config ?ident_label expr =
129121 (Map. singleton (module String ) ident vb, pat2expr pat)
130122 | { pexp_desc = Pexp_array _; _ }
131123 | { pexp_desc = Pexp_construct ({ txt = Lident "::" ; _ } , _ ); _ } ->
132- (no_vbs, ndarray_op ?ident_label expr)
124+ (no_vbs, ndarray_op ?label expr)
133125 | [% expr [% e? expr1] **. [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
134126 (* We need to hardcode these two patterns to prevent the numbers from being converted to
135127 tensors. *)
136128 let vbs, e1 = loop expr1 in
137- ( vbs,
138- [% expr
139- TDSL.O. ( **. )
140- ~label: [% e opt_pat2string_list ~loc ident_label]
141- [% e e1]
142- (Float. of_int [% e i])] )
129+ (vbs, [% expr TDSL.O. ( **. ) ?label:[% e opt_expr ~loc label] [% e e1] (Float. of_int [% e i])])
143130 | [% expr [% e? expr1] **. [% e? expr2]] ->
144131 let vbs, e1 = loop expr1 in
145- ( vbs,
146- [% expr TDSL.O. ( **. ) ~label: [% e opt_pat2string_list ~loc ident_label] [% e e1] [% e expr2]]
147- )
132+ (vbs, [% expr TDSL.O. ( **. ) ?label:[% e opt_expr ~loc label] [% e e1] [% e expr2]])
148133 | [% expr [% e? expr1] [% e? expr2] [% e? expr3]] ->
149- let vbs1, e1 = loop ?ident_label expr1 in
134+ let vbs1, e1 = loop ?label expr1 in
150135 let vbs2, e2 = loop expr2 in
151136 let vbs3, e3 = loop expr3 in
152137 (reduce_vbss [ vbs1; vbs2; vbs3 ], [% expr [% e e1] [% e e2] [% e e3]])
153138 | [% expr [% e? expr1] [% e? expr2]] ->
154- let vbs1, e1 = loop ?ident_label expr1 in
139+ let vbs1, e1 = loop ?label expr1 in
155140 let vbs2, e2 = loop expr2 in
156141 (Map. merge_skewed vbs1 vbs2 ~combine: (fun ~key :_ _v1 v2 -> v2), [% expr [% e e1] [% e e2]])
157142 | [% expr fun ~config -> [% e? body]] ->
158- let vbs, body = translate ~has_config: true ?ident_label body in
143+ let vbs, body = translate ~has_config: true ?label body in
159144 (no_vbs, [% expr fun ~config -> [% e let_opt ~loc vbs body]])
160145 | [% expr fun ~(config : [%typ? config_ty] ) -> [% e? body]] ->
161- let vbs, body = translate ~has_config: true ?ident_label body in
146+ let vbs, body = translate ~has_config: true ?label body in
162147 (no_vbs, [% expr fun ~(config : [%typ ty] ) -> [% e let_opt ~loc vbs body]])
163148 | [% expr fun [%p ? pat ] -> [% e? body]] ->
164- let vbs, body = loop ?ident_label body in
149+ let vbs, body = loop ?label body in
165150 (vbs, [% expr fun [%p pat ] -> [% e body]])
166151 | [% expr
167152 while [% e? test_expr] do
168153 [% e? body_expr]
169154 done ] ->
170- let vbs, body = loop ?ident_label body_expr in
155+ let vbs, body = loop ?label body_expr in
171156 ( vbs,
172157 [% expr
173158 while [% e test_expr] do
@@ -177,7 +162,7 @@ let rec translate ~has_config ?ident_label expr =
177162 for [% p? pat] = [% e? init] to [% e? final] do
178163 [% e? body_expr]
179164 done ] ->
180- let vbs, body = loop ?ident_label body_expr in
165+ let vbs, body = loop ?label body_expr in
181166 ( vbs,
182167 [% expr
183168 for [% p pat] = [% e init] to [% e final] do
@@ -187,7 +172,7 @@ let rec translate ~has_config ?ident_label expr =
187172 for [% p? pat] = [% e? init] downto [% e? final] do
188173 [% e? body_expr]
189174 done ] ->
190- let vbs, body = loop ?ident_label body_expr in
175+ let vbs, body = loop ?label body_expr in
191176 ( vbs,
192177 [% expr
193178 for [% p pat] = [% e init] downto [% e final] do
@@ -197,48 +182,52 @@ let rec translate ~has_config ?ident_label expr =
197182 [% e? expr1];
198183 [% e? expr2]] ->
199184 let vbs1, e1 = loop expr1 in
200- let vbs2, e2 = loop ?ident_label expr2 in
185+ let vbs2, e2 = loop ?label expr2 in
201186 ( reduce_vbss [ vbs1; vbs2 ],
202187 [% expr
203188 [% e e1];
204189 [% e e2]] )
205190 | [% expr if [% e? expr1] then [% e? expr2] else [% e? expr3]] ->
206- let vbs2, e2 = loop ?ident_label expr2 in
207- let vbs3, e3 = loop ?ident_label expr3 in
191+ let vbs2, e2 = loop ?label expr2 in
192+ let vbs3, e3 = loop ?label expr3 in
208193 (reduce_vbss [ vbs2; vbs3 ], [% expr if [% e expr1] then [% e e2] else [% e e3]])
209194 | [% expr if [% e? expr1] then [% e? expr2]] ->
210- let vbs2, e2 = loop ?ident_label expr2 in
195+ let vbs2, e2 = loop ?label expr2 in
211196 (vbs2, [% expr if [% e expr1] then [% e e2]])
212197 | { pexp_desc = Pexp_match (expr1 , cases ); _ } ->
213198 let vbss, cases =
214199 List. unzip
215200 @@ List. map cases ~f: (fun ({ pc_rhs; _ } as c ) ->
216- let vbs, pc_rhs = loop ?ident_label pc_rhs in
201+ let vbs, pc_rhs = loop ?label pc_rhs in
217202 (vbs, { c with pc_rhs }))
218203 in
219204 (reduce_vbss vbss, { expr with pexp_desc = Pexp_match (expr1, cases) })
220205 | { pexp_desc = Pexp_let (recflag , bindings , body ); _ } ->
221206 let vbss1, bindings =
222207 List. unzip
223208 @@ List. map bindings ~f: (fun binding ->
224- let vbs, pvb_expr = loop ~ident_label: binding.pvb_pat binding.pvb_expr in
209+ let vbs, pvb_expr =
210+ loop ~label: [% expr [ [% e pat2string binding.pvb_pat] ]] binding.pvb_expr
211+ in
225212 (vbs, { binding with pvb_expr }))
226213 in
227- let vbs2, body = loop ?ident_label body in
214+ let vbs2, body = loop ?label body in
228215 let all_bindings = (Map. data @@ reduce_vbss vbss1) @ bindings @ Map. data vbs2 in
229216 (no_vbs, { expr with pexp_desc = Pexp_let (recflag, all_bindings, body) })
230217 | { pexp_desc = Pexp_open (decl , body ); _ } ->
231- let vbs, body = loop ?ident_label body in
218+ let vbs, body = loop ?label body in
232219 (vbs, { expr with pexp_desc = Pexp_open (decl, body) })
233220 | { pexp_desc = Pexp_letmodule (name , module_expr , body ); _ } ->
234- let vbs, body = loop ?ident_label body in
221+ let vbs, body = loop ?label body in
235222 (vbs, { expr with pexp_desc = Pexp_letmodule (name, module_expr, body) })
236223 | { pexp_desc = Pexp_ident { txt = Lident op_ident ; _ } ; _ } when is_operator op_ident ->
237- (no_vbs, [% expr [% e expr] ~ label: [% e opt_pat2string_list ~loc ident_label ]])
224+ (no_vbs, [% expr [% e expr] ? label:[% e opt_expr ~loc label ]])
238225 | expr -> (no_vbs, expr)
239226
240227let translate ?ident_label expr =
241- let vbs, expr = translate ~has_config: false ?ident_label expr in
228+ let vbs, expr =
229+ translate ~has_config: false ~label: (opt_pat2string_list ~loc: expr.pexp_loc ident_label) expr
230+ in
242231 let loc = expr.pexp_loc in
243232 ( vbs,
244233 match ident_label with
0 commit comments