Skip to content

Commit 4f4f859

Browse files
committed
Better configurability for inline_complex_computations
1 parent 77773e8 commit 4f4f859

File tree

2 files changed

+8
-4
lines changed

2 files changed

+8
-4
lines changed

arrayjit/lib/low_level.ml

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ type virtualize_settings = {
8787
mutable max_tracing_dim : int;
8888
mutable inline_scalar_constexprs : bool;
8989
mutable inline_simple_computations : bool;
90+
mutable inline_complex_computations : bool;
9091
}
9192

9293
let virtualize_settings =
@@ -103,12 +104,17 @@ let virtualize_settings =
103104
let inline_simple_computations =
104105
Utils.get_global_flag ~default:true ~arg_name:"inline_simple_computations"
105106
in
107+
let inline_complex_computations =
108+
(* TODO(#351): change to true once CSE is implemented *)
109+
Utils.get_global_flag ~default:false ~arg_name:"inline_complex_computations"
110+
in
106111
{
107112
enable_device_only;
108113
max_visits;
109114
max_tracing_dim;
110115
inline_scalar_constexprs;
111116
inline_simple_computations;
117+
inline_complex_computations;
112118
}
113119

114120
type visits = Visits of int | Recurrent [@@deriving sexp, equal, variants]
@@ -224,9 +230,6 @@ let is_scalar_dims tn = Array.for_all ~f:(( = ) 1) @@ Lazy.force tn.Tn.dims
224230

225231
let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
226232
(* FIXME(#351): avoid excessive inlining while CSE is not implemented *)
227-
let inline_complex_computations =
228-
Utils.get_global_flag ~default:false ~arg_name:"inline_complex_computations"
229-
in
230233
let is_too_many = function Visits i -> i > max_visits | Recurrent -> true in
231234
(* FIXME: migrate hashtable to use offsets instead of indices *)
232235
let lookup env indices =
@@ -341,7 +344,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
341344
let traced : traced_array = get_node traced_store ptr in
342345
let at_pos = lookup env indices in
343346
if
344-
(not inline_complex_computations)
347+
(not virtualize_settings.inline_complex_computations)
345348
|| Option.value_map access_pos ~default:true ~f:(fun pos ->
346349
not ([%equal: int array] pos at_pos))
347350
then

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ type virtualize_settings = {
6868
mutable max_tracing_dim : int;
6969
mutable inline_scalar_constexprs : bool;
7070
mutable inline_simple_computations : bool;
71+
mutable inline_complex_computations : bool;
7172
}
7273

7374
val virtualize_settings : virtualize_settings

0 commit comments

Comments
 (0)