@@ -601,18 +601,22 @@ let consume_forward_code t =
601601 @@ Session_error
602602 ( " Tensor.consume_forward_code: tensor is not a root for tnode: " ^ Tn. debug_name t.value,
603603 Some t );
604- (* FIXME(#321): this is too aggressive, instead we should check if the code contains any
605- non-embedded nodes that are embedded nodes of the other roots. *)
606- let unsafe_roots =
607- Map. data session_state.forward_roots
608- |> List. filter ~f: (fun r -> not ( List. is_empty r.children || r. id = t.id) )
604+ (* Check if any non-embedded descendants of t are embedded in other roots *)
605+ let all_read = fst @@ Asgns. collect_nodes_guess_output t.forward.asgns in
606+ let non_embedded_descendants = Set. diff all_read t.forward.embedded_nodes in
607+ let other_roots =
608+ Map. data session_state.forward_roots |> List. filter ~f: (fun r -> r. id <> t.id)
609609 in
610- if not @@ List. is_empty unsafe_roots then
610+ let conflicting_roots =
611+ List. filter other_roots ~f: (fun root ->
612+ not (Set. is_empty (Set. inter non_embedded_descendants root.forward.embedded_nodes)))
613+ in
614+ if not @@ List. is_empty conflicting_roots then
611615 raise
612616 @@ Session_error
613617 ( [% string
614618 {| Tensor. consume_forward_code for % {debug_name t}:
615- found potentially unsafe roots: % {String. concat ~sep: " , " @@ List. map ~f: debug_name unsafe_roots }| }],
619+ found conflicting roots with shared non - embedded descendants : % {String. concat ~sep: " , " @@ List. map ~f: debug_name conflicting_roots }| }],
616620 Some t );
617621 remove_fwd_root t;
618622 t.forward
@@ -630,16 +634,26 @@ let consume_backprop_code t =
630634 raise
631635 @@ Session_error
632636 (" Tensor.consume_backprop_code: tensor is not a root for tnode: " ^ debug_grad t, Some t);
633- let unsafe_roots =
634- Map. data session_state.backprop_roots
635- |> List. filter ~f: (fun r -> not (List. is_empty r.children || r.id = t.id))
636- in
637- if not @@ List. is_empty unsafe_roots then
637+ (* Check if any non-embedded grad descendants of t are embedded in other roots *)
638+ let all_read = fst @@ Asgns. collect_nodes_guess_output diff.backprop.asgns in
639+ let non_embedded_grad_descendants = Set. diff all_read diff.backprop.embedded_nodes in
640+ let other_roots =
641+ Map. data session_state.backprop_roots |> List. filter ~f: (fun r -> r.id <> t.id)
642+ in
643+ let conflicting_roots =
644+ List. filter other_roots ~f: (fun root ->
645+ match root.diff with
646+ | Some rdiff ->
647+ not
648+ (Set. is_empty (Set. inter non_embedded_grad_descendants rdiff.backprop.embedded_nodes))
649+ | None -> false )
650+ in
651+ if not @@ List. is_empty conflicting_roots then
638652 raise
639653 @@ Session_error
640654 ( [% string
641655 {| Tensor. consume_backprop_code for % {debug_grad t}:
642- found potentially unsafe roots: % {String. concat ~sep: " , " @@ List. map ~f: debug_name unsafe_roots }| }],
656+ found conflicting roots with shared non - embedded grad descendants : % {String. concat ~sep: " , " @@ List. map ~f: debug_grad conflicting_roots }| }],
643657 Some t );
644658 remove_bprop_root t;
645659 diff.backprop
0 commit comments