@@ -365,6 +365,31 @@ let args_for ~loc = function
365365
366366let reduce_res_vbs rs = reduce_vbss @@ List. map rs ~f: (fun r -> r.vbs)
367367
368+ (* * Helper function to handle cases (for Pexp_match, Pexp_function with cases, etc.) *)
369+ let handle_cases ~bad_pun_hints ~proj_in_scope transl cases =
370+ let fields, transformed_cases =
371+ List. unzip
372+ @@ List. map cases ~f: (fun ({ pc_rhs; _ } as c ) ->
373+ let res = transl ~bad_pun_hints ~proj_in_scope pc_rhs in
374+ ((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
375+ in
376+ let vbss, typs, slots = List. unzip3 fields in
377+ (* TODO: make the inference of typ and slot more strict by detecting mismatches. *)
378+ let typ = Option. value ~default: Unknown @@ List. find typs ~f: (Fn. non is_unknown) in
379+ let slot =
380+ Option. value ~default: Undet @@ List. find ~f: (function Undet -> false | _ -> true ) slots
381+ in
382+ let loc = (List. hd_exn cases).pc_lhs.ppat_loc in
383+ ( transformed_cases,
384+ {
385+ vbs = reduce_vbss vbss;
386+ typ;
387+ slot;
388+ expr = [% expr () ];
389+ (* This will be replaced by the caller *)
390+ array_opt_of_code = None ;
391+ } )
392+
368393let translate (expr : expression ) : result =
369394 let punned = Hashtbl. create (module String ) in
370395 let rec transl ~bad_pun_hints ~proj_in_scope (expr : expression ) : result =
@@ -736,15 +761,26 @@ let translate (expr : expression) : result =
736761 | { pexp_desc = Pexp_ident { txt = Lident op_ident ; _ } ; _ } when is_primitive_op op_ident ->
737762 default_result
738763 | [% expr ! .[% e? expr1]] ->
739- (* Hardcoding these two patterns to improve projection derivation expressivity. *)
740- let res1 = loop ~proj_in_scope expr1 in
741- { res1 with typ = Tensor ; slot = Scalar ; expr = [% expr NTDSL.O. ( ! . ) [% e res1.expr]] }
764+ (* Hardcoding these two patterns (!. and !..) to improve projection derivation expressivity
765+ and avoid treating the constants as already tensors. *)
766+ {
767+ typ = Tensor ;
768+ slot = Scalar ;
769+ expr = [% expr NTDSL.O. ( ! . ) [% e expr1]];
770+ array_opt_of_code = None ;
771+ vbs = no_vbs;
772+ }
742773 | [% expr ! ..[% e? expr1]] ->
743- let res1 = loop ~proj_in_scope expr1 in
744- { res1 with typ = Tensor ; slot = Scalar ; expr = [% expr NTDSL.O. ( ! .. ) [% e res1.expr]] }
774+ {
775+ typ = Tensor ;
776+ slot = Scalar ;
777+ expr = [% expr NTDSL.O. ( ! .. ) [% e expr1]];
778+ array_opt_of_code = None ;
779+ vbs = no_vbs;
780+ }
745781 | [% expr [% e? expr1] **. [% e? { pexp_desc = Pexp_constant (Pconst_integer _); _ } as i]] ->
746- (* FIXME: `**.` should take a tensor and require that it's a literal. *)
747- (* We need to hardcode these two patterns to prevent the numbers from being converted to tensors. *)
782+ (* We need to hardcode these two patterns (for **. ) to prevent the numbers from
783+ being converted to tensors. *)
748784 let res1 = loop ~proj_in_scope expr1 in
749785 {
750786 res1 with
@@ -1135,17 +1171,49 @@ let translate (expr : expression) : result =
11351171 expr = [% expr [% e res1.expr] [% e res2.expr]];
11361172 array_opt_of_code = None ;
11371173 }
1138- | { pexp_desc = Pexp_fun (( arg_label : arg_label ), arg , pat , expr1 ); _ } as expr ->
1174+ | { pexp_desc = Pexp_function ( args , constr , body ); _ } as expr ->
11391175 let proj_in_scope =
11401176 proj_in_scope
1141- ||
1142- match arg_label with
1143- | (Labelled s | Optional s ) when String. equal s " projections" -> true
1144- | _ -> false
1177+ || List. exists args ~f: (function
1178+ | { pparam_desc = Pparam_val ((Labelled s | Optional s), _, _); _ }
1179+ when String. equal s " projections" ->
1180+ true
1181+ | _ -> false )
1182+ in
1183+ let bad_pun_hints =
1184+ Set. union_list (module String )
1185+ @@ bad_pun_hints
1186+ :: List. map args ~f: (fun arg ->
1187+ match arg.pparam_desc with
1188+ | Pparam_val (_ , _ , pat ) -> collect_pat_idents pat
1189+ | _ -> Set. empty (module String ))
1190+ in
1191+ let result =
1192+ match body with
1193+ | Pfunction_body body ->
1194+ let res = transl ~bad_pun_hints ~proj_in_scope body in
1195+ {
1196+ res with
1197+ expr =
1198+ { expr with pexp_desc = Pexp_function (args, constr, Pfunction_body res.expr) };
1199+ }
1200+ | Pfunction_cases (cases , loc , attrs ) ->
1201+ let transformed_cases, cases_result =
1202+ handle_cases ~bad_pun_hints ~proj_in_scope
1203+ (fun ~bad_pun_hints ~proj_in_scope -> transl ~bad_pun_hints ~proj_in_scope )
1204+ cases
1205+ in
1206+ {
1207+ cases_result with
1208+ expr =
1209+ {
1210+ expr with
1211+ pexp_desc =
1212+ Pexp_function (args, constr, Pfunction_cases (transformed_cases, loc, attrs));
1213+ };
1214+ }
11451215 in
1146- let bad_pun_hints = Set. union bad_pun_hints @@ collect_pat_idents pat in
1147- let res1 = transl ~bad_pun_hints ~proj_in_scope expr1 in
1148- { res1 with expr = { expr with pexp_desc = Pexp_fun (arg_label, arg, pat, res1.expr) } }
1216+ result
11491217 | [% expr
11501218 while [% e? _test_expr] do
11511219 [% e? _body]
@@ -1222,26 +1290,13 @@ let translate (expr : expression) : result =
12221290 array_opt_of_code = res2.array_opt_of_code;
12231291 }
12241292 | { pexp_desc = Pexp_match (expr1 , cases ); _ } ->
1225- let fields, cases =
1226- List. unzip
1227- @@ List. map cases ~f: (fun ({ pc_rhs; _ } as c ) ->
1228- let res = loop ~proj_in_scope pc_rhs in
1229- ((res.vbs, res.typ, res.slot), { c with pc_rhs = res.expr }))
1293+ let transformed_cases, cases_result =
1294+ handle_cases ~bad_pun_hints ~proj_in_scope transl cases
12301295 in
1231- let vbss, typs, slots = List. unzip3 fields in
1232- let typ = Option. value ~default: Unknown @@ List. find typs ~f: (Fn. non is_unknown) in
1233- let slot =
1234- Option. value ~default: Undet @@ List. find ~f: (function Undet -> false | _ -> true ) slots
1235- in
1236- {
1237- vbs = reduce_vbss vbss;
1238- typ;
1239- slot;
1240- expr = { expr with pexp_desc = Pexp_match (expr1, cases) };
1241- array_opt_of_code = None ;
1242- }
1296+ { cases_result with expr = { expr with pexp_desc = Pexp_match (expr1, transformed_cases) } }
12431297 | { pexp_desc = Pexp_let (_recflag , _bindings , _body ); _ } ->
1244- (* TODO(80): to properly support local bindings, we need to collect the type environment. *)
1298+ (* TODO(#80): to properly support local bindings, we need to collect the type
1299+ environment. *)
12451300 {
12461301 default_result with
12471302 typ = Unknown ;
0 commit comments