Skip to content

Commit e05af9c

Browse files
committed
Migrate to ppxlib 0.36: ppx_cd and cleanup
1 parent d857a40 commit e05af9c

File tree

7 files changed

+111
-46
lines changed

7 files changed

+111
-46
lines changed

arrayjit.opam

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ depends: [
2525
"sexplib"
2626
"num"
2727
"saturn_lockfree" {>= "0.5.0"}
28-
"ppxlib"
28+
"ppxlib" {>= "0.36.0"}
2929
"ppx_compare"
3030
"ppx_hash"
3131
"ppx_here"

dune-project

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@
5656
num
5757
(saturn_lockfree
5858
(>= 0.5.0))
59-
ppxlib
59+
(ppxlib (>= 0.36.0))
6060
ppx_compare
6161
ppx_hash
6262
ppx_here
@@ -116,7 +116,7 @@
116116
curl
117117
time_now
118118
camlzip
119-
ppxlib
119+
(ppxlib (>= 0.36.0))
120120
ppx_compare
121121
ppx_fields_conv
122122
ppx_hash

lib/ppx_cd.ml

Lines changed: 88 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -365,6 +365,31 @@ let args_for ~loc = function
365365

366366
let reduce_res_vbs rs = reduce_vbss @@ List.map rs ~f:(fun r -> r.vbs)
367367

368+
(** Helper function to handle cases (for Pexp_match, Pexp_function with cases, etc.) *)
369+
let handle_cases ~bad_pun_hints ~proj_in_scope transl cases =
370+
let fields, transformed_cases =
371+
List.unzip
372+
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
373+
let res = transl ~bad_pun_hints ~proj_in_scope pc_rhs in
374+
((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
375+
in
376+
let vbss, typs, slots = List.unzip3 fields in
377+
(* TODO: make the inference of typ and slot more strict by detecting mismatches. *)
378+
let typ = Option.value ~default:Unknown @@ List.find typs ~f:(Fn.non is_unknown) in
379+
let slot =
380+
Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) slots
381+
in
382+
let loc = (List.hd_exn cases).pc_lhs.ppat_loc in
383+
( transformed_cases,
384+
{
385+
vbs = reduce_vbss vbss;
386+
typ;
387+
slot;
388+
expr = [%expr ()];
389+
(* This will be replaced by the caller *)
390+
array_opt_of_code = None;
391+
} )
392+
368393
let translate (expr : expression) : result =
369394
let punned = Hashtbl.create (module String) in
370395
let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression) : result =
@@ -736,15 +761,26 @@ let translate (expr : expression) : result =
736761
| { pexp_desc = Pexp_ident { txt = Lident op_ident; _ }; _ } when is_primitive_op op_ident ->
737762
default_result
738763
| [%expr !.[%e? expr1]] ->
739-
(* Hardcoding these two patterns to improve projection derivation expressivity. *)
740-
let res1 = loop ~proj_in_scope expr1 in
741-
{ res1 with typ = Tensor; slot = Scalar; expr = [%expr NTDSL.O.( !. ) [%e res1.expr]] }
764+
(* Hardcoding these two patterns (!. and !..) to improve projection derivation expressivity
765+
and avoid treating the constants as already tensors. *)
766+
{
767+
typ = Tensor;
768+
slot = Scalar;
769+
expr = [%expr NTDSL.O.( !. ) [%e expr1]];
770+
array_opt_of_code = None;
771+
vbs = no_vbs;
772+
}
742773
| [%expr !..[%e? expr1]] ->
743-
let res1 = loop ~proj_in_scope expr1 in
744-
{ res1 with typ = Tensor; slot = Scalar; expr = [%expr NTDSL.O.( !.. ) [%e res1.expr]] }
774+
{
775+
typ = Tensor;
776+
slot = Scalar;
777+
expr = [%expr NTDSL.O.( !.. ) [%e expr1]];
778+
array_opt_of_code = None;
779+
vbs = no_vbs;
780+
}
745781
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
746-
(* FIXME: `**.` should take a tensor and require that it's a literal. *)
747-
(* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)
782+
(* We need to hardcode these two patterns (for **. ) to prevent the numbers from
783+
being converted to tensors. *)
748784
let res1 = loop ~proj_in_scope expr1 in
749785
{
750786
res1 with
@@ -1135,17 +1171,49 @@ let translate (expr : expression) : result =
11351171
expr = [%expr [%e res1.expr] [%e res2.expr]];
11361172
array_opt_of_code = None;
11371173
}
1138-
| { pexp_desc = Pexp_fun ((arg_label : arg_label), arg, pat, expr1); _ } as expr ->
1174+
| { pexp_desc = Pexp_function (args, constr, body); _ } as expr ->
11391175
let proj_in_scope =
11401176
proj_in_scope
1141-
||
1142-
match arg_label with
1143-
| (Labelled s | Optional s) when String.equal s "projections" -> true
1144-
| _ -> false
1177+
|| List.exists args ~f:(function
1178+
| { pparam_desc = Pparam_val ((Labelled s | Optional s), _, _); _ }
1179+
when String.equal s "projections" ->
1180+
true
1181+
| _ -> false)
1182+
in
1183+
let bad_pun_hints =
1184+
Set.union_list (module String)
1185+
@@ bad_pun_hints
1186+
:: List.map args ~f:(fun arg ->
1187+
match arg.pparam_desc with
1188+
| Pparam_val (_, _, pat) -> collect_pat_idents pat
1189+
| _ -> Set.empty (module String))
1190+
in
1191+
let result =
1192+
match body with
1193+
| Pfunction_body body ->
1194+
let res = transl ~bad_pun_hints ~proj_in_scope body in
1195+
{
1196+
res with
1197+
expr =
1198+
{ expr with pexp_desc = Pexp_function (args, constr, Pfunction_body res.expr) };
1199+
}
1200+
| Pfunction_cases (cases, loc, attrs) ->
1201+
let transformed_cases, cases_result =
1202+
handle_cases ~bad_pun_hints ~proj_in_scope
1203+
(fun ~bad_pun_hints ~proj_in_scope -> transl ~bad_pun_hints ~proj_in_scope)
1204+
cases
1205+
in
1206+
{
1207+
cases_result with
1208+
expr =
1209+
{
1210+
expr with
1211+
pexp_desc =
1212+
Pexp_function (args, constr, Pfunction_cases (transformed_cases, loc, attrs));
1213+
};
1214+
}
11451215
in
1146-
let bad_pun_hints = Set.union bad_pun_hints @@ collect_pat_idents pat in
1147-
let res1 = transl ~bad_pun_hints ~proj_in_scope expr1 in
1148-
{ res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
1216+
result
11491217
| [%expr
11501218
while [%e? _test_expr] do
11511219
[%e? _body]
@@ -1222,26 +1290,13 @@ let translate (expr : expression) : result =
12221290
array_opt_of_code = res2.array_opt_of_code;
12231291
}
12241292
| { pexp_desc = Pexp_match (expr1, cases); _ } ->
1225-
let fields, cases =
1226-
List.unzip
1227-
@@ List.map cases ~f:(fun ({ pc_rhs; _ } as c) ->
1228-
let res = loop ~proj_in_scope pc_rhs in
1229-
((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
1293+
let transformed_cases, cases_result =
1294+
handle_cases ~bad_pun_hints ~proj_in_scope transl cases
12301295
in
1231-
let vbss, typs, slots = List.unzip3 fields in
1232-
let typ = Option.value ~default:Unknown @@ List.find typs ~f:(Fn.non is_unknown) in
1233-
let slot =
1234-
Option.value ~default:Undet @@ List.find ~f:(function Undet -> false | _ -> true) slots
1235-
in
1236-
{
1237-
vbs = reduce_vbss vbss;
1238-
typ;
1239-
slot;
1240-
expr = { expr with pexp_desc = Pexp_match (expr1, cases) };
1241-
array_opt_of_code = None;
1242-
}
1296+
{ cases_result with expr = { expr with pexp_desc = Pexp_match (expr1, transformed_cases) } }
12431297
| { pexp_desc = Pexp_let (_recflag, _bindings, _body); _ } ->
1244-
(* TODO(80): to properly support local bindings, we need to collect the type environment. *)
1298+
(* TODO(#80): to properly support local bindings, we need to collect the type
1299+
environment. *)
12451300
{
12461301
default_result with
12471302
typ = Unknown;

lib/ppx_op.ml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -143,9 +143,12 @@ let rec translate ~num_configs ~is_toplevel ~has_config ?label expr =
143143
| { pexp_desc = Pexp_array _; _ }
144144
| { pexp_desc = Pexp_construct ({ txt = Lident "::"; _ }, _); _ } ->
145145
(no_vbs, ndarray_op ?label expr)
146+
| [%expr !.[%e? expr1]] ->
147+
(* Hardcoding the patterns for (!.), (!..), and ( **. ) to avoid treating the constants as
148+
already tensors. *)
149+
(no_vbs, [%expr TDSL.O.( !. ) [%e expr1]])
150+
| [%expr !..[%e? expr1]] -> (no_vbs, [%expr TDSL.O.( !.. ) [%e expr1]])
146151
| [%expr [%e? expr1] **. [%e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
147-
(* We need to hardcode these two patterns to prevent the numbers from being converted to
148-
tensors. *)
149152
let vbs, e1 = loop expr1 in
150153
(vbs, [%expr TDSL.O.( **. ) ?label:[%e opt_expr ~loc label] [%e e1] (Float.of_int [%e i])])
151154
| [%expr [%e? expr1] **. [%e? expr2]] ->

lib/train.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -256,7 +256,7 @@ let%track3_sexp parallel_update (type buffer_ptr dev runner event)
256256
and type dev = dev
257257
and type runner = runner
258258
and type event = event) ~(grad_updates : Backend.context BT.routine array)
259-
~(sgd_update : Backend.context BT.routine) ~copy_to_merge ~post_sync updaten : unit -> unit =
259+
~(sgd_update : Backend.context BT.routine) ~copy_to_merge ~post_sync updaten =
260260
assert (not @@ Array.is_empty grad_updates);
261261
let num_streams : int = Array.length grad_updates in
262262
let bindings : Idx.static_symbol list = List.map ~f:fst sgd_update.bindings in

neural_nets_lib.opam

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ depends: [
2828
"curl"
2929
"time_now"
3030
"camlzip"
31-
"ppxlib"
31+
"ppxlib" {>= "0.36.0"}
3232
"ppx_compare"
3333
"ppx_fields_conv"
3434
"ppx_hash"

test_ppx/test_ppx_op_expected.ml

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,19 @@ let y1 =
1111
let hey2 = TDSL.param ?values:None "hey2" in
1212
let open! TDSL.O in
1313
fun x ->
14-
((+) ?label:(Some (["y1"] @ (x.Tensor.value).Ir.Tnode.label)))
14+
((+) ?label:(Some
15+
(List.concat [["y1"]; (x.Tensor.value).Ir.Tnode.label])))
1516
((( * ) ?label:None) hey2 (TDSL.number (Float.of_int 2))) x
1617
let y2 =
1718
let hey3 = TDSL.param ?values:None "hey3" in
1819
let open! TDSL.O in
19-
fun x1 ->
20-
fun x2 ->
21-
((+) ?label:(Some (["y2"] @ (x1.Tensor.value).Ir.Tnode.label)))
22-
((( *. ) ?label:None) x1 hey3) x2
20+
fun x1 x2 ->
21+
((+) ?label:(Some
22+
(List.concat
23+
[["y2"];
24+
(x1.Tensor.value).Ir.Tnode.label;
25+
(x2.Tensor.value).Ir.Tnode.label])))
26+
((( *. ) ?label:None) x1 hey3) x2
2327
let a =
2428
let open! TDSL.O in
2529
TDSL.ndarray ?label:(Some ["a"]) ~batch_dims:[] ~input_dims:[3]
@@ -56,7 +60,10 @@ let mlp_layer =
5660
"b"
5761
and w = (TDSL.param ~more_label:(config.label)) ?values:None "w" in
5862
fun x ->
59-
(relu ?label:(Some (["mlp_layer"] @ (x.Tensor.value).Ir.Tnode.label)))
63+
(relu
64+
?label:(Some
65+
(List.concat
66+
[["mlp_layer"]; (x.Tensor.value).Ir.Tnode.label])))
6067
(((+) ?label:None) ((( * ) ?label:None) w x) b)
6168
let _use_layer =
6269
let config_block__0 = mlp_layer ~config:{ label = ["L2"]; hid_dim = 3 }

0 commit comments

Comments
 (0)