Skip to content

Commit 1ed3279

Browse files
committed
Untested: new %op functionality: refine the param's label when under ~config
1 parent 5d820fb commit 1ed3279

File tree

4 files changed

+61
-46
lines changed

4 files changed

+61
-46
lines changed

lib/ppx_op.ml

Lines changed: 43 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,27 @@ let ndarray_op ?ident_label ?axis_labels expr =
1818
~batch_dims:[%e edims batch_dims] ~input_dims:[%e edims input_dims]
1919
~output_dims:[%e edims output_dims] [%e values]]
2020

21-
let make_vb ?value ~loc ~str_loc ~ident string =
21+
let make_p ~has_config ~loc =
22+
if has_config then [%expr TDSL.param ~more_label:config.label] else [%expr TDSL.param]
23+
24+
let make_vb ?value ~has_config ~loc ~str_loc ~ident string =
2225
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
2326
let value = match value with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
24-
let v = [%expr TDSL.param ?values:[%e value] [%e string]] in
27+
let v = [%expr [%e make_p ~has_config ~loc] ?values:[%e value] [%e string]] in
2528
let vb = Ast_helper.Vb.mk ~loc pat v in
2629
(pat, vb)
2730

28-
let make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc string =
31+
let make_vb_dims ~has_config ~loc ~str_loc ~ident ~dims ~dims_loc string =
2932
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
3033
let dims =
3134
let loc = dims_loc in
3235
List.fold_right dims ~init:[%expr []] ~f:(fun d ds -> [%expr [%e d] :: [%e ds]])
3336
in
34-
let v = [%expr TDSL.param ~output_dims:[%e dims] [%e string]] in
37+
let v = [%expr [%e make_p ~has_config ~loc] ~output_dims:[%e dims] [%e string]] in
3538
let vb = Ast_helper.Vb.mk ~loc pat v in
3639
(pat, vb)
3740

38-
let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
41+
let make_vb_nd ~has_config ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
3942
let pat = Ast_helper.Pat.var ~loc { loc = str_loc; txt = ident } in
4043
let values, batch_dims, output_dims, input_dims = ndarray_constant init_nd in
4144
let v =
@@ -47,8 +50,8 @@ let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
4750
let edims dims = Ast_builder.Default.elist ~loc dims in
4851
let op =
4952
match axis_labels with
50-
| None -> [%expr TDSL.param]
51-
| Some axis_labels -> [%expr TDSL.param ~axis_labels:[%e axis_labels]]
53+
| None -> make_p ~has_config ~loc
54+
| Some axis_labels -> [%expr [%e make_p ~has_config ~loc] ~axis_labels:[%e axis_labels]]
5255
in
5356
[%expr
5457
[%e op] ~input_dims:[%e edims input_dims] ~output_dims:[%e edims output_dims]
@@ -57,8 +60,9 @@ let make_vb_nd ~loc ~str_loc ?axis_labels ~ident ~init_nd string =
5760
let vb = Ast_helper.Vb.mk ~loc pat v in
5861
(pat, vb)
5962

60-
let rec translate ?ident_label expr =
63+
let rec translate ~has_config ?ident_label expr =
6164
let loc = expr.pexp_loc in
65+
let loop = translate ~has_config in
6266
match expr with
6367
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
6468
(no_vbs, [%expr TDSL.number ~label:[%e opt_pat2string_list ~loc ident_label] [%e expr]])
@@ -90,77 +94,76 @@ let rec translate ?ident_label expr =
9094
[%e? expr1]
9195
*+ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec] [%e? expr2]]
9296
when String.contains spec_str '>' ->
93-
let vbs1, e1 = translate expr1 in
94-
let vbs2, e2 = translate expr2 in
97+
let vbs1, e1 = loop expr1 in
98+
let vbs2, e2 = loop expr2 in
9599
( reduce_vbss [ vbs1; vbs2 ],
96100
[%expr
97101
TDSL.einsum ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1] [%e e2]] )
98102
| [%expr
99103
[%e? expr1] ++ [%e? { pexp_desc = Pexp_constant (Pconst_string (spec_str, _, _)); _ } as spec]]
100104
when String.contains spec_str '>' ->
101-
let vbs1, e1 = translate expr1 in
105+
let vbs1, e1 = loop expr1 in
102106
(vbs1, [%expr TDSL.einsum1 ~label:[%e opt_pat2string_list ~loc ident_label] [%e spec] [%e e1]])
103107
| [%expr
104108
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
105109
[%e?
106110
( { pexp_desc = Pexp_constant (Pconst_integer _); pexp_loc = dims_loc; _ }
107111
| { pexp_desc = Pexp_ident _; pexp_loc = dims_loc; _ } ) as d]] ->
108-
let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims:[ d ] ~dims_loc s in
112+
let pat, vb = make_vb_dims ~has_config ~loc ~str_loc ~ident ~dims:[ d ] ~dims_loc s in
109113
(Map.singleton (module String) ident vb, pat2expr pat)
110114
| [%expr
111115
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
112116
[%e?
113117
( { pexp_desc = Pexp_array _; _ }
114118
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ) as init_nd]] ->
115-
let pat, vb = make_vb_nd ~loc ~str_loc ~ident ~init_nd s in
119+
let pat, vb = make_vb_nd ~has_config ~loc ~str_loc ~ident ~init_nd s in
116120
(Map.singleton (module String) ident vb, pat2expr pat)
117121
| [%expr
118122
[%e? { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } as s]
119123
[%e? { pexp_desc = Pexp_tuple dims; pexp_loc = dims_loc; _ }]] ->
120-
let pat, vb = make_vb_dims ~loc ~str_loc ~ident ~dims ~dims_loc s in
124+
let pat, vb = make_vb_dims ~has_config ~loc ~str_loc ~ident ~dims ~dims_loc s in
121125
(Map.singleton (module String) ident vb, pat2expr pat)
122126
| { pexp_desc = Pexp_constant (Pconst_string (ident, str_loc, _)); _ } ->
123-
let pat, vb = make_vb ~loc ~str_loc ~ident expr in
127+
let pat, vb = make_vb ~has_config ~loc ~str_loc ~ident expr in
124128
(Map.singleton (module String) ident vb, pat2expr pat)
125129
| { pexp_desc = Pexp_array _; _ }
126130
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
127131
(no_vbs, ndarray_op ?ident_label expr)
128132
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
129133
(* We need to hardcode these two patterns to prevent the numbers from being converted to
130134
tensors. *)
131-
let vbs, e1 = translate expr1 in
135+
let vbs, e1 = loop expr1 in
132136
( vbs,
133137
[%expr
134138
TDSL.O.( **. )
135139
~label:[%e opt_pat2string_list ~loc ident_label]
136140
[%e e1]
137141
(Float.of_int [%e i])] )
138142
| [%expr [%e? expr1] **. [%e? expr2]] ->
139-
let vbs, e1 = translate expr1 in
143+
let vbs, e1 = loop expr1 in
140144
( vbs,
141145
[%expr TDSL.O.( **. ) ~label:[%e opt_pat2string_list ~loc ident_label] [%e e1] [%e expr2]]
142146
)
143147
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
144-
let vbs1, e1 = translate ?ident_label expr1 in
145-
let vbs2, e2 = translate expr2 in
146-
let vbs3, e3 = translate expr3 in
148+
let vbs1, e1 = loop ?ident_label expr1 in
149+
let vbs2, e2 = loop expr2 in
150+
let vbs3, e3 = loop expr3 in
147151
(reduce_vbss [ vbs1; vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
148152
| [%expr [%e? expr1] [%e? expr2]] ->
149-
let vbs1, e1 = translate ?ident_label expr1 in
150-
let vbs2, e2 = translate expr2 in
153+
let vbs1, e1 = loop ?ident_label expr1 in
154+
let vbs2, e2 = loop expr2 in
151155
(Map.merge_skewed vbs1 vbs2 ~combine:(fun ~key:_ _v1 v2 -> v2), [%expr [%e e1] [%e e2]])
152156
| [%expr fun ~config -> [%e? body]] ->
153-
let vbs, body = translate ?ident_label body in
154-
( no_vbs,
155-
[%expr fun ~config -> [%e let_opt ~loc vbs body]] )
157+
let vbs, body = translate ~has_config:true ?ident_label body in
158+
(no_vbs, [%expr fun ~config -> [%e let_opt ~loc vbs body]])
156159
| [%expr fun [%p? pat] -> [%e? body]] ->
157-
let vbs, body = translate ?ident_label body in
160+
let vbs, body = loop ?ident_label body in
158161
(vbs, [%expr fun [%p pat] -> [%e body]])
159162
| [%expr
160163
while [%e? test_expr] do
161164
[%e? body_expr]
162165
done] ->
163-
let vbs, body = translate ?ident_label body_expr in
166+
let vbs, body = loop ?ident_label body_expr in
164167
( vbs,
165168
[%expr
166169
while [%e test_expr] do
@@ -170,7 +173,7 @@ let rec translate ?ident_label expr =
170173
for [%p? pat] = [%e? init] to [%e? final] do
171174
[%e? body_expr]
172175
done] ->
173-
let vbs, body = translate ?ident_label body_expr in
176+
let vbs, body = loop ?ident_label body_expr in
174177
( vbs,
175178
[%expr
176179
for [%p pat] = [%e init] to [%e final] do
@@ -180,7 +183,7 @@ let rec translate ?ident_label expr =
180183
for [%p? pat] = [%e? init] downto [%e? final] do
181184
[%e? body_expr]
182185
done] ->
183-
let vbs, body = translate ?ident_label body_expr in
186+
let vbs, body = loop ?ident_label body_expr in
184187
( vbs,
185188
[%expr
186189
for [%p pat] = [%e init] downto [%e final] do
@@ -189,49 +192,49 @@ let rec translate ?ident_label expr =
189192
| [%expr
190193
[%e? expr1];
191194
[%e? expr2]] ->
192-
let vbs1, e1 = translate expr1 in
193-
let vbs2, e2 = translate ?ident_label expr2 in
195+
let vbs1, e1 = loop expr1 in
196+
let vbs2, e2 = loop ?ident_label expr2 in
194197
( reduce_vbss [ vbs1; vbs2 ],
195198
[%expr
196199
[%e e1];
197200
[%e e2]] )
198201
| [%expr if [%e? expr1] then [%e? expr2] else [%e? expr3]] ->
199-
let vbs2, e2 = translate ?ident_label expr2 in
200-
let vbs3, e3 = translate ?ident_label expr3 in
202+
let vbs2, e2 = loop ?ident_label expr2 in
203+
let vbs3, e3 = loop ?ident_label expr3 in
201204
(reduce_vbss [ vbs2; vbs3 ], [%expr if [%e expr1] then [%e e2] else [%e e3]])
202205
| [%expr if [%e? expr1] then [%e? expr2]] ->
203-
let vbs2, e2 = translate ?ident_label expr2 in
206+
let vbs2, e2 = loop ?ident_label expr2 in
204207
(vbs2, [%expr if [%e expr1] then [%e e2]])
205208
| { pexp_desc = Pexp_match (expr1, cases); _ } ->
206209
let vbss, cases =
207210
List.unzip
208211
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
209-
let vbs, pc_rhs = translate ?ident_label pc_rhs in
212+
let vbs, pc_rhs = loop ?ident_label pc_rhs in
210213
(vbs, { c with pc_rhs }))
211214
in
212215
(reduce_vbss vbss, { expr with pexp_desc = Pexp_match (expr1, cases) })
213216
| { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
214217
let vbss1, bindings =
215218
List.unzip
216219
@@ List.map bindings ~f:(fun binding ->
217-
let vbs, pvb_expr = translate ~ident_label:binding.pvb_pat binding.pvb_expr in
220+
let vbs, pvb_expr = loop ~ident_label:binding.pvb_pat binding.pvb_expr in
218221
(vbs, { binding with pvb_expr }))
219222
in
220-
let vbs2, body = translate ?ident_label body in
223+
let vbs2, body = loop ?ident_label body in
221224
let all_bindings = (Map.data @@ reduce_vbss vbss1) @ bindings @ Map.data vbs2 in
222225
(no_vbs, { expr with pexp_desc = Pexp_let (recflag, all_bindings, body) })
223226
| { pexp_desc = Pexp_open (decl, body); _ } ->
224-
let vbs, body = translate ?ident_label body in
227+
let vbs, body = loop ?ident_label body in
225228
(vbs, { expr with pexp_desc = Pexp_open (decl, body) })
226229
| { pexp_desc = Pexp_letmodule (name, module_expr, body); _ } ->
227-
let vbs, body = translate ?ident_label body in
230+
let vbs, body = loop ?ident_label body in
228231
(vbs, { expr with pexp_desc = Pexp_letmodule (name, module_expr, body) })
229232
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_operator op_ident ->
230233
(no_vbs, [%expr [%e expr] ~label:[%e opt_pat2string_list ~loc ident_label]])
231234
| expr -> (no_vbs, expr)
232235

233236
let translate ?ident_label expr =
234-
let vbs, expr = translate ?ident_label expr in
237+
let vbs, expr = translate ~has_config:false ?ident_label expr in
235238
let loc = expr.pexp_loc in
236239
( vbs,
237240
match ident_label with

lib/syntax_extensions.md

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,18 @@
22

33
## Notes true for bogth `%cd` and `%op`
44

5+
### Wildcard bindings
6+
57
When an extension is over a wildcard (ignore result) binding: `let%cd _ = ...` and `let%op _ = ...`, the generated code is wrapped in `Tensor.with_unchanged_roots`, to prevent it from upsetting rootness checks. The use-case for writing `%op` and `%cd` notations with ignored result is to generate additional shape inference constraints.
68

9+
### Inline declarations
10+
11+
Both `%cd` and `%op` syntaxes support inline declarations of tensors. For `%op` these are differentiable, for `%cd` non-differentiable tensors. A declaration site uses the string syntax, the content of the string is the is bound to the newly created tensor, and the string itself functions equivalently to using the newly introduced identifier. The scope of the binding is the full scope of the extension point, even if the declaring string appeared in the body of a function that's inside the extension point scope (except for `%op` there is a special case of `~config` labeled argument discussed below). The first element of the label of the created tensor is the string that introduced it.
12+
13+
For `%cd`, the declaration is (currently) only allowed on the left-hand-side, i.e. in the assigned-to position, of an assignment. If possible, one of the tensors on the right-hand-side is picked to provide additional label information. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created.
14+
15+
For `%op`, the declaration is allowed anywhere. If there is a `~config` function parameter used inside the extension scope, for example as `fun ~config ... -> ...` or a more specific example `let%op mlp ~config x = ...`, the scope of an inline-declared tensor is no longer the full scope of the extension point. Instead, the tensor is defined right underneath the introduction of the `~config` parameter: `fun ~config -> let <definitions of the inline-declared tensors> in ...`. The config value passed to the generated code must be a record with at least a field `label : string list`. The inline-declared tensor that's defined under a `~config` parameter is defined as `TDSL.param ~more_label:config.label ...`
16+
717
## Syntax extension `%cd`, standing for "code", to express assignments: `Assignments.t`
818

919
### Implementation details
@@ -25,4 +35,4 @@ type expr_type =
2535
type projections_slot = LHS | RHS1 | RHS2 | Nonslot | Undet
2636
```
2737

28-
## Syntax extension `%op`, standing for "operation", to express tensors: `Tensor.t`
38+
## Syntax extension `%op`, standing for "operation", to express tensors: `Tensor.t`

lib/tensor.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -316,16 +316,16 @@ let ndarray ?(label = []) ?(grad_spec = Prohibit_grad) ?batch_dims ?input_dims ?
316316
Tn.update_memory_mode t.value Effectively_constant 24;
317317
t
318318

319-
let param ?input_dims ?output_dims ?input_axes ?output_axes ?deduced ?(strict = false) ?values label
320-
=
319+
let param ?(more_label = []) ?input_dims ?output_dims ?input_axes ?output_axes ?deduced
320+
?(strict = false) ?values label =
321321
let init_op =
322322
match values with
323323
| Some values -> Arrayjit.Ops.Constant_fill { values; strict }
324324
| None -> Standard_uniform
325325
in
326326
let t =
327-
term ~label:[ label ] ~grad_spec:Require_grad ~batch_dims:[] ?input_dims ?output_dims
328-
?input_axes ?output_axes ?deduced ~init_op ()
327+
term ~label:(label :: more_label) ~grad_spec:Require_grad ~batch_dims:[] ?input_dims
328+
?output_dims ?input_axes ?output_axes ?deduced ~init_op ()
329329
in
330330
let v = t.value in
331331
(* It is convenient to use the param syntax for volatiles (mutable inputs). *)

lib/tensor.mli

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ val ndarray :
148148
over to populate the [value] node. *)
149149

150150
val param :
151+
?more_label:string list ->
151152
?input_dims:int list ->
152153
?output_dims:int list ->
153154
?input_axes:(string * int) list ->
@@ -158,7 +159,8 @@ val param :
158159
string ->
159160
t
160161
(* A tensor with no batch axes; input and output axes are by default inferred. [grad_spec] is set to
161-
[Require_grad]. *)
162+
[Require_grad]. The resulting tensor's label is the passed string, appended by [more_label] if
163+
any. *)
162164

163165
val iter_embedded_arrays : f:(tn -> unit) -> t -> unit
164166
val non_and_embedded_nodes : t -> (t, comparator_witness) Set.t * (t, comparator_witness) Set.t

0 commit comments

Comments
 (0)