@@ -95,84 +95,7 @@ struct
9595 let arg_int_prefix = " const int "
9696 let extra_args = []
9797 let includes = [ " <stdio.h>" ; " <stdlib.h>" ; " <string.h>" ; " <math.h>" ]
98-
99- let extra_declarations =
100- [
101- (* BFloat16 conversion functions *)
102- " static inline float bfloat16_to_single(unsigned short bf16) {" ;
103- " unsigned int f32 = ((unsigned int)bf16) << 16;" ;
104- " return *((float*)&f32);" ;
105- " }" ;
106- " " ;
107- " static inline unsigned short single_to_bfloat16(float f) {" ;
108- " unsigned int f32 = *((unsigned int*)&f);" ;
109- " unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);" ;
110- " return (unsigned short)(rounded >> 16);" ;
111- " }" ;
112- " " ;
113- (* Half (Float16) support with zero-overhead abstraction *)
114- " #ifdef __FLT16_MAX__" ;
115- " #define HAS_NATIVE_FLOAT16 1" ;
116- " #define HALF_T _Float16" ;
117- " #define HALF_TO_FP(x) (x) /* Identity - already floating point */" ;
118- " #define FP_TO_HALF(x) (x) /* Identity - already half precision */" ;
119- " #define HALF_TO_FLOAT(x) ((float)(x))" ;
120- " #define FLOAT_TO_HALF(x) ((_Float16)(x))" ;
121- " #else" ;
122- " #define HAS_NATIVE_FLOAT16 0" ;
123- " #define HALF_T unsigned short" ;
124- " #define HALF_TO_FP(x) half_to_single(x) /* Convert to float for computation */" ;
125- " #define FP_TO_HALF(x) single_to_half(x) /* Convert back from float */" ;
126- " #define HALF_TO_FLOAT(x) half_to_single(x)" ;
127- " #define FLOAT_TO_HALF(x) single_to_half(x)" ;
128- " /* Conversion functions for emulation - provided by builtins.c */" ;
129- " extern float half_to_single(unsigned short h);" ;
130- " extern unsigned short single_to_half(float f);" ;
131- " #endif" ;
132- " " ;
133- (* FP8 E5M2 conversion functions *)
134- " static inline float fp8_to_single(unsigned char fp8) {" ;
135- " if (fp8 == 0) return 0.0f;" ;
136- " unsigned int sign = (fp8 >> 7) & 1;" ;
137- " unsigned int exp = (fp8 >> 2) & 0x1F;" ;
138- " unsigned int mant = fp8 & 0x3;" ;
139- " if (exp == 0x1F) {" ;
140- " if (mant == 0) return sign ? -INFINITY : INFINITY;" ;
141- " else return NAN;" ;
142- " }" ;
143- " if (exp == 0) {" ;
144- " float result = ldexpf((float)mant / 4.0f, -14);" ;
145- " if (sign) result = -result;" ;
146- " return result;" ;
147- " }" ;
148- " float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);" ;
149- " if (sign) result = -result;" ;
150- " return result;" ;
151- " }" ;
152- " " ;
153- " static inline unsigned char single_to_fp8(float f) {" ;
154- " if (f == 0.0f) return 0;" ;
155- " unsigned int sign = (f < 0) ? 1 : 0;" ;
156- " f = fabsf(f);" ;
157- " if (isinf(f)) return (sign << 7) | 0x7C;" ;
158- " if (isnan(f)) return (sign << 7) | 0x7F;" ;
159- " int exp_val;" ;
160- " float mant_f = frexpf(f, &exp_val);" ;
161- " int exp = exp_val + 14;" ;
162- " if (exp < 0) return sign << 7;" ;
163- " if (exp > 30) return (sign << 7) | 0x7C;" ;
164- " if (exp == 0) {" ;
165- " float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;" ;
166- " unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);" ;
167- " if (mant_bits > 3) mant_bits = 3;" ;
168- " return (sign << 7) | mant_bits;" ;
169- " }" ;
170- " mant_f = (mant_f - 0.5f) * 4.0f;" ;
171- " unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);" ;
172- " if (mant_bits > 3) mant_bits = 3;" ;
173- " return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));" ;
174- " }" ;
175- ]
98+ let extra_declarations = []
17699
177100 let typ_of_prec = Ops. c_typ_of_prec
178101 let vec_typ_of_prec = Ops. c_vec_typ_of_prec
@@ -251,243 +174,27 @@ struct
251174 Set. to_list ! functions
252175
253176 let ternop_syntax prec op v1 v2 v3 =
254- match prec with
255- | Ops. Bfloat16_prec _ ->
256- (* For BFloat16, perform operations in float precision *)
257- let float_v1 = PPrint. (string " bfloat16_to_single(" ^^ v1 ^^ string " )" ) in
258- let float_v2 = PPrint. (string " bfloat16_to_single(" ^^ v2 ^^ string " )" ) in
259- let float_v3 = PPrint. (string " bfloat16_to_single(" ^^ v3 ^^ string " )" ) in
260- let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
261- let float_result =
262- PPrint. (
263- group
264- (string op_prefix ^^ float_v1 ^^ string op_infix1
265- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
266- ^^ string op_infix2
267- ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
268- ^^ string op_suffix))
269- in
270- PPrint. (string " single_to_bfloat16(" ^^ float_result ^^ string " )" )
271- | Ops. Half_prec _ ->
272- (* For Half, perform operations in float precision on non-native systems *)
273- let float_v1 = PPrint. (string " HALF_TO_FP(" ^^ v1 ^^ string " )" ) in
274- let float_v2 = PPrint. (string " HALF_TO_FP(" ^^ v2 ^^ string " )" ) in
275- let float_v3 = PPrint. (string " HALF_TO_FP(" ^^ v3 ^^ string " )" ) in
276- let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
277- let float_result =
278- PPrint. (
279- group
280- (string op_prefix ^^ float_v1 ^^ string op_infix1
281- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
282- ^^ string op_infix2
283- ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
284- ^^ string op_suffix))
285- in
286- PPrint. (string " FP_TO_HALF(" ^^ float_result ^^ string " )" )
287- | Ops. Fp8_prec _ ->
288- (* For FP8, perform operations in float precision *)
289- let float_v1 = PPrint. (string " fp8_to_single(" ^^ v1 ^^ string " )" ) in
290- let float_v2 = PPrint. (string " fp8_to_single(" ^^ v2 ^^ string " )" ) in
291- let float_v3 = PPrint. (string " fp8_to_single(" ^^ v3 ^^ string " )" ) in
292- let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
293- let float_result =
294- PPrint. (
295- group
296- (string op_prefix ^^ float_v1 ^^ string op_infix1
297- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
298- ^^ string op_infix2
299- ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
300- ^^ string op_suffix))
301- in
302- PPrint. (string " single_to_fp8(" ^^ float_result ^^ string " )" )
303- | _ ->
304- let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
305- let open PPrint in
306- group
307- (string op_prefix ^^ v1 ^^ string op_infix1
308- ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
309- ^^ string op_infix2
310- ^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
311- ^^ string op_suffix)
177+ let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
178+ let open PPrint in
179+ group
180+ (string op_prefix ^^ v1 ^^ string op_infix1
181+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
182+ ^^ string op_infix2
183+ ^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
184+ ^^ string op_suffix)
312185
313186 let binop_syntax prec op v1 v2 =
314- match op with
315- | Ops. Threefry4x32 -> (
316- match prec with
317- | Ops. Uint4x32_prec _ ->
318- let open PPrint in
319- group (string " arrayjit_threefry4x32(" ^^ v1 ^^ string " , " ^^ v2 ^^ string " )" )
320- | _ -> invalid_arg " Pure_C_config.binop_syntax: Threefry4x32 on non-uint4x32 precision" )
321- | Ops. Satur01_gate -> (
322- match prec with
323- | Ops. Byte_prec _ | Ops. Uint16_prec _ | Ops. Int32_prec _ | Ops. Int64_prec _
324- | Ops. Uint4x32_prec _ ->
325- let open PPrint in
326- group
327- (parens
328- (group
329- (parens
330- (string " (float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
331- ^^ string " < 1.0f" ))
332- ^^ ifflat
333- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
334- ^^ string " ("
335- ^^ string (typ_of_prec prec)
336- ^^ string " )0" )
337- (nest 2
338- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
339- ^^ string " ("
340- ^^ string (typ_of_prec prec)
341- ^^ string " )0" ))))
342- | Ops. Fp8_prec _ ->
343- let open PPrint in
344- group
345- (parens
346- (group
347- (parens
348- (string " fp8_to_single(" ^^ v1
349- ^^ string " ) > 0.0f && fp8_to_single("
350- ^^ v1 ^^ string " ) < 1.0f" ))
351- ^^ ifflat
352- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
353- ^^ string " single_to_fp8(0.0f)" )
354- (nest 2
355- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
356- ^^ string " single_to_fp8(0.0f)" ))))
357- | Ops. Bfloat16_prec _ ->
358- let open PPrint in
359- group
360- (parens
361- (group
362- (parens
363- (string " bfloat16_to_single(" ^^ v1
364- ^^ string " ) > 0.0f && bfloat16_to_single("
365- ^^ v1 ^^ string " ) < 1.0f" ))
366- ^^ ifflat
367- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
368- ^^ string " single_to_bfloat16(0.0f)" )
369- (nest 2
370- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
371- ^^ string " single_to_bfloat16(0.0f)" ))))
372- | Ops. Half_prec _ ->
373- let open PPrint in
374- group
375- (parens
376- (group
377- (parens
378- (string " HALF_TO_FP(" ^^ v1
379- ^^ string " ) > 0.0f && HALF_TO_FP("
380- ^^ v1 ^^ string " ) < 1.0f" ))
381- ^^ ifflat
382- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
383- ^^ string " FP_TO_HALF(0.0f)" )
384- (nest 2
385- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
386- ^^ string " FP_TO_HALF(0.0f)" ))))
387- | Ops. Single_prec _ ->
388- let open PPrint in
389- group
390- (parens
391- (group (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f" ))
392- ^^ ifflat
393- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
394- ^^ string " 0.0f" )
395- (nest 2
396- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
397- ^^ string " 0.0f" ))))
398- | Ops. Double_prec _ ->
399- let open PPrint in
400- group
401- (parens
402- (group (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0" ))
403- ^^ ifflat
404- (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
405- ^^ string " 0.0" )
406- (nest 2
407- (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
408- ^^ string " 0.0" ))))
409- | Ops. Void_prec -> invalid_arg " Pure_C_config.binop_syntax: Satur01_gate on Void_prec" )
410- | _ -> (
411- match prec with
412- | Ops. Bfloat16_prec _ -> (
413- (* For BFloat16, perform all operations in float precision *)
414- let float_v1 = PPrint. (string " bfloat16_to_single(" ^^ v1 ^^ string " )" ) in
415- let float_v2 = PPrint. (string " bfloat16_to_single(" ^^ v2 ^^ string " )" ) in
416- let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
417- let float_result =
418- PPrint. (
419- group
420- (string op_prefix ^^ float_v1 ^^ string op_infix
421- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
422- ^^ string op_suffix))
423- in
424- (* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
425- match op with
426- | Ops. Cmplt | Ops. Cmpeq | Ops. Cmpne | Ops. Or | Ops. And ->
427- PPrint. (string " single_to_bfloat16(" ^^ float_result ^^ string " )" )
428- | _ -> PPrint. (string " single_to_bfloat16(" ^^ float_result ^^ string " )" ))
429- | Ops. Fp8_prec _ ->
430- (* For FP8, perform all operations in float precision *)
431- let float_v1 = PPrint. (string " fp8_to_single(" ^^ v1 ^^ string " )" ) in
432- let float_v2 = PPrint. (string " fp8_to_single(" ^^ v2 ^^ string " )" ) in
433- let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
434- let float_result =
435- PPrint. (
436- group
437- (string op_prefix ^^ float_v1 ^^ string op_infix
438- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
439- ^^ string op_suffix))
440- in
441- PPrint. (string " single_to_fp8(" ^^ float_result ^^ string " )" )
442- | Ops. Half_prec _ ->
443- (* For Half, perform all operations in float precision on non-native systems *)
444- let float_v1 = PPrint. (string " HALF_TO_FP(" ^^ v1 ^^ string " )" ) in
445- let float_v2 = PPrint. (string " HALF_TO_FP(" ^^ v2 ^^ string " )" ) in
446- let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
447- let float_result =
448- PPrint. (
449- group
450- (string op_prefix ^^ float_v1 ^^ string op_infix
451- ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
452- ^^ string op_suffix))
453- in
454- (* For comparison operations, return float result (0.0 or 1.0) converted to Half *)
455- (match op with
456- | Ops. Cmplt | Ops. Cmpeq | Ops. Cmpne | Ops. Or | Ops. And ->
457- PPrint. (string " FP_TO_HALF(" ^^ float_result ^^ string " )" )
458- | _ -> PPrint. (string " FP_TO_HALF(" ^^ float_result ^^ string " )" ))
459- | _ ->
460- let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
461- let open PPrint in
462- group
463- (string op_prefix ^^ v1 ^^ string op_infix
464- ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
465- ^^ string op_suffix))
187+ let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
188+ let open PPrint in
189+ group
190+ (string op_prefix ^^ v1 ^^ string op_infix
191+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
192+ ^^ string op_suffix)
466193
467194 let unop_syntax prec op v =
468- match prec with
469- | Ops. Bfloat16_prec _ ->
470- (* For BFloat16, perform operations in float precision *)
471- let float_v = PPrint. (string " bfloat16_to_single(" ^^ v ^^ string " )" ) in
472- let op_prefix, op_suffix = Ops. unop_c_syntax Ops. single op in
473- let float_result = PPrint. (group (string op_prefix ^^ float_v ^^ string op_suffix)) in
474- PPrint. (string " single_to_bfloat16(" ^^ float_result ^^ string " )" )
475- | Ops. Fp8_prec _ ->
476- (* For FP8, perform operations in float precision *)
477- let float_v = PPrint. (string " fp8_to_single(" ^^ v ^^ string " )" ) in
478- let op_prefix, op_suffix = Ops. unop_c_syntax Ops. single op in
479- let float_result = PPrint. (group (string op_prefix ^^ float_v ^^ string op_suffix)) in
480- PPrint. (string " single_to_fp8(" ^^ float_result ^^ string " )" )
481- | Ops. Half_prec _ ->
482- (* For Half, perform operations in float precision on non-native systems *)
483- let float_v = PPrint. (string " HALF_TO_FP(" ^^ v ^^ string " )" ) in
484- let op_prefix, op_suffix = Ops. unop_c_syntax Ops. single op in
485- let float_result = PPrint. (group (string op_prefix ^^ float_v ^^ string op_suffix)) in
486- PPrint. (string " FP_TO_HALF(" ^^ float_result ^^ string " )" )
487- | _ ->
488- let op_prefix, op_suffix = Ops. unop_c_syntax prec op in
489- let open PPrint in
490- group (string op_prefix ^^ v ^^ string op_suffix)
195+ let op_prefix, op_suffix = Ops. unop_c_syntax prec op in
196+ let open PPrint in
197+ group (string op_prefix ^^ v ^^ string op_suffix)
491198
492199 let vec_unop_syntax prec op v =
493200 let op_prefix, op_suffix = Ops. vec_unop_c_syntax prec op in
0 commit comments