@@ -220,6 +220,9 @@ let is_complex_comp traced_store llsc =
220220let is_scalar_dims tn = Array. for_all ~f: (( = ) 1 ) @@ Lazy. force tn.Tn. dims
221221
222222let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
223+ let inline_complex_computations =
224+ Utils. get_global_flag ~default: true ~arg_name: " inline_complex_computations"
225+ in
223226 let is_too_many = function Visits i -> i > max_visits | Recurrent -> true in
224227 (* FIXME: migrate hashtable to use offsets instead of indices *)
225228 let lookup env indices =
@@ -256,7 +259,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
256259 if is_scalar_dims tn then traced.is_scalar_constexpr < - true );
257260 traced.zeroed_out < - true
258261 | Set { tn; idcs; llsc; debug = _ } ->
259- loop_scalar env llsc;
262+ loop_scalar env ( Some (lookup env idcs)) llsc;
260263 let traced : traced_array = get_node traced_store tn in
261264 if
262265 Hash_set. is_empty traced.assignments
@@ -280,7 +283,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
280283 let old_tn = Hashtbl. find_or_add reverse_node_map s ~default: (fun () -> tn) in
281284 assert (Tn. equal old_tn tn)))
282285 | Set_from_vec { tn; idcs; length; vec_unop = _ ; arg; debug = _ } ->
283- loop_scalar env arg;
286+ loop_scalar env ( Some (lookup env idcs)) arg;
284287 let traced : traced_array = get_node traced_store tn in
285288 (* Vector operations cannot be scalar constexpr *)
286289 traced.is_scalar_constexpr < - false ;
@@ -323,18 +326,23 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
323326 List. iter symbols ~f: (fun (_ , s ) ->
324327 let old_tn = Hashtbl. find_or_add reverse_node_map s ~default: (fun () -> tn) in
325328 assert (Tn. equal old_tn tn)))
326- | Set_local (_ , llsc ) -> loop_scalar env llsc
329+ | Set_local (_ , llsc ) -> loop_scalar env None llsc
327330 | Comment _ -> ()
328331 | Staged_compilation _ -> ()
329- and loop_scalar env llsc =
330- let loop = loop_scalar env in
332+ and loop_scalar env ( access_pos : int array option ) llsc =
333+ let loop = loop_scalar env access_pos in
331334 match llsc with
332335 | Constant _ -> ()
333336 | Get (ptr , indices ) ->
334337 let traced : traced_array = get_node traced_store ptr in
335338 let at_pos = lookup env indices in
336- Hashtbl. update traced.accesses at_pos
337- ~f: (visit ~is_assigned: (traced.zeroed_out || Hash_set. mem traced.assignments at_pos))
339+ if
340+ (not inline_complex_computations)
341+ || Option. value_map access_pos ~default: true ~f: (fun pos ->
342+ not ([% equal: int array ] pos at_pos))
343+ then
344+ Hashtbl. update traced.accesses at_pos
345+ ~f: (visit ~is_assigned: (traced.zeroed_out || Hash_set. mem traced.assignments at_pos))
338346 | Local_scope { body; _ } -> loop_proc ~first_visit: true env body
339347 | Get_local _ -> ()
340348 | Get_merge_buffer (source , _ ) ->
0 commit comments