Skip to content

Commit 6e40abb

Browse files
committed
Cleanup c_syntax.ml vs. cc_backend.ml responsibilities, by Claude Opus
Prompt: > Let's restore (or establish) a proper division of responsibilities between the files / modules c_syntax.ml and cc_backend.ml . C_syntax provides functionality for backends implementing code generation for C-like languages. C_syntax.Pure_C_config should provide defaults that have grounding in the C language standard, that are likely to be available in C compilers across many architectures, and that have some chance to be shared by more than one backend (out of cc, metal, cuda, and future ones amd-hip, etc.) Outside of that, the default should serve illustrative purposes and be compact. Summary by Claude: The refactoring properly established a division of responsibilities between c_syntax.ml and the backend modules: c_syntax.ml - Pure_C_config Now provides only the most basic, standard C-compatible defaults: - Standard C library includes (<stdio.h>, <stdlib.h>, <string.h>, <math.h>) - Simple pass-through implementations for operator syntax that just use the basic C syntax from ops.ml - No precision-specific emulation logic cc_backend.ml - CC_syntax_config Now properly overrides to provide C compiler-specific implementations: - Added precision conversion declarations (BFloat16, FP8, Half/Float16 emulation) - Added operator syntax overrides for special precision types (BFloat16, FP8, Half) that perform operations in float precision - Removed redundant field assignments that didn't actually override anything This separation ensures: 1. Pure_C_config contains only universally applicable C standards that work across architectures 2. Backend-specific precision handling and optimizations are in the appropriate backend modules 3. The cc backend now correctly handles precision conversions through software emulation when needed 4. CUDA and Metal backends can continue to use their hardware-specific intrinsics
1 parent 7996bb5 commit 6e40abb

File tree

2 files changed

+242
-316
lines changed

2 files changed

+242
-316
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 18 additions & 311 deletions
Original file line numberDiff line numberDiff line change
@@ -95,84 +95,7 @@ struct
9595
let arg_int_prefix = "const int "
9696
let extra_args = []
9797
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
98-
99-
let extra_declarations =
100-
[
101-
(* BFloat16 conversion functions *)
102-
"static inline float bfloat16_to_single(unsigned short bf16) {";
103-
" unsigned int f32 = ((unsigned int)bf16) << 16;";
104-
" return *((float*)&f32);";
105-
"}";
106-
"";
107-
"static inline unsigned short single_to_bfloat16(float f) {";
108-
" unsigned int f32 = *((unsigned int*)&f);";
109-
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
110-
" return (unsigned short)(rounded >> 16);";
111-
"}";
112-
"";
113-
(* Half (Float16) support with zero-overhead abstraction *)
114-
"#ifdef __FLT16_MAX__";
115-
" #define HAS_NATIVE_FLOAT16 1";
116-
" #define HALF_T _Float16";
117-
" #define HALF_TO_FP(x) (x) /* Identity - already floating point */";
118-
" #define FP_TO_HALF(x) (x) /* Identity - already half precision */";
119-
" #define HALF_TO_FLOAT(x) ((float)(x))";
120-
" #define FLOAT_TO_HALF(x) ((_Float16)(x))";
121-
"#else";
122-
" #define HAS_NATIVE_FLOAT16 0";
123-
" #define HALF_T unsigned short";
124-
" #define HALF_TO_FP(x) half_to_single(x) /* Convert to float for computation */";
125-
" #define FP_TO_HALF(x) single_to_half(x) /* Convert back from float */";
126-
" #define HALF_TO_FLOAT(x) half_to_single(x)";
127-
" #define FLOAT_TO_HALF(x) single_to_half(x)";
128-
" /* Conversion functions for emulation - provided by builtins.c */";
129-
" extern float half_to_single(unsigned short h);";
130-
" extern unsigned short single_to_half(float f);";
131-
"#endif";
132-
"";
133-
(* FP8 E5M2 conversion functions *)
134-
"static inline float fp8_to_single(unsigned char fp8) {";
135-
" if (fp8 == 0) return 0.0f;";
136-
" unsigned int sign = (fp8 >> 7) & 1;";
137-
" unsigned int exp = (fp8 >> 2) & 0x1F;";
138-
" unsigned int mant = fp8 & 0x3;";
139-
" if (exp == 0x1F) {";
140-
" if (mant == 0) return sign ? -INFINITY : INFINITY;";
141-
" else return NAN;";
142-
" }";
143-
" if (exp == 0) {";
144-
" float result = ldexpf((float)mant / 4.0f, -14);";
145-
" if (sign) result = -result;";
146-
" return result;";
147-
" }";
148-
" float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);";
149-
" if (sign) result = -result;";
150-
" return result;";
151-
"}";
152-
"";
153-
"static inline unsigned char single_to_fp8(float f) {";
154-
" if (f == 0.0f) return 0;";
155-
" unsigned int sign = (f < 0) ? 1 : 0;";
156-
" f = fabsf(f);";
157-
" if (isinf(f)) return (sign << 7) | 0x7C;";
158-
" if (isnan(f)) return (sign << 7) | 0x7F;";
159-
" int exp_val;";
160-
" float mant_f = frexpf(f, &exp_val);";
161-
" int exp = exp_val + 14;";
162-
" if (exp < 0) return sign << 7;";
163-
" if (exp > 30) return (sign << 7) | 0x7C;";
164-
" if (exp == 0) {";
165-
" float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;";
166-
" unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);";
167-
" if (mant_bits > 3) mant_bits = 3;";
168-
" return (sign << 7) | mant_bits;";
169-
" }";
170-
" mant_f = (mant_f - 0.5f) * 4.0f;";
171-
" unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);";
172-
" if (mant_bits > 3) mant_bits = 3;";
173-
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
174-
"}";
175-
]
98+
let extra_declarations = []
17699

177100
let typ_of_prec = Ops.c_typ_of_prec
178101
let vec_typ_of_prec = Ops.c_vec_typ_of_prec
@@ -251,243 +174,27 @@ struct
251174
Set.to_list !functions
252175

253176
let ternop_syntax prec op v1 v2 v3 =
254-
match prec with
255-
| Ops.Bfloat16_prec _ ->
256-
(* For BFloat16, perform operations in float precision *)
257-
let float_v1 = PPrint.(string "bfloat16_to_single(" ^^ v1 ^^ string ")") in
258-
let float_v2 = PPrint.(string "bfloat16_to_single(" ^^ v2 ^^ string ")") in
259-
let float_v3 = PPrint.(string "bfloat16_to_single(" ^^ v3 ^^ string ")") in
260-
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
261-
let float_result =
262-
PPrint.(
263-
group
264-
(string op_prefix ^^ float_v1 ^^ string op_infix1
265-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
266-
^^ string op_infix2
267-
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
268-
^^ string op_suffix))
269-
in
270-
PPrint.(string "single_to_bfloat16(" ^^ float_result ^^ string ")")
271-
| Ops.Half_prec _ ->
272-
(* For Half, perform operations in float precision on non-native systems *)
273-
let float_v1 = PPrint.(string "HALF_TO_FP(" ^^ v1 ^^ string ")") in
274-
let float_v2 = PPrint.(string "HALF_TO_FP(" ^^ v2 ^^ string ")") in
275-
let float_v3 = PPrint.(string "HALF_TO_FP(" ^^ v3 ^^ string ")") in
276-
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
277-
let float_result =
278-
PPrint.(
279-
group
280-
(string op_prefix ^^ float_v1 ^^ string op_infix1
281-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
282-
^^ string op_infix2
283-
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
284-
^^ string op_suffix))
285-
in
286-
PPrint.(string "FP_TO_HALF(" ^^ float_result ^^ string ")")
287-
| Ops.Fp8_prec _ ->
288-
(* For FP8, perform operations in float precision *)
289-
let float_v1 = PPrint.(string "fp8_to_single(" ^^ v1 ^^ string ")") in
290-
let float_v2 = PPrint.(string "fp8_to_single(" ^^ v2 ^^ string ")") in
291-
let float_v3 = PPrint.(string "fp8_to_single(" ^^ v3 ^^ string ")") in
292-
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax Ops.single op in
293-
let float_result =
294-
PPrint.(
295-
group
296-
(string op_prefix ^^ float_v1 ^^ string op_infix1
297-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
298-
^^ string op_infix2
299-
^^ ifflat (space ^^ float_v3) (nest 2 (break 1 ^^ float_v3))
300-
^^ string op_suffix))
301-
in
302-
PPrint.(string "single_to_fp8(" ^^ float_result ^^ string ")")
303-
| _ ->
304-
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
305-
let open PPrint in
306-
group
307-
(string op_prefix ^^ v1 ^^ string op_infix1
308-
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
309-
^^ string op_infix2
310-
^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
311-
^^ string op_suffix)
177+
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
178+
let open PPrint in
179+
group
180+
(string op_prefix ^^ v1 ^^ string op_infix1
181+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
182+
^^ string op_infix2
183+
^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
184+
^^ string op_suffix)
312185

313186
let binop_syntax prec op v1 v2 =
314-
match op with
315-
| Ops.Threefry4x32 -> (
316-
match prec with
317-
| Ops.Uint4x32_prec _ ->
318-
let open PPrint in
319-
group (string "arrayjit_threefry4x32(" ^^ v1 ^^ string ", " ^^ v2 ^^ string ")")
320-
| _ -> invalid_arg "Pure_C_config.binop_syntax: Threefry4x32 on non-uint4x32 precision")
321-
| Ops.Satur01_gate -> (
322-
match prec with
323-
| Ops.Byte_prec _ | Ops.Uint16_prec _ | Ops.Int32_prec _ | Ops.Int64_prec _
324-
| Ops.Uint4x32_prec _ ->
325-
let open PPrint in
326-
group
327-
(parens
328-
(group
329-
(parens
330-
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1
331-
^^ string " < 1.0f"))
332-
^^ ifflat
333-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
334-
^^ string "("
335-
^^ string (typ_of_prec prec)
336-
^^ string ")0")
337-
(nest 2
338-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
339-
^^ string "("
340-
^^ string (typ_of_prec prec)
341-
^^ string ")0"))))
342-
| Ops.Fp8_prec _ ->
343-
let open PPrint in
344-
group
345-
(parens
346-
(group
347-
(parens
348-
(string "fp8_to_single(" ^^ v1
349-
^^ string ") > 0.0f && fp8_to_single("
350-
^^ v1 ^^ string ") < 1.0f"))
351-
^^ ifflat
352-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
353-
^^ string "single_to_fp8(0.0f)")
354-
(nest 2
355-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
356-
^^ string "single_to_fp8(0.0f)"))))
357-
| Ops.Bfloat16_prec _ ->
358-
let open PPrint in
359-
group
360-
(parens
361-
(group
362-
(parens
363-
(string "bfloat16_to_single(" ^^ v1
364-
^^ string ") > 0.0f && bfloat16_to_single("
365-
^^ v1 ^^ string ") < 1.0f"))
366-
^^ ifflat
367-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
368-
^^ string "single_to_bfloat16(0.0f)")
369-
(nest 2
370-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
371-
^^ string "single_to_bfloat16(0.0f)"))))
372-
| Ops.Half_prec _ ->
373-
let open PPrint in
374-
group
375-
(parens
376-
(group
377-
(parens
378-
(string "HALF_TO_FP(" ^^ v1
379-
^^ string ") > 0.0f && HALF_TO_FP("
380-
^^ v1 ^^ string ") < 1.0f"))
381-
^^ ifflat
382-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
383-
^^ string "FP_TO_HALF(0.0f)")
384-
(nest 2
385-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
386-
^^ string "FP_TO_HALF(0.0f)"))))
387-
| Ops.Single_prec _ ->
388-
let open PPrint in
389-
group
390-
(parens
391-
(group (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f"))
392-
^^ ifflat
393-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
394-
^^ string "0.0f")
395-
(nest 2
396-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
397-
^^ string "0.0f"))))
398-
| Ops.Double_prec _ ->
399-
let open PPrint in
400-
group
401-
(parens
402-
(group (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0"))
403-
^^ ifflat
404-
(space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space
405-
^^ string "0.0")
406-
(nest 2
407-
(break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space
408-
^^ string "0.0"))))
409-
| Ops.Void_prec -> invalid_arg "Pure_C_config.binop_syntax: Satur01_gate on Void_prec")
410-
| _ -> (
411-
match prec with
412-
| Ops.Bfloat16_prec _ -> (
413-
(* For BFloat16, perform all operations in float precision *)
414-
let float_v1 = PPrint.(string "bfloat16_to_single(" ^^ v1 ^^ string ")") in
415-
let float_v2 = PPrint.(string "bfloat16_to_single(" ^^ v2 ^^ string ")") in
416-
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
417-
let float_result =
418-
PPrint.(
419-
group
420-
(string op_prefix ^^ float_v1 ^^ string op_infix
421-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
422-
^^ string op_suffix))
423-
in
424-
(* For comparison operations, return float result (0.0 or 1.0) converted to BFloat16 *)
425-
match op with
426-
| Ops.Cmplt | Ops.Cmpeq | Ops.Cmpne | Ops.Or | Ops.And ->
427-
PPrint.(string "single_to_bfloat16(" ^^ float_result ^^ string ")")
428-
| _ -> PPrint.(string "single_to_bfloat16(" ^^ float_result ^^ string ")"))
429-
| Ops.Fp8_prec _ ->
430-
(* For FP8, perform all operations in float precision *)
431-
let float_v1 = PPrint.(string "fp8_to_single(" ^^ v1 ^^ string ")") in
432-
let float_v2 = PPrint.(string "fp8_to_single(" ^^ v2 ^^ string ")") in
433-
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
434-
let float_result =
435-
PPrint.(
436-
group
437-
(string op_prefix ^^ float_v1 ^^ string op_infix
438-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
439-
^^ string op_suffix))
440-
in
441-
PPrint.(string "single_to_fp8(" ^^ float_result ^^ string ")")
442-
| Ops.Half_prec _ ->
443-
(* For Half, perform all operations in float precision on non-native systems *)
444-
let float_v1 = PPrint.(string "HALF_TO_FP(" ^^ v1 ^^ string ")") in
445-
let float_v2 = PPrint.(string "HALF_TO_FP(" ^^ v2 ^^ string ")") in
446-
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax Ops.single op in
447-
let float_result =
448-
PPrint.(
449-
group
450-
(string op_prefix ^^ float_v1 ^^ string op_infix
451-
^^ ifflat (space ^^ float_v2) (nest 2 (break 1 ^^ float_v2))
452-
^^ string op_suffix))
453-
in
454-
(* For comparison operations, return float result (0.0 or 1.0) converted to Half *)
455-
(match op with
456-
| Ops.Cmplt | Ops.Cmpeq | Ops.Cmpne | Ops.Or | Ops.And ->
457-
PPrint.(string "FP_TO_HALF(" ^^ float_result ^^ string ")")
458-
| _ -> PPrint.(string "FP_TO_HALF(" ^^ float_result ^^ string ")"))
459-
| _ ->
460-
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
461-
let open PPrint in
462-
group
463-
(string op_prefix ^^ v1 ^^ string op_infix
464-
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
465-
^^ string op_suffix))
187+
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
188+
let open PPrint in
189+
group
190+
(string op_prefix ^^ v1 ^^ string op_infix
191+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
192+
^^ string op_suffix)
466193

467194
let unop_syntax prec op v =
468-
match prec with
469-
| Ops.Bfloat16_prec _ ->
470-
(* For BFloat16, perform operations in float precision *)
471-
let float_v = PPrint.(string "bfloat16_to_single(" ^^ v ^^ string ")") in
472-
let op_prefix, op_suffix = Ops.unop_c_syntax Ops.single op in
473-
let float_result = PPrint.(group (string op_prefix ^^ float_v ^^ string op_suffix)) in
474-
PPrint.(string "single_to_bfloat16(" ^^ float_result ^^ string ")")
475-
| Ops.Fp8_prec _ ->
476-
(* For FP8, perform operations in float precision *)
477-
let float_v = PPrint.(string "fp8_to_single(" ^^ v ^^ string ")") in
478-
let op_prefix, op_suffix = Ops.unop_c_syntax Ops.single op in
479-
let float_result = PPrint.(group (string op_prefix ^^ float_v ^^ string op_suffix)) in
480-
PPrint.(string "single_to_fp8(" ^^ float_result ^^ string ")")
481-
| Ops.Half_prec _ ->
482-
(* For Half, perform operations in float precision on non-native systems *)
483-
let float_v = PPrint.(string "HALF_TO_FP(" ^^ v ^^ string ")") in
484-
let op_prefix, op_suffix = Ops.unop_c_syntax Ops.single op in
485-
let float_result = PPrint.(group (string op_prefix ^^ float_v ^^ string op_suffix)) in
486-
PPrint.(string "FP_TO_HALF(" ^^ float_result ^^ string ")")
487-
| _ ->
488-
let op_prefix, op_suffix = Ops.unop_c_syntax prec op in
489-
let open PPrint in
490-
group (string op_prefix ^^ v ^^ string op_suffix)
195+
let op_prefix, op_suffix = Ops.unop_c_syntax prec op in
196+
let open PPrint in
197+
group (string op_prefix ^^ v ^^ string op_suffix)
491198

492199
let vec_unop_syntax prec op v =
493200
let op_prefix, op_suffix = Ops.vec_unop_c_syntax prec op in

0 commit comments

Comments
 (0)