@@ -87,6 +87,7 @@ type virtualize_settings = {
8787 mutable max_tracing_dim : int ;
8888 mutable inline_scalar_constexprs : bool ;
8989 mutable inline_simple_computations : bool ;
90+ mutable inline_complex_computations : bool ;
9091}
9192
9293let virtualize_settings =
@@ -103,12 +104,17 @@ let virtualize_settings =
103104 let inline_simple_computations =
104105 Utils. get_global_flag ~default: true ~arg_name: " inline_simple_computations"
105106 in
107+ let inline_complex_computations =
108+ (* TODO(#351): change to true once CSE is implemented *)
109+ Utils. get_global_flag ~default: false ~arg_name: " inline_complex_computations"
110+ in
106111 {
107112 enable_device_only;
108113 max_visits;
109114 max_tracing_dim;
110115 inline_scalar_constexprs;
111116 inline_simple_computations;
117+ inline_complex_computations;
112118 }
113119
114120type visits = Visits of int | Recurrent [@@ deriving sexp , equal , variants ]
@@ -224,9 +230,6 @@ let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
224230
225231let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
226232 (* FIXME(#351): avoid excessive inlining while CSE is not implemented *)
227- let inline_complex_computations =
228- Utils. get_global_flag ~default: false ~arg_name: " inline_complex_computations"
229- in
230233 let is_too_many = function Visits i -> i > max_visits | Recurrent -> true in
231234 (* FIXME: migrate hashtable to use offsets instead of indices *)
232235 let lookup env indices =
@@ -341,7 +344,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
341344 let traced : traced_array = get_node traced_store ptr in
342345 let at_pos = lookup env indices in
343346 if
344- (not inline_complex_computations)
347+ (not virtualize_settings. inline_complex_computations)
345348 || Option. value_map access_pos ~default: true ~f: (fun pos ->
346349 not ([% equal: int array ] pos at_pos))
347350 then
0 commit comments