@@ -24,6 +24,16 @@ and t =
2424 | Noop
2525 | Seq of t * t
2626 | Block_comment of string * t (* * Same as the given code, with a comment. *)
27+ | Accum_ternop of {
28+ initialize_neutral : bool ;
29+ accum : Ops .binop ;
30+ op : Ops .ternop ;
31+ lhs : Tn .t ;
32+ rhs1 : buffer ;
33+ rhs2 : buffer ;
34+ rhs3 : buffer ;
35+ projections : Indexing .projections Lazy .t ;
36+ }
2737 | Accum_binop of {
2838 initialize_neutral : bool ;
2939 accum : Ops .binop ;
@@ -93,6 +103,8 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
93103 | Accum_unop { lhs; rhs; _ } -> Set. union (one lhs) (of_node rhs)
94104 | Accum_binop { lhs; rhs1; rhs2; _ } ->
95105 Set. union_list (module Tn ) [ one lhs; of_node rhs1; of_node rhs2 ]
106+ | Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
107+ Set. union_list (module Tn ) [ one lhs; of_node rhs1; of_node rhs2; of_node rhs3 ]
96108 | Fetch { array; _ } -> one array
97109 in
98110 loop asgns
@@ -139,98 +151,60 @@ let%diagn2_sexp to_low_level code =
139151 assert (Array. length idcs = Array. length (Lazy. force tn.Tn. dims));
140152 Low_level. Set { tn; idcs; llv; debug = " " }
141153 in
142- let rec loop code =
154+ let rec loop_accum ~initialize_neutral ~accum ~op ~lhs ~rhses projections =
155+ let projections = Lazy. force projections in
156+ let lhs_idx =
157+ derive_index ~product_syms: projections.product_iterators ~projection: projections.project_lhs
158+ in
159+ let rhs_idcs =
160+ Array. map projections.project_rhs ~f: (fun projection ->
161+ derive_index ~product_syms: projections.product_iterators ~projection )
162+ in
163+ let basecase rev_iters =
164+ let product = Array. of_list_rev_map rev_iters ~f: (fun s -> Indexing. Iterator s) in
165+ let rhses_idcs = Array. map rhs_idcs ~f: (fun rhs_idx -> rhs_idx ~product ) in
166+ let lhs_idcs = lhs_idx ~product in
167+ let open Low_level in
168+ let lhs_ll = get (Node lhs) lhs_idcs in
169+ let rhses_ll = Array. mapi rhses_idcs ~f: (fun i rhs_idcs -> get rhses.(i) rhs_idcs) in
170+ let rhs2 = apply_op op rhses_ll in
171+ if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
172+ else set lhs lhs_idcs @@ apply_op (Ops. Binop accum) [| lhs_ll; rhs2 |]
173+ in
174+ let rec for_loop rev_iters = function
175+ | [] -> basecase rev_iters
176+ | d :: product ->
177+ let index = Indexing. get_symbol () in
178+ For_loop
179+ {
180+ index;
181+ from_ = 0 ;
182+ to_ = d - 1 ;
183+ body = for_loop (index :: rev_iters) product;
184+ trace_it = true ;
185+ }
186+ in
187+ let for_loops =
188+ try for_loop [] (Array. to_list projections.product_space)
189+ with e ->
190+ [% log " projections=" , (projections : projections )];
191+ raise e
192+ in
193+ if initialize_neutral && not (is_total ~initialize_neutral ~projections ) then
194+ let dims = lazy projections.lhs_dims in
195+ let fetch_op = Constant (Ops. neutral_elem accum) in
196+ Low_level. Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
197+ else for_loops
198+ and loop code =
143199 match code with
200+ | Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
201+ loop_accum ~initialize_neutral ~accum ~op: (Ops. Ternop op) ~lhs ~rhses: [| rhs1; rhs2; rhs3 |]
202+ projections
144203 | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
145- let projections = Lazy. force projections in
146- let lhs_idx =
147- derive_index ~product_syms: projections.product_iterators
148- ~projection: projections.project_lhs
149- in
150- let rhs1_idx =
151- derive_index ~product_syms: projections.product_iterators
152- ~projection: projections.project_rhs.(0 )
153- in
154- let rhs2_idx =
155- derive_index ~product_syms: projections.product_iterators
156- ~projection: projections.project_rhs.(1 )
157- in
158- let basecase rev_iters =
159- let product = Array. of_list_rev_map rev_iters ~f: (fun s -> Indexing. Iterator s) in
160- let rhs1_idcs = rhs1_idx ~product in
161- let rhs2_idcs = rhs2_idx ~product in
162- let lhs_idcs = lhs_idx ~product in
163- let open Low_level in
164- let lhs_ll = get (Node lhs) lhs_idcs in
165- let rhs1_ll = get rhs1 rhs1_idcs in
166- let rhs2_ll = get rhs2 rhs2_idcs in
167- let rhs2 = binop ~op ~rhs1: rhs1_ll ~rhs2: rhs2_ll in
168- if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
169- else set lhs lhs_idcs @@ binop ~op: accum ~rhs1: lhs_ll ~rhs2
170- in
171- let rec for_loop rev_iters = function
172- | [] -> basecase rev_iters
173- | d :: product ->
174- let index = Indexing. get_symbol () in
175- For_loop
176- {
177- index;
178- from_ = 0 ;
179- to_ = d - 1 ;
180- body = for_loop (index :: rev_iters) product;
181- trace_it = true ;
182- }
183- in
184- let for_loops =
185- try for_loop [] (Array. to_list projections.product_space)
186- with e ->
187- [% log " projections=" , (projections : projections )];
188- raise e
189- in
190- if initialize_neutral && not (is_total ~initialize_neutral ~projections ) then
191- let dims = lazy projections.lhs_dims in
192- let fetch_op = Constant (Ops. neutral_elem accum) in
193- Low_level. Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
194- else for_loops
204+ loop_accum ~initialize_neutral ~accum ~op: (Ops. Binop op) ~lhs ~rhses: [| rhs1; rhs2 |]
205+ projections
195206 | Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
196- let projections = Lazy. force projections in
197- let lhs_idx =
198- derive_index ~product_syms: projections.product_iterators
199- ~projection: projections.project_lhs
200- in
201- let rhs_idx =
202- derive_index ~product_syms: projections.product_iterators
203- ~projection: projections.project_rhs.(0 )
204- in
205- let basecase rev_iters =
206- let product = Array. of_list_rev_map rev_iters ~f: (fun s -> Indexing. Iterator s) in
207- let lhs_idcs = lhs_idx ~product in
208- let open Low_level in
209- let lhs_ll = get (Node lhs) lhs_idcs in
210- let rhs_ll = get rhs @@ rhs_idx ~product in
211- let rhs2 = unop ~op ~rhs: rhs_ll in
212- if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
213- else set lhs lhs_idcs @@ binop ~op: accum ~rhs1: lhs_ll ~rhs2
214- in
215- let rec for_loop rev_iters = function
216- | [] -> basecase rev_iters
217- | d :: product ->
218- let index = Indexing. get_symbol () in
219- For_loop
220- {
221- index;
222- from_ = 0 ;
223- to_ = d - 1 ;
224- body = for_loop (index :: rev_iters) product;
225- trace_it = true ;
226- }
227- in
228- let for_loops = for_loop [] (Array. to_list projections.product_space) in
229- if initialize_neutral && not (is_total ~initialize_neutral ~projections ) then
230- let dims = lazy projections.lhs_dims in
231- let fetch_op = Constant (Ops. neutral_elem accum) in
232- Low_level. Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
233- else for_loops
207+ loop_accum ~initialize_neutral ~accum ~op: (Ops. Unop op) ~lhs ~rhses: [| rhs |] projections
234208 | Noop -> Low_level. Noop
235209 | Block_comment (s , c ) -> Low_level. unflat_lines [ Comment s; loop c; Comment " end" ]
236210 | Seq (c1 , c2 ) ->
@@ -251,15 +225,14 @@ let%diagn2_sexp to_low_level code =
251225 Low_level. loop_over_dims (Lazy. force dims) ~body: (fun idcs ->
252226 set array idcs @@ Get_global (global, Some idcs))
253227 in
254-
255228 loop code
256229
257230let flatten c =
258231 let rec loop = function
259232 | Noop -> []
260233 | Seq (c1 , c2 ) -> loop c1 @ loop c2
261234 | Block_comment (s , c ) -> Block_comment (s, Noop ) :: loop c
262- | (Accum_binop _ | Accum_unop _ | Fetch _ ) as c -> [ c ]
235+ | (Accum_ternop _ | Accum_binop _ | Accum_unop _ | Fetch _ ) as c -> [ c ]
263236 in
264237 loop c
265238
@@ -286,6 +259,9 @@ let get_ident_within_code ?no_dots c =
286259 loop c1;
287260 loop c2
288261 | Block_comment (_ , c ) -> loop c
262+ | Accum_ternop
263+ { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; rhs3; projections = _ } ->
264+ List. iter ~f: visit [ lhs; tn rhs1; tn rhs2; tn rhs3 ]
289265 | Accum_binop { initialize_neutral = _ ; accum = _ ; op = _ ; lhs; rhs1; rhs2; projections = _ } ->
290266 List. iter ~f: visit [ lhs; tn rhs1; tn rhs2 ]
291267 | Accum_unop { initialize_neutral = _ ; accum = _ ; op = _ ; lhs; rhs; projections = _ } ->
@@ -331,6 +307,16 @@ let fprint_hum ?name ?static_indices () ppf c =
331307 | Block_comment (s , c ) ->
332308 fprintf ppf " # \" %s\" ;@ " s;
333309 loop c
310+ | Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
311+ let proj_spec =
312+ if Lazy. is_val projections then (Lazy. force projections).debug_info.spec
313+ else " <not-in-yet>"
314+ in
315+ (* Uncurried syntax for ternary operations. *)
316+ fprintf ppf " %s %s %s(%s, %s, %s)%s;@ " (ident lhs)
317+ (Ops. assign_op_cd_syntax ~initialize_neutral accum)
318+ (Ops. ternop_cd_syntax op) (buffer_ident rhs1) (buffer_ident rhs2) (buffer_ident rhs3)
319+ (if not (String. equal proj_spec " ." ) then " ~logic:\" " ^ proj_spec ^ " \" " else " " )
334320 | Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
335321 let proj_spec =
336322 if Lazy. is_val projections then (Lazy. force projections).debug_info.spec
0 commit comments