@@ -65,8 +65,16 @@ let make_vb ~loc ~name ~name_expr ~hint_label =
6565 vb
6666
6767(* * The expression argument is of type: [Assignments.t]. *)
68- let assignment ~punned ~lhs ~rhses body =
68+ let assignment ~punned ~lhs ~rhses ? body_for_lhs ? raw_body () =
6969 let setups = lhs :: rhses in
70+ let body, is_for_lhs =
71+ match (body_for_lhs, raw_body) with
72+ | Some body_for_lhs , None ->
73+ let loc = body_for_lhs.pexp_loc in
74+ ([% expr Option. value ~default: Ir.Assignments. Noop [% e body_for_lhs]], true )
75+ | None , Some raw_body -> (raw_body, false )
76+ | _ -> assert false
77+ in
7078 let loc = body.pexp_loc in
7179 let forward_args = List. filter_map setups ~f: (fun { fwd_code_or_noop; _ } -> fwd_code_or_noop) in
7280 let vbs, body =
@@ -107,9 +115,18 @@ let assignment ~punned ~lhs ~rhses body =
107115 List. fold (body :: List. rev forward_args) ~init: [% expr []] ~f: (fun xs x ->
108116 [% expr [% e x] :: [% e xs]])
109117 in
110- let expr = [% expr Ir.Assignments. sequence [% e comps]] in
118+ let body = [% expr Ir.Assignments. sequence [% e comps]] in
119+ let body =
120+ if List. is_empty tensor_vbs then body else A.Exp. let_ ~loc Nonrecursive tensor_vbs body
121+ in
111122 let expr =
112- if List. is_empty tensor_vbs then expr else A.Exp. let_ ~loc Nonrecursive tensor_vbs expr
123+ if is_for_lhs then
124+ [% expr
125+ Option. value
126+ ~default:
127+ Ir.Assignments. { asgns = Noop ; embedded_nodes = Base.Set. empty (module Ir. Tnode ) }
128+ @@ Option. map [% e lhs.array_opt] ~f: (fun lhs -> [% e body])]
129+ else body
113130 in
114131 {
115132 vbs;
@@ -519,23 +536,22 @@ let translate ?ident_label (expr : expression) : result =
519536 (proj_lazy, [% expr projections.Tensor. projections_debug])
520537 in
521538 (* FIXME: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
522- eliding the code. *)
523- let body =
539+ eliding the code, only lhs should decide whether to elide the code . *)
540+ let body_for_lhs =
524541 [% expr
525- Option. value ~default: Ir.Assignments. Noop
526- @@ Option. map3 [% e setup_r1.array_opt] [% e setup_r2.array_opt] [% e setup_r3.array_opt]
527- ~f: (fun rhs1 rhs2 rhs3 ->
528- Ir.Assignments. Accum_op
529- {
530- initialize_neutral = [% e initialize_neutral];
531- accum = [% e accu_op];
532- lhs = Option. value_exn [% e setup_l.array_opt];
533- rhs = Ternop { op = [% e tern_op]; rhs1; rhs2; rhs3 };
534- projections = [% e projections_lazy];
535- projections_debug = [% e projections_debug];
536- })]
542+ Option. map3 [% e setup_r1.array_opt] [% e setup_r2.array_opt] [% e setup_r3.array_opt]
543+ ~f: (fun rhs1 rhs2 rhs3 ->
544+ Ir.Assignments. Accum_op
545+ {
546+ initialize_neutral = [% e initialize_neutral];
547+ accum = [% e accu_op];
548+ lhs;
549+ rhs = Ternop { op = [% e tern_op]; rhs1; rhs2; rhs3 };
550+ projections = [% e projections_lazy];
551+ projections_debug = [% e projections_debug];
552+ })]
537553 in
538- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] body
554+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] ~body_for_lhs ()
539555 in
540556 let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
541557 let initialize_neutral, accu_op = assignment_op accu_op in
@@ -582,24 +598,22 @@ let translate ?ident_label (expr : expression) : result =
582598 in
583599 (proj_lazy, [% expr projections.Tensor. projections_debug])
584600 in
585- (* TODO : might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
586- the code. *)
587- let body =
601+ (* FIXME : might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
602+ the code, only lhs should decide whether to elide the code . *)
603+ let body_for_lhs =
588604 [% expr
589- Option. value ~default: Ir.Assignments. Noop
590- @@ Option. map3 [% e setup_l.array_opt] [% e setup_r1.array_opt] [% e setup_r2.array_opt]
591- ~f: (fun lhs rhs1 rhs2 ->
592- Ir.Assignments. Accum_op
593- {
594- initialize_neutral = [% e initialize_neutral];
595- accum = [% e accu_op];
596- lhs;
597- rhs = Binop { op = [% e bin_op]; rhs1; rhs2 };
598- projections = [% e projections_lazy];
599- projections_debug = [% e projections_debug];
600- })]
605+ Option. map2 [% e setup_r1.array_opt] [% e setup_r2.array_opt] ~f: (fun rhs1 rhs2 ->
606+ Ir.Assignments. Accum_op
607+ {
608+ initialize_neutral = [% e initialize_neutral];
609+ accum = [% e accu_op];
610+ lhs;
611+ rhs = Binop { op = [% e bin_op]; rhs1; rhs2 };
612+ projections = [% e projections_lazy];
613+ projections_debug = [% e projections_debug];
614+ })]
601615 in
602- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2 ] body
616+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2 ] ~body_for_lhs ()
603617 in
604618 let process_assign_unop ~accu_op ~lhs ~un_op ~rhs ?projections ~proj_in_scope () =
605619 let initialize_neutral, accum = assignment_op accu_op in
@@ -644,23 +658,22 @@ let translate ?ident_label (expr : expression) : result =
644658 in
645659 (proj_lazy, [% expr projections.Tensor. projections_debug])
646660 in
647- (* TODO : might be better to treat missing [rhs] as zeros or errors rather than eliding the
648- code. *)
649- let body =
661+ (* FIXME : might be better to treat missing [rhs] as zeros or errors rather than eliding the
662+ code, only lhs should decide whether to elide the code . *)
663+ let body_for_lhs =
650664 [% expr
651- Option. value ~default: Ir.Assignments. Noop
652- @@ Option. map2 [% e setup_l.array_opt] [% e setup_r.array_opt] ~f: (fun lhs rhs ->
653- Ir.Assignments. Accum_op
654- {
655- initialize_neutral = [% e initialize_neutral];
656- accum = [% e accum];
657- lhs;
658- rhs = Unop { op = [% e op]; rhs };
659- projections = [% e projections_lazy];
660- projections_debug = [% e projections_debug];
661- })]
665+ Option. map [% e setup_r.array_opt] ~f: (fun rhs ->
666+ Ir.Assignments. Accum_op
667+ {
668+ initialize_neutral = [% e initialize_neutral];
669+ accum = [% e accum];
670+ lhs;
671+ rhs = Unop { op = [% e op]; rhs };
672+ projections = [% e projections_lazy];
673+ projections_debug = [% e projections_debug];
674+ })]
662675 in
663- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] body
676+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] ~body_for_lhs ()
664677 in
665678 let process_vec_unop ~lhs ~vec_un_op ~rhs ?projections ~proj_in_scope () =
666679 (* Vector unary operations do not have accumulation, they directly set values *)
@@ -700,20 +713,19 @@ let translate ?ident_label (expr : expression) : result =
700713 in
701714 (proj_lazy, [% expr projections.Tensor. projections_debug])
702715 in
703- let body =
716+ let body_for_lhs =
704717 [% expr
705- Option. value ~default: Ir.Assignments. Noop
706- @@ Option. map2 [% e setup_l.array_opt] [% e setup_r.array_opt] ~f: (fun lhs rhs ->
707- Ir.Assignments. Set_vec_unop
708- {
709- lhs;
710- op = [% e op];
711- rhs;
712- projections = [% e projections_lazy];
713- projections_debug = [% e projections_debug];
714- })]
718+ Option. map [% e setup_r.array_opt] ~f: (fun rhs ->
719+ Ir.Assignments. Set_vec_unop
720+ {
721+ lhs;
722+ op = [% e op];
723+ rhs;
724+ projections = [% e projections_lazy];
725+ projections_debug = [% e projections_debug];
726+ })]
715727 in
716- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] body
728+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] ~body_for_lhs ()
717729 in
718730 let process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic =
719731 let initialize_neutral, accu_op = assignment_op accu_op in
@@ -726,15 +738,15 @@ let translate ?ident_label (expr : expression) : result =
726738 let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
727739 let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
728740 let t3_expr, rhs3_is_grad, rhs3_is_merge = args_for ~loc setup_r3 in
729- let body =
741+ let raw_body =
730742 [% expr
731743 Tensor. raw_ternop ~initialize_neutral: [% e initialize_neutral] ~accum: [% e accu_op]
732744 ~t: [% e t_expr] ~lhs_is_grad: [% e lhs_is_grad] ~op: [% e tern_op] ~t1: [% e t1_expr]
733745 ~rhs1_is_grad: [% e rhs1_is_grad] ~rhs1_is_merge: [% e rhs1_is_merge] ~t2: [% e t2_expr]
734746 ~rhs2_is_grad: [% e rhs2_is_grad] ~rhs2_is_merge: [% e rhs2_is_merge] ~t3: [% e t3_expr]
735747 ~rhs3_is_grad: [% e rhs3_is_grad] ~rhs3_is_merge: [% e rhs3_is_merge] ~logic: [% e logic]]
736748 in
737- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] body
749+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2; setup_r3 ] ~raw_body ()
738750 in
739751 let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
740752 let initialize_neutral, accu_op = assignment_op accu_op in
@@ -745,14 +757,14 @@ let translate ?ident_label (expr : expression) : result =
745757 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
746758 let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
747759 let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
748- let body =
760+ let raw_body =
749761 [% expr
750762 Tensor. raw_binop ~initialize_neutral: [% e initialize_neutral] ~accum: [% e accu_op]
751763 ~t: [% e t_expr] ~lhs_is_grad: [% e lhs_is_grad] ~op: [% e bin_op] ~t1: [% e t1_expr]
752764 ~rhs1_is_grad: [% e rhs1_is_grad] ~rhs1_is_merge: [% e rhs1_is_merge] ~t2: [% e t2_expr]
753765 ~rhs2_is_grad: [% e rhs2_is_grad] ~rhs2_is_merge: [% e rhs2_is_merge] ~logic: [% e logic]]
754766 in
755- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2 ] body
767+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r1; setup_r2 ] ~raw_body ()
756768 in
757769 let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
758770 let initialize_neutral, accu_op = assignment_op accu_op in
@@ -761,13 +773,13 @@ let translate ?ident_label (expr : expression) : result =
761773 let initialize_neutral = if initialize_neutral then [% expr true ] else [% expr false ] in
762774 let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
763775 let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in
764- let body =
776+ let raw_body =
765777 [% expr
766778 Tensor. raw_unop ~initialize_neutral: [% e initialize_neutral] ~accum: [% e accu_op]
767779 ~t: [% e t_expr] ~lhs_is_grad: [% e lhs_is_grad] ~op: [% e un_op] ~t1: [% e t1_expr]
768780 ~rhs_is_grad: [% e rhs_is_grad] ~rhs_is_merge: [% e rhs_is_merge] ~logic: [% e logic]]
769781 in
770- assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] body
782+ assignment ~punned ~lhs: setup_l ~rhses: [ setup_r ] ~raw_body ()
771783 in
772784 match expr with
773785 | { pexp_desc = Pexp_constant (Pconst_float _ ); _ } ->
0 commit comments