Skip to content

Commit c9c9a8e

Browse files
committed
Follow-up: cleaner c_syntax.ml and backend headers, by Claude Opus
Prompt: > Let's restore (or establish) a proper division of responsibilities between the files / modules c_syntax.ml and cc_backend.ml . C_syntax provides functionality for backends implementing code generation for C-like languages. C_syntax.Pure_C_config should provide defaults that have grounding in the C language standard, that are likely to be available in C compilers across many architectures, and that have some chance to be shared by more than one backend (out of cc, metal, cuda, and future ones amd-hip, etc.) Outside of that, the default should serve illustrative purposes and be compact. Summary by Claude: I've successfully refactored the code to eliminate extra_declarations and print_declarations as you suggested. The changes include: 1. Removed from c_syntax.ml: - includes field from C_syntax_config interface - extra_declarations field from C_syntax_config interface - pp_include helper function - print_declarations function - Removed the default includes from Pure_C_config 2. Updated cc_backend.ml: - Moved all headers and precision conversion declarations into the builtins_header raw string - Removed the redundant CC_syntax_config overrides that weren't actually changing anything - Kept the precision-specific operator syntax overrides that the cc backend needs 3. Updated metal_backend.ml: - Removed includes and extra_declarations - Metal headers are already in Builtins_metal.source 4. Updated cuda_backend.ml: - Removed extra_declarations - CUDA headers are already in Builtins_cuda_small.source The refactoring makes the code cleaner and more straightforward - each backend now manages its complete preamble (headers + declarations) as a single cohesive string, rather than having it split across multiple fields and functions. This is much more maintainable and easier to understand.
1 parent 6e40abb commit c9c9a8e

File tree

4 files changed

+93
-121
lines changed

4 files changed

+93
-121
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 0 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,6 @@ module type C_syntax_config = sig
2727
val buffer_suffix : pos:int -> string
2828
val arg_int_prefix : string
2929
val extra_args : string list
30-
val includes : string list
31-
val extra_declarations : string list
3230
val typ_of_prec : Ops.prec -> string
3331
val vec_typ_of_prec : length:int -> Ops.prec -> string
3432
val ident_blacklist : string list
@@ -94,8 +92,6 @@ struct
9492
let buffer_suffix = fun ~pos:_ -> ""
9593
let arg_int_prefix = "const int "
9694
let extra_args = []
97-
let includes = [ "<stdio.h>"; "<stdlib.h>"; "<string.h>"; "<math.h>" ]
98-
let extra_declarations = []
9995

10096
let typ_of_prec = Ops.c_typ_of_prec
10197
let vec_typ_of_prec = Ops.c_vec_typ_of_prec
@@ -236,8 +232,6 @@ module C_syntax (B : C_syntax_config) = struct
236232
@@ Array.map B.procs ~f:(fun l -> l.llc)
237233

238234
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46)
239-
let pp_include s = PPrint.(string "#include " ^^ string s)
240-
241235
open Indexing
242236
open Doc_helpers
243237

@@ -262,12 +256,6 @@ module C_syntax (B : C_syntax_config) = struct
262256

263257
let array_offset_to_string (idcs, dims) = doc_to_string @@ pp_array_offset (idcs, dims)
264258

265-
let print_declarations () =
266-
let open PPrint in
267-
let includes = separate hardline (List.map B.includes ~f:pp_include) in
268-
let extras = separate hardline (List.map B.extra_declarations ~f:string) in
269-
includes ^^ hardline ^^ extras ^^ hardline
270-
271259
let pp_local_defs (local_defs : (int * PPrint.document) list) =
272260
let open PPrint in
273261
List.dedup_and_sort local_defs ~compare:(fun (a, _) (b, _) -> Int.compare a b)

arrayjit/lib/cc_backend.ml

Lines changed: 83 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,14 @@ open Backend_intf
1414

1515
let name = "cc"
1616

17-
(* Header declarations for arrayjit builtins *)
17+
(* Complete header with includes and declarations for arrayjit builtins *)
1818
let builtins_header =
1919
{|
20-
/* ArrayJIT builtins declarations */
20+
/* Standard C library headers */
21+
#include <stdio.h>
22+
#include <stdlib.h>
23+
#include <string.h>
24+
#include <math.h>
2125
#include <stdint.h>
2226

2327
typedef struct {
@@ -78,6 +82,81 @@ extern uint4x32_t bfloat16_to_uint4x32(uint16_t x);
7882
extern uint4x32_t half_to_uint4x32(uint16_t x);
7983
extern uint4x32_t fp8_to_uint4x32(uint8_t x);
8084

85+
/* BFloat16 conversion functions */
86+
static inline float bfloat16_to_single(unsigned short bf16) {
87+
unsigned int f32 = ((unsigned int)bf16) << 16;
88+
return *((float*)&f32);
89+
}
90+
91+
static inline unsigned short single_to_bfloat16(float f) {
92+
unsigned int f32 = *((unsigned int*)&f);
93+
unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);
94+
return (unsigned short)(rounded >> 16);
95+
}
96+
97+
/* Half (Float16) support with zero-overhead abstraction */
98+
#ifdef __FLT16_MAX__
99+
#define HAS_NATIVE_FLOAT16 1
100+
#define HALF_T _Float16
101+
#define HALF_TO_FP(x) (x) /* Identity - already floating point */
102+
#define FP_TO_HALF(x) (x) /* Identity - already half precision */
103+
#define HALF_TO_FLOAT(x) ((float)(x))
104+
#define FLOAT_TO_HALF(x) ((_Float16)(x))
105+
#else
106+
#define HAS_NATIVE_FLOAT16 0
107+
#define HALF_T unsigned short
108+
#define HALF_TO_FP(x) half_to_single(x) /* Convert to float for computation */
109+
#define FP_TO_HALF(x) single_to_half(x) /* Convert back from float */
110+
#define HALF_TO_FLOAT(x) half_to_single(x)
111+
#define FLOAT_TO_HALF(x) single_to_half(x)
112+
/* Conversion functions for emulation - provided by builtins.c */
113+
extern float half_to_single(unsigned short h);
114+
extern unsigned short single_to_half(float f);
115+
#endif
116+
117+
/* FP8 E5M2 conversion functions */
118+
static inline float fp8_to_single(unsigned char fp8) {
119+
if (fp8 == 0) return 0.0f;
120+
unsigned int sign = (fp8 >> 7) & 1;
121+
unsigned int exp = (fp8 >> 2) & 0x1F;
122+
unsigned int mant = fp8 & 0x3;
123+
if (exp == 0x1F) {
124+
if (mant == 0) return sign ? -INFINITY : INFINITY;
125+
else return NAN;
126+
}
127+
if (exp == 0) {
128+
float result = ldexpf((float)mant / 4.0f, -14);
129+
if (sign) result = -result;
130+
return result;
131+
}
132+
float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);
133+
if (sign) result = -result;
134+
return result;
135+
}
136+
137+
static inline unsigned char single_to_fp8(float f) {
138+
if (f == 0.0f) return 0;
139+
unsigned int sign = (f < 0) ? 1 : 0;
140+
f = fabsf(f);
141+
if (isinf(f)) return (sign << 7) | 0x7C;
142+
if (isnan(f)) return (sign << 7) | 0x7F;
143+
int exp_val;
144+
float mant_f = frexpf(f, &exp_val);
145+
int exp = exp_val + 14;
146+
if (exp < 0) return sign << 7;
147+
if (exp > 30) return (sign << 7) | 0x7C;
148+
if (exp == 0) {
149+
float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;
150+
unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);
151+
if (mant_bits > 3) mant_bits = 3;
152+
return (sign << 7) | mant_bits;
153+
}
154+
mant_f = (mant_f - 0.5f) * 4.0f;
155+
unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);
156+
if (mant_bits > 3) mant_bits = 3;
157+
return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));
158+
}
159+
81160
|}
82161

83162
let optimization_level () =
@@ -215,85 +294,6 @@ struct
215294
not @@ Utils.get_global_flag ~default:false ~arg_name:"prefer_backend_uniformity"
216295
end)
217296

218-
(* Add declarations for precision conversions that standard C compilers can use *)
219-
let extra_declarations =
220-
[
221-
(* BFloat16 conversion functions *)
222-
"static inline float bfloat16_to_single(unsigned short bf16) {";
223-
" unsigned int f32 = ((unsigned int)bf16) << 16;";
224-
" return *((float*)&f32);";
225-
"}";
226-
"";
227-
"static inline unsigned short single_to_bfloat16(float f) {";
228-
" unsigned int f32 = *((unsigned int*)&f);";
229-
" unsigned int rounded = f32 + 0x7FFF + ((f32 >> 16) & 1);";
230-
" return (unsigned short)(rounded >> 16);";
231-
"}";
232-
"";
233-
(* Half (Float16) support with zero-overhead abstraction *)
234-
"#ifdef __FLT16_MAX__";
235-
" #define HAS_NATIVE_FLOAT16 1";
236-
" #define HALF_T _Float16";
237-
" #define HALF_TO_FP(x) (x) /* Identity - already floating point */";
238-
" #define FP_TO_HALF(x) (x) /* Identity - already half precision */";
239-
" #define HALF_TO_FLOAT(x) ((float)(x))";
240-
" #define FLOAT_TO_HALF(x) ((_Float16)(x))";
241-
"#else";
242-
" #define HAS_NATIVE_FLOAT16 0";
243-
" #define HALF_T unsigned short";
244-
" #define HALF_TO_FP(x) half_to_single(x) /* Convert to float for computation */";
245-
" #define FP_TO_HALF(x) single_to_half(x) /* Convert back from float */";
246-
" #define HALF_TO_FLOAT(x) half_to_single(x)";
247-
" #define FLOAT_TO_HALF(x) single_to_half(x)";
248-
" /* Conversion functions for emulation - provided by builtins.c */";
249-
" extern float half_to_single(unsigned short h);";
250-
" extern unsigned short single_to_half(float f);";
251-
"#endif";
252-
"";
253-
(* FP8 E5M2 conversion functions *)
254-
"static inline float fp8_to_single(unsigned char fp8) {";
255-
" if (fp8 == 0) return 0.0f;";
256-
" unsigned int sign = (fp8 >> 7) & 1;";
257-
" unsigned int exp = (fp8 >> 2) & 0x1F;";
258-
" unsigned int mant = fp8 & 0x3;";
259-
" if (exp == 0x1F) {";
260-
" if (mant == 0) return sign ? -INFINITY : INFINITY;";
261-
" else return NAN;";
262-
" }";
263-
" if (exp == 0) {";
264-
" float result = ldexpf((float)mant / 4.0f, -14);";
265-
" if (sign) result = -result;";
266-
" return result;";
267-
" }";
268-
" float result = (1.0f + (float)mant * 0.25f) * ldexpf(1.0f, (int)exp - 15);";
269-
" if (sign) result = -result;";
270-
" return result;";
271-
"}";
272-
"";
273-
"static inline unsigned char single_to_fp8(float f) {";
274-
" if (f == 0.0f) return 0;";
275-
" unsigned int sign = (f < 0) ? 1 : 0;";
276-
" f = fabsf(f);";
277-
" if (isinf(f)) return (sign << 7) | 0x7C;";
278-
" if (isnan(f)) return (sign << 7) | 0x7F;";
279-
" int exp_val;";
280-
" float mant_f = frexpf(f, &exp_val);";
281-
" int exp = exp_val + 14;";
282-
" if (exp < 0) return sign << 7;";
283-
" if (exp > 30) return (sign << 7) | 0x7C;";
284-
" if (exp == 0) {";
285-
" float denorm_mant = f * ldexpf(1.0f, 14) * 4.0f;";
286-
" unsigned int mant_bits = (unsigned int)(denorm_mant + 0.5f);";
287-
" if (mant_bits > 3) mant_bits = 3;";
288-
" return (sign << 7) | mant_bits;";
289-
" }";
290-
" mant_f = (mant_f - 0.5f) * 4.0f;";
291-
" unsigned int mant_bits = (unsigned int)(mant_f + 0.5f);";
292-
" if (mant_bits > 3) mant_bits = 3;";
293-
" return (unsigned char)((sign << 7) | ((exp & 0x1F) << 2) | (mant_bits & 0x3));";
294-
"}";
295-
]
296-
297297
(* Override operation syntax to handle special precision types *)
298298
let ternop_syntax prec op v1 v2 v3 =
299299
match prec with
@@ -448,10 +448,9 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
448448
(* FIXME: do we really want all of them, or only the used ones? *)
449449
let idx_params = Indexing.bound_symbols bindings in
450450
let build_file = Utils.open_build_file ~base_name:name ~extension:".c" in
451-
let declarations_doc = Syntax.print_declarations () in
452451
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
453452
let header_doc = PPrint.string builtins_header in
454-
let final_doc = PPrint.(header_doc ^^ declarations_doc ^^ proc_doc) in
453+
let final_doc = PPrint.(header_doc ^^ proc_doc) in
455454
(* Use ribbon = 1.0 for usual code formatting, width 110 *)
456455
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
457456
build_file.finalize ();
@@ -473,15 +472,14 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt
473472
@@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names))
474473
in
475474
let build_file = Utils.open_build_file ~base_name ~extension:".c" in
476-
let declarations_doc = Syntax.print_declarations () in
477475
let params_and_docs =
478476
Array.map2_exn names lowereds ~f:(fun name_opt lowered_opt ->
479477
Option.map2 name_opt lowered_opt ~f:(fun name lowered ->
480478
Syntax.compile_proc ~name idx_params lowered))
481479
in
482480
let all_proc_docs = List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd) in
483481
let header_doc = PPrint.string builtins_header in
484-
let final_doc = PPrint.(header_doc ^^ declarations_doc ^^ separate hardline all_proc_docs) in
482+
let final_doc = PPrint.(header_doc ^^ separate hardline all_proc_docs) in
485483
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
486484
build_file.finalize ();
487485
let result_library = c_compile_and_load ~f_path:build_file.f_path in

arrayjit/lib/cuda_backend.ml

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -682,7 +682,6 @@ end) : Ir.Backend_impl.Lowered_backend = struct
682682
| FMA, Ops.Single_prec _ -> func "fmaf"
683683
| FMA, _ -> func "fma"
684684

685-
let extra_declarations = []
686685

687686
let convert_precision ~from ~to_ =
688687
match (from, to_) with
@@ -761,15 +760,10 @@ end) : Ir.Backend_impl.Lowered_backend = struct
761760
end)) in
762761
let idx_params = Indexing.bound_symbols bindings in
763762
let b = Buffer.create 4096 in
764-
let declarations_doc = Syntax.print_declarations () in
765-
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
766-
let final_doc = PPrint.(declarations_doc ^^ proc_doc) in
767-
PPrint.ToBuffer.pretty 1.0 110 b final_doc;
768-
(* Prepend builtins after syntax generation to preserve include order *)
769-
let full_source = Buffer.contents b in
770-
Buffer.clear b;
763+
(* Prepend builtins first *)
771764
prepend_builtins b;
772-
Buffer.add_string b full_source;
765+
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
766+
PPrint.ToBuffer.pretty 1.0 110 b proc_doc;
773767
let ptx = cuda_to_ptx ~name (Buffer.contents b) in
774768
{ traced_store; ptx; params; bindings; name }
775769

@@ -779,7 +773,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
779773
end)) in
780774
let idx_params = Indexing.bound_symbols bindings in
781775
let b = Buffer.create 4096 in
782-
let declarations_doc = Syntax.print_declarations () in
776+
(* Prepend builtins first *)
777+
prepend_builtins b;
783778
let params_and_docs =
784779
Array.map2_exn names lowereds
785780
~f:
@@ -788,13 +783,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
788783
((params, name), doc)))
789784
in
790785
let all_proc_docs = List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd) in
791-
let final_doc = PPrint.(declarations_doc ^^ separate hardline all_proc_docs) in
786+
let final_doc = PPrint.(separate hardline all_proc_docs) in
792787
PPrint.ToBuffer.pretty 1.0 110 b final_doc;
793-
(* Prepend builtins after syntax generation to preserve include order *)
794-
let full_source = Buffer.contents b in
795-
Buffer.clear b;
796-
prepend_builtins b;
797-
Buffer.add_string b full_source;
798788

799789
let name : string =
800790
String.(

arrayjit/lib/metal_backend.ml

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -433,11 +433,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
433433
"uint3 gid [[threadgroup_position_in_grid]]"; "uint3 lid [[thread_position_in_threadgroup]]";
434434
]
435435

436-
let includes =
437-
[ "<metal_stdlib>"; "<metal_math>"; "<metal_logging>"; "<metal_compute>"; "<metal_atomic>" ]
438-
439436
let metal_log_object_name = "os_log_default"
440-
let extra_declarations = [ "using namespace metal;" ]
441437

442438
let typ_of_prec = function
443439
| Ops.Byte_prec _ -> "uchar"
@@ -658,10 +654,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
658654
let b = Buffer.create 4096 in
659655
Buffer.add_string b Builtins_metal.source;
660656
Buffer.add_string b "\n";
661-
let declarations_doc = Syntax.print_declarations () in
662657
(* Add Metal address space qualifiers *)
663658
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
664-
let final_doc = PPrint.(declarations_doc ^^ proc_doc) in
659+
let final_doc = proc_doc in
665660
PPrint.ToBuffer.pretty 1.0 110 b final_doc;
666661
let source = Buffer.contents b in
667662
{
@@ -681,7 +676,8 @@ end) : Ir.Backend_impl.Lowered_backend = struct
681676
let idx_params = Indexing.bound_symbols bindings in
682677
let b = Buffer.create 4096 in
683678
(* Read and prepend the Metal builtins file *)
684-
let declarations_doc = Syntax.print_declarations () in
679+
Buffer.add_string b Builtins_metal.source;
680+
Buffer.add_string b "\n";
685681
let funcs_and_docs =
686682
Array.map2_exn names lowereds
687683
~f:
@@ -690,7 +686,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
690686
((name, params), doc)))
691687
in
692688
let all_proc_docs = List.filter_map (Array.to_list funcs_and_docs) ~f:(Option.map ~f:snd) in
693-
let final_doc = PPrint.(declarations_doc ^^ separate hardline all_proc_docs) in
689+
let final_doc = PPrint.(separate hardline all_proc_docs) in
694690
PPrint.ToBuffer.pretty 1.0 110 b final_doc;
695691
let source = Buffer.contents b in
696692
let traced_stores = Array.map lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store)) in

0 commit comments

Comments
 (0)