Skip to content

Commit 5659e7a

Browse files
committed
Better error for memory mode prohibiting node reuse
1 parent 6da09c4 commit 5659e7a

File tree

2 files changed

+14
-3
lines changed

2 files changed

+14
-3
lines changed

arrayjit/lib/low_level.ml

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -271,13 +271,21 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
271271
&& not (Tn.known_non_virtual tn)
272272
then Tn.update_memory_mode tn Virtual 40;
273273
if Option.is_none tn.memory_mode && Hashtbl.exists traced.accesses ~f:is_too_many then
274-
Tn.update_memory_mode tn Never_virtual 1
274+
Tn.update_memory_mode tn Never_virtual 1;
275+
if (not traced.zeroed_out) && Hash_set.is_empty traced.assignments then (
275276
(* The tensor node is read-only/recurrent for this computation, but maybe computed or
276277
specified as virtual by another routine. However, if the memory mode is unspecified, we
277-
assume this will be the first computation involving the tensor node. *);
278-
if (not traced.zeroed_out) && Hash_set.is_empty traced.assignments then (
278+
assume this will be the first computation involving the tensor node. *)
279279
traced.read_only <- true;
280280
if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Constant) 37
281+
else if Tn.known_not_materialized tn then (
282+
if Tn.known_non_virtual tn then
283+
raise
284+
(Utils.User_error
285+
[%string
286+
"Mark %{Tn.debug_name tn} as materialized (e.g. via Train.set_materialized) \
287+
before the first routine using it gets compiled; another routine re-uses that \
288+
computation. Debug: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]))
281289
else if Tn.known_non_virtual tn then Tn.update_memory_mode tn Materialized 35);
282290
if Hashtbl.exists traced.accesses ~f:is_recurrent then (
283291
traced.read_before_write <- true;

arrayjit/lib/tnode.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -196,6 +196,9 @@ let log_debug_info ~from_log_level tn =
196196

197197
(** The one exception to "most local" is that the sharing property is kept at [Unset]. *)
198198
let default_to_most_local tn provenance =
199+
let provenance =
200+
match tn.memory_mode with Some (_, prov) -> (1000 * prov) + provenance | None -> provenance
201+
in
199202
match tn.memory_mode with
200203
| None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance)
201204
| Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance)

0 commit comments

Comments
 (0)