@@ -217,20 +217,48 @@ struct
217217 Set. to_list ! functions
218218
219219 let ternop_syntax prec op v1 v2 v3 =
220- let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
221- let open PPrint in
222- group
223- (string op_prefix ^^ v1 ^^ string op_infix1
224- ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
225- ^^ string op_infix2
226- ^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
227- ^^ string op_suffix)
220+ match prec with
221+ | Ops. Bfloat16_prec _ ->
222+ (* For BFloat16, perform operations in float precision *)
223+ let float_v1 = PPrint. (string " bfloat16_to_float(" ^^ v1 ^^ string " )" ) in
224+ let float_v2 = PPrint. (string " bfloat16_to_float(" ^^ v2 ^^ string " )" ) in
225+ let float_v3 = PPrint. (string " bfloat16_to_float(" ^^ v3 ^^ string " )" ) in
226+ let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
227+ let float_result = PPrint. (
228+ group (string op_prefix ^^ float_v1 ^^ string op_infix1
229+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
230+ ^^ string op_infix2
231+ ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
232+ ^^ string op_suffix)) in
233+ PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
234+ | Ops. Fp8_prec _ ->
235+ (* For FP8, perform operations in float precision *)
236+ let float_v1 = PPrint. (string " fp8_to_float(" ^^ v1 ^^ string " )" ) in
237+ let float_v2 = PPrint. (string " fp8_to_float(" ^^ v2 ^^ string " )" ) in
238+ let float_v3 = PPrint. (string " fp8_to_float(" ^^ v3 ^^ string " )" ) in
239+ let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax Ops. single op in
240+ let float_result = PPrint. (
241+ group (string op_prefix ^^ float_v1 ^^ string op_infix1
242+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
243+ ^^ string op_infix2
244+ ^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
245+ ^^ string op_suffix)) in
246+ PPrint. (string " float_to_fp8(" ^^ float_result ^^ string " )" )
247+ | _ ->
248+ let op_prefix, op_infix1, op_infix2, op_suffix = Ops. ternop_c_syntax prec op in
249+ let open PPrint in
250+ group
251+ (string op_prefix ^^ v1 ^^ string op_infix1
252+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
253+ ^^ string op_infix2
254+ ^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
255+ ^^ string op_suffix)
228256
229257 let binop_syntax prec op v1 v2 =
230258 match op with
231259 | Ops. Satur01_gate -> (
232260 match prec with
233- | Ops. Byte_prec _ | Ops. Uint16_prec _ | Ops. Int32_prec _ | Ops. Fp8_prec _ ->
261+ | Ops. Byte_prec _ | Ops. Uint16_prec _ | Ops. Int32_prec _ ->
234262 let open PPrint in
235263 group
236264 (parens
@@ -244,21 +272,34 @@ struct
244272 (nest 2
245273 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
246274 ^^ string " (" ^^ string (typ_of_prec prec) ^^ string " )0" ))))
275+ | Ops. Fp8_prec _ ->
276+ let open PPrint in
277+ group
278+ (parens
279+ (group
280+ (parens
281+ (string " fp8_to_float(" ^^ v1 ^^ string " ) > 0.0f && fp8_to_float("
282+ ^^ v1 ^^ string " ) < 1.0f" ))
283+ ^^ ifflat
284+ (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
285+ ^^ string " float_to_fp8(0.0f)" )
286+ (nest 2
287+ (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
288+ ^^ string " float_to_fp8(0.0f)" ))))
247289 | Ops. Bfloat16_prec _ ->
248- (* For CC backend, convert to float for computation *)
249290 let open PPrint in
250291 group
251292 (parens
252293 (group
253294 (parens
254- (string " (float) " ^^ v1 ^^ string " > 0.0f && (float) " ^^ v1
255- ^^ string " < 1.0f" ))
295+ (string " bfloat16_to_float( " ^^ v1 ^^ string " ) > 0.0f && bfloat16_to_float( "
296+ ^^ v1 ^^ string " ) < 1.0f" ))
256297 ^^ ifflat
257298 (space ^^ string " ?" ^^ space ^^ v2 ^^ space ^^ string " :" ^^ space
258- ^^ string " (unsigned short)0 " )
299+ ^^ string " float_to_bfloat16(0.0f) " )
259300 (nest 2
260301 (break 1 ^^ string " ?" ^^ space ^^ v2 ^^ break 1 ^^ string " :" ^^ space
261- ^^ string " (unsigned short)0 " ))))
302+ ^^ string " float_to_bfloat16(0.0f) " ))))
262303 | Ops. Half_prec _ ->
263304 let open PPrint in
264305 group
@@ -294,17 +335,58 @@ struct
294335 ^^ string " 0.0" ))))
295336 | Ops. Void_prec -> invalid_arg " Pure_C_config.binop_syntax: Satur01_gate on Void_prec" )
296337 | _ ->
297- let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
298- let open PPrint in
299- group
300- (string op_prefix ^^ v1 ^^ string op_infix
301- ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
302- ^^ string op_suffix)
338+ match prec with
339+ | Ops. Bfloat16_prec _ ->
340+ (* For BFloat16, perform all operations in float precision *)
341+ let float_v1 = PPrint. (string " bfloat16_to_float(" ^^ v1 ^^ string " )" ) in
342+ let float_v2 = PPrint. (string " bfloat16_to_float(" ^^ v2 ^^ string " )" ) in
343+ let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
344+ let float_result = PPrint. (
345+ group (string op_prefix ^^ float_v1 ^^ string op_infix
346+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
347+ ^^ string op_suffix)) in
348+ (* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
349+ (match op with
350+ | Ops. Cmplt | Ops. Cmpeq | Ops. Cmpne | Ops. Or | Ops. And ->
351+ PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
352+ | _ ->
353+ PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" ))
354+ | Ops. Fp8_prec _ ->
355+ (* For FP8, perform all operations in float precision *)
356+ let float_v1 = PPrint. (string " fp8_to_float(" ^^ v1 ^^ string " )" ) in
357+ let float_v2 = PPrint. (string " fp8_to_float(" ^^ v2 ^^ string " )" ) in
358+ let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax Ops. single op in
359+ let float_result = PPrint. (
360+ group (string op_prefix ^^ float_v1 ^^ string op_infix
361+ ^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
362+ ^^ string op_suffix)) in
363+ PPrint. (string " float_to_fp8(" ^^ float_result ^^ string " )" )
364+ | _ ->
365+ let op_prefix, op_infix, op_suffix = Ops. binop_c_syntax prec op in
366+ let open PPrint in
367+ group
368+ (string op_prefix ^^ v1 ^^ string op_infix
369+ ^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
370+ ^^ string op_suffix)
303371
304372 let unop_syntax prec op v =
305- let op_prefix, op_suffix = Ops. unop_c_syntax prec op in
306- let open PPrint in
307- group (string op_prefix ^^ v ^^ string op_suffix)
373+ match prec with
374+ | Ops. Bfloat16_prec _ ->
375+ (* For BFloat16, perform operations in float precision *)
376+ let float_v = PPrint. (string " bfloat16_to_float(" ^^ v ^^ string " )" ) in
377+ let op_prefix, op_suffix = Ops. unop_c_syntax Ops. single op in
378+ let float_result = PPrint. (group (string op_prefix ^^ float_v ^^ string op_suffix)) in
379+ PPrint. (string " float_to_bfloat16(" ^^ float_result ^^ string " )" )
380+ | Ops. Fp8_prec _ ->
381+ (* For FP8, perform operations in float precision *)
382+ let float_v = PPrint. (string " fp8_to_float(" ^^ v ^^ string " )" ) in
383+ let op_prefix, op_suffix = Ops. unop_c_syntax Ops. single op in
384+ let float_result = PPrint. (group (string op_prefix ^^ float_v ^^ string op_suffix)) in
385+ PPrint. (string " float_to_fp8(" ^^ float_result ^^ string " )" )
386+ | _ ->
387+ let op_prefix, op_suffix = Ops. unop_c_syntax prec op in
388+ let open PPrint in
389+ group (string op_prefix ^^ v ^^ string op_suffix)
308390
309391 let convert_precision = Ops. c_convert_precision
310392 let kernel_log_param = Some (" const char*" , " log_file_name" )
0 commit comments