@@ -14,6 +14,7 @@ type expr_type =
1414 | Merge_value of expression
1515 | Merge_grad of expression
1616 | No_grad_tensor_intro of { name : string ; name_expr : expression }
17+ | Function
1718
1819let is_unknown = function Unknown -> true | _ -> false
1920
@@ -62,6 +63,8 @@ let make_vb ~loc ~name ~name_expr ~hint_label =
6263 in
6364 let vb = A.Vb. mk ~loc pat v in
6465 vb
66+ (* let make_code ~loc ~name ~name_expr ~hint_label code_expr = [%expr { asgns = [%e code_expr];
67+ embedded_nodes = Base.Set.empty (module Ir.Tnode) }] *)
6568
6669let reduce_embs_arr ~loc (rs : array_setup list ) =
6770 List. filter_map rs ~f: (fun hs -> hs.fwd_code_or_noop)
@@ -152,7 +155,7 @@ let guess_pun_hint ~no_filler_label ~punned ~bad_pun_hints filler_typ filler =
152155 let loc = filler.pexp_loc in
153156 let hint = [% expr [% e filler].Ir.Tnode. label] in
154157 match (filler_typ, filler, no_filler_label) with
155- | Code , _ , _ -> None
158+ | ( Code | Function ) , _ , _ -> None
156159 | _ , { pexp_desc = Pexp_ident { txt = Lident name ; _ } ; _ } , _ when Set. mem bad_pun_hints name ->
157160 None
158161 | Array , _ , false -> Some (hint, false )
@@ -296,6 +299,15 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
296299 }]
297300 in
298301 { (default_setup false ) with fwd_code_or_noop; tensor = Some filler }
302+ | _ , Function ->
303+ {
304+ (default_setup false ) with
305+ fwd_code_or_noop = Some filler;
306+ array_opt =
307+ Ast_builder.Default. pexp_extension ~loc
308+ @@ Location. error_extensionf ~loc
309+ " ppx_ocannl %%cd: a syntactic function in place of an array is not supported" ;
310+ }
299311 | _ , Code when Option. is_none array_opt_of_code ->
300312 {
301313 (default_setup false ) with
@@ -332,7 +344,11 @@ let setup_array ~punned ~bad_pun_hints ~is_lhs
332344 @@ Location. error_extensionf ~loc " ppx_ocannl %%cd: merge buffers cannot be assigned to" ;
333345 }
334346 | _ , Merge_value t ->
335- { (default_setup false ) with array_opt = [% expr Some (Merge_buffer [% e filler])]; tensor = Some t }
347+ {
348+ (default_setup false ) with
349+ array_opt = [% expr Some (Merge_buffer [% e filler])];
350+ tensor = Some t;
351+ }
336352 | _ , Merge_grad t ->
337353 {
338354 (default_setup false ) with
@@ -388,7 +404,7 @@ let handle_cases ~bad_pun_hints ~proj_in_scope transl cases =
388404 array_opt_of_code = None ;
389405 } )
390406
391- let translate (expr : expression ) : result =
407+ let translate ? ident_label (expr : expression ) : result =
392408 let punned = Hashtbl. create (module String ) in
393409 let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression ) : result =
394410 let loc = expr.pexp_loc in
@@ -778,9 +794,10 @@ let translate (expr : expression) : result =
778794 slot = Scalar ;
779795 }
780796 | { pexp_desc = Pexp_constant (Pconst_string (name , str_loc , _ )); _ } ->
781- (* TODO: consider passing toplevel binding name as a hint label *)
782797 let vbs =
783- Map. singleton (module String ) name @@ make_vb ~loc ~name ~name_expr: expr ~hint_label: None
798+ Map. singleton (module String ) name
799+ @@ make_vb ~loc ~name ~name_expr: expr
800+ ~hint_label: (Option. map ~f: (fun s -> [% expr [ [% e s] ]]) ident_label)
784801 in
785802 {
786803 vbs;
@@ -906,7 +923,7 @@ let translate (expr : expression) : result =
906923 @@ Location. error_extensionf ~loc
907924 " ppx_ocannl %%cd: write .grad.merge instead of .merge.grad" ;
908925 }
909- | Code | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_grad _ ->
926+ | Function | Code | Array | Value_of_tensor _ | Grad_of_tensor _ | Merge_grad _ ->
910927 {
911928 res1 with
912929 typ = Array ;
@@ -924,7 +941,7 @@ let translate (expr : expression) : result =
924941 typ = Value_of_tensor res1.expr;
925942 expr = [% expr [% e res1.expr].Tensor. value];
926943 }
927- | Code ->
944+ | Function | Code ->
928945 {
929946 res1 with
930947 typ = Array ;
@@ -942,7 +959,7 @@ let translate (expr : expression) : result =
942959 { res1 with typ = Merge_value res1.expr; expr = [% expr [% e res1.expr].Tensor. value] }
943960 | Value_of_tensor t ->
944961 { res1 with typ = Merge_value t; expr = [% expr [% e res1.expr].Tensor. value] }
945- | Array | Code ->
962+ | Function | Array | Code ->
946963 {
947964 res1 with
948965 typ = Array ;
@@ -1275,6 +1292,7 @@ let translate (expr : expression) : result =
12751292 let res = transl ~bad_pun_hints ~proj_in_scope body in
12761293 {
12771294 res with
1295+ typ = Function ;
12781296 expr =
12791297 { expr with pexp_desc = Pexp_function (args, constr, Pfunction_body res.expr) };
12801298 }
@@ -1286,6 +1304,7 @@ let translate (expr : expression) : result =
12861304 in
12871305 {
12881306 cases_result with
1307+ typ = Function ;
12891308 expr =
12901309 {
12911310 expr with
@@ -1396,7 +1415,7 @@ let translate (expr : expression) : result =
13961415 transl ~proj_in_scope: false ~bad_pun_hints: (Set. empty (module String )) expr
13971416
13981417let translate ?ident_label expr =
1399- let res = translate expr in
1418+ let res = translate ?ident_label:( Option. map ~f: pat2string ident_label) expr in
14001419 let loc = res.expr.pexp_loc in
14011420 let expr = res.expr in
14021421 ( res.vbs,
0 commit comments