Skip to content

Commit 5a04c74

Browse files
committed
Fixes #303: major expansion of available operations, work in progress
1 parent 4ca8b16 commit 5a04c74

File tree

6 files changed

+308
-21
lines changed

6 files changed

+308
-21
lines changed

arrayjit/lib/low_level.ml

Lines changed: 63 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -847,7 +847,7 @@ let get_ident_within_code ?no_dots llcs =
847847
Tn.update_code_name tn ident;
848848
ident
849849

850-
let fprint_hum ?name ?static_indices () ppf llc =
850+
let fprint_cstyle ?name ?static_indices () ppf llc =
851851
let ident_label = get_ident_within_code [| llc |] in
852852
let open Stdlib.Format in
853853
pp_set_margin ppf !code_hum_margin;
@@ -899,7 +899,68 @@ let fprint_hum ?name ?static_indices () ppf llc =
899899
let prefix, infix, postfix = Ops.binop_c_syntax prec op in
900900
fprintf ppf "@[<1>%s%a%s@ %a@]%s" prefix (pp_float prec) v1 infix (pp_float prec) v2 postfix
901901
| Unop (Identity, v) -> (pp_float prec) ppf v
902-
| Unop (Relu, v) -> fprintf ppf "@[<1>relu(%a@])" (pp_float prec) v
902+
| Unop (op, v) ->
903+
let prefix, postfix = Ops.unop_c_syntax prec op in
904+
fprintf ppf "%s%a%s" prefix (pp_float prec) v postfix
905+
in
906+
fprintf ppf "@,@[<v 2>";
907+
fprint_function_header ?name ?static_indices () ppf;
908+
pp_ll ppf llc;
909+
fprintf ppf "@]"
910+
911+
let fprint_hum ?name ?static_indices () ppf llc =
912+
let ident_label = get_ident_within_code [| llc |] in
913+
let open Stdlib.Format in
914+
pp_set_margin ppf !code_hum_margin;
915+
let pp_ident ppf la = fprintf ppf "%s" @@ ident_label la in
916+
let pp_local ppf { tn; scope_id } = fprintf ppf "v%d_%a" scope_id pp_ident tn in
917+
let rec pp_ll ppf c : unit =
918+
match c with
919+
| Noop -> ()
920+
| Seq (c1, c2) ->
921+
fprintf ppf "@[<v 0>%a@]" (pp_print_list pp_ll)
922+
(List.filter [ c1; c2 ] ~f:(function Noop -> false | _ -> true))
923+
| For_loop { index = i; from_; to_; body; trace_it = _ } ->
924+
fprintf ppf "@[<v 2>for %a = %d to %d {@ %a@]@,}" pp_symbol i from_ to_ pp_ll body
925+
| Zero_out tn -> fprintf ppf "zero_out %a;" pp_ident tn
926+
| Set p ->
927+
p.debug <- asprintf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv;
928+
fprintf ppf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv
929+
| Comment message -> fprintf ppf "/* %s */" message
930+
| Staged_compilation _ -> fprintf ppf "STAGED_COMPILATION_CALLBACK()"
931+
| Set_local (id, llv) -> fprintf ppf "@[<2>%a :=@ %a;@]" pp_local id pp_float llv
932+
and pp_float ppf value =
933+
match value with
934+
| Local_scope { id; body; _ } -> fprintf ppf "@[<2>%a {@ %a@]@ }@," pp_local id pp_ll body
935+
| Get_local id -> pp_local ppf id
936+
| Get_global (Ops.C_function s, None) -> fprintf ppf "%s()" s
937+
| Get_global (Ops.C_function s, Some idcs) -> fprintf ppf "%s(%a)" s pp_indices idcs
938+
| Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, None) ->
939+
fprintf ppf "%s" @@ Ops.ptr_to_string_hum ptr prec
940+
| Get_global (Ops.External_unsafe { ptr; prec; dims = _ }, Some idcs) ->
941+
fprintf ppf "%s[%a]" (Ops.ptr_to_string_hum ptr prec) pp_indices idcs
942+
| Get_global (Ops.Merge_buffer { source_node_id }, None) ->
943+
let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in
944+
fprintf ppf "%a.merge" pp_ident tn
945+
| Get_global (Ops.Merge_buffer { source_node_id }, Some idcs) ->
946+
let tn = Option.value_exn ~here:[%here] @@ Tnode.find ~id:source_node_id in
947+
fprintf ppf "@[<2>%a.merge[@,%a]@]" pp_ident tn pp_indices idcs
948+
| Get (tn, idcs) -> fprintf ppf "@[<2>%a[@,%a]@]" pp_ident tn pp_indices idcs
949+
| Constant c -> fprintf ppf "%.16g" c
950+
| Embed_index idx -> pp_axis_index ppf idx
951+
| Binop (Arg1, v1, _v2) -> pp_float ppf v1
952+
| Binop (Arg2, _v1, v2) -> pp_float ppf v2
953+
| Binop (op, v1, v2) ->
954+
if Ops.is_binop_nice_infix op then
955+
let infix = Ops.binop_cd_syntax op in
956+
fprintf ppf "@[<1>(%a %s@ %a@])" pp_float v1 infix pp_float v2
957+
else
958+
let prefix = Ops.binop_cd_fallback_syntax op in
959+
fprintf ppf "@[<1>%s(%a,@ %a@])" prefix pp_float v1 pp_float v2
960+
| Unop (Identity, v) -> pp_float ppf v
961+
| Unop (op, v) ->
962+
let prefix = Ops.unop_cd_syntax op in
963+
fprintf ppf "%s(%a)" prefix pp_float v
903964
in
904965
fprintf ppf "@,@[<v 2>";
905966
fprint_function_header ?name ?static_indices () ppf;

arrayjit/lib/low_level.mli

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,20 @@ val fprint_function_header :
118118

119119
val get_ident_within_code : ?no_dots:bool -> t array -> Tnode.t -> string
120120

121+
val fprint_cstyle :
122+
?name:string ->
123+
?static_indices:Indexing.static_symbol list ->
124+
unit ->
125+
Stdlib.Format.formatter ->
126+
t ->
127+
unit
128+
(** Adheres more to the C syntax, outputs implicit type casts. *)
129+
121130
val fprint_hum :
122131
?name:string ->
123132
?static_indices:Indexing.static_symbol list ->
124133
unit ->
125134
Stdlib.Format.formatter ->
126135
t ->
127136
unit
137+
(** Adheres more to the %cd syntax, does not output implicit type casts. *)

arrayjit/lib/ops.ml

Lines changed: 187 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,14 @@ let hum_typ_of_prec = function
112112
| Double_prec _ -> "double"
113113
| Void_prec -> "void"
114114

115-
(** {2 *** Operations ***} *)
115+
(** {2 *** Operations ***}
116+
117+
See: {{https://github.com/tinygrad/tinygrad/blob/master/tinygrad/ops.py#L123} tinygrad ops},
118+
{{https://docs.nvidia.com/cuda/cuda-math-api/index.html} CUDA Math API} (intrinsics).
119+
120+
This is a redundant set of operations, aiming to expose hardware-supported "intrinsics",
121+
to reduce the need for backends to pattern-match and optimize. Also for convenience.
122+
*)
116123

117124
(** Initializes or resets a array by filling in the corresponding numbers, at the appropriate
118125
precision. *)
@@ -127,10 +134,49 @@ type init_op =
127134
| File_mapped of string * prec (** Reads the data using [Unix.openfile] and [Unix.map_file]. *)
128135
[@@deriving equal, sexp]
129136

130-
type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1
137+
type binop =
138+
| Add
139+
| Sub
140+
| Mul
141+
| Div
142+
| ToPowOf
143+
| Relu_gate
144+
| Arg2
145+
| Arg1
146+
| Max
147+
| Min
148+
| Mod
149+
| Cmplt
150+
| Cmpne
151+
(* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
152+
(* | Shl *)
153+
(* | Shr *)
154+
| Or
155+
| And
156+
| Threefry (** Counter-based random number generator. *)
131157
[@@deriving sexp, compare, equal]
132158

133-
type unop = Identity | Relu [@@deriving sexp, compare, equal]
159+
type unop =
160+
| Identity
161+
| Relu
162+
| Satur01 (** Saturate (truncate) to within the interval [[0; 1]]. *)
163+
| Exp
164+
| Log
165+
| Exp2
166+
| Log2
167+
| Exp10
168+
| Log10
169+
| Sin
170+
| Cos
171+
| Sqrt
172+
| Recip
173+
| Recip_sqrt
174+
| Neg
175+
| Tanh_approx
176+
[@@deriving sexp, compare, equal]
177+
178+
type ternop = Where (** Where(a,b,c): if a then b else c *) | FMA (** FMA(a,b,c): (a * b) + c *)
179+
[@@deriving sexp, compare, equal]
134180

135181
(** Either the left-neutral or right-neutral element of the operation. Unspecified if the operation
136182
does not have a neutral element. *)
@@ -139,8 +185,11 @@ let neutral_elem = function
139185
| Mul | Div -> 1.
140186
| ToPowOf -> 1.
141187
| Relu_gate -> 1.
142-
| Arg2 -> 0.
143-
| Arg1 -> 0.
188+
| Max -> Float.neg_infinity
189+
| Min -> Float.infinity
190+
| And -> 1.
191+
| Or -> 0.
192+
| Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr *) | Threefry -> 0.
144193

145194
let interpret_binop op v1 v2 =
146195
let open Float in
@@ -153,10 +202,47 @@ let interpret_binop op v1 v2 =
153202
| Div -> v1 / v2
154203
| ToPowOf -> if is_integer v2 then int_pow v1 @@ to_int v2 else v1 ** v2
155204
| Relu_gate -> if v1 > 0.0 then v2 else 0.0
205+
| Max -> max v1 v2
206+
| Min -> min v1 v2
207+
| Mod -> v1 % v2
208+
| Cmplt -> if v1 < v2 then 1. else 0.
209+
| Cmpne -> if v1 <> v2 then 1. else 0.
210+
(* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
211+
(* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
212+
| Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
213+
| And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
214+
| Threefry ->
215+
(* FIXME: NOT IMPLEMENTED YET *)
216+
failwith "FIXME: NOT IMPLEMENTED YET"
156217

157218
let interpret_unop op v =
158219
let open Float in
159-
match op with Identity -> v | Relu when v >= 0. -> v | Relu -> 0.
220+
match op with
221+
| Identity -> v
222+
| Relu when v >= 0. -> v
223+
| Relu -> 0.
224+
| Satur01 when v <= 0. -> 0.
225+
| Satur01 when v >= 1. -> 1.
226+
| Satur01 -> v
227+
| Exp -> exp v
228+
| Log -> log v
229+
| Exp2 -> 2. ** v
230+
| Log2 -> log v / log 2.
231+
| Exp10 -> 10. ** v
232+
| Log10 -> log v / log 10.
233+
| Sin -> sin v
234+
| Cos -> cos v
235+
| Sqrt -> sqrt v
236+
| Recip -> 1. / v
237+
| Recip_sqrt -> 1. / sqrt v
238+
| Neg -> ~-.v
239+
| Tanh_approx -> tanh v
240+
241+
let is_binop_infix = function Threefry -> false | _ -> true
242+
243+
let is_binop_nice_infix = function
244+
| Arg1 | Arg2 | Relu_gate | Max | Min | Threefry -> false
245+
| _ -> true
160246

161247
let binop_cd_syntax = function
162248
| Arg1 -> "-@>"
@@ -167,6 +253,36 @@ let binop_cd_syntax = function
167253
| Div -> "/"
168254
| ToPowOf -> "**"
169255
| Relu_gate -> "-?/"
256+
| Cmplt -> "<"
257+
| Cmpne -> "<>"
258+
| Or -> "||"
259+
| And -> "&&"
260+
| Mod -> "%"
261+
| Max -> "@^"
262+
| Min -> "^^"
263+
(* | Shl -> "lsl" *)
264+
(* | Shr -> "lsr" *)
265+
| Threefry -> "threefry"
266+
267+
let binop_cd_fallback_syntax = function
268+
| Arg1 -> "fst"
269+
| Arg2 -> "snd"
270+
| Add -> "add"
271+
| Sub -> "sub"
272+
| Mul -> "mul"
273+
| Div -> "div"
274+
| ToPowOf -> "pow"
275+
| Relu_gate -> "relu_gate"
276+
| Cmplt -> "lt"
277+
| Cmpne -> "le"
278+
| Or -> "orf"
279+
| And -> "andf"
280+
| Mod -> "modf"
281+
| Max -> "max"
282+
| Min -> "min"
283+
(* | Shl -> "shlf" *)
284+
(* | Shr -> "shrf" *)
285+
| Threefry -> "threefry"
170286

171287
let binop_c_syntax prec v =
172288
match (v, prec) with
@@ -184,22 +300,56 @@ let binop_c_syntax prec v =
184300
invalid_arg "Ops.binop_c_syntax: ToPowOf not supported for byte/integer precisions"
185301
| Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)")
186302
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
303+
| Max, Double_prec _ -> ("fmax(", ",", ")")
304+
| Max, Single_prec _ -> ("fmaxf(", ",", ")")
305+
| Max, Half_prec _ -> ("fmaxf(", ",", ")")
306+
| Max, Byte_prec _ -> ("fmax(", ",", ")")
307+
| Min, Double_prec _ -> ("fmin(", ",", ")")
308+
| Min, Single_prec _ -> ("fminf(", ",", ")")
309+
| Min, Half_prec _ -> ("fminf(", ",", ")")
310+
| Min, Byte_prec _ -> ("fmin(", ",", ")")
311+
| Mod, _ -> ("(", " %", ")")
312+
| Cmplt, _ -> ("(", " <", ")")
313+
| Cmpne, _ -> ("(", " !=", ")")
314+
(* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
315+
(* | Shl, _ -> ("((", ") * exp2(", "))") *)
316+
(* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
317+
(* | Shr, _ -> ("((", ") / exp2(", "))") *)
318+
| Or, _ -> ("(", " ||", ")")
319+
| And, _ -> ("(", " &&", ")")
320+
| Threefry, Double_prec _ -> ("threefry(", ",", ")")
321+
| Threefry, Single_prec _ -> ("threefryf(", ",", ")")
322+
| Threefry, Half_prec _ -> ("threefryf(", ",", ")")
323+
| Threefry, Byte_prec _ -> ("threefryf(", ",", ")")
324+
325+
let is_assign_op = function
326+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne | Threefry -> false
327+
| Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true
187328

188329
let assign_op_cd_syntax ~initialize_neutral = function
189-
| Arg1 -> invalid_arg "Ops.assign_op_cd_syntax: Arg1 is not a %cd assignment operator"
190330
| Arg2 -> "=:"
191331
| Add when initialize_neutral -> "=:+"
192332
| Sub when initialize_neutral -> "=:-"
193333
| Mul when initialize_neutral -> "=:*"
194334
| Div when initialize_neutral -> "=:/"
195335
| ToPowOf when initialize_neutral -> "=:**"
196336
| Relu_gate when initialize_neutral -> "=:?/"
337+
| Or when initialize_neutral -> "=:||"
338+
| And when initialize_neutral -> "=:&&"
339+
| Max when initialize_neutral -> "=:@^"
340+
| Min when initialize_neutral -> "=:^^"
197341
| Add -> "=+"
198342
| Sub -> "=-"
199343
| Mul -> "=*"
200344
| Div -> "=/"
201345
| ToPowOf -> "=**"
202346
| Relu_gate -> "=?/"
347+
| Max -> "=@^"
348+
| Min -> "=^^"
349+
| Or -> "=||"
350+
| And -> "=&&"
351+
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne | Threefry ->
352+
invalid_arg "Ops.assign_op_cd_syntax: not an assignment op"
203353

204354
let assign_op_c_syntax = function
205355
| Arg1 -> invalid_arg "Ops.assign_op_c_syntax: Arg1 is not a C assignment operator"
@@ -208,17 +358,43 @@ let assign_op_c_syntax = function
208358
| Sub -> "-="
209359
| Mul -> "*="
210360
| Div -> "/="
211-
| ToPowOf -> invalid_arg "Ops.assign_op_c_syntax: ToPowOf function is not a C assignment operator"
212-
| Relu_gate -> invalid_arg "Ops.assign_op_c_syntax: Relu_gate is not a C assignment operator"
213-
214-
let unop_cd_syntax = function Identity -> "~=" | Relu -> "?/"
361+
| Mod -> "%="
362+
(* | Shl -> "<<=" *)
363+
(* | Shr -> ">>=" *)
364+
| _ -> invalid_arg "Ops.assign_op_c_syntax: not a C assignment operator"
365+
366+
(** Note: currently we do not support unary prefix symbols. *)
367+
let unop_cd_syntax = function
368+
| Identity -> "id"
369+
| Relu -> "relu"
370+
| Satur01 -> "sat01"
371+
| Exp -> "exp"
372+
| Log -> "log"
373+
| Exp2 -> "exp2"
374+
| Log2 -> "log2"
375+
| Exp10 -> "exp10"
376+
| Log10 -> "log10"
377+
| Sin -> "sin"
378+
| Cos -> "cos"
379+
| Sqrt -> "sqrt"
380+
| Recip -> "recip"
381+
| Recip_sqrt -> "recip_sqrt"
382+
| Neg -> "neg"
383+
| Tanh_approx -> "tanh"
215384

216385
let unop_c_syntax prec v =
217386
match (v, prec) with
218387
| Identity, _ -> ("", "")
219388
| Relu, Single_prec _ -> ("fmaxf(0.0, ", ")")
220389
| Relu, Byte_prec _ -> ("fmax(0, ", ")")
221390
| Relu, _ -> ("fmax(0.0, ", ")")
391+
| _ ->
392+
(* FIXME: NOT IMPLEMENTED YET *)
393+
failwith "NOT IMPLEMENTED YET"
394+
(* | Satur01, _ -> ("", "") | Exp, _ -> ("", "") | Log, _ -> ("", "") | Exp2, _ -> ("", "") | Log2,
395+
_ -> ("", "") | Exp10, _ -> ("", "") | Log10, _ -> ("", "") | Sin, _ -> ("", "") | Cos, _ -> ("",
396+
"") | Sqrt, _ -> ("", "") | Recip, _ -> ("", "") | Recip_sqrt, _ -> ("", "") | Neg, _ -> ("", "")
397+
| Tanh_approx, _ -> ("", "") *)
222398

223399
let c_convert_precision ~from ~to_ =
224400
match (from, to_) with

0 commit comments

Comments
 (0)