@@ -309,10 +309,25 @@ module Fresh () = struct
309309 | Satur01_gate , Byte_prec _ -> (" (abs(" , " ) > 0 ? 0 : (" , " )" )
310310 | Satur01_gate , Half_prec _ ->
311311 ( " (__hgt(__habs(htrunc(" ,
312- " )), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned short)0x0000U) : (" ,
312+ " )), __ushort_as_half((unsigned short)0x0000U)) ? __ushort_as_half((unsigned \
313+ short)0x0000U) : (" ,
313314 " ))" )
314315 | Satur01_gate , Double_prec _ -> (" (fabs(trunc(" , " )) > 0.0 ? 0.0 : (" , " ))" )
315316 | Satur01_gate , Single_prec _ -> (" (fabsf(truncf(" , " )) > 0.0 ? 0.0 : (" , " ))" )
317+ | Max , Byte_prec _ -> (" max(" , " , " , " )" )
318+ | Max , Half_prec _ -> (" __hmax(" , " , " , " )" )
319+ | Max , Double_prec _ -> (" fmax(" , " , " , " )" )
320+ | Max , Single_prec _ -> (" fmaxf(" , " , " , " )" )
321+ | Min , Byte_prec _ -> (" min(" , " , " , " )" )
322+ | Min , Half_prec _ -> (" __hmin(" , " , " , " )" )
323+ | Min , Double_prec _ -> (" fmin(" , " , " , " )" )
324+ | Min , Single_prec _ -> (" fminf(" , " , " , " )" )
325+ | Mod , Byte_prec _ -> (" (" , " % " , " )" )
326+ | Mod , _ -> (" fmod(" , " , " , " )" )
327+ | Cmplt , _ -> (" (" , " < " , " )" )
328+ | Cmpeq , _ -> (" (" , " == " , " )" )
329+ | Or , _ -> (" (" , " || " , " )" )
330+ | And , _ -> (" (" , " && " , " )" )
316331
317332 let unop_syntax prec v =
318333 match (v, prec) with
@@ -321,6 +336,59 @@ module Fresh () = struct
321336 | Relu , Ops. Half_prec _ -> (" __hmax_nan(__ushort_as_half((unsigned short)0x0000U), " , " )" )
322337 | Relu , Ops. Byte_prec _ -> (" fmax(0, " , " )" )
323338 | Relu , _ -> (" fmax(0.0, " , " )" )
339+ | Satur01 , Byte_prec _ -> (" fmax(0, fmin(1, " , " ))" )
340+ | Satur01 , Half_prec _ ->
341+ ( " __hmax_nan(__ushort_as_half((unsigned short)0x0000U), \
342+ __hmin_nan(__ushort_as_half((unsigned short)0x3C00U), " ,
343+ " ))" )
344+ | Satur01 , Single_prec _ -> (" fmaxf(0.0f, fminf(1.0f, " , " ))" )
345+ | Satur01 , _ -> (" fmax(0.0, fmin(1.0, " , " ))" )
346+ | Exp , Half_prec _ -> (" hexp(" , " )" )
347+ | Exp , Double_prec _ -> (" exp(" , " )" )
348+ | Exp , _ -> (" expf(" , " )" )
349+ | Log , Half_prec _ -> (" hlog(" , " )" )
350+ | Log , Double_prec _ -> (" log(" , " )" )
351+ | Log , _ -> (" logf(" , " )" )
352+ | Exp2 , Half_prec _ -> (" hexp2(" , " )" )
353+ | Exp2 , Double_prec _ -> (" exp2(" , " )" )
354+ | Exp2 , _ -> (" exp2f(" , " )" )
355+ | Log2 , Half_prec _ -> (" hlog2(" , " )" )
356+ | Log2 , Double_prec _ -> (" log2(" , " )" )
357+ | Log2 , _ -> (" log2f(" , " )" )
358+ | Sin , Half_prec _ -> (" hsin(" , " )" )
359+ | Sin , Double_prec _ -> (" sin(" , " )" )
360+ | Sin , _ -> (" sinf(" , " )" )
361+ | Cos , Half_prec _ -> (" hcos(" , " )" )
362+ | Cos , Double_prec _ -> (" cos(" , " )" )
363+ | Cos , _ -> (" cosf(" , " )" )
364+ | Sqrt , Half_prec _ -> (" hsqrt(" , " )" )
365+ | Sqrt , Double_prec _ -> (" sqrt(" , " )" )
366+ | Sqrt , _ -> (" sqrtf(" , " )" )
367+ | Recip , Byte_prec _ ->
368+ invalid_arg " Cuda_backend.unop_syntax: Recip not supported for byte/integer precisions"
369+ | Recip , Half_prec _ -> (" hrcp(" , " )" )
370+ | Recip , _ -> (" (1.0 / (" , " ))" )
371+ | Recip_sqrt , Byte_prec _ ->
372+ invalid_arg
373+ " Cuda_backend.unop_syntax: Recip_sqrt not supported for byte/integer precisions"
374+ | Recip_sqrt , Half_prec _ -> (" hrsqrt(" , " )" )
375+ | Recip_sqrt , Double_prec _ -> (" (1.0 / sqrt(" , " ))" )
376+ | Recip_sqrt , _ -> (" (1.0 / sqrtf(" , " ))" )
377+ | Neg , _ -> (" (-(" , " ))" )
378+ | Tanh_approx , Byte_prec _ ->
379+ invalid_arg
380+ " Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
381+ | Tanh_approx , Half_prec _ -> (" htanh_approx(" , " )" )
382+ | Tanh_approx , Single_prec _ -> (" __tanhf(" , " )" )
383+ | Tanh_approx , _ -> (" tanh(" , " )" )
384+ | Not , _ -> (" (" , " == 0.0 ? 1.0 : 0.0)" )
385+
386+ let ternop_syntax prec v =
387+ match (v, prec) with
388+ | Ops. Where , _ -> (" (" , " ? " , " : " , " )" )
389+ | FMA , Ops. Half_prec _ -> (" __hfma(" , " , " , " , " , " )" )
390+ | FMA , Ops. Single_prec _ -> (" fmaf(" , " , " , " , " , " )" )
391+ | FMA , _ -> (" fma(" , " , " , " , " , " )" )
324392
325393 let convert_precision ~from ~to_ =
326394 match (from, to_) with
0 commit comments