@@ -90,6 +90,7 @@ type virtualize_settings = {
9090 mutable max_visits : int ;
9191 mutable max_tracing_dim : int ;
9292 mutable inline_scalar_constexprs : bool ;
93+ mutable inline_simple_computations : bool ;
9394}
9495
9596let virtualize_settings =
@@ -105,7 +106,16 @@ let virtualize_settings =
105106 let inline_scalar_constexprs =
106107 Bool. of_string @@ Utils. get_global_arg ~arg_name: " inline_scalar_constexprs" ~default: " true"
107108 in
108- { enable_device_only; max_visits; max_tracing_dim; inline_scalar_constexprs }
109+ let inline_simple_computations =
110+ Bool. of_string @@ Utils. get_global_arg ~arg_name: " inline_simple_computations" ~default: " true"
111+ in
112+ {
113+ enable_device_only;
114+ max_visits;
115+ max_tracing_dim;
116+ inline_scalar_constexprs;
117+ inline_simple_computations;
118+ }
109119
110120type visits = Visits of int | Recurrent [@@ deriving sexp , equal , variants ]
111121
@@ -118,6 +128,7 @@ type traced_array = {
118128 mutable read_before_write : bool ;
119129 mutable read_only : bool ;
120130 mutable is_scalar_constexpr : bool ;
131+ mutable is_complex : bool ;
121132}
122133[@@ deriving sexp_of ]
123134
@@ -147,6 +158,7 @@ let get_node store tn =
147158 read_before_write = false ;
148159 read_only = false ;
149160 is_scalar_constexpr = false ;
161+ is_complex = false ;
150162 })
151163
152164let visit ~is_assigned old =
@@ -175,6 +187,24 @@ let is_constexpr_comp traced_store llv =
175187 in
176188 loop llv
177189
190+ let is_complex_comp traced_store llv =
191+ let rec loop llv =
192+ match llv with
193+ | Get_local { tn; _ } | Local_scope { id = { tn; _ } ; _ } ->
194+ let traced = get_node traced_store tn in
195+ traced.is_complex
196+ | Access (_ , _ ) -> true
197+ | Get (tn , _ ) ->
198+ let traced = get_node traced_store tn in
199+ not traced.is_scalar_constexpr
200+ | Ternop (_ , v1 , v2 , v3 ) -> loop v1 || loop v2 || loop v3
201+ | Binop (_ , v1 , v2 ) -> loop v1 || loop v2
202+ | Unop (_ , v ) -> loop v
203+ | Constant _ -> false
204+ | Embed_index _ -> false
205+ in
206+ loop llv
207+
178208let is_scalar_dims tn = Array. for_all ~f: (( = ) 1 ) @@ Lazy. force tn.Tn. dims
179209
180210let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
@@ -187,23 +217,27 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
187217 List. fold symbols ~init: offset ~f: (fun acc (coeff , s ) ->
188218 acc + (coeff * (Option. value ~default: 0 @@ Map. find env s))))
189219 in
190- let rec loop_proc env llc =
191- let loop = loop_proc env in
220+ let rec loop_proc ~ first_visit env llc =
221+ let loop = loop_proc ~first_visit env in
192222 match llc with
193223 | Noop -> ()
194224 | (Seq (c1 , c2 ) : t ) ->
195225 loop c1;
196226 loop c2
197227 | For_loop { index; from_; to_ = _ ; body; trace_it = false } ->
198- loop_proc (Map. add_exn ~key: index ~data: from_ env) body
228+ loop_proc ~first_visit (Map. add_exn ~key: index ~data: from_ env) body
199229 | For_loop { index; from_; to_; body; trace_it = true } ->
200230 for data = from_ to min to_ (from_ + virtualize_settings.max_tracing_dim) do
201- loop_proc (Map. add_exn ~key: index ~data env) body
231+ loop_proc
232+ ~first_visit: (first_visit && data = from_)
233+ (Map. add_exn ~key: index ~data env)
234+ body
202235 done
203236 | Zero_out tn ->
204237 let traced : traced_array = get_node traced_store tn in
205238 if Hash_set. is_empty traced.assignments && Hashtbl. is_empty traced.accesses then (
206239 traced.zero_initialized < - true ;
240+ traced.is_complex < - false ;
207241 if is_scalar_dims tn then traced.is_scalar_constexpr < - true );
208242 traced.zeroed_out < - true
209243 | Set { tn; idcs; llv; debug = _ } ->
@@ -215,6 +249,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
215249 then traced.is_scalar_constexpr < - is_constexpr_comp traced_store llv
216250 (* Note: this prevents detection if the same constant is assigned inside a loop. *)
217251 else if not @@ Hash_set. is_empty traced.assignments then traced.is_scalar_constexpr < - false ;
252+ if first_visit then
253+ traced.is_complex < - traced.is_complex || is_complex_comp traced_store llv;
218254 Hash_set. add traced.assignments (lookup env idcs);
219255 Array. iter idcs ~f: (function
220256 | Fixed_idx _ -> ()
@@ -238,7 +274,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
238274 let at_pos = lookup env indices in
239275 Hashtbl. update traced.accesses at_pos
240276 ~f: (visit ~is_assigned: (traced.zeroed_out || Hash_set. mem traced.assignments at_pos))
241- | Local_scope { body; _ } -> loop_proc env body
277+ | Local_scope { body; _ } -> loop_proc ~first_visit: true env body
242278 | Get_local _ -> ()
243279 | Access (Merge_buffer { source } , _ ) ->
244280 let source_node_id = source.Tn. id in
@@ -263,15 +299,21 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
263299 loop llv2
264300 | Unop (_ , llv ) -> loop llv
265301 in
266- loop_proc Indexing. empty_env llc;
302+ loop_proc ~first_visit: true Indexing. empty_env llc;
267303 Hashtbl. iter traced_store ~f: (fun traced ->
268304 let tn = traced.tn in
269305 if
270306 virtualize_settings.inline_scalar_constexprs && traced.is_scalar_constexpr
271307 && not (Tn. known_non_virtual tn)
272308 then Tn. update_memory_mode tn Virtual 40 ;
273- if Option. is_none tn.memory_mode && Hashtbl. exists traced.accesses ~f: is_too_many then
274- Tn. update_memory_mode tn Never_virtual 1 ;
309+ let skip_simple =
310+ virtualize_settings.inline_simple_computations && (not traced.is_complex)
311+ && not (Tn. known_non_virtual tn)
312+ in
313+ if
314+ (not skip_simple) && Option. is_none tn.memory_mode
315+ && Hashtbl. exists traced.accesses ~f: is_too_many
316+ then Tn. update_memory_mode tn Never_virtual 1 ;
275317 if (not traced.zeroed_out) && Hash_set. is_empty traced.assignments then (
276318 (* The tensor node is read-only/recurrent for this computation, but maybe computed or
277319 specified as virtual by another routine. However, if the memory mode is unspecified, we
@@ -287,7 +329,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
287329 before the first routine using it gets compiled; another routine re-uses that \
288330 computation. Debug: %{Tn.debug_memory_mode tn.Tn.memory_mode}" ]))
289331 else if Tn. known_non_virtual tn then Tn. update_memory_mode tn Materialized 35 );
290- if Hashtbl. exists traced.accesses ~f: is_recurrent then (
332+ (* We allow sharing virtual nodes across routines. *)
333+ if Hashtbl. exists traced.accesses ~f: is_recurrent && not (Tn. known_virtual tn) then (
291334 traced.read_before_write < - true ;
292335 if Tn. mode_is_unspecified tn then
293336 Tn. update_memory_mode tn (Hosted (Changed_on_devices Unset )) 38
0 commit comments