@@ -271,13 +271,21 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
271271 && not (Tn. known_non_virtual tn)
272272 then Tn. update_memory_mode tn Virtual 40 ;
273273 if Option. is_none tn.memory_mode && Hashtbl. exists traced.accesses ~f: is_too_many then
274- Tn. update_memory_mode tn Never_virtual 1
274+ Tn. update_memory_mode tn Never_virtual 1 ;
275+ if (not traced.zeroed_out) && Hash_set. is_empty traced.assignments then (
275276 (* The tensor node is read-only/recurrent for this computation, but maybe computed or
276277 specified as virtual by another routine. However, if the memory mode is unspecified, we
277- assume this will be the first computation involving the tensor node. *) ;
278- if (not traced.zeroed_out) && Hash_set. is_empty traced.assignments then (
278+ assume this will be the first computation involving the tensor node. *)
279279 traced.read_only < - true ;
280280 if Tn. mode_is_unspecified tn then Tn. update_memory_mode tn (Hosted Constant ) 37
281+ else if Tn. known_not_materialized tn then (
282+ if Tn. known_non_virtual tn then
283+ raise
284+ (Utils. User_error
285+ [% string
286+ " Mark %{Tn.debug_name tn} as materialized (e.g. via Train.set_materialized) \
287+ before the first routine using it gets compiled; another routine re-uses that \
288+ computation. Debug: %{Tn.debug_memory_mode tn.Tn.memory_mode}" ]))
281289 else if Tn. known_non_virtual tn then Tn. update_memory_mode tn Materialized 35 );
282290 if Hashtbl. exists traced.accesses ~f: is_recurrent then (
283291 traced.read_before_write < - true ;
0 commit comments