Skip to content

Commit ffbea70

Browse files
committed
Progress toward incorporating new ops
Comments, fma intro in simplification but no elim yet, introduce binary ops in cd_ppx.ml op lists. TODO: remove unary prefix ops from parsing, handle unary and ternary op applications in parsing.
1 parent fde7983 commit ffbea70

File tree

3 files changed

+39
-13
lines changed

3 files changed

+39
-13
lines changed

arrayjit/lib/low_level.ml

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -665,13 +665,6 @@ let simplify_llc llc =
665665
| Get_global _ -> llv
666666
| Embed_index (Fixed_idx i) -> Constant (Float.of_int i)
667667
| Embed_index (Iterator _) -> llv
668-
| Ternop (op, llv1, llv2, llv3) ->
669-
(* FIXME: NOT IMPLEMENTED YET *)
670-
let v1 = loop_float llv1 in
671-
let v2 = loop_float llv2 in
672-
let v3 = loop_float llv3 in
673-
let result = Ternop (op, v1, v2, v3) in
674-
if equal_float_t llv1 v1 && equal_float_t llv2 v2 then result else loop_float result
675668
| Binop (Arg1, llv1, _) -> loop_float llv1
676669
| Binop (Arg2, _, llv2) -> loop_float llv2
677670
| Binop (op, Constant c1, Constant c2) -> Constant (Ops.interpret_binop op c1 c2)
@@ -721,11 +714,20 @@ let simplify_llc llc =
721714
| Constant c when Float.is_integer c ->
722715
loop_float @@ unroll_pow ~base:v1 ~exp:(Float.to_int c)
723716
| _ -> result)
717+
| Binop (Add, Binop (Mul, llv1, llv2), llv3) | Binop (Add, llv3, Binop (Mul, llv1, llv2)) ->
718+
(* TODO: this is tentative. *)
719+
loop_float @@ Ternop (FMA, llv1, llv2, llv3)
724720
| Binop (op, llv1, llv2) ->
725721
let v1 = loop_float llv1 in
726722
let v2 = loop_float llv2 in
727723
let result = Binop (op, v1, v2) in
728724
if equal_float_t llv1 v1 && equal_float_t llv2 v2 then result else loop_float result
725+
| Ternop (op, llv1, llv2, llv3) ->
726+
let v1 = loop_float llv1 in
727+
let v2 = loop_float llv2 in
728+
let v3 = loop_float llv3 in
729+
let result = Ternop (op, v1, v2, v3) in
730+
if equal_float_t llv1 v1 && equal_float_t llv2 v2 then result else loop_float result
729731
| Unop (Identity, llv) -> loop_float llv
730732
| Unop (op, Constant c) -> Constant (Ops.interpret_unop op c)
731733
| Unop (op, llv) ->

arrayjit/lib/ops.ml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,9 @@ let interpret_ternop op v1 v2 v3 =
236236
let open Float in
237237
match op with Where -> if v1 <> 0. then v2 else v3 | FMA -> (v1 * v2) + v3
238238

239+
(** Note: currently the %cd syntax only supports infix binops as assignment ops. *)
239240
let is_binop_infix _ = true
241+
240242
let is_binop_nice_infix = function Arg1 | Arg2 | Relu_gate | Max | Min -> false | _ -> true
241243

242244
let binop_cd_syntax = function
@@ -258,6 +260,8 @@ let binop_cd_syntax = function
258260
(* | Shl -> "lsl" *)
259261
(* | Shr -> "lsr" *)
260262

263+
(** In the %cd syntax, we support uncurried notation for binary ops in addition to the infix
264+
notation. *)
261265
let binop_cd_fallback_syntax = function
262266
| Arg1 -> "fst"
263267
| Arg2 -> "snd"
@@ -397,6 +401,7 @@ let unop_c_syntax prec op =
397401
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
398402
| Tanh_approx, _ -> ("tanhf(", ")")
399403

404+
(** In the %cd syntax, we use uncurried notation for ternary ops. *)
400405
let ternop_cd_syntax = function Where -> "where" | FMA -> "fma"
401406

402407
let ternop_c_syntax prec op =

lib/ppx_cd.ml

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -46,21 +46,29 @@ let assignment_op expr =
4646
| [%expr ( =/ )] -> (false, [%expr Arrayjit.Ops.Div])
4747
| [%expr ( =** )] -> (false, [%expr Arrayjit.Ops.ToPowOf])
4848
| [%expr ( =?/ )] -> (false, [%expr Arrayjit.Ops.Relu_gate])
49+
| [%expr ( =|| )] -> (false, [%expr Arrayjit.Ops.Or])
50+
| [%expr ( =&& )] -> (false, [%expr Arrayjit.Ops.And])
51+
| [%expr ( =@^ )] -> (false, [%expr Arrayjit.Ops.Max])
52+
| [%expr ( =^^ )] -> (false, [%expr Arrayjit.Ops.Min])
4953
| [%expr ( =:+ )] -> (true, [%expr Arrayjit.Ops.Add])
5054
| [%expr ( =:- )] -> (true, [%expr Arrayjit.Ops.Sub])
5155
| [%expr ( =:* )] -> (true, [%expr Arrayjit.Ops.Mul])
5256
| [%expr ( =:/ )] -> (true, [%expr Arrayjit.Ops.Div])
5357
| [%expr ( =:** )] -> (true, [%expr Arrayjit.Ops.ToPowOf])
5458
| [%expr ( =:?/ )] -> (true, [%expr Arrayjit.Ops.Relu_gate])
59+
| [%expr ( =:|| )] -> (true, [%expr Arrayjit.Ops.Or])
60+
| [%expr ( =:&& )] -> (true, [%expr Arrayjit.Ops.And])
61+
| [%expr ( =:@^ )] -> (true, [%expr Arrayjit.Ops.Max])
62+
| [%expr ( =:^^ )] -> (true, [%expr Arrayjit.Ops.Min])
5563
| _ ->
5664
( false,
5765
Ast_builder.Default.pexp_extension ~loc
5866
@@ Location.error_extensionf ~loc
5967
"ppx_ocannl %%cd: expected an assignment operator, one of: %s %s"
60-
"=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =: (Arg2), \
61-
=:+, =:-,"
62-
" =:*, =:/, =:**, =:?/ (same with initializing the tensor to the neutral value before \
63-
the start of the calculation)" )
68+
"=+ (Add), =- (Sub), =* (Mul), =/ (Div), =** (ToPowOf), =?/ (Relu_gate), =|| (Or), \
69+
=&& (And), =@^ (Max), =^^ (Min), =: (Arg2), =:+, =:-,"
70+
" =:*, =:/, =:**, =:?/, =:||, =:&&, =:@^, =:^^ (same with initializing the tensor to \
71+
the neutral value before the start of the calculation)" )
6472

6573
let binary_op expr =
6674
(* This and is_binary_op should stay in sync with Arrayjit.Ops.binop_cd_syntax. *)
@@ -84,14 +92,25 @@ let binary_op expr =
8492
| [%expr ( -?/ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Relu_gate])
8593
| [%expr ( -/> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Arg2])
8694
| [%expr ( -@> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Arg1])
95+
| [%expr ( < )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmplt])
96+
| [%expr ( <> )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Cmpne])
97+
| [%expr ( || )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Or])
98+
| [%expr ( && )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.And])
99+
| [%expr ( % )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Mod])
100+
| [%expr ( @^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Max])
101+
| [%expr ( ^^ )] -> ([%expr Shape.Pointwise_bin], [%expr Arrayjit.Ops.Min])
87102
| _ ->
88103
( [%expr Shape.Pointwise_bin],
89104
Ast_builder.Default.pexp_extension ~loc
90105
@@ Location.error_extensionf ~loc "ppx_ocannl %%cd: expected a binary operator, one of: %s"
91-
"+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2)" )
106+
"+ (Add), - (Sub), * (Mul), / (Div), ** (ToPowOf), -?/ (Relu_gate), -/> (Arg2), < \
107+
(Cmplt), <> (Cmpne), || (Or), && (And), % (Mod), @^ (Max), ^^ (Min)" )
92108

93109
let is_binary_op ident =
94-
List.mem [ "+"; "-"; "*"; "/"; "**"; "-?/"; "-/>"; "-@>" ] ident ~equal:String.equal
110+
(* TODO: compile into a hashtable *)
111+
List.mem
112+
[ "+"; "-"; "*"; "/"; "**"; "-?/"; "-/>"; "-@>"; "<"; "<>"; "&&"; "%"; "@^"; "^^" ]
113+
ident ~equal:String.equal
95114

96115
let unary_op expr =
97116
(* This and is_unary_op should stay in sync with Arrayjit.Ops.unop_cd_syntax. *)

0 commit comments

Comments
 (0)