Skip to content

Commit fde7983

Browse files
committed
Ternary primitive operations, in progress
1 parent 2bd4336 commit fde7983

File tree

7 files changed

+179
-120
lines changed

7 files changed

+179
-120
lines changed

arrayjit/lib/assignments.ml

Lines changed: 77 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ and t =
2424
| Noop
2525
| Seq of t * t
2626
| Block_comment of string * t (** Same as the given code, with a comment. *)
27+
| Accum_ternop of {
28+
initialize_neutral : bool;
29+
accum : Ops.binop;
30+
op : Ops.ternop;
31+
lhs : Tn.t;
32+
rhs1 : buffer;
33+
rhs2 : buffer;
34+
rhs3 : buffer;
35+
projections : Indexing.projections Lazy.t;
36+
}
2737
| Accum_binop of {
2838
initialize_neutral : bool;
2939
accum : Ops.binop;
@@ -93,6 +103,8 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
93103
| Accum_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs)
94104
| Accum_binop { lhs; rhs1; rhs2; _ } ->
95105
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2 ]
106+
| Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
107+
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2; of_node rhs3 ]
96108
| Fetch { array; _ } -> one array
97109
in
98110
loop asgns
@@ -139,98 +151,60 @@ let%diagn2_sexp to_low_level code =
139151
assert (Array.length idcs = Array.length (Lazy.force tn.Tn.dims));
140152
Low_level.Set { tn; idcs; llv; debug = "" }
141153
in
142-
let rec loop code =
154+
let rec loop_accum ~initialize_neutral ~accum ~op ~lhs ~rhses projections =
155+
let projections = Lazy.force projections in
156+
let lhs_idx =
157+
derive_index ~product_syms:projections.product_iterators ~projection:projections.project_lhs
158+
in
159+
let rhs_idcs =
160+
Array.map projections.project_rhs ~f:(fun projection ->
161+
derive_index ~product_syms:projections.product_iterators ~projection)
162+
in
163+
let basecase rev_iters =
164+
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
165+
let rhses_idcs = Array.map rhs_idcs ~f:(fun rhs_idx -> rhs_idx ~product) in
166+
let lhs_idcs = lhs_idx ~product in
167+
let open Low_level in
168+
let lhs_ll = get (Node lhs) lhs_idcs in
169+
let rhses_ll = Array.mapi rhses_idcs ~f:(fun i rhs_idcs -> get rhses.(i) rhs_idcs) in
170+
let rhs2 = apply_op op rhses_ll in
171+
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
172+
else set lhs lhs_idcs @@ apply_op (Ops.Binop accum) [| lhs_ll; rhs2 |]
173+
in
174+
let rec for_loop rev_iters = function
175+
| [] -> basecase rev_iters
176+
| d :: product ->
177+
let index = Indexing.get_symbol () in
178+
For_loop
179+
{
180+
index;
181+
from_ = 0;
182+
to_ = d - 1;
183+
body = for_loop (index :: rev_iters) product;
184+
trace_it = true;
185+
}
186+
in
187+
let for_loops =
188+
try for_loop [] (Array.to_list projections.product_space)
189+
with e ->
190+
[%log "projections=", (projections : projections)];
191+
raise e
192+
in
193+
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
194+
let dims = lazy projections.lhs_dims in
195+
let fetch_op = Constant (Ops.neutral_elem accum) in
196+
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
197+
else for_loops
198+
and loop code =
143199
match code with
200+
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
201+
loop_accum ~initialize_neutral ~accum ~op:(Ops.Ternop op) ~lhs ~rhses:[| rhs1; rhs2; rhs3 |]
202+
projections
144203
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
145-
let projections = Lazy.force projections in
146-
let lhs_idx =
147-
derive_index ~product_syms:projections.product_iterators
148-
~projection:projections.project_lhs
149-
in
150-
let rhs1_idx =
151-
derive_index ~product_syms:projections.product_iterators
152-
~projection:projections.project_rhs.(0)
153-
in
154-
let rhs2_idx =
155-
derive_index ~product_syms:projections.product_iterators
156-
~projection:projections.project_rhs.(1)
157-
in
158-
let basecase rev_iters =
159-
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
160-
let rhs1_idcs = rhs1_idx ~product in
161-
let rhs2_idcs = rhs2_idx ~product in
162-
let lhs_idcs = lhs_idx ~product in
163-
let open Low_level in
164-
let lhs_ll = get (Node lhs) lhs_idcs in
165-
let rhs1_ll = get rhs1 rhs1_idcs in
166-
let rhs2_ll = get rhs2 rhs2_idcs in
167-
let rhs2 = binop ~op ~rhs1:rhs1_ll ~rhs2:rhs2_ll in
168-
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
169-
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
170-
in
171-
let rec for_loop rev_iters = function
172-
| [] -> basecase rev_iters
173-
| d :: product ->
174-
let index = Indexing.get_symbol () in
175-
For_loop
176-
{
177-
index;
178-
from_ = 0;
179-
to_ = d - 1;
180-
body = for_loop (index :: rev_iters) product;
181-
trace_it = true;
182-
}
183-
in
184-
let for_loops =
185-
try for_loop [] (Array.to_list projections.product_space)
186-
with e ->
187-
[%log "projections=", (projections : projections)];
188-
raise e
189-
in
190-
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
191-
let dims = lazy projections.lhs_dims in
192-
let fetch_op = Constant (Ops.neutral_elem accum) in
193-
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
194-
else for_loops
204+
loop_accum ~initialize_neutral ~accum ~op:(Ops.Binop op) ~lhs ~rhses:[| rhs1; rhs2 |]
205+
projections
195206
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
196-
let projections = Lazy.force projections in
197-
let lhs_idx =
198-
derive_index ~product_syms:projections.product_iterators
199-
~projection:projections.project_lhs
200-
in
201-
let rhs_idx =
202-
derive_index ~product_syms:projections.product_iterators
203-
~projection:projections.project_rhs.(0)
204-
in
205-
let basecase rev_iters =
206-
let product = Array.of_list_rev_map rev_iters ~f:(fun s -> Indexing.Iterator s) in
207-
let lhs_idcs = lhs_idx ~product in
208-
let open Low_level in
209-
let lhs_ll = get (Node lhs) lhs_idcs in
210-
let rhs_ll = get rhs @@ rhs_idx ~product in
211-
let rhs2 = unop ~op ~rhs:rhs_ll in
212-
if is_total ~initialize_neutral ~projections then set lhs lhs_idcs rhs2
213-
else set lhs lhs_idcs @@ binop ~op:accum ~rhs1:lhs_ll ~rhs2
214-
in
215-
let rec for_loop rev_iters = function
216-
| [] -> basecase rev_iters
217-
| d :: product ->
218-
let index = Indexing.get_symbol () in
219-
For_loop
220-
{
221-
index;
222-
from_ = 0;
223-
to_ = d - 1;
224-
body = for_loop (index :: rev_iters) product;
225-
trace_it = true;
226-
}
227-
in
228-
let for_loops = for_loop [] (Array.to_list projections.product_space) in
229-
if initialize_neutral && not (is_total ~initialize_neutral ~projections) then
230-
let dims = lazy projections.lhs_dims in
231-
let fetch_op = Constant (Ops.neutral_elem accum) in
232-
Low_level.Seq (loop (Fetch { array = lhs; fetch_op; dims }), for_loops)
233-
else for_loops
207+
loop_accum ~initialize_neutral ~accum ~op:(Ops.Unop op) ~lhs ~rhses:[| rhs |] projections
234208
| Noop -> Low_level.Noop
235209
| Block_comment (s, c) -> Low_level.unflat_lines [ Comment s; loop c; Comment "end" ]
236210
| Seq (c1, c2) ->
@@ -251,15 +225,14 @@ let%diagn2_sexp to_low_level code =
251225
Low_level.loop_over_dims (Lazy.force dims) ~body:(fun idcs ->
252226
set array idcs @@ Get_global (global, Some idcs))
253227
in
254-
255228
loop code
256229

257230
let flatten c =
258231
let rec loop = function
259232
| Noop -> []
260233
| Seq (c1, c2) -> loop c1 @ loop c2
261234
| Block_comment (s, c) -> Block_comment (s, Noop) :: loop c
262-
| (Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ]
235+
| (Accum_ternop _ | Accum_binop _ | Accum_unop _ | Fetch _) as c -> [ c ]
263236
in
264237
loop c
265238

@@ -286,6 +259,9 @@ let get_ident_within_code ?no_dots c =
286259
loop c1;
287260
loop c2
288261
| Block_comment (_, c) -> loop c
262+
| Accum_ternop
263+
{ initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; rhs3; projections = _ } ->
264+
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2; tn rhs3 ]
289265
| Accum_binop { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; projections = _ } ->
290266
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2 ]
291267
| Accum_unop { initialize_neutral = _; accum = _; op = _; lhs; rhs; projections = _ } ->
@@ -331,6 +307,16 @@ let fprint_hum ?name ?static_indices () ppf c =
331307
| Block_comment (s, c) ->
332308
fprintf ppf "# \"%s\";@ " s;
333309
loop c
310+
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
311+
let proj_spec =
312+
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
313+
else "<not-in-yet>"
314+
in
315+
(* Uncurried syntax for ternary operations. *)
316+
fprintf ppf "%s %s %s(%s, %s, %s)%s;@ " (ident lhs)
317+
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
318+
(Ops.ternop_cd_syntax op) (buffer_ident rhs1) (buffer_ident rhs2) (buffer_ident rhs3)
319+
(if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "")
334320
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
335321
let proj_spec =
336322
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec

arrayjit/lib/c_syntax.ml

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ module C_syntax (B : sig
2323
val kernel_prep_line : string
2424
val includes : string list
2525
val typ_of_prec : Ops.prec -> string
26+
val ternop_syntax : Ops.prec -> Ops.ternop -> string * string * string * string
2627
val binop_syntax : Ops.prec -> Ops.binop -> string * string * string
2728
val unop_syntax : Ops.prec -> Ops.unop -> string * string
2829
val convert_precision : from:Ops.prec -> to_:Ops.prec -> string * string
@@ -166,6 +167,7 @@ struct
166167
| Get_local _ | Get_global _ | Get _ | Constant _ | Embed_index _ -> 0
167168
| Binop (Arg1, v1, _v2) -> pp_top_locals ppf v1
168169
| Binop (Arg2, _v1, v2) -> pp_top_locals ppf v2
170+
| Ternop (_, v1, v2, v3) -> pp_top_locals ppf v1 + pp_top_locals ppf v2 + pp_top_locals ppf v3
169171
| Binop (_, v1, v2) -> pp_top_locals ppf v1 + pp_top_locals ppf v2
170172
| Unop (_, v) -> pp_top_locals ppf v
171173
and pp_float (prec : Ops.prec) ppf value =
@@ -203,6 +205,10 @@ struct
203205
fprintf ppf "%s%a%s" prefix pp_axis_index idx postfix
204206
| Binop (Arg1, v1, _v2) -> loop ppf v1
205207
| Binop (Arg2, _v1, v2) -> loop ppf v2
208+
| Ternop (op, v1, v2, v3) ->
209+
let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in
210+
fprintf ppf "@[<1>%s%a%s@ %a%s@ %a@]%s" prefix loop v1 comma1 loop v2 comma2 loop v3
211+
postfix
206212
| Binop (op, v1, v2) ->
207213
let prefix, infix, postfix = B.binop_syntax prec op in
208214
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix loop v1 infix loop v2 postfix
@@ -250,6 +256,13 @@ struct
250256
| Embed_index (Iterator s) -> (Indexing.symbol_ident s, [])
251257
| Binop (Arg1, v1, _v2) -> loop v1
252258
| Binop (Arg2, _v1, v2) -> loop v2
259+
| Ternop (op, v1, v2, v3) ->
260+
let prefix, comma1, comma2, postfix = B.ternop_syntax prec op in
261+
let v1, idcs1 = loop v1 in
262+
let v2, idcs2 = loop v2 in
263+
let v3, idcs3 = loop v3 in
264+
( String.concat [ prefix; v1; comma1; " "; v2; comma2; " "; v3; postfix ],
265+
idcs1 @ idcs2 @ idcs3 )
253266
| Binop (op, v1, v2) ->
254267
let prefix, infix, postfix = B.binop_syntax prec op in
255268
let v1, idcs1 = loop v1 in

arrayjit/lib/cc_backend.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ struct
8484
let kernel_prep_line = ""
8585
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
8686
let typ_of_prec = Ops.c_typ_of_prec
87+
let ternop_syntax = Ops.ternop_c_syntax
8788
let binop_syntax = Ops.binop_c_syntax
8889
let unop_syntax = Ops.unop_c_syntax
8990
let convert_precision = Ops.c_convert_precision

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,18 @@ let debug_log_index ctx log_functions =
209209
Block.eval block @@ RValue.call ctx ff [ lf ]
210210
| _ -> fun _block _i _index -> ()
211211

212+
let assign_op_c_syntax = function
213+
| Ops.Arg1 -> invalid_arg "Gcc_backend.assign_op_c_syntax: Arg1 is not a C assignment operator"
214+
| Arg2 -> "="
215+
| Add -> "+="
216+
| Sub -> "-="
217+
| Mul -> "*="
218+
| Div -> "/="
219+
| Mod -> "%="
220+
(* | Shl -> "<<=" *)
221+
(* | Shr -> ">>=" *)
222+
| _ -> invalid_arg "Gcc_backend.assign_op_c_syntax: not a C assignment operator"
223+
212224
let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node; _ } func
213225
initial_block (body : Low_level.t) =
214226
let open Gccjit in
@@ -335,7 +347,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
335347
@@ lf
336348
:: RValue.string_literal ctx
337349
[%string
338-
{|%{node_debug_name get_ident node}[%d]{=%g} %{Ops.assign_op_c_syntax accum_op} %g = %{v_format}
350+
{|%{node_debug_name get_ident node}[%d]{=%g} %{assign_op_c_syntax accum_op} %g = %{v_format}
339351
|}]
340352
:: (to_d @@ RValue.lvalue @@ LValue.access_array (Lazy.force node.ptr) offset)
341353
:: offset :: to_d value :: v_fillers;

0 commit comments

Comments
 (0)