@@ -83,16 +83,30 @@ let%track7_sexp c_compile_and_load ~f_name =
8383 Stdlib.Gc. finalise finalize result;
8484 result
8585
86- let % diagn_sexp compile ~(name : string ) bindings (lowered : Low_level.optimized ) : procedure =
87- let module Syntax = C_syntax. C_syntax (C_syntax. Pure_C_config (struct
86+ module CC_syntax_config (Procs : sig
87+ val procs : Low_level .optimized array
88+ end ) =
89+ struct
90+ include C_syntax. Pure_C_config (struct
8891 type nonrec buffer_ptr = buffer_ptr
8992
9093 let use_host_memory = use_host_memory
91- let procs = [| lowered |]
94+ let procs = Procs. procs
9295
9396 let full_printf_support =
9497 not @@ Bool. of_string
9598 @@ Utils. get_global_arg ~default: " false" ~arg_name: " prefer_backend_uniformity"
99+ end )
100+
101+ (* Override to add our custom type and conversion support *)
102+ let typ_of_prec = typ_of_prec
103+ let extra_declarations = extra_declarations (* Our bfloat16/fp8 conversion functions *)
104+ let convert_precision = convert_precision
105+ end
106+
107+ let % diagn_sexp compile ~(name : string ) bindings (lowered : Low_level.optimized ) : procedure =
108+ let module Syntax = C_syntax. C_syntax (CC_syntax_config (struct
109+ let procs = [| lowered |]
96110 end )) in
97111 (* FIXME: do we really want all of them, or only the used ones? *)
98112 let idx_params = Indexing. bound_symbols bindings in
@@ -110,15 +124,8 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
110124
111125let % diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array ) :
112126 procedure option array =
113- let module Syntax = C_syntax. C_syntax (C_syntax. Pure_C_config (struct
114- type nonrec buffer_ptr = buffer_ptr
115-
116- let use_host_memory = use_host_memory
127+ let module Syntax = C_syntax. C_syntax (CC_syntax_config (struct
117128 let procs = Array. filter_opt lowereds
118-
119- let full_printf_support =
120- not @@ Bool. of_string
121- @@ Utils. get_global_arg ~default: " false" ~arg_name: " prefer_backend_uniformity"
122129 end )) in
123130 (* FIXME: do we really want all of them, or only the used ones? *)
124131 let idx_params = Indexing. bound_symbols bindings in
@@ -203,3 +210,71 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
203210 description = " executes " ^ code.name ^ " on " ^ runner_label;
204211 work;
205212 } )
213+ (*
214+ let typ_of_prec = function
215+ | Ops.Byte_prec _ -> "unsigned char"
216+ | Ops.Uint16_prec _ -> "unsigned short"
217+ | Ops.Int32_prec _ -> "int"
218+ | Ops.Half_prec _ -> "_Float16"
219+ | Ops.Bfloat16_prec _ -> "unsigned short" (* Stored as uint16, emulated as float *)
220+ | Ops.Fp8_prec _ -> "unsigned char" (* Stored as uint8, emulated as float *)
221+ | Ops.Single_prec _ -> "float"
222+ | Ops.Double_prec _ -> "double"
223+ | Ops.Void_prec -> "void"
224+
225+ (* Helper functions for bfloat16 and fp8 conversions *)
226+ let extra_declarations =
227+ [
228+ "/* Emulation functions for special float types */";
229+ "static inline float bfloat16_to_float(unsigned short bf16) {";
230+ " unsigned int f32 = ((unsigned int)bf16) << 16;";
231+ " return *(float*) & f32;" ;
232+ " }" ;
233+ " " ;
234+ " static inline unsigned short float_to_bfloat16(float f) {" ;
235+ " unsigned int f32 = * (unsigned int * )& f;" ;
236+ " unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16 ) & 1 );" ;
237+ " return (unsigned short)(rounded >> 16 );" ;
238+ " }" ;
239+ " " ;
240+ " /* Simplified FP8 E5M2 format emulation */ " ;
241+ " static inline float fp8_to_float(unsigned char fp8) {" ;
242+ " if (fp8 == 0 ) return 0.0 f;" ;
243+ " unsigned int sign = (fp8 >> 7 ) & 1 ;" ;
244+ " unsigned int exp = (fp8 >> 2 ) & 0x1F ;" ;
245+ " unsigned int mant = fp8 & 0x3 ;" ;
246+ " float result = (1.0 f + mant * 0.25 f) * powf(2.0 f, (float )exp - 15.0 f);" ;
247+ " return sign ? - result : result ; " ;
248+ " }" ;
249+ " " ;
250+ " static inline unsigned char float_to_fp8(float f) {" ;
251+ " if (f == 0.0 f) return 0 ;" ;
252+ " unsigned int sign = (f < 0 ) ? 1 : 0 ;" ;
253+ " f = fabsf(f);" ;
254+ " int exp = (int )floorf(log2f(f)) + 15 ;" ;
255+ " if (exp < 0 ) return 0 ;" ;
256+ " if (exp > 31 ) return sign ? 0xFF : 0x7F ;" ;
257+ " float mant = f / powf(2.0 f, (float )exp - 15.0 f) - 1.0 f;" ;
258+ " unsigned int mant_bits = (unsigned int )(mant * 4.0 f + 0.5 f);" ;
259+ " if (mant_bits > 3 ) mant_bits = 3 ;" ;
260+ " return (unsigned char )((sign << 7 ) | ((exp & 0x1F ) << 2 ) | (mant_bits & 0x3 ));" ;
261+ " }" ;
262+ ]
263+
264+ let convert_precision ~from ~to_ =
265+ match (from, to_) with
266+ | p1, p2 when Ops.equal_prec p1 p2 -> (" " , " " )
267+ | Ops.Bfloat16_prec _, Ops.Single_prec _ -> (" bfloat16_to_float(" , " )" )
268+ | Ops.Bfloat16_prec _, Ops.Double_prec _ -> (" (double)bfloat16_to_float(" , " )" )
269+ | Ops.Single_prec _, Ops.Bfloat16_prec _ -> (" float_to_bfloat16(" , " )" )
270+ | Ops.Double_prec _, Ops.Bfloat16_prec _ -> (" float_to_bfloat16((float )" , " )" )
271+ | Ops.Fp8_prec _, Ops.Single_prec _ -> (" fp8_to_float(" , " )" )
272+ | Ops.Fp8_prec _, Ops.Double_prec _ -> (" (double)fp8_to_float(" , " )" )
273+ | Ops.Single_prec _, Ops.Fp8_prec _ -> (" float_to_fp8(" , " )" )
274+ | Ops.Double_prec _, Ops.Fp8_prec _ -> (" float_to_fp8((float )" , " )" )
275+ | Ops.Bfloat16_prec _, _ -> (" (float )bfloat16_to_float(" , " )" ) (* Convert via float *)
276+ | _, Ops.Bfloat16_prec _ -> (" float_to_bfloat16((float )" , " )" )
277+ | Ops.Fp8_prec _, _ -> (" (float )fp8_to_float(" , " )" ) (* Convert via float *)
278+ | _, Ops.Fp8_prec _ -> (" float_to_fp8((float )" , " )" )
279+ | _ -> Ops.c_convert_precision ~from ~to_
280+ *)
0 commit comments