@@ -164,8 +164,6 @@ type unop =
164164 | Log
165165 | Exp2
166166 | Log2
167- | Exp10
168- | Log10
169167 | Sin
170168 | Cos
171169 | Sqrt
@@ -228,8 +226,6 @@ let interpret_unop op v =
228226 | Log -> log v
229227 | Exp2 -> 2. ** v
230228 | Log2 -> log v / log 2.
231- | Exp10 -> 10. ** v
232- | Log10 -> log v / log 10.
233229 | Sin -> sin v
234230 | Cos -> cos v
235231 | Sqrt -> sqrt v
@@ -294,20 +290,15 @@ let binop_c_syntax prec v =
294290 | Mul , _ -> (" (" , " *" , " )" )
295291 | Div , _ -> (" (" , " /" , " )" )
296292 | ToPowOf , Double_prec _ -> (" pow(" , " ," , " )" )
297- | ToPowOf , Single_prec _ -> (" powf(" , " ," , " )" )
298- | ToPowOf , Half_prec _ -> (" powf(" , " ," , " )" )
299293 | ToPowOf , Byte_prec _ ->
300294 invalid_arg " Ops.binop_c_syntax: ToPowOf not supported for byte/integer precisions"
295+ | ToPowOf , _ -> (" powf(" , " ," , " )" )
301296 | Relu_gate , Byte_prec _ -> (" (" , " > 0 ?" , " : 0)" )
302297 | Relu_gate , _ -> (" (" , " > 0.0 ?" , " : 0.0)" )
303- | Max , Double_prec _ -> (" fmax(" , " ," , " )" )
304- | Max , Single_prec _ -> (" fmaxf(" , " ," , " )" )
305- | Max , Half_prec _ -> (" fmaxf(" , " ," , " )" )
306- | Max , Byte_prec _ -> (" fmax(" , " ," , " )" )
307- | Min , Double_prec _ -> (" fmin(" , " ," , " )" )
308- | Min , Single_prec _ -> (" fminf(" , " ," , " )" )
309- | Min , Half_prec _ -> (" fminf(" , " ," , " )" )
310- | Min , Byte_prec _ -> (" fmin(" , " ," , " )" )
298+ | Max , (Double_prec _ | Byte_prec _ ) -> (" fmax(" , " ," , " )" )
299+ | Max , _ -> (" fmaxf(" , " ," , " )" )
300+ | Min , (Double_prec _ | Byte_prec _ ) -> (" fmin(" , " ," , " )" )
301+ | Min , _ -> (" fminf(" , " ," , " )" )
311302 | Mod , _ -> (" (" , " %" , " )" )
312303 | Cmplt , _ -> (" (" , " <" , " )" )
313304 | Cmpne , _ -> (" (" , " !=" , " )" )
@@ -317,10 +308,9 @@ let binop_c_syntax prec v =
317308 (* | Shr, _ -> ("((", ") / exp2(", "))") *)
318309 | Or , _ -> (" (" , " ||" , " )" )
319310 | And , _ -> (" (" , " &&" , " )" )
320- | Threefry , Double_prec _ -> (" threefry(" , " ," , " )" )
321- | Threefry , Single_prec _ -> (" threefryf(" , " ," , " )" )
322- | Threefry , Half_prec _ -> (" threefryf(" , " ," , " )" )
323- | Threefry , Byte_prec _ -> (" threefryf(" , " ," , " )" )
311+ | Threefry , _ ->
312+ (* FIXME: NOT IMPLEMENTED YET *)
313+ failwith " Ops.binop_c_syntax: threefry NOT IMPLEMENTED YET"
324314
325315let is_assign_op = function
326316 | Arg1 | Mod (* | Shl | Shr * ) | Cmplt | Cmpne | Threefry -> false
@@ -372,8 +362,6 @@ let unop_cd_syntax = function
372362 | Log -> " log"
373363 | Exp2 -> " exp2"
374364 | Log2 -> " log2"
375- | Exp10 -> " exp10"
376- | Log10 -> " log10"
377365 | Sin -> " sin"
378366 | Cos -> " cos"
379367 | Sqrt -> " sqrt"
@@ -383,18 +371,49 @@ let unop_cd_syntax = function
383371 | Tanh_approx -> " tanh"
384372
385373let unop_c_syntax prec v =
374+ let fmax () =
375+ (* See: https://en.cppreference.com/w/c/numeric/math/fmax option (4) *)
376+ match prec with
377+ | Double_prec _ | Byte_prec _ -> " fmax"
378+ | _ -> " fmaxf"
379+ in
380+ let fmin () =
381+ (* See: https://en.cppreference.com/w/c/numeric/math/fmin option (4) *)
382+ match prec with
383+ | Double_prec _ | Byte_prec _ -> " fmax"
384+ | _ -> " fmaxf"
385+ in
386386 match (v, prec) with
387387 | Identity , _ -> (" " , " " )
388- | Relu , Single_prec _ -> (" fmaxf(0.0, " , " )" )
389388 | Relu , Byte_prec _ -> (" fmax(0, " , " )" )
390- | Relu , _ -> (" fmax(0.0, " , " )" )
391- | _ ->
392- (* FIXME: NOT IMPLEMENTED YET *)
393- failwith " NOT IMPLEMENTED YET"
394- (* | Satur01, _ -> ("", "") | Exp, _ -> ("", "") | Log, _ -> ("", "") | Exp2, _ -> ("", "") | Log2,
395- _ -> ("", "") | Exp10, _ -> ("", "") | Log10, _ -> ("", "") | Sin, _ -> ("", "") | Cos, _ -> ("",
396- "") | Sqrt, _ -> ("", "") | Recip, _ -> ("", "") | Recip_sqrt, _ -> ("", "") | Neg, _ -> ("", "")
397- | Tanh_approx, _ -> ("", "") *)
389+ | Relu , _ -> (fmax () ^ " (0.0, " , " )" )
390+ | Satur01 , Byte_prec _ -> (" fmax(0, fmin(1, " , " ))" )
391+ | Satur01 , _ -> (fmax () ^ " (0.0, " ^ fmin () ^ " (1.0, " , " ))" )
392+ | Exp , (Double_prec _ | Byte_prec _ ) -> (" exp(" , " )" )
393+ | Exp , _ -> (" expf(" , " )" )
394+ | Log , (Double_prec _ | Byte_prec _ ) -> (" log(" , " )" )
395+ | Log , _ -> (" logf(" , " )" )
396+ | Exp2 , (Double_prec _ | Byte_prec _ ) -> (" exp2(" , " )" )
397+ | Exp2 , _ -> (" exp2f(" , " )" )
398+ | Log2 , (Double_prec _ | Byte_prec _ ) -> (" log2(" , " )" )
399+ | Log2 , _ -> (" log2f(" , " )" )
400+ | Sin , (Double_prec _ | Byte_prec _ ) -> (" sin(" , " )" )
401+ | Sin , _ -> (" sinf(" , " )" )
402+ | Cos , (Double_prec _ | Byte_prec _ ) -> (" cos(" , " )" )
403+ | Cos , _ -> (" cosf(" , " )" )
404+ | Sqrt , (Double_prec _ | Byte_prec _ ) -> (" sqrt(" , " )" )
405+ | Sqrt , _ -> (" sqrtf(" , " )" )
406+ | Recip , Byte_prec _ ->
407+ invalid_arg " Ops.unop_c_syntax: Recip not supported for byte/integer precisions"
408+ | Recip , _ -> (" (1.0 / (" , " ))" )
409+ | Recip_sqrt , Byte_prec _ ->
410+ invalid_arg " Ops.unop_c_syntax: Recip_sqrt not supported for byte/integer precisions"
411+ | Recip_sqrt , Double_prec _ -> (" (1.0 / sqrt(" , " ))" )
412+ | Recip_sqrt , _ -> (" (1.0 / sqrtf(" , " ))" )
413+ | Neg , _ -> (" (-(" , " ))" )
414+ | Tanh_approx , Byte_prec _ ->
415+ invalid_arg " Ops.unop_c_syntax: Tanh_approx not supported for byte/integer precisions"
416+ | Tanh_approx , _ -> (" tanhf(" , " )" )
398417
399418let c_convert_precision ~from ~to_ =
400419 match (from, to_) with
0 commit comments