Skip to content

Commit 7fe3406

Browse files
authored
Merge pull request #368 from ahrefs/feature/tensor-rootness-check
Improve tensor rootness safety checks: precise checks in Tensor.consume_forward_code and Tensor.consume_backprop_code
2 parents e791459 + 3dcabdc commit 7fe3406

File tree

3 files changed

+32
-17
lines changed

3 files changed

+32
-17
lines changed

arrayjit/lib/assignments.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -113,8 +113,9 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
113113
in
114114
loop asgns
115115

116-
(** Returns the nodes that are not read from after being written to. *)
117-
let%debug3_sexp guess_output_nodes (asgns : t) : Tn.t_set =
116+
(** In the second set, returns the nodes that are not read from after being written to. In the first
117+
set, returns the nodes that are ever read from. *)
118+
let%debug3_sexp collect_nodes_guess_output (asgns : t) : Tn.t_set * Tn.t_set =
118119
let open Utils.Set_O in
119120
let empty = Set.empty (module Tn) in
120121
let one = Set.singleton (module Tn) in
@@ -137,7 +138,7 @@ let%debug3_sexp guess_output_nodes (asgns : t) : Tn.t_set =
137138
| Set_vec_unop { lhs; rhs; _ } -> (of_node rhs, one lhs)
138139
| Fetch { array; _ } -> (empty, one array)
139140
in
140-
snd @@ loop asgns
141+
loop asgns
141142

142143
let sequential l =
143144
Option.value ~default:Noop @@ List.reduce l ~f:(fun sts another_st -> Seq (sts, another_st))

lib/tensor.ml

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -601,18 +601,22 @@ let consume_forward_code t =
601601
@@ Session_error
602602
( "Tensor.consume_forward_code: tensor is not a root for tnode: " ^ Tn.debug_name t.value,
603603
Some t );
604-
(* FIXME(#321): this is too aggressive, instead we should check if the code contains any
605-
non-embedded nodes that are embedded nodes of the other roots. *)
606-
let unsafe_roots =
607-
Map.data session_state.forward_roots
608-
|> List.filter ~f:(fun r -> not (List.is_empty r.children || r.id = t.id))
604+
(* Check if any non-embedded descendants of t are embedded in other roots *)
605+
let all_read = fst @@ Asgns.collect_nodes_guess_output t.forward.asgns in
606+
let non_embedded_descendants = Set.diff all_read t.forward.embedded_nodes in
607+
let other_roots =
608+
Map.data session_state.forward_roots |> List.filter ~f:(fun r -> r.id <> t.id)
609609
in
610-
if not @@ List.is_empty unsafe_roots then
610+
let conflicting_roots =
611+
List.filter other_roots ~f:(fun root ->
612+
not (Set.is_empty (Set.inter non_embedded_descendants root.forward.embedded_nodes)))
613+
in
614+
if not @@ List.is_empty conflicting_roots then
611615
raise
612616
@@ Session_error
613617
( [%string
614618
{|Tensor.consume_forward_code for %{debug_name t}:
615-
found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_name unsafe_roots}|}],
619+
found conflicting roots with shared non-embedded descendants: %{String.concat ~sep:", " @@ List.map ~f:debug_name conflicting_roots}|}],
616620
Some t );
617621
remove_fwd_root t;
618622
t.forward
@@ -630,16 +634,26 @@ let consume_backprop_code t =
630634
raise
631635
@@ Session_error
632636
("Tensor.consume_backprop_code: tensor is not a root for tnode: " ^ debug_grad t, Some t);
633-
let unsafe_roots =
634-
Map.data session_state.backprop_roots
635-
|> List.filter ~f:(fun r -> not (List.is_empty r.children || r.id = t.id))
636-
in
637-
if not @@ List.is_empty unsafe_roots then
637+
(* Check if any non-embedded grad descendants of t are embedded in other roots *)
638+
let all_read = fst @@ Asgns.collect_nodes_guess_output diff.backprop.asgns in
639+
let non_embedded_grad_descendants = Set.diff all_read diff.backprop.embedded_nodes in
640+
let other_roots =
641+
Map.data session_state.backprop_roots |> List.filter ~f:(fun r -> r.id <> t.id)
642+
in
643+
let conflicting_roots =
644+
List.filter other_roots ~f:(fun root ->
645+
match root.diff with
646+
| Some rdiff ->
647+
not
648+
(Set.is_empty (Set.inter non_embedded_grad_descendants rdiff.backprop.embedded_nodes))
649+
| None -> false)
650+
in
651+
if not @@ List.is_empty conflicting_roots then
638652
raise
639653
@@ Session_error
640654
( [%string
641655
{|Tensor.consume_backprop_code for %{debug_grad t}:
642-
found potentially unsafe roots: %{String.concat ~sep:", " @@ List.map ~f:debug_name unsafe_roots}|}],
656+
found conflicting roots with shared non-embedded grad descendants: %{String.concat ~sep:", " @@ List.map ~f:debug_grad conflicting_roots}|}],
643657
Some t );
644658
remove_bprop_root t;
645659
diff.backprop

lib/train.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ let to_routine (type buffer_ptr dev runner event optimize_ctx)
317317
and type event = event
318318
and type optimize_ctx = optimize_ctx) (context : Backend.context) ?(hosted = true) ?name
319319
bindings comp =
320-
if hosted then Set.iter (Asgns.guess_output_nodes comp.Asgns.asgns) ~f:set_hosted;
320+
if hosted then Set.iter (snd @@ Asgns.collect_nodes_guess_output comp.Asgns.asgns) ~f:set_hosted;
321321
Backend.link context @@ Backend.compile context.optimize_ctx ?name bindings comp
322322

323323
(** [init_params] initializes the parameters of [t], via running their forward code or copying from

0 commit comments

Comments
 (0)