Skip to content

Commit c8da5c4

Browse files
authored
Merge pull request #378 from ahrefs/fix-label-parameter-type
Fix ~label parameter type in %op syntax extension
2 parents 16b612b + 43255ed commit c8da5c4

File tree

5 files changed

+106
-127
lines changed

5 files changed

+106
-127
lines changed

lib/nn_blocks.ml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,11 @@ open! Base
55
module TDSL = Operation.TDSL
66
module NTDSL = Operation.NTDSL
77

8-
type mlp_layer_config = { label : string list; hid_dim : int }
8+
let%op mlp_layer ~label ~hid_dim () x = relu (({ w = uniform () } * x) + { b = 0.; o = [ hid_dim ] })
99

10-
let%op mlp_layer ~config x = relu (({ w = uniform () } * x) + { b = 0.; o = [ config.hid_dim ] })
11-
12-
type mlp_config = { label : string list; hid_dims : int list }
13-
14-
let mlp ~config =
10+
let mlp ~hid_dims =
1511
let layers =
16-
List.mapi config.hid_dims ~f:(fun i hid_dim ->
17-
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim })
12+
List.mapi hid_dims ~f:(fun i hid_dim ->
13+
mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim ())
1814
in
1915
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)

lib/ppx_op.ml

Lines changed: 53 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@ open Ppxlib
33
open Ppx_arrayjit.Ppx_helper
44
open Ppx_shared
55

6-
let make_p ~has_config ~loc ?value ?values ?param_init ~extra_args name =
7-
let more_label = if has_config then [%expr Some config.label] else [%expr None] in
6+
let make_p ~opt_label ~loc ?value ?values ?param_init ~extra_args name =
7+
let more_label = match opt_label with
8+
| Some (_label_name, label_pat) -> [%expr Some [%e pat2expr label_pat]]
9+
| None -> [%expr None] in
810
let value = match value with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
911
let values = match values with Some c -> [%expr Some [%e c]] | None -> [%expr None] in
1012
let param_init =
@@ -36,13 +38,13 @@ let make_p ~has_config ~loc ?value ?values ?param_init ~extra_args name =
3638
in
3739
[%expr [%e with_extra_args] ()]
3840

39-
let make_vb ~has_config ?value ?param_init ~extra_args ~loc name =
41+
let make_vb ~opt_label ?value ?param_init ~extra_args ~loc name =
4042
let pat = Ast_helper.Pat.var ~loc:name.loc name in
41-
let v = make_p ~has_config ~loc ?value ?param_init ~extra_args name in
43+
let v = make_p ~opt_label ~loc ?value ?param_init ~extra_args name in
4244
let vb = Ast_helper.Vb.mk ~loc pat v in
4345
(pat, vb)
4446

45-
let make_vb_nd ~has_config ~init_nd ~extra_args ~loc name =
47+
let make_vb_nd ~opt_label ~init_nd ~extra_args ~loc name =
4648
let pat = Ast_helper.Pat.var ~loc:name.loc name in
4749
let values, batch_dims, output_dims, input_dims = ndarray_constant init_nd in
4850
let v =
@@ -59,7 +61,7 @@ let make_vb_nd ~has_config ~init_nd ~extra_args ~loc name =
5961
:: ({ txt = Lident "output_dims"; loc }, output_dims_expr)
6062
:: extra_args
6163
in
62-
make_p ~has_config ~loc ~values ~extra_args name
64+
make_p ~opt_label ~loc ~values ~extra_args name
6365
in
6466
let vb = Ast_helper.Vb.mk ~loc pat v in
6567
(pat, vb)
@@ -80,9 +82,9 @@ let lift_config_vb ~loop ~num_configs ?label ~expr1 ~c_expr arg_exprs =
8082
| [ e2; e3 ] -> [%expr [%e pat2expr pat] [%e e2] [%e e3]]
8183
| _ -> assert false )
8284

83-
let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
85+
let rec translate ~num_configs ~is_toplevel ~opt_label ?label expr =
8486
let loc = expr.pexp_loc in
85-
let loop = translate ~num_configs ~is_toplevel:false ~has_config in
87+
let loop = translate ~num_configs ~is_toplevel:false ~opt_label in
8688
match expr with
8789
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->
8890
(no_vbs, [%expr TDSL.number ?label:[%e opt_expr ~loc label] [%e expr]])
@@ -141,7 +143,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
141143
match first_label.txt with
142144
| Lident tensor_name ->
143145
let name = { loc = first_label.loc; txt = tensor_name } in
144-
let pat, vb = make_vb ~has_config ~value ~extra_args ~loc name in
146+
let pat, vb = make_vb ~opt_label ~value ~extra_args ~loc name in
145147
(Map.singleton (module String) tensor_name vb, pat2expr pat)
146148
| _ ->
147149
( no_vbs,
@@ -160,7 +162,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
160162
| Lident tensor_name ->
161163
let value = [%expr Float.of_int [%e int_val]] in
162164
let name = { loc = first_label.loc; txt = tensor_name } in
163-
let pat, vb = make_vb ~has_config ~value ~extra_args ~loc name in
165+
let pat, vb = make_vb ~opt_label ~value ~extra_args ~loc name in
164166
(Map.singleton (module String) tensor_name vb, pat2expr pat)
165167
| _ ->
166168
( no_vbs,
@@ -181,7 +183,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
181183
match first_label.txt with
182184
| Lident tensor_name ->
183185
let name = { loc = first_label.loc; txt = tensor_name } in
184-
let pat, vb = make_vb_nd ~has_config ~init_nd ~extra_args ~loc name in
186+
let pat, vb = make_vb_nd ~opt_label ~init_nd ~extra_args ~loc name in
185187
(* Note: expect a type error if batch_dims exist or extra_args modify the shape *)
186188
(Map.singleton (module String) tensor_name vb, pat2expr pat)
187189
| _ ->
@@ -204,7 +206,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
204206
(vbs, Some e)
205207
in
206208
let name = { loc = first_label.loc; txt = tensor_name } in
207-
let pat, vb = make_vb ~has_config ?param_init ~extra_args ~loc name in
209+
let pat, vb = make_vb ~opt_label ?param_init ~extra_args ~loc name in
208210
(* Combine with any bindings from the initialization *)
209211
let all_vbs = Map.add_exn init_vbs ~key:tensor_name ~data:vb in
210212
(all_vbs, pat2expr pat)
@@ -258,29 +260,44 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
258260
let vbs1, e1 = loop ?label expr1 in
259261
let vbs2, e2 = loop expr2 in
260262
(reduce_vbss [ vbs1; vbs2 ], [%expr [%e e1] [%e e2]])
261-
| {
262-
pexp_desc =
263-
Pexp_function
264-
( ({ pparam_desc = Pparam_val (Labelled "config", c_e, c_pat); _ } as arg) :: args,
265-
constr,
266-
body );
267-
_;
268-
} ->
269-
let vbs, body =
270-
translate ~num_configs ~is_toplevel:true ~has_config:true ?label
271-
{ expr with pexp_desc = Pexp_function (args, constr, body) }
263+
| { pexp_desc = Pexp_function (args, constr, body); _ } when is_toplevel -> (
264+
(* Check if there's a unit parameter or a labeled parameter with label "label" *)
265+
let rec find_unit_pos idx = function
266+
| [] -> None
267+
| { pparam_desc = Pparam_val (Nolabel, _, pat); _ } :: _
268+
when match pat.ppat_desc with
269+
| Ppat_construct ({ txt = Lident "()"; _ }, None) -> true
270+
| _ -> false ->
271+
Some idx
272+
| _ :: rest -> find_unit_pos (idx + 1) rest
272273
in
273-
let body = let_opt ~loc vbs body in
274-
( no_vbs,
275-
{
276-
expr with
277-
pexp_desc =
278-
Pexp_function
279-
( [ { arg with pparam_desc = Pparam_val (Labelled "config", c_e, c_pat) } ],
280-
constr,
281-
Pfunction_body body );
282-
} )
283-
| { pexp_desc = Pexp_function (args, constr, body); _ } when is_toplevel ->
274+
let rec find_label_param = function
275+
| [] -> None
276+
| { pparam_desc = Pparam_val (Labelled "label", _, pat); _ } :: _ -> Some ("label", pat)
277+
| _ :: rest -> find_label_param rest
278+
in
279+
match find_unit_pos 0 args with
280+
| Some unit_idx ->
281+
(* Split args at unit parameter *)
282+
let before_unit, unit_and_after = List.split_n args unit_idx in
283+
let unit_param, after_unit = match unit_and_after with
284+
| unit :: rest -> (unit, rest)
285+
| [] -> failwith "Internal error: unit_and_after should not be empty" in
286+
let opt_label = find_label_param before_unit in
287+
let vbs, inner_body =
288+
translate ~num_configs ~is_toplevel:false ~opt_label ?label
289+
{ expr with pexp_desc = Pexp_function (after_unit, constr, body) }
290+
in
291+
let inner_body = let_opt ~loc vbs inner_body in
292+
(* The inner_body already has after_unit parameters processed, so use it directly *)
293+
let new_body = inner_body in
294+
( no_vbs,
295+
if List.is_empty before_unit then
296+
{ expr with pexp_desc = Pexp_function ([unit_param], constr, Pfunction_body new_body) }
297+
else
298+
{ expr with pexp_desc = Pexp_function (before_unit @ [unit_param], constr, Pfunction_body new_body) } )
299+
| None ->
300+
(* No unit parameter, normal processing *)
284301
let labels =
285302
Option.to_list label
286303
@ List.filter_map args ~f:(function
@@ -323,7 +340,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
323340
~f:(fun acc vbs -> Map.merge_disjoint_exn acc vbs),
324341
Pfunction_cases (cases, loc, attrs) )
325342
in
326-
(vbs, { expr with pexp_desc = Pexp_function (args, constr, body) })
343+
(vbs, { expr with pexp_desc = Pexp_function (args, constr, body) }) )
327344
| { pexp_desc = Pexp_function (args, constr, body); _ } ->
328345
let vbs, body =
329346
match body with
@@ -422,7 +439,7 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
422439

423440
let translate ?ident_label expr =
424441
let vbs, expr =
425-
translate ~num_configs:(ref 0) ~is_toplevel:true ~has_config:false
442+
translate ~num_configs:(ref 0) ~is_toplevel:true ~opt_label:None
426443
~label:(opt_pat2string_list ~loc:expr.pexp_loc ident_label)
427444
expr
428445
in

lib/syntax_extensions.md

Lines changed: 25 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -307,12 +307,10 @@ Inline declarations can also be used outside of assignments for creating non-dif
307307
in
308308
```
309309

310-
For `%op`, the declaration is allowed anywhere. If there is a `~config` function parameter used inside the extension scope, for example as `fun ~config ... -> ...` or a more specific example `let%op mlp ~config x = ...`, the scope of an inline-declared tensor is no longer the full scope of the extension point. Instead, the tensor is defined right underneath the introduction of the `~config` parameter: `fun ~config -> let <definitions of the inline-declared tensors> in ...`. The config value passed to the generated code must be a record with at least a field `label : string list`. The inline-declared tensor that's defined under a `~config` parameter is defined as `TDSL.param ~more_label:config.label ...` Example showing two param tensors declared inline, including `config.label` in their labels:
310+
For `%op`, the declaration is allowed anywhere. If there is a unit `()` parameter in the function, the scope of inline-declared tensors is delimited at that parameter. The tensors are defined right after the unit parameter. If there is a labeled parameter with label `label` before the unit parameter (e.g., `~label`), the inline-declared tensors will use that parameter (which should be of type `string list`) to enrich their labels. Example showing two param tensors declared inline, with scope delimited by `()` and labels enriched by the `label` parameter:
311311

312312
```ocaml
313-
type mlp_layer_config = { label : string list; hid_dim : int }
314-
315-
let%op mlp_layer ~config x = relu ({ w } * x + { b; o = [ config.hid_dim ] })
313+
let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })
316314
```
317315

318316
## Using OCANNL's generalized einsum notation
@@ -418,13 +416,15 @@ This syntax used to be very important, because comments in assignments are used
418416

419417
### Name from binding
420418

421-
When an extension point is applied to a let-binding, e.g. `let%op mlp_layer ~config x = relu ({ w } * x + { b; o = [ config.hid_dim ] })`, it uses the name of the binding (`mlp_layer` in the example) for the label of the primary tensor created by the extension, if any. This is why the resulting layer tensor in the example has its label starting with `"mlp_layer"`. If the extension is over a semicolon-separated sequence of expressions, the primary tensor can only be in the last component of the sequence, other syntax constructs are handled analogously.
419+
When an extension point is applied to a let-binding, e.g. `let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })`, it uses the name of the binding (`mlp_layer` in the example) for the label of the primary tensor created by the extension, if any. This is why the resulting layer tensor in the example has its label starting with `"mlp_layer"`. If the extension is over a semicolon-separated sequence of expressions, the primary tensor can only be in the last component of the sequence, other syntax constructs are handled analogously.
420+
421+
The example `let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })` also illustrates providing additional string list to populate the label of the tensor: `label` must be of type `string list`.
422422

423423
### Label from function argument
424424

425425
The resulting (primary) tensor's label will also have incorporated the label of the input argument, if any. In our example, the resulting `mlp_layer` tensor will also include the label of the actually applied `x`.
426426

427-
Note that we do not include `config.label`, even if `config` is available, because the actually applied input argument will typically have more specific information.
427+
Note that we do not include separate config labels, because the actually applied input argument will typically have more specific information.
428428

429429
### Configuring inline declarations: inline output dimensions, initial values
430430

@@ -449,59 +449,51 @@ A very simple example from [micrograd_demo: Micrograd README basic example](test
449449

450450
### Lifting of the applications of config arguments: if an error, refactor your code
451451

452-
If you recall, inline declared param tensors get lifted out of functions except for the function `fun ~config ->`, where they get defined. Our example `let%op mlp_layer ~config x = relu ({ w } * x + { b; o = [ config.hid_dim ] })` translates as:
452+
If you recall, inline declared param tensors get lifted out of functions to be defined at the point of a unit `()` parameter. Our example `let%op mlp_layer ~label ~hid_dim () x = relu ({ w } * x + { b; o = [ hid_dim ] })` translates as:
453453

454454
```ocaml
455-
let mlp_layer ~config =
456-
let w = TDSL.param ~more_label:config.label "w" ()
457-
and b = TDSL.param ~more_label:config.label ~output_dims:[ config.hid_dim ] "b" () in
455+
let mlp_layer ~label ~hid_dim () =
456+
let w = TDSL.param ~more_label:label "w" ()
457+
and b = TDSL.param ~more_label:label ~output_dims:[ hid_dim ] "b" () in
458458
fun x -> TDSL.O.(relu (w * x + b))
459459
```
460460

461461
For this to work properly, when employing such network blocks, their params also need to be introduced at the right moment. Therefore, the `%op` syntax ensures that this example:
462462

463463
```ocaml
464-
type tlp_config = { label : string list; dim1 : int; dim2 : int; dim3 : int }
465-
466-
let%op three_layer_perceptron ~config x =
467-
mlp_layer ~config:{ label = [ "L3" ]; hid_dim = config.dim3 }
468-
(mlp_layer ~config:{ label = [ "L2" ]; hid_dim = config.dim2 }
469-
(mlp_layer ~config:{ label = [ "L1" ]; hid_dim = config.dim1 } x))
464+
let%op three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () x =
465+
mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
466+
(mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
467+
(mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () x))
470468
```
471469

472470
gets expanded to:
473471

474472
```ocaml
475-
type tlp_config = { label : string list; dim1 : int; dim2 : int; dim3 : int }
476-
477-
let three_layer_perceptron ~config =
478-
let config_block__1 = mlp_layer ~config:{ label = [ "L3" ]; hid_dim = config.dim3 }
479-
and config_block__2 = mlp_layer ~config:{ label = [ "L2" ]; hid_dim = config.dim2 }
480-
and config_block__3 = mlp_layer ~config:{ label = [ "L1" ]; hid_dim = config.dim1 } in
473+
let three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () =
474+
let config_block__1 = mlp_layer ~label:[ "L3" ] ~hid_dim:dim3 ()
475+
and config_block__2 = mlp_layer ~label:[ "L2" ] ~hid_dim:dim2 ()
476+
and config_block__3 = mlp_layer ~label:[ "L1" ] ~hid_dim:dim1 () in
481477
fun x -> config_block__1 (config_block__2 (config_block__3 x))
482478
```
483479

484480
However, this raises a concern for more complex situations. Consider this code that fails to compile:
485481

486482
```ocaml
487-
type mlp_config = { label : string list; hid_dims : int list }
488-
489-
let%op mlp ~config x =
490-
List.foldi config.hid_dims ~init:x ~f:(fun i x hid_dim ->
491-
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim } x)
483+
let%op mlp ~label ~hid_dims () x =
484+
List.foldi hid_dims ~init:x ~f:(fun i x hid_dim ->
485+
mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim () x)
492486
```
493487

494488
The attempted lifting breaks because of the escaping variables `i` and `hid_dim`. This reminds us to rewrite the example, ensuring the proper introduction of params:
495489

496490
```ocaml
497-
type mlp_config = { label : string list; hid_dims : int list }
498-
499-
let mlp ~config =
491+
let mlp ~label ~hid_dims =
500492
let layers =
501-
List.mapi config.hid_dims ~f:(fun i hid_dim ->
502-
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim })
493+
List.mapi hid_dims ~f:(fun i hid_dim ->
494+
mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim ())
503495
in
504-
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
496+
fun () x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
505497
```
506498

507499
Unfortunately, we need to be mindful to introduce params at the right times.

test/ppx/test_ppx_op.ml

Lines changed: 8 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -21,21 +21,17 @@ let z3 =
2121

2222
let () = ignore (y0, y1, y2, a, b, y, z, z2, z3)
2323

24-
type mlp_layer_config = { label : string list; hid_dim : int }
25-
26-
let%op mlp_layer ~config x = relu (({ w } * x) + { b; o = [ config.hid_dim ] })
24+
let%op mlp_layer ~label ~hid_dim () x = relu (({ w } * x) + { b; o = [ hid_dim ] })
2725

2826
let%op _use_layer x =
29-
mlp_layer ~config:{ label = [ "L" ]; hid_dim = 3 }
30-
(mlp_layer ~config:{ label = [ "L2" ]; hid_dim = 3 } x)
31-
32-
let%op _config_layer ~config:_ x = mlp_layer ~config:{ label = [ "L" ]; hid_dim = 3 } x
27+
mlp_layer ~label:[ "L" ] ~hid_dim:3 ()
28+
(mlp_layer ~label:[ "L2" ] ~hid_dim:3 () x)
3329

34-
type tlp_config = { label : string list; dim1 : int; dim2 : int; dim3 : int }
30+
let%op _config_layer ~config:_ x = mlp_layer ~label:[ "L" ] ~hid_dim:3 () x
3531

36-
let%op _three_layer_perceptron ~(config : tlp_config) x =
32+
let%op _three_layer_perceptron ~label ~dim1 ~dim2 ~dim3 () x =
3733
mlp_layer
38-
~config:{ label = "L3" :: config.label; hid_dim = config.dim3 }
34+
~label:(label @ [ "L3" ]) ~hid_dim:dim3 ()
3935
(mlp_layer
40-
~config:{ label = "L2" :: config.label; hid_dim = config.dim2 }
41-
(mlp_layer ~config:{ label = "L1" :: config.label; hid_dim = config.dim1 } x))
36+
~label:(label @ [ "L2" ]) ~hid_dim:dim2 ()
37+
(mlp_layer ~label:(label @ [ "L1" ]) ~hid_dim:dim1 () x))

0 commit comments

Comments
 (0)