Skip to content

Commit e0f5eb1

Browse files
committed
Avoid generating dead tensors by delaying computation of RHSes, e.g. for non-existent backpropagations
As side effect also fixes potential crashes where such dead tensor cases use ternary op assignments.
1 parent fed7a6d commit e0f5eb1

File tree

5 files changed

+137
-131
lines changed

5 files changed

+137
-131
lines changed

lib/operation.ml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -347,11 +347,9 @@ let fma ?(label = []) ~grad_spec t1 t2 t3 =
347347
let where ?(label = []) ~grad_spec t1 t2 t3 =
348348
let module NTDSL = NTDSL_before_div in
349349
let%cd op_asn ~v ~t1 ~t2 ~t3 ~projections = v =: where v1 v2 v3 in
350-
(* Just to illustrate that both [0] and [!..0] are handled. *)
351-
let zero_cst = 0 in
352350
let%cd grad_asn ~t:_ ~g ~t1 ~t2 ~t3 ~projections =
353351
g2 =+ where v1 g 0;
354-
g3 =+ where v1 !..zero_cst g
352+
g3 =+ where v1 0 g
355353
in
356354
Tensor.ternop ~label:("where" :: label) ~ternary_op:Pointwise_tern ~op_asn ~grad_asn ~grad_spec t1
357355
t2 t3

lib/ppx_cd.ml

Lines changed: 79 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,16 @@ let make_vb ~loc ~name ~name_expr ~hint_label =
6565
vb
6666

6767
(** The expression argument is of type: [Assignments.t]. *)
68-
let assignment ~punned ~lhs ~rhses body =
68+
let assignment ~punned ~lhs ~rhses ?body_for_lhs ?raw_body () =
6969
let setups = lhs :: rhses in
70+
let body, is_for_lhs =
71+
match (body_for_lhs, raw_body) with
72+
| Some body_for_lhs, None ->
73+
let loc = body_for_lhs.pexp_loc in
74+
([%expr Option.value ~default:Ir.Assignments.Noop [%e body_for_lhs]], true)
75+
| None, Some raw_body -> (raw_body, false)
76+
| _ -> assert false
77+
in
7078
let loc = body.pexp_loc in
7179
let forward_args = List.filter_map setups ~f:(fun { fwd_code_or_noop; _ } -> fwd_code_or_noop) in
7280
let vbs, body =
@@ -107,9 +115,18 @@ let assignment ~punned ~lhs ~rhses body =
107115
List.fold (body :: List.rev forward_args) ~init:[%expr []] ~f:(fun xs x ->
108116
[%expr [%e x] :: [%e xs]])
109117
in
110-
let expr = [%expr Ir.Assignments.sequence [%e comps]] in
118+
let body = [%expr Ir.Assignments.sequence [%e comps]] in
119+
let body =
120+
if List.is_empty tensor_vbs then body else A.Exp.let_ ~loc Nonrecursive tensor_vbs body
121+
in
111122
let expr =
112-
if List.is_empty tensor_vbs then expr else A.Exp.let_ ~loc Nonrecursive tensor_vbs expr
123+
if is_for_lhs then
124+
[%expr
125+
Option.value
126+
~default:
127+
Ir.Assignments.{ asgns = Noop; embedded_nodes = Base.Set.empty (module Ir.Tnode) }
128+
@@ Option.map [%e lhs.array_opt] ~f:(fun lhs -> [%e body])]
129+
else body
113130
in
114131
{
115132
vbs;
@@ -519,23 +536,22 @@ let translate ?ident_label (expr : expression) : result =
519536
(proj_lazy, [%expr projections.Tensor.projections_debug])
520537
in
521538
(* FIXME: might be better to treat missing [rhs1, rhs2, rhs3] as zeros or errors rather than
522-
eliding the code. *)
523-
let body =
539+
eliding the code, only lhs should decide whether to elide the code. *)
540+
let body_for_lhs =
524541
[%expr
525-
Option.value ~default:Ir.Assignments.Noop
526-
@@ Option.map3 [%e setup_r1.array_opt] [%e setup_r2.array_opt] [%e setup_r3.array_opt]
527-
~f:(fun rhs1 rhs2 rhs3 ->
528-
Ir.Assignments.Accum_op
529-
{
530-
initialize_neutral = [%e initialize_neutral];
531-
accum = [%e accu_op];
532-
lhs = Option.value_exn [%e setup_l.array_opt];
533-
rhs = Ternop { op = [%e tern_op]; rhs1; rhs2; rhs3 };
534-
projections = [%e projections_lazy];
535-
projections_debug = [%e projections_debug];
536-
})]
542+
Option.map3 [%e setup_r1.array_opt] [%e setup_r2.array_opt] [%e setup_r3.array_opt]
543+
~f:(fun rhs1 rhs2 rhs3 ->
544+
Ir.Assignments.Accum_op
545+
{
546+
initialize_neutral = [%e initialize_neutral];
547+
accum = [%e accu_op];
548+
lhs;
549+
rhs = Ternop { op = [%e tern_op]; rhs1; rhs2; rhs3 };
550+
projections = [%e projections_lazy];
551+
projections_debug = [%e projections_debug];
552+
})]
537553
in
538-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
554+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] ~body_for_lhs ()
539555
in
540556
let process_assign_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ?projections ~proj_in_scope () =
541557
let initialize_neutral, accu_op = assignment_op accu_op in
@@ -582,24 +598,22 @@ let translate ?ident_label (expr : expression) : result =
582598
in
583599
(proj_lazy, [%expr projections.Tensor.projections_debug])
584600
in
585-
(* TODO: might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
586-
the code. *)
587-
let body =
601+
(* FIXME: might be better to treat missing [rhs1, rhs2] as zeros or errors rather than eliding
602+
the code, only lhs should decide whether to elide the code. *)
603+
let body_for_lhs =
588604
[%expr
589-
Option.value ~default:Ir.Assignments.Noop
590-
@@ Option.map3 [%e setup_l.array_opt] [%e setup_r1.array_opt] [%e setup_r2.array_opt]
591-
~f:(fun lhs rhs1 rhs2 ->
592-
Ir.Assignments.Accum_op
593-
{
594-
initialize_neutral = [%e initialize_neutral];
595-
accum = [%e accu_op];
596-
lhs;
597-
rhs = Binop { op = [%e bin_op]; rhs1; rhs2 };
598-
projections = [%e projections_lazy];
599-
projections_debug = [%e projections_debug];
600-
})]
605+
Option.map2 [%e setup_r1.array_opt] [%e setup_r2.array_opt] ~f:(fun rhs1 rhs2 ->
606+
Ir.Assignments.Accum_op
607+
{
608+
initialize_neutral = [%e initialize_neutral];
609+
accum = [%e accu_op];
610+
lhs;
611+
rhs = Binop { op = [%e bin_op]; rhs1; rhs2 };
612+
projections = [%e projections_lazy];
613+
projections_debug = [%e projections_debug];
614+
})]
601615
in
602-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2 ] body
616+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2 ] ~body_for_lhs ()
603617
in
604618
let process_assign_unop ~accu_op ~lhs ~un_op ~rhs ?projections ~proj_in_scope () =
605619
let initialize_neutral, accum = assignment_op accu_op in
@@ -644,23 +658,22 @@ let translate ?ident_label (expr : expression) : result =
644658
in
645659
(proj_lazy, [%expr projections.Tensor.projections_debug])
646660
in
647-
(* TODO: might be better to treat missing [rhs] as zeros or errors rather than eliding the
648-
code. *)
649-
let body =
661+
(* FIXME: might be better to treat missing [rhs] as zeros or errors rather than eliding the
662+
code, only lhs should decide whether to elide the code. *)
663+
let body_for_lhs =
650664
[%expr
651-
Option.value ~default:Ir.Assignments.Noop
652-
@@ Option.map2 [%e setup_l.array_opt] [%e setup_r.array_opt] ~f:(fun lhs rhs ->
653-
Ir.Assignments.Accum_op
654-
{
655-
initialize_neutral = [%e initialize_neutral];
656-
accum = [%e accum];
657-
lhs;
658-
rhs = Unop { op = [%e op]; rhs };
659-
projections = [%e projections_lazy];
660-
projections_debug = [%e projections_debug];
661-
})]
665+
Option.map [%e setup_r.array_opt] ~f:(fun rhs ->
666+
Ir.Assignments.Accum_op
667+
{
668+
initialize_neutral = [%e initialize_neutral];
669+
accum = [%e accum];
670+
lhs;
671+
rhs = Unop { op = [%e op]; rhs };
672+
projections = [%e projections_lazy];
673+
projections_debug = [%e projections_debug];
674+
})]
662675
in
663-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] body
676+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] ~body_for_lhs ()
664677
in
665678
let process_vec_unop ~lhs ~vec_un_op ~rhs ?projections ~proj_in_scope () =
666679
(* Vector unary operations do not have accumulation, they directly set values *)
@@ -700,20 +713,19 @@ let translate ?ident_label (expr : expression) : result =
700713
in
701714
(proj_lazy, [%expr projections.Tensor.projections_debug])
702715
in
703-
let body =
716+
let body_for_lhs =
704717
[%expr
705-
Option.value ~default:Ir.Assignments.Noop
706-
@@ Option.map2 [%e setup_l.array_opt] [%e setup_r.array_opt] ~f:(fun lhs rhs ->
707-
Ir.Assignments.Set_vec_unop
708-
{
709-
lhs;
710-
op = [%e op];
711-
rhs;
712-
projections = [%e projections_lazy];
713-
projections_debug = [%e projections_debug];
714-
})]
718+
Option.map [%e setup_r.array_opt] ~f:(fun rhs ->
719+
Ir.Assignments.Set_vec_unop
720+
{
721+
lhs;
722+
op = [%e op];
723+
rhs;
724+
projections = [%e projections_lazy];
725+
projections_debug = [%e projections_debug];
726+
})]
715727
in
716-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] body
728+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] ~body_for_lhs ()
717729
in
718730
let process_raw_ternop ~accu_op ~lhs ~tern_op ~rhs1 ~rhs2 ~rhs3 ~logic =
719731
let initialize_neutral, accu_op = assignment_op accu_op in
@@ -726,15 +738,15 @@ let translate ?ident_label (expr : expression) : result =
726738
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
727739
let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
728740
let t3_expr, rhs3_is_grad, rhs3_is_merge = args_for ~loc setup_r3 in
729-
let body =
741+
let raw_body =
730742
[%expr
731743
Tensor.raw_ternop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op]
732744
~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e tern_op] ~t1:[%e t1_expr]
733745
~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr]
734746
~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~t3:[%e t3_expr]
735747
~rhs3_is_grad:[%e rhs3_is_grad] ~rhs3_is_merge:[%e rhs3_is_merge] ~logic:[%e logic]]
736748
in
737-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] body
749+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2; setup_r3 ] ~raw_body ()
738750
in
739751
let process_raw_binop ~accu_op ~lhs ~bin_op ~rhs1 ~rhs2 ~logic =
740752
let initialize_neutral, accu_op = assignment_op accu_op in
@@ -745,14 +757,14 @@ let translate ?ident_label (expr : expression) : result =
745757
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
746758
let t1_expr, rhs1_is_grad, rhs1_is_merge = args_for ~loc setup_r1 in
747759
let t2_expr, rhs2_is_grad, rhs2_is_merge = args_for ~loc setup_r2 in
748-
let body =
760+
let raw_body =
749761
[%expr
750762
Tensor.raw_binop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op]
751763
~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e bin_op] ~t1:[%e t1_expr]
752764
~rhs1_is_grad:[%e rhs1_is_grad] ~rhs1_is_merge:[%e rhs1_is_merge] ~t2:[%e t2_expr]
753765
~rhs2_is_grad:[%e rhs2_is_grad] ~rhs2_is_merge:[%e rhs2_is_merge] ~logic:[%e logic]]
754766
in
755-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2 ] body
767+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r1; setup_r2 ] ~raw_body ()
756768
in
757769
let process_raw_unop ~accu_op ~lhs ~un_op ~rhs ~logic =
758770
let initialize_neutral, accu_op = assignment_op accu_op in
@@ -761,13 +773,13 @@ let translate ?ident_label (expr : expression) : result =
761773
let initialize_neutral = if initialize_neutral then [%expr true] else [%expr false] in
762774
let t_expr, lhs_is_grad, _ = args_for ~loc setup_l in
763775
let t1_expr, rhs_is_grad, rhs_is_merge = args_for ~loc setup_r in
764-
let body =
776+
let raw_body =
765777
[%expr
766778
Tensor.raw_unop ~initialize_neutral:[%e initialize_neutral] ~accum:[%e accu_op]
767779
~t:[%e t_expr] ~lhs_is_grad:[%e lhs_is_grad] ~op:[%e un_op] ~t1:[%e t1_expr]
768780
~rhs_is_grad:[%e rhs_is_grad] ~rhs_is_merge:[%e rhs_is_merge] ~logic:[%e logic]]
769781
in
770-
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] body
782+
assignment ~punned ~lhs:setup_l ~rhses:[ setup_r ] ~raw_body ()
771783
in
772784
match expr with
773785
| { pexp_desc = Pexp_constant (Pconst_float _); _ } ->

test/einsum/moons_demo_variant.expected

Lines changed: 30 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -40,43 +40,38 @@ n39 grad_- as n38.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
4040
n40 relu_margin_loss as relu_margin_loss: Virt/15; single prec 10x1; mem in bytes: <not-in-yet>
4141
n41 grad_relu_margin_loss as relu_margin_loss.grad: Local/1046; single prec 10x1; mem in bytes: <not-in-yet>
4242
n42 10 as _10: Virt/40; single prec 1; mem in bytes: <not-in-yet>
43-
n43 => as n43: Local/1046; single prec 1; mem in bytes: <not-in-yet>
4443
n44 grad_=> as n43.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
4544
n46 grad_/._scalar_loss as scalar_loss.grad: Virt/40; single prec 1; mem in bytes: <not-in-yet>
46-
n47 2 as _2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
47-
n48 **. as n48: Virt/40; single prec 1; mem in bytes: <not-in-yet>
48-
n49 -1 as n49: Virt/40; single prec 1; mem in bytes: <not-in-yet>
49-
n50 *. as n50: Virt/152; single prec 1; mem in bytes: <not-in-yet>
50-
n52 1 as _1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
51-
n53 200 as _200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
52-
n54 !@ as n54: Virt/152; single prec 1; mem in bytes: <not-in-yet>
53-
n55 200 as _200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
54-
n56 2 as _2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
55-
n57 *. as n57: Virt/40; single prec 1; mem in bytes: <not-in-yet>
56-
n58 - as n58: Virt/152; single prec 1; mem in bytes: <not-in-yet>
57-
n59 0.1 as n59: Virt/40; single prec 1; mem in bytes: <not-in-yet>
58-
n60 *. as n60: Virt/152; single prec 1; mem in bytes: <not-in-yet>
59-
n62 sgd_delta_b1 as sgd_delta_b1: Virt/15; single prec 16; mem in bytes: <not-in-yet>
60-
n63 sgd_momentum_b1 as sgd_momentum_b1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
61-
n64 0.0001 as n64: Virt/40; single prec 1; mem in bytes: <not-in-yet>
62-
n65 *. as n65: Virt/15; single prec 16; mem in bytes: <not-in-yet>
63-
n66 sgd_delta_w1 as sgd_delta_w1: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
64-
n67 sgd_momentum_w1 as sgd_momentum_w1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
65-
n68 0.0001 as n68: Virt/40; single prec 1; mem in bytes: <not-in-yet>
66-
n69 *. as n69: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
67-
n70 sgd_delta_w2 as sgd_delta_w2: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
68-
n71 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
69-
n72 0.0001 as n72: Virt/40; single prec 1; mem in bytes: <not-in-yet>
70-
n73 *. as n73: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
71-
n74 point_mlp_result as point_mlp_result: Host&shared/37039; single prec 2; mem in bytes: <not-in-yet>
72-
n75 * as n75: Local/1046; single prec 16; mem in bytes: <not-in-yet>
73-
n76 grad_* as n75.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
74-
n77 + as n77: Virt/15; single prec 16; mem in bytes: <not-in-yet>
75-
n78 grad_+ as n77.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
76-
n79 relu as relu: Virt/15; single prec 16; mem in bytes: <not-in-yet>
77-
n80 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
78-
n81 *_mlp_point_mlp_result as mlp_point_mlp_result: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
79-
n82 grad_*_mlp_point_mlp_result as mlp_point_mlp_result.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
45+
n47 1 as _1: Virt/40; single prec 1; mem in bytes: <not-in-yet>
46+
n48 200 as _200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
47+
n49 !@ as n49: Virt/152; single prec 1; mem in bytes: <not-in-yet>
48+
n50 200 as _200: Virt/40; single prec 1; mem in bytes: <not-in-yet>
49+
n51 2 as _2: Virt/40; single prec 1; mem in bytes: <not-in-yet>
50+
n52 *. as n52: Virt/40; single prec 1; mem in bytes: <not-in-yet>
51+
n53 - as n53: Virt/152; single prec 1; mem in bytes: <not-in-yet>
52+
n54 0.1 as n54: Virt/40; single prec 1; mem in bytes: <not-in-yet>
53+
n55 *. as n55: Virt/152; single prec 1; mem in bytes: <not-in-yet>
54+
n57 sgd_delta_b1 as sgd_delta_b1: Virt/15; single prec 16; mem in bytes: <not-in-yet>
55+
n58 sgd_momentum_b1 as sgd_momentum_b1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
56+
n59 0.0001 as n59: Virt/40; single prec 1; mem in bytes: <not-in-yet>
57+
n60 *. as n60: Virt/15; single prec 16; mem in bytes: <not-in-yet>
58+
n61 sgd_delta_w1 as sgd_delta_w1: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
59+
n62 sgd_momentum_w1 as sgd_momentum_w1: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
60+
n63 0.0001 as n63: Virt/40; single prec 1; mem in bytes: <not-in-yet>
61+
n64 *. as n64: Virt/15; single prec 16x2; mem in bytes: <not-in-yet>
62+
n65 sgd_delta_w2 as sgd_delta_w2: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
63+
n66 sgd_momentum_w2 as sgd_momentum_w2: unknown; single prec <not-in-yet>; mem in bytes: <not-in-yet>
64+
n67 0.0001 as n67: Virt/40; single prec 1; mem in bytes: <not-in-yet>
65+
n68 *. as n68: Virt/15; single prec 1x16; mem in bytes: <not-in-yet>
66+
n69 point_mlp_result as point_mlp_result: Host&shared/37039; single prec 2; mem in bytes: <not-in-yet>
67+
n70 * as n70: Local/1046; single prec 16; mem in bytes: <not-in-yet>
68+
n71 grad_* as n70.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
69+
n72 + as n72: Virt/15; single prec 16; mem in bytes: <not-in-yet>
70+
n73 grad_+ as n72.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
71+
n74 relu as relu: Virt/15; single prec 16; mem in bytes: <not-in-yet>
72+
n75 grad_relu as relu.grad: unknown; single prec 16; mem in bytes: <not-in-yet>
73+
n76 *_mlp_point_mlp_result as mlp_point_mlp_result: Host&stream/412410; single prec 1; mem in bytes: <not-in-yet>
74+
n77 grad_*_mlp_point_mlp_result as mlp_point_mlp_result.grad: unknown; single prec 1; mem in bytes: <not-in-yet>
8075
Tnode: Finished printing headers.
8176
mlp_result's name: mlp_point_mlp_result
8277
(mlp moons_input) name: mlp_moons_input

0 commit comments

Comments
 (0)