@@ -91,62 +91,65 @@ struct
9191 let arg_int_prefix = " const int "
9292 let extra_args = []
9393 let includes = [ " <stdio.h>" ; " <stdlib.h>" ; " <string.h>" ; " <math.h>" ]
94- let extra_declarations = [
95- (* BFloat16 conversion functions *)
96- " static inline float bfloat16_to_float(unsigned short bf16) {" ;
97- " unsigned int f32 = ((unsigned int)bf16) << 16;" ;
98- " return *((float*)&f32);" ;
99- " }" ;
100- " " ;
101- " static inline unsigned short float_to_bfloat16(float f) {" ;
102- " unsigned int f32 = *((unsigned int*)&f);" ;
103- " unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);" ;
104- " return (unsigned short)(rounded >> 16);" ;
105- " }" ;
106- " " ;
107- (* FP8 E5M2 conversion functions *)
108- " static inline float fp8_to_float(unsigned char fp8) {" ;
109- " if (fp8 == 0) return 0.0f;" ;
110- " unsigned int sign = (fp8 >> 7) & 1;" ;
111- " unsigned int exp = (fp8 >> 2) & 0x1F;" ;
112- " unsigned int mant = fp8 & 0x3;" ;
113- " if (exp == 0x1F) {" ;
114- " if (mant == 0) return sign ? -INFINITY : INFINITY;" ;
115- " else return NAN;" ;
116- " }" ;
117- " if (exp == 0) {" ;
118- " float result = ldexpf((float)mant / 4.0f, -14);" ;
119- " if (sign) result = -result;" ;
120- " return result;" ;
121- " }" ;
122- " float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);" ;
123- " if (sign) result = -result;" ;
124- " return result;" ;
125- " }" ;
126- " " ;
127- " static inline unsigned char float_to_fp8(float f) {" ;
128- " if (f == 0.0f) return 0;" ;
129- " unsigned int sign = (f < 0) ? 1 : 0;" ;
130- " f = fabsf(f);" ;
131- " if (isinf(f)) return (sign << 7) | 0x7C;" ;
132- " if (isnan(f)) return (sign << 7) | 0x7F;" ;
133- " int exp_val;" ;
134- " float mant_f = frexpf(f, &exp_val);" ;
135- " int exp = exp_val + 14;" ;
136- " if (exp < 0) return sign << 7;" ;
137- " if (exp > 30) return (sign << 7) | 0x7C;" ;
138- " if (exp == 0) {" ;
139- " float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;" ;
140- " unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);" ;
141- " if (mant_bits > 3) mant_bits = 3;" ;
142- " return (sign << 7) | mant_bits;" ;
143- " }" ;
144- " mant_f = (mant_f - 0.5f) * 4.0f;" ;
145- " unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);" ;
146- " if (mant_bits > 3) mant_bits = 3;" ;
147- " return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));" ;
148- " }" ;
149- ]
94+
95+ let extra_declarations =
96+ [
97+ (* BFloat16 conversion functions *)
98+ " static inline float bfloat16_to_float(unsigned short bf16) {" ;
99+ " unsigned int f32 = ((unsigned int)bf16) << 16;" ;
100+ " return *((float*)&f32);" ;
101+ " }" ;
102+ " " ;
103+ " static inline unsigned short float_to_bfloat16(float f) {" ;
104+ " unsigned int f32 = *((unsigned int*)&f);" ;
105+ " unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);" ;
106+ " return (unsigned short)(rounded >> 16);" ;
107+ " }" ;
108+ " " ;
109+ (* FP8 E5M2 conversion functions *)
110+ " static inline float fp8_to_float(unsigned char fp8) {" ;
111+ " if (fp8 == 0) return 0.0f;" ;
112+ " unsigned int sign = (fp8 >> 7) & 1;" ;
113+ " unsigned int exp = (fp8 >> 2) & 0x1F;" ;
114+ " unsigned int mant = fp8 & 0x3;" ;
115+ " if (exp == 0x1F) {" ;
116+ " if (mant == 0) return sign ? -INFINITY : INFINITY;" ;
117+ " else return NAN;" ;
118+ " }" ;
119+ " if (exp == 0) {" ;
120+ " float result = ldexpf((float)mant / 4.0f, -14);" ;
121+ " if (sign) result = -result;" ;
122+ " return result;" ;
123+ " }" ;
124+ " float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);" ;
125+ " if (sign) result = -result;" ;
126+ " return result;" ;
127+ " }" ;
128+ " " ;
129+ " static inline unsigned char float_to_fp8(float f) {" ;
130+ " if (f == 0.0f) return 0;" ;
131+ " unsigned int sign = (f < 0) ? 1 : 0;" ;
132+ " f = fabsf(f);" ;
133+ " if (isinf(f)) return (sign << 7) | 0x7C;" ;
134+ " if (isnan(f)) return (sign << 7) | 0x7F;" ;
135+ " int exp_val;" ;
136+ " float mant_f = frexpf(f, &exp_val);" ;
137+ " int exp = exp_val + 14;" ;
138+ " if (exp < 0) return sign << 7;" ;
139+ " if (exp > 30) return (sign << 7) | 0x7C;" ;
140+ " if (exp == 0) {" ;
141+ " float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;" ;
142+ " unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);" ;
143+ " if (mant_bits > 3) mant_bits = 3;" ;
144+ " return (sign << 7) | mant_bits;" ;
145+ " }" ;
146+ " mant_f = (mant_f - 0.5f) * 4.0f;" ;
147+ " unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);" ;
148+ " if (mant_bits > 3) mant_bits = 3;" ;
149+ " return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));" ;
150+ " }" ;
151+ ]
152+
150153 let typ_of_prec = Ops. c_typ_of_prec
151154 let float_log_style = if Input. full_printf_support then " %g" else " %de-3"
152155
@@ -221,28 +224,34 @@ struct
221224 | Ops. Bfloat16_prec _ ->
222225 (* For BFloat16, perform operations in float precision *)
223226 let float_v1 = PPrint. (string " bfloat16_to_float(" ^^ v1 ^^ string " )" ) in
224- let float_v2 = PPrint. (string " bfloat16_to_float(" ^^ v2 ^^ string " )" ) in
227+ let float_v2 = PPrint. (string " bfloat16_to_float(" ^^ v2 ^^ string " )" ) in
225228 let float_v3 = PPrint. (string " bfloat16_to_float(" ^^ v3 ^^ string " )" ) in
226229 let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
227- let float_result = PPrint. (
228- group (string op_prefix ^^ float_v1 ^^ string op_infix1
229- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
230- ^^ string op_infix2
231- ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
232- ^^ string op_suffix)) in
230+ let float_result =
231+ PPrint. (
232+ group
233+ (string op_prefix ^^ float_v1 ^^ string op_infix1
234+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
235+ ^^ string op_infix2
236+ ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
237+ ^^ string op_suffix))
238+ in
233239 PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
234240 | Ops. Fp8_prec _ ->
235241 (* For FP8, perform operations in float precision *)
236242 let float_v1 = PPrint. (string " fp8_to_float(" ^^ v1 ^^ string " )" ) in
237- let float_v2 = PPrint. (string " fp8_to_float(" ^^ v2 ^^ string " )" ) in
243+ let float_v2 = PPrint. (string " fp8_to_float(" ^^ v2 ^^ string " )" ) in
238244 let float_v3 = PPrint. (string " fp8_to_float(" ^^ v3 ^^ string " )" ) in
239245 let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
240- let float_result = PPrint. (
241- group (string op_prefix ^^ float_v1 ^^ string op_infix1
242- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
243- ^^ string op_infix2
244- ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
245- ^^ string op_suffix)) in
246+ let float_result =
247+ PPrint. (
248+ group
249+ (string op_prefix ^^ float_v1 ^^ string op_infix1
250+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
251+ ^^ string op_infix2
252+ ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
253+ ^^ string op_suffix))
254+ in
246255 PPrint. (string " float_to_fp8(" ^^ float_result ^^ string " )" )
247256 | _ ->
248257 let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
@@ -268,18 +277,23 @@ struct
268277 ^^ string " < 1.0f" ))
269278 ^^ ifflat
270279 (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
271- ^^ string " (" ^^ string (typ_of_prec prec) ^^ string " )0" )
280+ ^^ string " ("
281+ ^^ string (typ_of_prec prec)
282+ ^^ string " )0" )
272283 (nest 2
273284 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
274- ^^ string " (" ^^ string (typ_of_prec prec) ^^ string " )0" ))))
285+ ^^ string " ("
286+ ^^ string (typ_of_prec prec)
287+ ^^ string " )0" ))))
275288 | Ops. Fp8_prec _ ->
276289 let open PPrint in
277290 group
278291 (parens
279292 (group
280293 (parens
281- (string " fp8_to_float(" ^^ v1 ^^ string " ) > 0.0f && fp8_to_float("
282- ^^ v1 ^^ string " ) < 1.0f" ))
294+ (string " fp8_to_float(" ^^ v1
295+ ^^ string " ) > 0.0f && fp8_to_float("
296+ ^^ v1 ^^ string " ) < 1.0f" ))
283297 ^^ ifflat
284298 (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
285299 ^^ string " float_to_fp8(0.0f)" )
@@ -292,8 +306,9 @@ struct
292306 (parens
293307 (group
294308 (parens
295- (string " bfloat16_to_float(" ^^ v1 ^^ string " ) > 0.0f && bfloat16_to_float("
296- ^^ v1 ^^ string " ) < 1.0f" ))
309+ (string " bfloat16_to_float(" ^^ v1
310+ ^^ string " ) > 0.0f && bfloat16_to_float("
311+ ^^ v1 ^^ string " ) < 1.0f" ))
297312 ^^ ifflat
298313 (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
299314 ^^ string " float_to_bfloat16(0.0f)" )
@@ -334,40 +349,45 @@ struct
334349 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
335350 ^^ string " 0.0" ))))
336351 | Ops. Void_prec -> invalid_arg " Pure_C_config.binop_syntax: Satur01_gate on Void_prec" )
337- | _ ->
352+ | _ -> (
338353 match prec with
339- | Ops. Bfloat16_prec _ ->
354+ | Ops. Bfloat16_prec _ -> (
340355 (* For BFloat16, perform all operations in float precision *)
341356 let float_v1 = PPrint. (string " bfloat16_to_float(" ^^ v1 ^^ string " )" ) in
342357 let float_v2 = PPrint. (string " bfloat16_to_float(" ^^ v2 ^^ string " )" ) in
343358 let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
344- let float_result = PPrint. (
345- group (string op_prefix ^^ float_v1 ^^ string op_infix
346- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
347- ^^ string op_suffix)) in
359+ let float_result =
360+ PPrint. (
361+ group
362+ (string op_prefix ^^ float_v1 ^^ string op_infix
363+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
364+ ^^ string op_suffix))
365+ in
348366 (* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
349- ( match op with
367+ match op with
350368 | Ops. Cmplt | Ops. Cmpeq | Ops. Cmpne | Ops. Or | Ops. And ->
351- PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
352- | _ ->
353- PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" ))
369+ PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
370+ | _ -> PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" ))
354371 | Ops. Fp8_prec _ ->
355372 (* For FP8, perform all operations in float precision *)
356373 let float_v1 = PPrint. (string " fp8_to_float(" ^^ v1 ^^ string " )" ) in
357374 let float_v2 = PPrint. (string " fp8_to_float(" ^^ v2 ^^ string " )" ) in
358375 let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
359- let float_result = PPrint. (
360- group (string op_prefix ^^ float_v1 ^^ string op_infix
361- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
362- ^^ string op_suffix)) in
376+ let float_result =
377+ PPrint. (
378+ group
379+ (string op_prefix ^^ float_v1 ^^ string op_infix
380+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
381+ ^^ string op_suffix))
382+ in
363383 PPrint. (string " float_to_fp8(" ^^ float_result ^^ string " )" )
364384 | _ ->
365385 let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
366386 let open PPrint in
367387 group
368388 (string op_prefix ^^ v1 ^^ string op_infix
369389 ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
370- ^^ string op_suffix)
390+ ^^ string op_suffix))
371391
372392 let unop_syntax prec op v =
373393 match prec with
0 commit comments