@@ -112,7 +112,14 @@ let hum_typ_of_prec = function
112112 | Double_prec _ -> " double"
113113 | Void_prec -> " void"
114114
115- (* * {2 *** Operations ***} *)
115+ (* * {2 *** Operations ***}
116+
117+ See: {{https://github.com/tinygrad/tinygrad/blob/master/tinygrad/ops.py#L123} tinygrad ops},
118+ {{https://docs.nvidia.com/cuda/cuda-math-api/index.html} CUDA Math API} (intrinsics).
119+
120+ This is a redundant set of operations, aiming to expose hardware-supported "intrinsics",
121+ to reduce the need for backends to pattern-match and optimize. Also for convenience.
122+ *)
116123
117124(* * Initializes or resets a array by filling in the corresponding numbers, at the appropriate
118125 precision. *)
@@ -127,10 +134,49 @@ type init_op =
127134 | File_mapped of string * prec (* * Reads the data using [Unix.openfile] and [Unix.map_file]. *)
128135[@@ deriving equal , sexp ]
129136
130- type binop = Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Arg1
137+ type binop =
138+ | Add
139+ | Sub
140+ | Mul
141+ | Div
142+ | ToPowOf
143+ | Relu_gate
144+ | Arg2
145+ | Arg1
146+ | Max
147+ | Min
148+ | Mod
149+ | Cmplt
150+ | Cmpne
151+ (* Waiting till we have a use-case to see how to sensibly introduce bitwise operations. *)
152+ (* | Shl *)
153+ (* | Shr *)
154+ | Or
155+ | And
156+ | Threefry (* * Counter-based random number generator. *)
131157[@@ deriving sexp , compare , equal ]
132158
133- type unop = Identity | Relu [@@ deriving sexp , compare , equal ]
159+ type unop =
160+ | Identity
161+ | Relu
162+ | Satur01 (* * Saturate (truncate) to within the interval [[0; 1]]. *)
163+ | Exp
164+ | Log
165+ | Exp2
166+ | Log2
167+ | Exp10
168+ | Log10
169+ | Sin
170+ | Cos
171+ | Sqrt
172+ | Recip
173+ | Recip_sqrt
174+ | Neg
175+ | Tanh_approx
176+ [@@ deriving sexp , compare , equal ]
177+
178+ type ternop = Where (* * Where(a,b,c): if a then b else c *) | FMA (* * FMA(a,b,c): (a * b) + c *)
179+ [@@ deriving sexp , compare , equal ]
134180
135181(* * Either the left-neutral or right-neutral element of the operation. Unspecified if the operation
136182 does not have a neutral element. *)
@@ -139,8 +185,11 @@ let neutral_elem = function
139185 | Mul | Div -> 1.
140186 | ToPowOf -> 1.
141187 | Relu_gate -> 1.
142- | Arg2 -> 0.
143- | Arg1 -> 0.
188+ | Max -> Float. neg_infinity
189+ | Min -> Float. infinity
190+ | And -> 1.
191+ | Or -> 0.
192+ | Arg2 | Arg1 | Mod | Cmplt | Cmpne (* | Shl | Shr * ) | Threefry -> 0.
144193
145194let interpret_binop op v1 v2 =
146195 let open Float in
@@ -153,10 +202,47 @@ let interpret_binop op v1 v2 =
153202 | Div -> v1 / v2
154203 | ToPowOf -> if is_integer v2 then int_pow v1 @@ to_int v2 else v1 ** v2
155204 | Relu_gate -> if v1 > 0.0 then v2 else 0.0
205+ | Max -> max v1 v2
206+ | Min -> min v1 v2
207+ | Mod -> v1 % v2
208+ | Cmplt -> if v1 < v2 then 1. else 0.
209+ | Cmpne -> if v1 <> v2 then 1. else 0.
210+ (* | Shl -> v1 * (int_pow 2. @@ to_int v2) *)
211+ (* | Shr -> v1 / (int_pow 2. @@ to_int v2) *)
212+ | Or -> if v1 <> 0. || v2 <> 0. then 1. else 0.
213+ | And -> if v1 <> 0. && v2 <> 0. then 1. else 0.
214+ | Threefry ->
215+ (* FIXME: NOT IMPLEMENTED YET *)
216+ failwith " FIXME: NOT IMPLEMENTED YET"
156217
157218let interpret_unop op v =
158219 let open Float in
159- match op with Identity -> v | Relu when v > = 0. -> v | Relu -> 0.
220+ match op with
221+ | Identity -> v
222+ | Relu when v > = 0. -> v
223+ | Relu -> 0.
224+ | Satur01 when v < = 0. -> 0.
225+ | Satur01 when v > = 1. -> 1.
226+ | Satur01 -> v
227+ | Exp -> exp v
228+ | Log -> log v
229+ | Exp2 -> 2. ** v
230+ | Log2 -> log v / log 2.
231+ | Exp10 -> 10. ** v
232+ | Log10 -> log v / log 10.
233+ | Sin -> sin v
234+ | Cos -> cos v
235+ | Sqrt -> sqrt v
236+ | Recip -> 1. / v
237+ | Recip_sqrt -> 1. / sqrt v
238+ | Neg -> ~-. v
239+ | Tanh_approx -> tanh v
240+
241+ let is_binop_infix = function Threefry -> false | _ -> true
242+
243+ let is_binop_nice_infix = function
244+ | Arg1 | Arg2 | Relu_gate | Max | Min | Threefry -> false
245+ | _ -> true
160246
161247let binop_cd_syntax = function
162248 | Arg1 -> " -@>"
@@ -167,6 +253,36 @@ let binop_cd_syntax = function
167253 | Div -> " /"
168254 | ToPowOf -> " **"
169255 | Relu_gate -> " -?/"
256+ | Cmplt -> " <"
257+ | Cmpne -> " <>"
258+ | Or -> " ||"
259+ | And -> " &&"
260+ | Mod -> " %"
261+ | Max -> " @^"
262+ | Min -> " ^^"
263+ (* | Shl -> "lsl" *)
264+ (* | Shr -> "lsr" *)
265+ | Threefry -> " threefry"
266+
267+ let binop_cd_fallback_syntax = function
268+ | Arg1 -> " fst"
269+ | Arg2 -> " snd"
270+ | Add -> " add"
271+ | Sub -> " sub"
272+ | Mul -> " mul"
273+ | Div -> " div"
274+ | ToPowOf -> " pow"
275+ | Relu_gate -> " relu_gate"
276+ | Cmplt -> " lt"
277+ | Cmpne -> " le"
278+ | Or -> " orf"
279+ | And -> " andf"
280+ | Mod -> " modf"
281+ | Max -> " max"
282+ | Min -> " min"
283+ (* | Shl -> "shlf" *)
284+ (* | Shr -> "shrf" *)
285+ | Threefry -> " threefry"
170286
171287let binop_c_syntax prec v =
172288 match (v, prec) with
@@ -184,22 +300,56 @@ let binop_c_syntax prec v =
184300 invalid_arg " Ops.binop_c_syntax: ToPowOf not supported for byte/integer precisions"
185301 | Relu_gate , Byte_prec _ -> (" (" , " > 0 ?" , " : 0)" )
186302 | Relu_gate , _ -> (" (" , " > 0.0 ?" , " : 0.0)" )
303+ | Max , Double_prec _ -> (" fmax(" , " ," , " )" )
304+ | Max , Single_prec _ -> (" fmaxf(" , " ," , " )" )
305+ | Max , Half_prec _ -> (" fmaxf(" , " ," , " )" )
306+ | Max , Byte_prec _ -> (" fmax(" , " ," , " )" )
307+ | Min , Double_prec _ -> (" fmin(" , " ," , " )" )
308+ | Min , Single_prec _ -> (" fminf(" , " ," , " )" )
309+ | Min , Half_prec _ -> (" fminf(" , " ," , " )" )
310+ | Min , Byte_prec _ -> (" fmin(" , " ," , " )" )
311+ | Mod , _ -> (" (" , " %" , " )" )
312+ | Cmplt , _ -> (" (" , " <" , " )" )
313+ | Cmpne , _ -> (" (" , " !=" , " )" )
314+ (* | Shl, Byte_prec _ -> ("(", " <<", ")") *)
315+ (* | Shl, _ -> ("((", ") * exp2(", "))") *)
316+ (* | Shr, Byte_prec _ -> ("(", " >>", ")") *)
317+ (* | Shr, _ -> ("((", ") / exp2(", "))") *)
318+ | Or , _ -> (" (" , " ||" , " )" )
319+ | And , _ -> (" (" , " &&" , " )" )
320+ | Threefry , Double_prec _ -> (" threefry(" , " ," , " )" )
321+ | Threefry , Single_prec _ -> (" threefryf(" , " ," , " )" )
322+ | Threefry , Half_prec _ -> (" threefryf(" , " ," , " )" )
323+ | Threefry , Byte_prec _ -> (" threefryf(" , " ," , " )" )
324+
325+ let is_assign_op = function
326+ | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpne | Threefry -> false
327+ | Add | Sub | Mul | Div | ToPowOf | Relu_gate | Arg2 | Max | Min | Or | And -> true
187328
188329let assign_op_cd_syntax ~initialize_neutral = function
189- | Arg1 -> invalid_arg " Ops.assign_op_cd_syntax: Arg1 is not a %cd assignment operator"
190330 | Arg2 -> " =:"
191331 | Add when initialize_neutral -> " =:+"
192332 | Sub when initialize_neutral -> " =:-"
193333 | Mul when initialize_neutral -> " =:*"
194334 | Div when initialize_neutral -> " =:/"
195335 | ToPowOf when initialize_neutral -> " =:**"
196336 | Relu_gate when initialize_neutral -> " =:?/"
337+ | Or when initialize_neutral -> " =:||"
338+ | And when initialize_neutral -> " =:&&"
339+ | Max when initialize_neutral -> " =:@^"
340+ | Min when initialize_neutral -> " =:^^"
197341 | Add -> " =+"
198342 | Sub -> " =-"
199343 | Mul -> " =*"
200344 | Div -> " =/"
201345 | ToPowOf -> " =**"
202346 | Relu_gate -> " =?/"
347+ | Max -> " =@^"
348+ | Min -> " =^^"
349+ | Or -> " =||"
350+ | And -> " =&&"
351+ | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpne | Threefry ->
352+ invalid_arg " Ops.assign_op_cd_syntax: not an assignment op"
203353
204354let assign_op_c_syntax = function
205355 | Arg1 -> invalid_arg " Ops.assign_op_c_syntax: Arg1 is not a C assignment operator"
@@ -208,17 +358,43 @@ let assign_op_c_syntax = function
208358 | Sub -> " -="
209359 | Mul -> " *="
210360 | Div -> " /="
211- | ToPowOf -> invalid_arg " Ops.assign_op_c_syntax: ToPowOf function is not a C assignment operator"
212- | Relu_gate -> invalid_arg " Ops.assign_op_c_syntax: Relu_gate is not a C assignment operator"
213-
214- let unop_cd_syntax = function Identity -> " ~=" | Relu -> " ?/"
361+ | Mod -> " %="
362+ (* | Shl -> "<<=" *)
363+ (* | Shr -> ">>=" *)
364+ | _ -> invalid_arg " Ops.assign_op_c_syntax: not a C assignment operator"
365+
366+ (* * Note: currently we do not support unary prefix symbols. *)
367+ let unop_cd_syntax = function
368+ | Identity -> " id"
369+ | Relu -> " relu"
370+ | Satur01 -> " sat01"
371+ | Exp -> " exp"
372+ | Log -> " log"
373+ | Exp2 -> " exp2"
374+ | Log2 -> " log2"
375+ | Exp10 -> " exp10"
376+ | Log10 -> " log10"
377+ | Sin -> " sin"
378+ | Cos -> " cos"
379+ | Sqrt -> " sqrt"
380+ | Recip -> " recip"
381+ | Recip_sqrt -> " recip_sqrt"
382+ | Neg -> " neg"
383+ | Tanh_approx -> " tanh"
215384
216385let unop_c_syntax prec v =
217386 match (v, prec) with
218387 | Identity , _ -> (" " , " " )
219388 | Relu , Single_prec _ -> (" fmaxf(0.0, " , " )" )
220389 | Relu , Byte_prec _ -> (" fmax(0, " , " )" )
221390 | Relu , _ -> (" fmax(0.0, " , " )" )
391+ | _ ->
392+ (* FIXME: NOT IMPLEMENTED YET *)
393+ failwith " NOT IMPLEMENTED YET"
394+ (* | Satur01, _ -> ("", "") | Exp, _ -> ("", "") | Log, _ -> ("", "") | Exp2, _ -> ("", "") | Log2,
395+ _ -> ("", "") | Exp10, _ -> ("", "") | Log10, _ -> ("", "") | Sin, _ -> ("", "") | Cos, _ -> ("",
396+ "") | Sqrt, _ -> ("", "") | Recip, _ -> ("", "") | Recip_sqrt, _ -> ("", "") | Neg, _ -> ("", "")
397+ | Tanh_approx, _ -> ("", "") *)
222398
223399let c_convert_precision ~from ~to_ =
224400 match (from, to_) with
0 commit comments