Skip to content

Commit b4fa5c3

Browse files
committed
Untested: missing new primitive ops for optional backends CUDA, GCC
1 parent ae845d0 commit b4fa5c3

File tree

2 files changed

+156
-8
lines changed

2 files changed

+156
-8
lines changed

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 69 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,10 +309,25 @@ module Fresh () = struct
309309
| Satur01_gate, Byte_prec _ -> ("(abs(", ") > 0 ? 0 : (", ")")
310310
| Satur01_gate, Half_prec _ ->
311311
( "(__hgt(__habs(htrunc(",
312-
")), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned short)0x0000U) : (",
312+
")), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned \
313+
short)0x0000U) : (",
313314
"))" )
314315
| Satur01_gate, Double_prec _ -> ("(fabs(trunc(", ")) > 0.0 ? 0.0 : (", "))")
315316
| Satur01_gate, Single_prec _ -> ("(fabsf(truncf(", ")) > 0.0 ? 0.0 : (", "))")
317+
| Max, Byte_prec _ -> ("max(", ", ", ")")
318+
| Max, Half_prec _ -> ("__hmax(", ", ", ")")
319+
| Max, Double_prec _ -> ("fmax(", ", ", ")")
320+
| Max, Single_prec _ -> ("fmaxf(", ", ", ")")
321+
| Min, Byte_prec _ -> ("min(", ", ", ")")
322+
| Min, Half_prec _ -> ("__hmin(", ", ", ")")
323+
| Min, Double_prec _ -> ("fmin(", ", ", ")")
324+
| Min, Single_prec _ -> ("fminf(", ", ", ")")
325+
| Mod, Byte_prec _ -> ("(", " % ", ")")
326+
| Mod, _ -> ("fmod(", ", ", ")")
327+
| Cmplt, _ -> ("(", " < ", ")")
328+
| Cmpeq, _ -> ("(", " == ", ")")
329+
| Or, _ -> ("(", " || ", ")")
330+
| And, _ -> ("(", " && ", ")")
316331

317332
let unop_syntax prec v =
318333
match (v, prec) with
@@ -321,6 +336,59 @@ module Fresh () = struct
321336
| Relu, Ops.Half_prec _ -> ("__hmax_nan(__ushort_as_half((unsigned short)0x0000U), ", ")")
322337
| Relu, Ops.Byte_prec _ -> ("fmax(0, ", ")")
323338
| Relu, _ -> ("fmax(0.0, ", ")")
339+
| Satur01, Byte_prec _ -> ("fmax(0, fmin(1, ", "))")
340+
| Satur01, Half_prec _ ->
341+
( "__hmax_nan(__ushort_as_half((unsigned short)0x0000U), \
342+
__hmin_nan(__ushort_as_half((unsigned short)0x3C00U), ",
343+
"))" )
344+
| Satur01, Single_prec _ -> ("fmaxf(0.0f, fminf(1.0f, ", "))")
345+
| Satur01, _ -> ("fmax(0.0, fmin(1.0, ", "))")
346+
| Exp, Half_prec _ -> ("hexp(", ")")
347+
| Exp, Double_prec _ -> ("exp(", ")")
348+
| Exp, _ -> ("expf(", ")")
349+
| Log, Half_prec _ -> ("hlog(", ")")
350+
| Log, Double_prec _ -> ("log(", ")")
351+
| Log, _ -> ("logf(", ")")
352+
| Exp2, Half_prec _ -> ("hexp2(", ")")
353+
| Exp2, Double_prec _ -> ("exp2(", ")")
354+
| Exp2, _ -> ("exp2f(", ")")
355+
| Log2, Half_prec _ -> ("hlog2(", ")")
356+
| Log2, Double_prec _ -> ("log2(", ")")
357+
| Log2, _ -> ("log2f(", ")")
358+
| Sin, Half_prec _ -> ("hsin(", ")")
359+
| Sin, Double_prec _ -> ("sin(", ")")
360+
| Sin, _ -> ("sinf(", ")")
361+
| Cos, Half_prec _ -> ("hcos(", ")")
362+
| Cos, Double_prec _ -> ("cos(", ")")
363+
| Cos, _ -> ("cosf(", ")")
364+
| Sqrt, Half_prec _ -> ("hsqrt(", ")")
365+
| Sqrt, Double_prec _ -> ("sqrt(", ")")
366+
| Sqrt, _ -> ("sqrtf(", ")")
367+
| Recip, Byte_prec _ ->
368+
invalid_arg "Cuda_backend.unop_syntax: Recip not supported for byte/integer precisions"
369+
| Recip, Half_prec _ -> ("hrcp(", ")")
370+
| Recip, _ -> ("(1.0 / (", "))")
371+
| Recip_sqrt, Byte_prec _ ->
372+
invalid_arg
373+
"Cuda_backend.unop_syntax: Recip_sqrt not supported for byte/integer precisions"
374+
| Recip_sqrt, Half_prec _ -> ("hrsqrt(", ")")
375+
| Recip_sqrt, Double_prec _ -> ("(1.0 / sqrt(", "))")
376+
| Recip_sqrt, _ -> ("(1.0 / sqrtf(", "))")
377+
| Neg, _ -> ("(-(", "))")
378+
| Tanh_approx, Byte_prec _ ->
379+
invalid_arg
380+
"Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
381+
| Tanh_approx, Half_prec _ -> ("htanh_approx(", ")")
382+
| Tanh_approx, Single_prec _ -> ("__tanhf(", ")")
383+
| Tanh_approx, _ -> ("tanh(", ")")
384+
| Not, _ -> ("(", " == 0.0 ? 1.0 : 0.0)")
385+
386+
let ternop_syntax prec v =
387+
match (v, prec) with
388+
| Ops.Where, _ -> ("(", " ? ", " : ", ")")
389+
| FMA, Ops.Half_prec _ -> ("__hfma(", ", ", ", ", ")")
390+
| FMA, Ops.Single_prec _ -> ("fmaf(", ", ", ", ", ")")
391+
| FMA, _ -> ("fma(", ", ", ", ", ")")
324392

325393
let convert_precision ~from ~to_ =
326394
match (from, to_) with

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 87 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -166,14 +166,15 @@ let prec_to_kind prec =
166166

167167
let is_builtin_op = function
168168
| Ops.Add | Sub | Mul | Div -> true
169-
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 -> false
169+
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 | Max | Min | Mod | Cmplt | Cmpeq | Or | And ->
170+
false
170171

171172
let builtin_op = function
172173
| Ops.Add -> Gccjit.Plus
173174
| Sub -> Gccjit.Minus
174175
| Mul -> Gccjit.Mult
175176
| Div -> Gccjit.Divide
176-
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 ->
177+
| ToPowOf | Relu_gate | Satur01_gate | Arg2 | Arg1 | Max | Min | Mod | Cmplt | Cmpeq | Or | And ->
177178
invalid_arg "Exec_as_gccjit.builtin_op: not a builtin"
178179

179180
let node_debug_name get_ident node = get_ident node.tn
@@ -278,13 +279,34 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
278279
| Satur01_gate, _ ->
279280
let cmp =
280281
cast_bool num_typ
281-
@@ RValue.binary_op ctx And
282+
@@ RValue.binary_op ctx Logical_and (Type.get ctx Type.Bool)
282283
(RValue.comparison ctx Lt (RValue.zero ctx num_typ) v1)
283284
(RValue.comparison ctx Lt v1 (RValue.one ctx num_typ))
284285
in
285286
RValue.binary_op ctx Mult num_typ cmp @@ v2
286287
| Arg2, _ -> v2
287288
| Arg1, _ -> v1
289+
| Max, Double_prec _ ->
290+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fmax") [ to_d v1; to_d v2 ]) num_typ
291+
| Max, Single_prec _ ->
292+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fmaxf") [ v1; v2 ]) num_typ
293+
| Max, Half_prec _ ->
294+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fmaxf") [ v1; v2 ]) num_typ
295+
| Max, Byte_prec _ ->
296+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "max") [ v1; v2 ]) num_typ
297+
| Min, Double_prec _ ->
298+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fmin") [ to_d v1; to_d v2 ]) num_typ
299+
| Min, Single_prec _ ->
300+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fminf") [ v1; v2 ]) num_typ
301+
| Min, Half_prec _ ->
302+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "fminf") [ v1; v2 ]) num_typ
303+
| Min, Byte_prec _ ->
304+
RValue.cast ctx (RValue.call ctx (Function.builtin ctx "min") [ v1; v2 ]) num_typ
305+
| Mod, _ -> RValue.binary_op ctx Modulo num_typ v1 v2
306+
| Cmplt, _ -> RValue.comparison ctx Lt v1 v2
307+
| Cmpeq, _ -> RValue.comparison ctx Eq v1 v2
308+
| Or, _ -> RValue.binary_op ctx Logical_or num_typ v1 v2
309+
| And, _ -> RValue.binary_op ctx Logical_and num_typ v1 v2
288310
in
289311
let log_comment c =
290312
match log_functions with
@@ -343,6 +365,17 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
343365
| Unop (Relu, v) ->
344366
let v, fillers = loop v in
345367
(String.concat [ "("; v; " > 0.0 ? "; v; " : 0.0)" ], fillers @ fillers)
368+
| Unop (op, v) ->
369+
let prefix, postfix = Ops.unop_c_syntax prec op in
370+
let v, fillers = loop v in
371+
(String.concat [ prefix; v; postfix ], fillers)
372+
| Ternop (op, cond_v, then_v, else_v) ->
373+
let prefix, infix1, infix2, postfix = Ops.ternop_c_syntax prec op in
374+
let cond, fillers1 = loop cond_v in
375+
let then_, fillers2 = loop then_v in
376+
let else_, fillers3 = loop else_v in
377+
( String.concat [ prefix; cond; infix1; then_; infix2; else_; postfix ],
378+
fillers1 @ fillers2 @ fillers3 )
346379
in
347380
let debug_log_assignment ~env debug idcs node accum_op value v_code =
348381
match log_functions with
@@ -484,12 +517,59 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
484517
| Binop (op, c1, c2) -> loop_binop op ~num_typ prec ~v1:(loop c1) ~v2:(loop c2)
485518
| Unop (Identity, c) -> loop c
486519
| Unop (Relu, c) ->
487-
(* FIXME: don't recompute c *)
488-
let cmp =
489-
cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) @@ loop c
520+
let v = loop c in
521+
let cmp = cast_bool num_typ @@ RValue.comparison ctx Lt (RValue.zero ctx num_typ) v in
522+
RValue.binary_op ctx Mult num_typ cmp v
523+
| Unop (Satur01, c) ->
524+
let v = loop c in
525+
let zero = RValue.zero ctx num_typ in
526+
let one = RValue.double ctx num_typ 1.0 in
527+
let min =
528+
RValue.binary_op ctx Plus num_typ zero
529+
(RValue.binary_op ctx Mult num_typ (RValue.comparison ctx Lt v zero)
530+
(RValue.binary_op ctx Minus num_typ zero v))
490531
in
491-
RValue.binary_op ctx Mult num_typ cmp @@ loop c
532+
RValue.binary_op ctx Plus num_typ min
533+
(RValue.binary_op ctx Mult num_typ (RValue.comparison ctx Gt v one)
534+
(RValue.binary_op ctx Minus num_typ one v))
535+
| Unop (((Exp | Log | Exp2 | Log2 | Sin | Cos | Sqrt | Tanh_approx) as op), c) ->
536+
let prefix, suffix = Ops.unop_c_syntax prec op in
537+
assert (
538+
String.is_suffix ~suffix:"(" prefix
539+
&& String.equal suffix ")"
540+
&& not (String.is_prefix ~prefix:"(" prefix));
541+
let f = Function.builtin ctx (String.drop_suffix prefix 1) in
542+
RValue.call ctx f [ loop c ]
543+
| Ternop (FMA as op, c1, c2, c3) ->
544+
let prefix, _, _, _ = Ops.ternop_c_syntax prec op in
545+
let f = Function.builtin ctx (String.drop_suffix prefix 1) in
546+
RValue.call ctx f [ loop c1; loop c2; loop c3 ]
547+
| Ternop (Where, c1, c2, c3) ->
548+
let cond = loop c1 in
549+
let zero = RValue.zero ctx num_typ in
550+
let cmp = RValue.comparison ctx Eq cond zero in
551+
let v1 = loop c2 in
552+
let v2 = loop c3 in
553+
RValue.binary_op ctx Plus num_typ
554+
(RValue.binary_op ctx Mult num_typ cmp v2)
555+
(RValue.binary_op ctx Mult num_typ
556+
(RValue.binary_op ctx Minus num_typ (RValue.one ctx num_typ) cmp)
557+
v1)
492558
| Constant v -> RValue.double ctx num_typ v
559+
| Unop (Recip, c) ->
560+
let v = loop c in
561+
RValue.binary_op ctx Divide num_typ (RValue.one ctx num_typ) v
562+
| Unop (Recip_sqrt, c) ->
563+
let v = loop c in
564+
RValue.binary_op ctx Divide num_typ (RValue.one ctx num_typ)
565+
(RValue.call ctx (Function.builtin ctx "sqrtf") [ v ])
566+
| Unop (Neg, c) ->
567+
let v = loop c in
568+
RValue.unary_op ctx Negate num_typ v
569+
| Unop (Not, c) ->
570+
let v = loop c in
571+
cast_bool num_typ @@ RValue.unary_op ctx Logical_negate (Type.get ctx Type.Bool)
572+
(RValue.comparison ctx Eq v (RValue.zero ctx num_typ))
493573
and loop_for_loop ~toplevel ~env key ~from_ ~to_ body =
494574
let open Gccjit in
495575
let i = Indexing.symbol_ident key in

0 commit comments

Comments
 (0)