Skip to content

Commit 5efbf29

Browse files
committed
Remove Exp10, Log10; fill out unop_c_syntax
1 parent 5a04c74 commit 5efbf29

File tree

1 file changed

+48
-29
lines changed

1 file changed

+48
-29
lines changed

arrayjit/lib/ops.ml

Lines changed: 48 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,6 @@ type unop =
164164
| Log
165165
| Exp2
166166
| Log2
167-
| Exp10
168-
| Log10
169167
| Sin
170168
| Cos
171169
| Sqrt
@@ -228,8 +226,6 @@ let interpret_unop op v =
228226
| Log -> log v
229227
| Exp2 -> 2. ** v
230228
| Log2 -> log v / log 2.
231-
| Exp10 -> 10. ** v
232-
| Log10 -> log v / log 10.
233229
| Sin -> sin v
234230
| Cos -> cos v
235231
| Sqrt -> sqrt v
@@ -294,20 +290,15 @@ let binop_c_syntax prec v =
294290
| Mul, _ -> ("(", " *", ")")
295291
| Div, _ -> ("(", " /", ")")
296292
| ToPowOf, Double_prec _ -> ("pow(", ",", ")")
297-
| ToPowOf, Single_prec _ -> ("powf(", ",", ")")
298-
| ToPowOf, Half_prec _ -> ("powf(", ",", ")")
299293
| ToPowOf, Byte_prec _ ->
300294
invalid_arg "Ops.binop_c_syntax: ToPowOf not supported for byte/integer precisions"
295+
| ToPowOf, _ -> ("powf(", ",", ")")
301296
| Relu_gate, Byte_prec _ -> ("(", " > 0 ?", " : 0)")
302297
| Relu_gate, _ -> ("(", " > 0.0 ?", " : 0.0)")
303-
| Max, Double_prec _ -> ("fmax(", ",", ")")
304-
| Max, Single_prec _ -> ("fmaxf(", ",", ")")
305-
| Max, Half_prec _ -> ("fmaxf(", ",", ")")
306-
| Max, Byte_prec _ -> ("fmax(", ",", ")")
307-
| Min, Double_prec _ -> ("fmin(", ",", ")")
308-
| Min, Single_prec _ -> ("fminf(", ",", ")")
309-
| Min, Half_prec _ -> ("fminf(", ",", ")")
310-
| Min, Byte_prec _ -> ("fmin(", ",", ")")
298+
| Max, (Double_prec _ | Byte_prec _) -> ("fmax(", ",", ")")
299+
| Max, _ -> ("fmaxf(", ",", ")")
300+
| Min, (Double_prec _ | Byte_prec _) -> ("fmin(", ",", ")")
301+
| Min, _ -> ("fminf(", ",", ")")
311302
| Mod, _ -> ("(", " %", ")")
312303
| Cmplt, _ -> ("(", " <", ")")
313304
| Cmpne, _ -> ("(", " !=", ")")
@@ -317,10 +308,9 @@ let binop_c_syntax prec v =
317308
(* | Shr, _ -> ("((", ") / exp2(", "))") *)
318309
| Or, _ -> ("(", " ||", ")")
319310
| And, _ -> ("(", " &&", ")")
320-
| Threefry, Double_prec _ -> ("threefry(", ",", ")")
321-
| Threefry, Single_prec _ -> ("threefryf(", ",", ")")
322-
| Threefry, Half_prec _ -> ("threefryf(", ",", ")")
323-
| Threefry, Byte_prec _ -> ("threefryf(", ",", ")")
311+
| Threefry, _ ->
312+
(* FIXME: NOT IMPLEMENTED YET *)
313+
failwith "Ops.binop_c_syntax: threefry NOT IMPLEMENTED YET"
324314

325315
let is_assign_op = function
326316
| Arg1 | Mod (* | Shl | Shr *) | Cmplt | Cmpne | Threefry -> false
@@ -372,8 +362,6 @@ let unop_cd_syntax = function
372362
| Log -> "log"
373363
| Exp2 -> "exp2"
374364
| Log2 -> "log2"
375-
| Exp10 -> "exp10"
376-
| Log10 -> "log10"
377365
| Sin -> "sin"
378366
| Cos -> "cos"
379367
| Sqrt -> "sqrt"
@@ -383,18 +371,49 @@ let unop_cd_syntax = function
383371
| Tanh_approx -> "tanh"
384372

385373
let unop_c_syntax prec v =
374+
let fmax () =
375+
(* See: https://en.cppreference.com/w/c/numeric/math/fmax option (4) *)
376+
match prec with
377+
| Double_prec _ | Byte_prec _ -> "fmax"
378+
| _ -> "fmaxf"
379+
in
380+
let fmin () =
381+
(* See: https://en.cppreference.com/w/c/numeric/math/fmin option (4) *)
382+
match prec with
383+
| Double_prec _ | Byte_prec _ -> "fmax"
384+
| _ -> "fmaxf"
385+
in
386386
match (v, prec) with
387387
| Identity, _ -> ("", "")
388-
| Relu, Single_prec _ -> ("fmaxf(0.0, ", ")")
389388
| Relu, Byte_prec _ -> ("fmax(0, ", ")")
390-
| Relu, _ -> ("fmax(0.0, ", ")")
391-
| _ ->
392-
(* FIXME: NOT IMPLEMENTED YET *)
393-
failwith "NOT IMPLEMENTED YET"
394-
(* | Satur01, _ -> ("", "") | Exp, _ -> ("", "") | Log, _ -> ("", "") | Exp2, _ -> ("", "") | Log2,
395-
_ -> ("", "") | Exp10, _ -> ("", "") | Log10, _ -> ("", "") | Sin, _ -> ("", "") | Cos, _ -> ("",
396-
"") | Sqrt, _ -> ("", "") | Recip, _ -> ("", "") | Recip_sqrt, _ -> ("", "") | Neg, _ -> ("", "")
397-
| Tanh_approx, _ -> ("", "") *)
389+
| Relu, _ -> (fmax () ^ "(0.0, ", ")")
390+
| Satur01, Byte_prec _ -> ("fmax(0, fmin(1, ", "))")
391+
| Satur01, _ -> (fmax () ^ "(0.0, " ^ fmin () ^ "(1.0, ", "))")
392+
| Exp, (Double_prec _ | Byte_prec _) -> ("exp(", ")")
393+
| Exp, _ -> ("expf(", ")")
394+
| Log, (Double_prec _ | Byte_prec _) -> ("log(", ")")
395+
| Log, _ -> ("logf(", ")")
396+
| Exp2, (Double_prec _ | Byte_prec _) -> ("exp2(", ")")
397+
| Exp2, _ -> ("exp2f(", ")")
398+
| Log2, (Double_prec _ | Byte_prec _) -> ("log2(", ")")
399+
| Log2, _ -> ("log2f(", ")")
400+
| Sin, (Double_prec _ | Byte_prec _) -> ("sin(", ")")
401+
| Sin, _ -> ("sinf(", ")")
402+
| Cos, (Double_prec _ | Byte_prec _) -> ("cos(", ")")
403+
| Cos, _ -> ("cosf(", ")")
404+
| Sqrt, (Double_prec _ | Byte_prec _) -> ("sqrt(", ")")
405+
| Sqrt, _ -> ("sqrtf(", ")")
406+
| Recip, Byte_prec _ ->
407+
invalid_arg "Ops.unop_c_syntax: Recip not supported for byte/integer precisions"
408+
| Recip, _ -> ("(1.0 / (", "))")
409+
| Recip_sqrt, Byte_prec _ ->
410+
invalid_arg "Ops.unop_c_syntax: Recip_sqrt not supported for byte/integer precisions"
411+
| Recip_sqrt, Double_prec _ -> ("(1.0 / sqrt(", "))")
412+
| Recip_sqrt, _ -> ("(1.0 / sqrtf(", "))")
413+
| Neg, _ -> ("(-(", "))")
414+
| Tanh_approx, Byte_prec _ ->
415+
invalid_arg "Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
416+
| Tanh_approx, _ -> ("tanhf(", ")")
398417

399418
let c_convert_precision ~from ~to_ =
400419
match (from, to_) with

0 commit comments

Comments
 (0)