@@ -17,9 +17,11 @@ type ('ocaml, 'impl) precision =
1717 | Byte : (char , uint8_elt ) precision
1818 | Uint16 : (int , uint16_elt ) precision
1919 | Int32 : (int32 , int32_elt ) precision
20- | Uint32 : (int32 , int32_elt ) precision (* * Using int32_elt representation but treating as unsigned *)
20+ | Uint32 : (int32 , int32_elt ) precision
21+ (* * Using int32_elt representation but treating as unsigned *)
2122 | Int64 : (int64 , int64_elt ) precision
22- | Uint64 : (int64 , int64_elt ) precision (* * Using int64_elt representation but treating as unsigned *)
23+ | Uint64 : (int64 , int64_elt ) precision
24+ (* * Using int64_elt representation but treating as unsigned *)
2325 | Uint4x32 : (Stdlib.Complex .t , Bigarray .complex64_elt ) precision
2426 (* * A 128-bit value that corresponds to e.g. CUDA's uint4 type. Luckily, the OCaml Bigarray
2527 library supports complex64_elt which is a 128-bit value, so we avoid dims conversions. *)
@@ -563,21 +565,32 @@ let binop_c_syntax prec v =
563565 | Mul , _ -> (" (" , " *" , " )" )
564566 | Div , _ -> (" (" , " /" , " )" )
565567 | ToPowOf , Double_prec _ -> (" pow(" , " ," , " )" )
566- | ToPowOf , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
568+ | ( ToPowOf ,
569+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
570+ | Fp8_prec _ ) ) ->
567571 invalid_arg " Ops.binop_c_syntax: ToPowOf not supported for integer precisions"
568572 | ToPowOf , _ -> (" powf(" , " ," , " )" )
569- | Relu_gate , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" (" , " > 0 ?" , " : 0)" )
573+ | ( Relu_gate ,
574+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
575+ | Fp8_prec _ ) ) ->
576+ (" (" , " > 0 ?" , " : 0)" )
570577 | Relu_gate , _ -> (" (" , " > 0.0 ?" , " : 0.0)" )
571- | Satur01_gate , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
578+ | ( Satur01_gate ,
579+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
580+ | Fp8_prec _ ) ) ->
572581 (" (abs(" , " ) > 0 ? 0 : (" , " ))" )
573582 | Satur01_gate , Single_prec _ ->
574583 (* This disagrees at 0 with the semantics. *)
575584 (" (fabsf(floorf(" , " )) > 0.0 ? 0.0 : (" , " ))" )
576585 | Satur01_gate , _ -> (" (fabs(floor(" , " )) > 0.0 ? 0.0 : (" , " ))" )
577- | Max , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
586+ | ( Max ,
587+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
588+ | Uint64_prec _ | Fp8_prec _ ) ) ->
578589 (" fmax(" , " ," , " )" )
579590 | Max , _ -> (" fmaxf(" , " ," , " )" )
580- | Min , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
591+ | ( Min ,
592+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
593+ | Uint64_prec _ | Fp8_prec _ ) ) ->
581594 (" fmin(" , " ," , " )" )
582595 | Min , _ -> (" fminf(" , " ," , " )" )
583596 | Mod , _ -> (" (" , " %" , " )" )
@@ -654,43 +667,80 @@ let unop_c_syntax prec op =
654667 let fmax () =
655668 (* See: https://en.cppreference.com/w/c/numeric/math/fmax option (4) *)
656669 match prec with
657- | Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ -> " fmax"
670+ | Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
671+ | Uint64_prec _ | Fp8_prec _ ->
672+ " fmax"
658673 | _ -> " fmaxf"
659674 in
660675 let fmin () =
661676 match prec with
662- | Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ -> " fmin"
677+ | Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
678+ | Uint64_prec _ | Fp8_prec _ ->
679+ " fmin"
663680 | _ -> " fminf"
664681 in
665682 match (op, prec) with
666683 | Identity , _ -> (" " , " " )
667- | Relu , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" fmax(0, " , " )" )
684+ | ( Relu ,
685+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
686+ | Fp8_prec _ ) ) ->
687+ (" fmax(0, " , " )" )
668688 | Relu , _ -> (fmax () ^ " (0.0, " , " )" )
669- | Satur01 , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" fmax(0, fmin(1, " , " ))" )
689+ | ( Satur01 ,
690+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
691+ | Fp8_prec _ ) ) ->
692+ (" fmax(0, fmin(1, " , " ))" )
670693 | Satur01 , _ -> (fmax () ^ " (0.0, " ^ fmin () ^ " (1.0, " , " ))" )
671- | Exp , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" exp(" , " )" )
694+ | ( Exp ,
695+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
696+ | Uint64_prec _ | Fp8_prec _ ) ) ->
697+ (" exp(" , " )" )
672698 | Exp , _ -> (" expf(" , " )" )
673- | Log , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" log(" , " )" )
699+ | ( Log ,
700+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
701+ | Uint64_prec _ | Fp8_prec _ ) ) ->
702+ (" log(" , " )" )
674703 | Log , _ -> (" logf(" , " )" )
675- | Exp2 , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" exp2(" , " )" )
704+ | ( Exp2 ,
705+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
706+ | Uint64_prec _ | Fp8_prec _ ) ) ->
707+ (" exp2(" , " )" )
676708 | Exp2 , _ -> (" exp2f(" , " )" )
677- | Log2 , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" log2(" , " )" )
709+ | ( Log2 ,
710+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
711+ | Uint64_prec _ | Fp8_prec _ ) ) ->
712+ (" log2(" , " )" )
678713 | Log2 , _ -> (" log2f(" , " )" )
679- | Sin , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" sin(" , " )" )
714+ | ( Sin ,
715+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
716+ | Uint64_prec _ | Fp8_prec _ ) ) ->
717+ (" sin(" , " )" )
680718 | Sin , _ -> (" sinf(" , " )" )
681- | Cos , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" cos(" , " )" )
719+ | ( Cos ,
720+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
721+ | Uint64_prec _ | Fp8_prec _ ) ) ->
722+ (" cos(" , " )" )
682723 | Cos , _ -> (" cosf(" , " )" )
683- | Sqrt , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) -> (" sqrt(" , " )" )
724+ | ( Sqrt ,
725+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
726+ | Uint64_prec _ | Fp8_prec _ ) ) ->
727+ (" sqrt(" , " )" )
684728 | Sqrt , _ -> (" sqrtf(" , " )" )
685- | Recip , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
729+ | ( Recip ,
730+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
731+ | Fp8_prec _ ) ) ->
686732 invalid_arg " Ops.unop_c_syntax: Recip not supported for integer precisions"
687733 | Recip , _ -> (" (1.0 / (" , " ))" )
688- | Recip_sqrt , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
734+ | ( Recip_sqrt ,
735+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
736+ | Fp8_prec _ ) ) ->
689737 invalid_arg " Ops.unop_c_syntax: Recip_sqrt not supported for integer precisions"
690738 | Recip_sqrt , Double_prec _ -> (" (1.0 / sqrt(" , " ))" )
691739 | Recip_sqrt , _ -> (" (1.0 / sqrtf(" , " ))" )
692740 | Neg , _ -> (" (-(" , " ))" )
693- | Tanh_approx , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
741+ | ( Tanh_approx ,
742+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
743+ | Fp8_prec _ ) ) ->
694744 invalid_arg " Ops.unop_c_syntax: Tanh_approx not supported for integer precisions"
695745 | Tanh_approx , _ -> (" tanhf(" , " )" )
696746 | Not , _ -> (" (" , " == 0.0 ? 1.0 : 0.0)" )
@@ -709,10 +759,14 @@ let ternop_cd_syntax = function Where -> "where" | FMA -> "fma"
709759
710760let ternop_c_syntax prec op =
711761 match (op, prec) with
712- | Where , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
762+ | ( Where ,
763+ ( Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _
764+ | Fp8_prec _ ) ) ->
713765 (" ((" , " ) != 0 ? (" , " ) : (" , " ))" )
714766 | Where , _ -> (" ((" , " ) != 0.0 ? (" , " ) : (" , " ))" )
715- | FMA , (Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ | Fp8_prec _ ) ->
767+ | ( FMA ,
768+ ( Double_prec _ | Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _
769+ | Uint64_prec _ | Fp8_prec _ ) ) ->
716770 (" fma(" , " ," , " ," , " )" )
717771 | FMA , _ -> (" fmaf(" , " ," , " ," , " )" )
718772
@@ -745,16 +799,22 @@ let c_convert_precision ~from ~to_ =
745799 (* Conversions involving BFloat16 and other types *)
746800 | Bfloat16_prec _ , Half_prec _ -> (" FLOAT_TO_HALF(bfloat16_to_single(" , " ))" )
747801 | Half_prec _ , Bfloat16_prec _ -> (" single_to_bfloat16(HALF_TO_FLOAT(" , " ))" )
748- | Bfloat16_prec _ , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ) ->
802+ | ( Bfloat16_prec _,
803+ (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
804+ ->
749805 (" (" ^ c_typ_of_prec to_ ^ " )bfloat16_to_single(" , " )" )
750- | (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ), Bfloat16_prec _ ->
806+ | ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
807+ Bfloat16_prec _ ) ->
751808 (" single_to_bfloat16((float)" , " )" )
752809 (* Conversions involving FP8 and other types *)
753810 | Fp8_prec _ , Half_prec _ -> (" FLOAT_TO_HALF(fp8_to_single(" , " ))" )
754811 | Half_prec _ , Fp8_prec _ -> (" single_to_fp8(HALF_TO_FLOAT(" , " ))" )
755- | Fp8_prec _ , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ) ->
812+ | ( Fp8_prec _,
813+ (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
814+ ->
756815 (" (" ^ c_typ_of_prec to_ ^ " )fp8_to_single(" , " )" )
757- | (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ), Fp8_prec _ ->
816+ | ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
817+ Fp8_prec _ ) ->
758818 (" single_to_fp8((float)" , " )" )
759819 (* BFloat16 <-> FP8 conversions *)
760820 | Bfloat16_prec _ , Fp8_prec _ -> (" single_to_fp8(bfloat16_to_single(" , " ))" )
@@ -764,9 +824,12 @@ let c_convert_precision ~from ~to_ =
764824 | Single_prec _ , Half_prec _ -> (" FLOAT_TO_HALF(" , " )" )
765825 | Half_prec _ , Double_prec _ -> (" (double)HALF_TO_FLOAT(" , " )" )
766826 | Double_prec _ , Half_prec _ -> (" FLOAT_TO_HALF((float)" , " )" )
767- | Half_prec _ , (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ) ->
827+ | ( Half_prec _,
828+ (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _) )
829+ ->
768830 (" (" ^ c_typ_of_prec to_ ^ " )HALF_TO_FLOAT(" , " )" )
769- | (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _ ), Half_prec _ ->
831+ | ( (Byte_prec _ | Uint16_prec _ | Int32_prec _ | Uint32_prec _ | Int64_prec _ | Uint64_prec _),
832+ Half_prec _ ) ->
770833 (" FLOAT_TO_HALF((float)" , " )" )
771834 (* Uint4x32 conversions - special handling *)
772835 | Uint4x32_prec _ , _ -> (" uint4x32_to_" ^ prec_string to_ ^ " (" , " )" )
0 commit comments