Skip to content

Commit db42dc4

Browse files
committed
Enhance %cd syntax to allow inline tensor declarations in standalone expressions. Adjust related errors, comments and documentation for consistency.
1 parent a75dce1 commit db42dc4

File tree

4 files changed

+49
-28
lines changed

4 files changed

+49
-28
lines changed

lib/ppx_cd.ml

Lines changed: 29 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,7 @@ let assignment ~punned ~lhs ~rhses body =
111111
| _ -> (no_vbs, body)
112112
in
113113
let body =
114+
(* Note: this is not a binding from an inline declaration, it's a temporary binding. *)
114115
if Option.is_some lhs.vb then
115116
Ast_builder.Default.pexp_extension ~loc
116117
@@ Location.error_extensionf ~loc
@@ -200,7 +201,6 @@ let empty_comp ~loc = [%expr { asgns = Ir.Assignments.Noop; embedded_nodes = [%e
200201

201202
let setup_array ~punned ~bad_pun_hints ~is_lhs
202203
{ typ = filler_typ; slot; expr = filler; vbs; array_opt_of_code } =
203-
assert (Map.is_empty vbs);
204204
let loc = filler.pexp_loc in
205205
let opt_buffer tn =
206206
if is_lhs then [%expr Some [%e tn]] else [%expr Some (Ir.Assignments.Node [%e tn])]
@@ -220,17 +220,18 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
220220
pun_hint_tnode;
221221
}
222222
in
223-
match filler_typ with
224-
| No_grad_tensor_intro _ when not is_lhs ->
223+
match (Map.is_empty vbs, filler_typ) with
224+
| (false, _ | _, No_grad_tensor_intro _) when not is_lhs ->
225225
{
226226
default_setup with
227227
array_opt =
228228
Ast_builder.Default.pexp_extension ~loc
229229
@@ Location.error_extensionf ~loc
230-
"ppx_ocannl %%cd: punning is only allowed in the assigned-to position";
230+
"ppx_ocannl %%cd: inline tensor declarations are not allowed in assignment \
231+
right-hand side, to prevent over-use in locations with less label information";
231232
}
232-
| (Tensor | Unknown) when match filler with { pexp_desc = Pexp_ident _; _ } -> true | _ -> false
233-
->
233+
| _, (Tensor | Unknown)
234+
when match filler with { pexp_desc = Pexp_ident _; _ } -> true | _ -> false ->
234235
let t = filler in
235236
let fwd_code_or_noop =
236237
Some
@@ -241,7 +242,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
241242
else [%e empty_comp ~loc]]
242243
in
243244
{ default_setup with fwd_code_or_noop; tensor = Some t }
244-
| Value_of_tensor ({ pexp_desc = Pexp_ident _; _ } as t) ->
245+
| _, Value_of_tensor ({ pexp_desc = Pexp_ident _; _ } as t) ->
245246
let fwd_code_or_noop =
246247
Some
247248
[%expr
@@ -256,7 +257,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
256257
array_opt = opt_buffer [%expr [%e t].Tensor.value];
257258
tensor = Some t;
258259
}
259-
| Value_of_tensor t ->
260+
| _, Value_of_tensor t ->
260261
{
261262
default_setup with
262263
array_opt =
@@ -266,7 +267,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
266267
identifier";
267268
tensor = Some t;
268269
}
269-
| Tensor | Unknown ->
270+
| _, (Tensor | Unknown) ->
270271
(* Need to bind the expression computing the tensor so we don't recompute it. *)
271272
let v =
272273
match slot with
@@ -293,7 +294,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
293294
array_opt = opt_buffer [%expr [%e t].Tensor.value];
294295
tensor = Some t;
295296
}
296-
| No_grad_tensor_intro _ ->
297+
| _, No_grad_tensor_intro _ ->
297298
(* Inline tensors are guaranteed to be leaf tensors, so they don't have forward code, but they
298299
are embedded. *)
299300
let fwd_code_or_noop =
@@ -306,7 +307,7 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
306307
}]
307308
in
308309
{ default_setup with fwd_code_or_noop; tensor = Some filler }
309-
| Code when Option.is_none array_opt_of_code ->
310+
| _, Code when Option.is_none array_opt_of_code ->
310311
{
311312
default_setup with
312313
fwd_code_or_noop = Some filler;
@@ -315,16 +316,16 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
315316
@@ Location.error_extensionf ~loc
316317
"ppx_ocannl %%cd: could not determine a lead array of provided code";
317318
}
318-
| Code ->
319+
| _, Code ->
319320
{
320321
default_setup with
321322
fwd_code_or_noop = Some filler;
322323
array_opt = buffer (Option.value_exn array_opt_of_code);
323324
}
324-
| Array -> { default_setup with array_opt = opt_buffer filler }
325-
| Grad_of_tensor ({ pexp_desc = Pexp_ident _; _ } as t) ->
325+
| _, Array -> { default_setup with array_opt = opt_buffer filler }
326+
| _, Grad_of_tensor ({ pexp_desc = Pexp_ident _; _ } as t) ->
326327
{ default_setup with array_opt = buffer filler; tensor = Some t }
327-
| Grad_of_tensor t ->
328+
| _, Grad_of_tensor t ->
328329
{
329330
default_setup with
330331
array_opt =
@@ -334,16 +335,16 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
334335
identifier";
335336
tensor = Some t;
336337
}
337-
| (Merge_value _ | Merge_grad _) when is_lhs ->
338+
| _, (Merge_value _ | Merge_grad _) when is_lhs ->
338339
{
339340
default_setup with
340341
array_opt =
341342
Ast_builder.Default.pexp_extension ~loc
342343
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: merge buffers cannot be assigned to";
343344
}
344-
| Merge_value t ->
345+
| _, Merge_value t ->
345346
{ default_setup with array_opt = [%expr Some (Merge_buffer [%e filler])]; tensor = Some t }
346-
| Merge_grad t ->
347+
| _, Merge_grad t ->
347348
{
348349
default_setup with
349350
array_opt = [%expr Option.map [%e filler] ~f:(fun tn -> Ir.Assignments.Merge_buffer tn)];
@@ -440,7 +441,7 @@ let translate (expr : expression) : result =
440441
@@ Location.error_extensionf ~loc
441442
"ppx_ocannl %%cd: expected a ternary operator, one of: where, fma" ))
442443
in
443-
(* FIXME: collapse these (code reuse) *)
444+
(* TODO: collapse these (code reuse) *)
444445
let process_assign_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ?projections ~proj_in_scope
445446
() =
446447
let initialize_neutral, accu_op = assignment_op accu_op in
@@ -707,10 +708,16 @@ let translate (expr : expression) : result =
707708
slot = Scalar;
708709
}
709710
| { pexp_desc = Pexp_constant (Pconst_string (name, str_loc, _)); _ } ->
711+
(* TODO: consider passing toplevel binding name as a hint label *)
712+
let vbs =
713+
Map.singleton (module String) name @@ make_vb ~loc ~name ~name_expr:expr ~hint_label:None
714+
in
710715
{
711-
default_result with
716+
vbs;
712717
typ = No_grad_tensor_intro { name; name_expr = expr };
713718
expr = A.Exp.ident ~loc:str_loc { txt = Lident name; loc = str_loc };
719+
array_opt_of_code = None;
720+
slot = Undet;
714721
}
715722
| { pexp_desc = Pexp_array _; _ }
716723
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
@@ -852,9 +859,8 @@ let translate (expr : expression) : result =
852859
}
853860
| Code ->
854861
{
855-
default_result with
862+
res1 with
856863
typ = Array;
857-
slot = res1.slot;
858864
expr =
859865
Ast_builder.Default.pexp_extension ~loc
860866
@@ Location.error_extensionf ~loc
@@ -878,7 +884,7 @@ let translate (expr : expression) : result =
878884
@@ Location.error_extensionf ~loc
879885
"ppx_ocannl %%cd: only tensor nodes (e.g. `.value` or `.grad`) can be merged";
880886
}
881-
| Grad_of_tensor t -> { res1 with vbs = no_vbs; typ = Merge_grad t }
887+
| Grad_of_tensor t -> { res1 with typ = Merge_grad t }
882888
| Merge_value _ | Merge_grad _ ->
883889
{
884890
res1 with

lib/ppx_shared.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ let reduce_vbss = List.reduce_exn ~f:(Map.merge_skewed ~combine:(fun ~key:_ _v1
241241
let expr_expander_with_punning translate ~loc ~path:_ payload =
242242
match payload with
243243
| { pexp_desc = Pexp_let (recflag, bindings, body); _ } ->
244-
(* We are at the %op annotation level: do not tranlsate the body. *)
244+
(* We are at the %op/%cd annotation level: do not tranlsate the body. *)
245245
let vbss, bindings =
246246
List.unzip
247247
@@ List.map bindings ~f:(fun vb ->

lib/syntax_extensions.md

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ When an extension is over a wildcard (ignore result) binding: `let%cd _ = ...` a
255255

256256
Both `%cd` and `%op` syntaxes support inline declarations of tensors. For `%op` these are differentiable, for `%cd` non-differentiable tensors. A declaration site uses the string syntax, the content of the string is the is bound to the newly created tensor, and the string itself functions equivalently to using the newly introduced identifier. The scope of the binding is the full scope of the extension point, even if the declaring string appeared in the body of a function that's inside the extension point scope (except for `%op` there is a special case of `~config` labeled argument discussed below). The first element of the label of the created tensor is the string that introduced it.
257257

258-
For `%cd`, the declaration is (currently) only allowed on the left-hand-side, i.e. in the assigned-to position, of an assignment. If possible, one of the tensors on the right-hand-side is picked to provide additional label information. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created. Example showing two tensor nodes declared inline, both of them include the label of the param `p` in their labels:
258+
For `%cd`, inline declarations are allowed both in the assigned-to position (left-hand side) of assignments and in standalone tensor expressions. When used in assignments, one of the tensors on the right-hand-side is picked to provide additional label information if possible. In particular, tensors that are function parameters inside the scope of the extension point, cannot be picked to provide label information, as they would escape their scope at the point the tensor is created. Inline declarations are still prohibited within the right-hand side of assignments to discourage over-use in locations with less label information. Example showing two tensor nodes declared inline, both of them include the label of the param `p` in their labels:
259259

260260
```ocaml
261261
let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov = false) p =
@@ -267,6 +267,21 @@ let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov =
267267
p =- learning_rate *. sgd_delta]
268268
```
269269

270+
Inline declarations can also be used outside of assignments for creating non-differentiable tensors, to mimic the behavior of `%op` but without the burden of initialization that a parameter would introduce:
271+
272+
```ocaml
273+
let%cd mlp_result = mlp "point" in
274+
let result_routine =
275+
Train.to_routine (module Backend) sgd_routine.context IDX.empty
276+
[%cd ~~("mlp infer"; mlp_result.forward)]
277+
in
278+
let callback (x, y) =
279+
Tn.set_values point [| x; y |];
280+
Train.run result_routine;
281+
Float.(mlp_result.@[0] >= 0.)
282+
in
283+
```
284+
270285
For `%op`, the declaration is allowed anywhere. If there is a `~config` function parameter used inside the extension scope, for example as `fun ~config ... -> ...` or a more specific example `let%op mlp ~config x = ...`, the scope of an inline-declared tensor is no longer the full scope of the extension point. Instead, the tensor is defined right underneath the introduction of the `~config` parameter: `fun ~config -> let <definitions of the inline-declared tensors> in ...`. The config value passed to the generated code must be a record with at least a field `label : string list`. The inline-declared tensor that's defined under a `~config` parameter is defined as `TDSL.param ~more_label:config.label ...` Example showing two param tensors declared inline, including `config.label` in their labels:
271286

272287
```ocaml

test/einsum/moons_demo_variant.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ let () =
5959
let%op learning_rate = 0.1 *. ((2 *. !..steps) - !@step_n) /. !..steps in
6060
Train.set_hosted learning_rate.value;
6161
let sgd = Train.sgd_update ~learning_rate ~weight_decay scalar_loss in
62-
let init = Train.to_routine (module Backend) ctx bindings init_params in
62+
let init_routine = Train.to_routine (module Backend) ctx bindings init_params in
6363
let sgd_routine =
64-
Train.to_routine (module Backend) init.context bindings (Asgns.sequence [ update; sgd ])
64+
Train.to_routine (module Backend) init_routine.context bindings (Asgns.sequence [ update; sgd ])
6565
in
6666
let step_ref = IDX.find_exn sgd_routine.bindings step_n in
6767
step_ref := 0;
68-
Train.run init;
68+
Train.run init_routine;
6969
for _epoch = 1 to epochs do
7070
Train.sequential_loop sgd_routine.bindings ~f:(fun () ->
7171
Train.run sgd_routine;

0 commit comments

Comments
 (0)