Skip to content

Commit 66009f5

Browse files
lukstaficlaude
andcommitted
ppx: Fix application handling and add %oc anti-quotation with unit-parameter heuristic
This commit improves the %op and %cd syntax extensions: 1. **Fixed general application handling**: Replaced hardcoded 2-3 argument patterns with proper Pexp_apply handling that works for any number of labeled/unlabeled arguments. 2. **Added %oc anti-quotation**: Provides an escape hatch to preserve pure OCaml expressions within %op/%cd contexts without tensor/assignment transformation. 3. **Implemented unit-parameter heuristic in %op**: When a function application contains a unit () argument, all arguments before it are automatically preserved as OCaml expressions. This aligns with OCANNL's pattern where configuration parameters come before the lifting point. These changes make the syntax cleaner and more intuitive, eliminating most needs for explicit escaping while maintaining flexibility for edge cases. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 3a8d078 commit 66009f5

File tree

5 files changed

+103
-39
lines changed

5 files changed

+103
-39
lines changed

lib/ppx_cd.ml

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -817,6 +817,17 @@ let translate ?ident_label (expr : expression) : result =
817817
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] ~raw_body ()
818818
in
819819
match expr with
820+
| { pexp_desc = Pexp_extension ({ txt = "oc"; _ }, payload); _ } -> (
821+
(* %oc anti-quotation: preserve the expression without transformation *)
822+
match payload with
823+
| PStr [ { pstr_desc = Pstr_eval (expr, _); _ } ] -> { default_result with expr }
824+
| _ ->
825+
{
826+
default_result with
827+
expr =
828+
Ast_builder.Default.pexp_extension ~loc
829+
@@ Location.error_extensionf ~loc "%%oc expects a single expression";
830+
})
820831
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
821832
{ default_result with expr = [%expr NTDSL.number [%e expr]]; slot = Scalar }
822833
| { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } ->

lib/ppx_op.ml

Lines changed: 32 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,14 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
153153
let loc = expr.pexp_loc in
154154
let loop = translate ~num_configs ~is_toplevel:false ~opt_label in
155155
match expr with
156+
| { pexp_desc = Pexp_extension ({ txt = "oc"; _ }, payload); _ } -> (
157+
(* %oc anti-quotation: preserve the expression without transformation *)
158+
match payload with
159+
| PStr [ { pstr_desc = Pstr_eval (expr, _); _ } ] -> (no_vbs, expr)
160+
| _ ->
161+
( no_vbs,
162+
Ast_builder.Default.pexp_extension ~loc
163+
@@ Location.error_extensionf ~loc "%%oc expects a single expression" ))
156164
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
157165
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] [%e expr]])
158166
| { pexp_desc = Pexp_constant (Pconst_integer (_, Some ('L' | 'l'))); _ } ->
@@ -374,15 +382,30 @@ let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
374382
let vbs3, e3 = loop expr3 in
375383
let vbs4, e4 = loop expr4 in
376384
(reduce_vbss [ vbs2; vbs3; vbs4 ], [%expr [%e e1] [%e e2] [%e e3] [%e e4]])
377-
| [%expr [%e? expr1] [%e? expr2] [%e? expr3]] ->
378-
let vbs1, e1 = loop ?label expr1 in
379-
let vbs2, e2 = loop expr2 in
380-
let vbs3, e3 = loop expr3 in
381-
(reduce_vbss [ vbs1; vbs2; vbs3 ], [%expr [%e e1] [%e e2] [%e e3]])
382-
| [%expr [%e? expr1] [%e? expr2]] ->
383-
let vbs1, e1 = loop ?label expr1 in
384-
let vbs2, e2 = loop expr2 in
385-
(reduce_vbss [ vbs1; vbs2 ], [%expr [%e e1] [%e e2]])
385+
| { pexp_desc = Pexp_apply (fn_expr, args); _ } ->
386+
(* Smart application handling with unit-parameter heuristic:
387+
If there's a unit () argument, don't transform args before it *)
388+
let unit_position =
389+
List.find_mapi args ~f:(fun i (_, arg_expr) ->
390+
match arg_expr.pexp_desc with
391+
| Pexp_construct ({ txt = Lident "()"; _ }, None) -> Some i
392+
| _ -> None)
393+
in
394+
let vbs_fn, e_fn = loop ?label fn_expr in
395+
let vbs_args, processed_args =
396+
List.unzip
397+
@@ List.mapi args ~f:(fun i (arg_label, arg_expr) ->
398+
match unit_position with
399+
| Some unit_pos when i < unit_pos ->
400+
(* Before unit: preserve as OCaml expression *)
401+
(no_vbs, (arg_label, arg_expr))
402+
| _ ->
403+
(* After unit or no unit: transform *)
404+
let vbs, e = loop arg_expr in
405+
(vbs, (arg_label, e)))
406+
in
407+
let all_vbs = reduce_vbss (vbs_fn :: vbs_args) in
408+
(all_vbs, Ast_builder.Default.pexp_apply ~loc e_fn processed_args)
386409
| { pexp_desc = Pexp_function (args, constr, body); _ } when is_toplevel -> (
387410
(* Check if there's a unit parameter or a labeled parameter with label "label" *)
388411
let rec find_unit acc = function

lib/syntax_extensions.md

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
- Syntax extension `%cd` stands for "code", to express assignments and computations: `Assignments.comp`.
2626
- Syntax extension `%op` stands for "operation", to express tensors: `Tensor.t`.
2727
- Both extensions use record syntax `{ tensor_name }` or `{ tensor_name = init_expr }` for inline tensor declarations.
28+
- Anti-quotation `%oc` escapes expressions to preserve them as pure OCaml without transformation.
2829

2930
## Preliminaries
3031

@@ -41,6 +42,35 @@ Functions inside `Operation.NTDSL` use `~grad_spec:Prohibit_grad` when calling i
4142

4243
The extension points open `NTDSL.O`, resp. `TDSL.O`, for the scope of the extension point, to expose the corresponding operators.
4344

45+
### The %oc anti-quotation and the unit-parameter heuristic
46+
47+
Within `%op` and `%cd` contexts, expressions typically undergo transformation to build tensors or assignments. However, OCANNL uses two mechanisms to preserve pure OCaml expressions:
48+
49+
#### Unit-parameter heuristic (automatic in %op)
50+
51+
In the `%op` syntax, when a function application contains a unit `()` argument, all arguments appearing **before** the unit are automatically preserved as pure OCaml expressions. This aligns with OCANNL's design pattern where configuration happens before the unit parameter:
52+
53+
```ocaml
54+
(* Arguments before () are automatically preserved as OCaml *)
55+
let%op my_fn ~label x =
56+
other_fn ~label:(("prefix_" ^ name) :: label) ~config:value () x
57+
(* label and config are preserved; x after () is transformed *)
58+
```
59+
60+
#### Explicit %oc anti-quotation
61+
62+
For cases where you need explicit control or the heuristic doesn't apply, the `%oc` (mnemonic: "OCaml") anti-quotation escapes from the transformation context:
63+
64+
```ocaml
65+
(* Force preservation even after () or in edge cases *)
66+
let%op special = process_data data [%oc complex_ocaml_expr]
67+
```
68+
69+
The `%oc` extension expects a single expression and returns it unchanged. Use cases:
70+
- Overriding the unit-parameter heuristic when needed
71+
- Preserving expressions in contexts without a unit parameter
72+
- Escaping from the DSL in `%cd` contexts (which don't use the unit heuristic)
73+
4474
## Primitive operations
4575

4676
To accomodate stylistic preferences, OCANNL supports both curried and uncurried syntaxes for primitive operation application. Binary operators are associated with infix operators, in addition to having alphabetic identifiers. This stems from the following restriction: in the `%cd` syntax, the assignment is always an infix operator, and it needs to pick the accumulation operation.

test/ppx/test_ppx_op.ml

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,19 +20,19 @@ let z3 =
2020
[%op { hey9 } +* "is*a+d*bc;b=>iac" { hey10 }]
2121

2222
let () = ignore (y0, y1, y2, a, b, y, z, z2, z3)
23-
let%op mlp_layer ~label ~hid_dim () x = relu (({ w } * x) + { b; o = [ hid_dim ] })
23+
let%op mlp_layer ~label ~hid_dim () ~x = relu (({ w } * x) + { b; o = [ hid_dim ] })
2424

2525
let%op _use_layer =
2626
let l1 = mlp_layer ~label:[ "L" ] ~hid_dim:3 () in
2727
let l2 = mlp_layer ~label:[ "L2" ] ~hid_dim:3 () in
28-
fun x -> l1 (l2 x)
28+
fun x -> l1 ~x:(l2 ~x)
2929

3030
let%op _config_layer ~label () =
31-
let l = mlp_layer ~label:(label @ [ "L" ]) ~hid_dim:3 () in
32-
fun x -> l x
31+
let l = mlp_layer ~label:(label @ [ "L" ]) ~hid_dim:3 () in
32+
fun x -> l ~x
3333

3434
let%op _three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () =
35-
let l1 = mlp_layer ~label:(label @ [ "L1" ]) ~hid_dim:dim1 () in
36-
let l2 = mlp_layer ~label:(label @ [ "L2" ]) ~hid_dim:dim2 () in
37-
let l3 = mlp_layer ~label:(label @ [ "L3" ]) ~hid_dim:dim3 () in
38-
fun x -> l3 (l2 (l1 x))
35+
let l1 = mlp_layer ~label:(label @ [ "L1" ]) ~hid_dim:dim1 () in
36+
let l2 = mlp_layer ~label:(label @ [ "L2" ]) ~hid_dim:dim2 () in
37+
let l3 = mlp_layer ~label:(label @ [ "L3" ]) ~hid_dim:dim3 () in
38+
fun x -> l3 ~x:(l2 ~x:(l1 ~x))

test/ppx/test_ppx_op_expected.ml

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,30 @@ let y0 =
66
(TDSL.param ?more_label:None ?value:None ?values:None ?param_init:None
77
"hey1") () in
88
let open! TDSL.O in
9-
((+) ?label:(Some ["y0"]))
10-
((( *. ) ?label:None) (TDSL.number (Float.of_int 2)) hey1)
9+
(+) ?label:(Some ["y0"])
10+
(( *. ) ?label:None (TDSL.number (Float.of_int 2)) hey1)
1111
(TDSL.number (Float.of_int 3))
1212
let y1 =
1313
let hey2 =
1414
(TDSL.param ?more_label:None ?value:None ?values:None ?param_init:None
1515
"hey2") () in
1616
let open! TDSL.O in
1717
fun x ->
18-
((+) ?label:(Some
19-
(List.concat [["y1"]; (x.Tensor.value).Ir.Tnode.label])))
20-
((( * ) ?label:None) hey2 (TDSL.number (Float.of_int 2))) x
18+
(+) ?label:(Some
19+
(List.concat [["y1"]; (x.Tensor.value).Ir.Tnode.label]))
20+
(( * ) ?label:None hey2 (TDSL.number (Float.of_int 2))) x
2121
let y2 =
2222
let hey3 =
2323
(TDSL.param ?more_label:None ?value:None ?values:None ?param_init:None
2424
"hey3") () in
2525
let open! TDSL.O in
2626
fun x1 x2 ->
27-
((+) ?label:(Some
28-
(List.concat
29-
[["y2"];
30-
(x1.Tensor.value).Ir.Tnode.label;
31-
(x2.Tensor.value).Ir.Tnode.label])))
32-
((( *. ) ?label:None) x1 hey3) x2
27+
(+) ?label:(Some
28+
(List.concat
29+
[["y2"];
30+
(x1.Tensor.value).Ir.Tnode.label;
31+
(x2.Tensor.value).Ir.Tnode.label]))
32+
(( *. ) ?label:None x1 hey3) x2
3333
let a =
3434
let open! TDSL.O in
3535
((TDSL.ndarray
@@ -46,8 +46,8 @@ let y =
4646
(TDSL.param ?more_label:None ?value:None ?values:None ?param_init:None
4747
"hey4") () in
4848
let open! TDSL.O in
49-
((+) ?label:(Some ["y"]))
50-
((( * ) ?label:None) hey4 (TDSL.number ?label:None ~axis_label:"q" 2.0))
49+
(+) ?label:(Some ["y"])
50+
(( * ) ?label:None hey4 (TDSL.number ?label:None ~axis_label:"q" 2.0))
5151
(TDSL.number ?label:None ~axis_label:"p" 1.0)
5252
let z =
5353
let hey5 =
@@ -57,9 +57,9 @@ let z =
5757
(TDSL.param ?more_label:None ?value:None ?values:None ?param_init:None
5858
"hey6") () in
5959
let open! TDSL.O in
60-
((+) ?label:(Some ["z"]))
61-
((( * ) ?label:None) (TDSL.number ?label:None ~axis_label:"q" 2.0) hey5)
62-
((( * ) ?label:None) hey6 (TDSL.number ?label:None ~axis_label:"p" 1.0))
60+
(+) ?label:(Some ["z"])
61+
(( * ) ?label:None (TDSL.number ?label:None ~axis_label:"q" 2.0) hey5)
62+
(( * ) ?label:None hey6 (TDSL.number ?label:None ~axis_label:"p" 1.0))
6363
let stride = 2
6464
and dilation = 3
6565
let z2 =
@@ -98,21 +98,21 @@ let mlp_layer =
9898
and w =
9999
(TDSL.param ?more_label:(Some label) ?value:None ?values:None
100100
?param_init:None "w") () in
101-
fun x ->
102-
(relu ?label:(Some ["mlp_layer"]))
103-
(((+) ?label:None) ((( * ) ?label:None) w x) b)
101+
fun ~x ->
102+
relu ?label:(Some ["mlp_layer"])
103+
((+) ?label:None (( * ) ?label:None w x) b)
104104
let _use_layer =
105105
let open! TDSL.O in
106106
let l1 = mlp_layer ~label:["L"] ~hid_dim:3 () in
107-
let l2 = mlp_layer ~label:["L2"] ~hid_dim:3 () in fun x -> l1 (l2 x)
107+
let l2 = mlp_layer ~label:["L2"] ~hid_dim:3 () in fun x -> l1 ~x:(l2 ~x)
108108
let _config_layer =
109109
let open! TDSL.O in
110110
fun ~label () ->
111-
let l = mlp_layer ~label:(label @ ["L"]) ~hid_dim:3 () in fun x -> l x
111+
let l = mlp_layer ~label:(label @ ["L"]) ~hid_dim:3 () in fun x -> l ~x
112112
let _three_layer_perceptron =
113113
let open! TDSL.O in
114114
fun ~label ~dim1 ~dim2 ~dim3 () ->
115115
let l1 = mlp_layer ~label:(label @ ["L1"]) ~hid_dim:dim1 () in
116116
let l2 = mlp_layer ~label:(label @ ["L2"]) ~hid_dim:dim2 () in
117117
let l3 = mlp_layer ~label:(label @ ["L3"]) ~hid_dim:dim3 () in
118-
fun x -> l3 (l2 (l1 x))
118+
fun x -> l3 ~x:(l2 ~x:(l1 ~x))

0 commit comments

Comments
 (0)