Skip to content

Commit d07e9fe

Browse files
committed
New syntax extension: %%extend_dsls
1 parent c381186 commit d07e9fe

File tree

7 files changed

+183
-1
lines changed

7 files changed

+183
-1
lines changed

docs/syntax_extensions.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
- [Label from function argument](#label-from-function-argument)
1919
- [Configuring inline declarations: inline output dimensions, initial values](#configuring-inline-declarations-inline-output-dimensions-initial-values)
2020
- [Lifting of the applications of config arguments: if an error, refactor your code](#lifting-of-the-applications-of-config-arguments-if-an-error-refactor-your-code)
21+
- [The syntax extension %%extend_dsls](#the-syntax-extension-extend_dsls)
2122
- [Implementation details](#implementation-details)
2223
- [The hard-coded to-the-power-of operator](#the-hard-coded-to-the-power-of-operator)
2324
- [Intricacies of the syntax extension %cd](#intricacies-of-the-syntax-extension-cd)
@@ -528,6 +529,10 @@ let mlp ~label ~hid_dims () =
528529
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
529530
```
530531

532+
## The syntax extension %%extend_dsls
533+
534+
This syntax extension creates a module `DSL_modules` with the same submodules as `Operation.DSL_modules`. It removes the boilerplate associated with introducing new operators into the modules `TDSL`, `NTDSL`, `PDSL` and their `O` submodules. The payload (i.e. content) of `%%extend_dsls` must be non-recursive let-bindings. They are parsed using a slight variant of the `%op` syntax, and are inserted into the DSL modules. One unique feature of `%%extend_dsls` parsing is that inline tensor definitions, like in `%cd`, do not introduce gradients for the tensors, but, like `%op`, they do introduce initialization for the inline-defined tensors.
535+
531536
## Implementation details
532537

533538
### The hard-coded to-the-power-of operator

tensor/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
(name ppx_ocannl)
3737
(public_name neural_nets_lib.ppx_ocannl)
3838
(kind ppx_rewriter)
39-
(modules ppx_shared ppx_cd ppx_op ppx_ocannl)
39+
(modules ppx_shared ppx_cd ppx_op ppx_extend_dsls ppx_ocannl)
4040
(libraries base ppxlib str ppx_arrayjit)
4141
(preprocess
4242
(pps ppxlib.metaquot)))

tensor/ppx_extend_dsls.ml

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,116 @@
1+
open Base
2+
open Ppxlib
3+
open Ppx_shared
4+
5+
(* Helper to transform a structure item by opening the DSL operators *)
6+
let transform_dsl_binding ~loc ~dsl_name binding =
7+
let transform_expr expr =
8+
let vbs, result =
9+
Ppx_op.translate ~no_grads_for_inline_defs:true
10+
@@ add_module_qualifier_to_applied_function ~module_name:dsl_name expr
11+
in
12+
if Map.is_empty vbs then result
13+
else
14+
Ast_builder.Default.pexp_extension ~loc
15+
@@ Location.error_extensionf ~loc
16+
"%%extend_dsls functions with inline definitions must take a unit parameter"
17+
in
18+
let params, pvb_expr =
19+
match binding.pvb_expr with
20+
| { pexp_desc = Pexp_function (params, _, _); _ } as expr -> (params, transform_expr expr)
21+
| _ -> ([], transform_expr binding.pvb_expr)
22+
in
23+
(params, { binding with pvb_expr })
24+
25+
(* Module-level expansion: create module bindings for TDSL, NTDSL, PDSL *)
26+
let str_expander ~loc:pstr_loc ~path:_ str_items =
27+
let transform_op_binding params binding =
28+
let loc = binding.pvb_loc in
29+
let label_p =
30+
{ pparam_loc = loc; pparam_desc = Pparam_val (Optional "label", None, [%pat? label]) }
31+
in
32+
let f = function
33+
| { pparam_desc = Pparam_val (label, _, pat); _ } -> (label, pat2expr pat)
34+
| _ -> assert false
35+
in
36+
let args = List.map params ~f in
37+
let body =
38+
Ast_helper.Exp.apply ~loc (pat2expr binding.pvb_pat)
39+
(args @ [ (Optional "label", [%expr label]); (Nolabel, [%expr ()]) ])
40+
in
41+
let pvb_expr =
42+
{
43+
binding.pvb_expr with
44+
pexp_desc = Pexp_function (label_p :: params, None, Pfunction_body body);
45+
}
46+
in
47+
{ binding with pvb_expr }
48+
in
49+
let items_for_dsl dsl_name =
50+
let item_bindings, op_item_bindings =
51+
List.unzip
52+
@@ List.concat_map str_items ~f:(function
53+
| { pstr_desc = Pstr_value (Nonrecursive, bindings); pstr_loc = loc; _ } ->
54+
List.map bindings ~f:(fun binding ->
55+
let params, binding = transform_dsl_binding ~loc ~dsl_name binding in
56+
let op_binding = transform_op_binding params binding in
57+
(binding, op_binding))
58+
| { pstr_loc = loc; _ } ->
59+
let pat = Ast_helper.Pat.var ~loc { txt = "syntax_error"; loc } in
60+
let v =
61+
Ast_builder.Default.pexp_extension ~loc
62+
@@ Location.error_extensionf ~loc
63+
"ppx_extend_dsls: currently only non-recursive value bindings are supported"
64+
in
65+
[ (Ast_helper.Vb.mk ~loc pat v, Ast_helper.Vb.mk ~loc pat v) ])
66+
in
67+
let item = { pstr_desc = Pstr_value (Nonrecursive, item_bindings); pstr_loc } in
68+
let op_item = { pstr_desc = Pstr_value (Nonrecursive, op_item_bindings); pstr_loc } in
69+
(item, op_item)
70+
in
71+
let loc = pstr_loc in
72+
let item_TDSL, op_item_TDSL = items_for_dsl "TDSL" in
73+
let item_NTDSL, op_item_NTDSL = items_for_dsl "NTDSL" in
74+
let item_PDSL, op_item_PDSL = items_for_dsl "PDSL" in
75+
[%stri
76+
module DSL_modules = struct
77+
module Ir = Ir
78+
module Shape = Shape
79+
module Tensor = Tensor
80+
81+
module TDSL = struct
82+
include TDSL
83+
84+
[%%i item_TDSL]
85+
86+
module O = struct
87+
include TDSL.O
88+
89+
[%%i op_item_TDSL]
90+
end
91+
end
92+
93+
module NTDSL = struct
94+
include NTDSL
95+
96+
[%%i item_NTDSL]
97+
98+
module O = struct
99+
include NTDSL.O
100+
101+
[%%i op_item_NTDSL]
102+
end
103+
end
104+
105+
module PDSL = struct
106+
include PDSL
107+
108+
[%%i item_PDSL]
109+
110+
module O = struct
111+
include PDSL.O
112+
113+
[%%i op_item_PDSL]
114+
end
115+
end
116+
end]

tensor/ppx_ocannl.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,10 @@ let rules =
1818
@@ Extension.declare "op" Extension.Context.structure_item
1919
Ast_pattern.(pstr __)
2020
Ppx_op.str_expander;
21+
Ppxlib.Context_free.Rule.extension
22+
@@ Extension.declare "extend_dsls" Extension.Context.structure_item
23+
Ast_pattern.(pstr __)
24+
Ppx_extend_dsls.str_expander;
2125
]
2226

2327
let () = Driver.register_transformation ~rules "ppx_ocannl"

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -314,6 +314,17 @@
314314
(preprocess
315315
(pps ppx_here ppx_ocannl ppx_sexp_conv)))
316316

317+
(test
318+
(name test_extend_dsls)
319+
(package neural_nets_lib)
320+
(deps
321+
ocannl_config
322+
(env_var OCANNL_BACKEND))
323+
(modules test_extend_dsls)
324+
(libraries base ocannl stdio)
325+
(preprocess
326+
(pps ppx_here ppx_ocannl ppx_expect)))
327+
317328
(library
318329
(name operations_tutorials)
319330
(package neural_nets_lib)
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
HERE: test/operations/test_extend_dsls.ml:23:21
4+
[23]: relu shape 0:$80 <not-in-yet>
5+
6+
HERE: test/operations/test_extend_dsls.ml:29:21
7+
[46]: relu shape 0:$161 <not-in-yet>
8+
9+
HERE: test/operations/test_extend_dsls.ml:35:21
10+
[75]: relu shape 0:$242 <not-in-yet>
11+
grad_relu <not-in-yet>
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
open! Base
2+
open Ocannl.Operation.DSL_modules
3+
4+
(* Test module-level expansion *)
5+
[%%extend_dsls
6+
let my_tensor x = TDSL.O.( + ) x !.2.0 + !.1.0
7+
8+
let my_complex_op a b c =
9+
let d = a * b in
10+
let e = d + c in
11+
relu e
12+
13+
let my_random_op () = { u1 = uniform () } - { u2 = uniform () }]
14+
15+
(* The above should create three modules: TDSL, NTDSL, and PDSL, each containing the functions with
16+
appropriate DSL operators *)
17+
18+
let () =
19+
(* Test that the functions are available in each DSL module *)
20+
let open! DSL_modules.TDSL.O in
21+
let x = uniform () in
22+
let result = my_complex_op (my_tensor x) (my_random_op ()) x in
23+
Tensor.print ~here:[%here] ~force:false ~with_code:false ~with_grad:true `Inline result
24+
25+
let () =
26+
let open! DSL_modules.NTDSL.O in
27+
let x = uniform () in
28+
let result = my_complex_op (my_tensor x) (my_random_op ()) x in
29+
Tensor.print ~here:[%here] ~force:false ~with_code:false ~with_grad:true `Inline result
30+
31+
let () =
32+
let open! DSL_modules.PDSL.O in
33+
let x = uniform () in
34+
let result = my_complex_op (my_tensor x) (my_random_op ()) x in
35+
Tensor.print ~here:[%here] ~force:false ~with_code:false ~with_grad:true `Inline result

0 commit comments

Comments
 (0)