Skip to content

Commit 98e7ca7

Browse files
committed
Improved formatting for generated code
(by Claude) Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 146ba70 commit 98e7ca7

File tree

6 files changed

+451
-148
lines changed

6 files changed

+451
-148
lines changed

CHANGES.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
## [0.5.3] -- 2025-05-19
1+
## [0.5.3] -- 2025-05-24
22

33
### Added
44

@@ -13,6 +13,8 @@
1313
- Removed `initialize` and `is_initialized` from the backend API; instead, backends should be initialized on functor application. The functors now take `config` as argument.
1414
- More descriptive identifier names in C-syntax code in case of name conflicts.
1515
- Changed the backend config name `cc` to `multicore_cc` for consistency.
16+
- Migrated out of `Stdlib.Format` to `PPrint` for all structured formatting.
17+
- Migrated stdout capture to thread-based (domain-based actually); for Windows compatibility but also much more robust for large logs.
1618

1719
### Fixed
1820

arrayjit/lib/c_syntax.ml

Lines changed: 88 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -164,40 +164,48 @@ struct
164164
let ternop_syntax prec op v1 v2 v3 =
165165
let op_prefix, op_infix1, op_infix2, op_suffix = Ops.ternop_c_syntax prec op in
166166
let open PPrint in
167-
group
168-
(string op_prefix ^^ v1 ^^ string op_infix1 ^^ space ^^ v2 ^^ string op_infix2 ^^ space ^^ v3
169-
^^ string op_suffix)
167+
group (string op_prefix ^^ v1 ^^ string op_infix1
168+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
169+
^^ string op_infix2
170+
^^ ifflat (space ^^ v3) (nest 2 (break 1 ^^ v3))
171+
^^ string op_suffix)
170172

171173
let binop_syntax prec op v1 v2 =
172174
match op with
173175
| Ops.Satur01_gate -> (
174176
match prec with
175177
| Ops.Byte_prec _ ->
176178
let open PPrint in
177-
parens
178-
(parens
179-
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f")
180-
^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0")
179+
group (parens
180+
(group (parens
181+
(string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f"))
182+
^^ ifflat (space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space ^^ string "(unsigned char)0")
183+
(nest 2 (break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space ^^ string "(unsigned char)0"))))
181184
| Ops.Half_prec _ ->
182185
let open PPrint in
183-
parens
184-
(parens (v1 ^^ string " > 0.0f16 && " ^^ v1 ^^ string " < 1.0f16")
185-
^^ string " ? " ^^ v2 ^^ string " : 0.0f16")
186+
group (parens
187+
(group (parens (v1 ^^ string " > 0.0f16 && " ^^ v1 ^^ string " < 1.0f16"))
188+
^^ ifflat (space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space ^^ string "0.0f16")
189+
(nest 2 (break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space ^^ string "0.0f16"))))
186190
| Ops.Single_prec _ ->
187191
let open PPrint in
188-
parens
189-
(parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f")
190-
^^ string " ? " ^^ v2 ^^ string " : 0.0f")
192+
group (parens
193+
(group (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f"))
194+
^^ ifflat (space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space ^^ string "0.0f")
195+
(nest 2 (break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space ^^ string "0.0f"))))
191196
| Ops.Double_prec _ ->
192197
let open PPrint in
193-
parens
194-
(parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0")
195-
^^ string " ? " ^^ v2 ^^ string " : 0.0")
198+
group (parens
199+
(group (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0"))
200+
^^ ifflat (space ^^ string "?" ^^ space ^^ v2 ^^ space ^^ string ":" ^^ space ^^ string "0.0")
201+
(nest 2 (break 1 ^^ string "?" ^^ space ^^ v2 ^^ break 1 ^^ string ":" ^^ space ^^ string "0.0"))))
196202
| Ops.Void_prec -> invalid_arg "Pure_C_config.binop_syntax: Satur01_gate on Void_prec")
197203
| _ ->
198204
let op_prefix, op_infix, op_suffix = Ops.binop_c_syntax prec op in
199205
let open PPrint in
200-
group (string op_prefix ^^ v1 ^^ string op_infix ^^ space ^^ v2 ^^ string op_suffix)
206+
group (string op_prefix ^^ v1 ^^ string op_infix
207+
^^ ifflat (space ^^ v2) (nest 2 (break 1 ^^ v2))
208+
^^ string op_suffix)
201209

202210
let unop_syntax prec op v =
203211
let op_prefix, op_suffix = Ops.unop_c_syntax prec op in
@@ -227,12 +235,11 @@ struct
227235
else res
228236
in
229237
log_file_check
230-
^^ group
231-
(string "fprintf(log_file, "
232-
^^ dquotes (string base_message_literal)
233-
^^ (if List.is_empty args_docs then empty else comma ^^ space)
234-
^^ separate (PPrint.comma ^^ PPrint.space) args_docs
235-
^^ rparen ^^ semi)
238+
^^ string "fprintf(log_file, "
239+
^^ dquotes (string base_message_literal)
240+
^^ (if List.is_empty args_docs then empty else comma ^^ space)
241+
^^ separate (comma ^^ space) args_docs
242+
^^ rparen ^^ semi
236243
end
237244

238245
module C_syntax (B : C_syntax_config) = struct
@@ -267,7 +274,7 @@ module C_syntax (B : C_syntax_config) = struct
267274
let open PPrint in
268275
let includes = separate hardline (List.map B.includes ~f:pp_include) in
269276
let extras = separate hardline (List.map B.extra_declarations ~f:string) in
270-
group (includes ^^ hardline ^^ hardline ^^ extras ^^ hardline ^^ hardline)
277+
includes ^^ hardline ^^ extras ^^ hardline
271278

272279
let rec pp_ll (c : Low_level.t) : PPrint.document =
273280
let open PPrint in
@@ -279,12 +286,12 @@ module C_syntax (B : C_syntax_config) = struct
279286
(* Avoid extra hardlines if one side is empty *)
280287
if PPrint.is_empty d1 then d2
281288
else if PPrint.is_empty d2 then d1
282-
else group (d1 ^^ hardline ^^ d2)
289+
else d1 ^^ hardline ^^ d2
283290
| For_loop { index = i; from_; to_; body; trace_it = _ } ->
284291
let header =
285292
string "for (int " ^^ pp_symbol i ^^ string " = " ^^ PPrint.OCaml.int from_ ^^ semi
286-
^^ space ^^ pp_symbol i ^^ string " <= " ^^ PPrint.OCaml.int to_ ^^ semi ^^ string " ++"
287-
^^ pp_symbol i ^^ string ")"
293+
^^ space ^^ pp_symbol i ^^ string " <= " ^^ PPrint.OCaml.int to_ ^^ semi
294+
^^ space ^^ string "++" ^^ pp_symbol i ^^ string ")"
288295
in
289296
let body_doc = ref (pp_ll body) in
290297
(if Utils.debug_log_from_routines () then
@@ -296,7 +303,7 @@ module C_syntax (B : C_syntax_config) = struct
296303
~args_docs:[ pp_symbol i ]
297304
in
298305
body_doc := log_doc ^^ hardline ^^ !body_doc);
299-
surround 2 1 (header ^^ space ^^ lbrace) !body_doc rbrace
306+
group (header ^^ space ^^ lbrace ^^ nest 2 (hardline ^^ !body_doc) ^^ hardline ^^ rbrace)
300307
| Zero_out tn ->
301308
pp_ll
302309
(Low_level.loop_over_dims (Lazy.force tn.dims) ~body:(fun idcs ->
@@ -307,11 +314,14 @@ module C_syntax (B : C_syntax_config) = struct
307314
let prec = Lazy.force tn.prec in
308315
let local_defs, val_doc = pp_float prec llv in
309316
let offset_doc = pp_array_offset (idcs, dims) in
310-
let assignment = ident_doc ^^ brackets offset_doc ^^ string " = " ^^ val_doc ^^ semi in
317+
let assignment =
318+
group (ident_doc ^^ brackets offset_doc ^^ string " ="
319+
^^ ifflat (space ^^ val_doc) (nest 4 (hardline ^^ val_doc)) ^^ semi)
320+
in
311321
if Utils.debug_log_from_routines () then
312322
let num_typ = string (B.typ_of_prec prec) in
313323
let new_var = string "new_set_v" in
314-
let decl = group (num_typ ^^ space ^^ new_var ^^ string " = " ^^ val_doc ^^ semi) in
324+
let decl = num_typ ^^ space ^^ new_var ^^ string " = " ^^ val_doc ^^ semi in
315325
let debug_val_doc, debug_args_docs = debug_float prec llv in
316326
let debug_val_str = doc_to_string debug_val_doc in
317327
let pp_args_docs =
@@ -340,14 +350,19 @@ module C_syntax (B : C_syntax_config) = struct
340350
~base_message_literal:value_base_msg ~args_docs:log_args_for_printf
341351
in
342352
let flush_log =
343-
if B.log_involves_file_management then string "fflush(log_file);" ^^ semi else empty
353+
if B.log_involves_file_management then string "fflush(log_file);" else empty
344354
in
345355
comment_log ^^ hardline ^^ value_log ^^ hardline ^^ flush_log
346356
in
347357
let assignment' = ident_doc ^^ brackets offset_doc ^^ string " = " ^^ new_var ^^ semi in
348-
let block_content = decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment' in
349-
surround 2 1 lbrace (local_defs ^^ hardline ^^ block_content) rbrace
350-
else local_defs ^^ (if PPrint.is_empty local_defs then empty else hardline) ^^ assignment
358+
let block_content =
359+
if PPrint.is_empty local_defs then decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment'
360+
else local_defs ^^ hardline ^^ decl ^^ hardline ^^ log_doc ^^ hardline ^^ assignment'
361+
in
362+
lbrace ^^ nest 2 (hardline ^^ block_content) ^^ hardline ^^ rbrace
363+
else
364+
if PPrint.is_empty local_defs then assignment
365+
else local_defs ^^ hardline ^^ assignment
351366
| Comment message ->
352367
if Utils.debug_log_from_routines () then
353368
let base_message = "COMMENT: " ^ message ^ "\n" in
@@ -359,22 +374,19 @@ module C_syntax (B : C_syntax_config) = struct
359374
callback ()
360375
| Set_local ({ scope_id; tn = { prec; _ } }, value) ->
361376
let local_defs, value_doc = pp_float (Lazy.force prec) value in
362-
let assignment =
363-
group (string ("v" ^ Int.to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi)
364-
in
365-
local_defs ^^ (if PPrint.is_empty local_defs then empty else hardline) ^^ assignment
377+
let assignment = string ("v" ^ Int.to_string scope_id) ^^ string " = " ^^ value_doc ^^ semi in
378+
if PPrint.is_empty local_defs then assignment
379+
else local_defs ^^ hardline ^^ assignment
366380

367381
and pp_float (prec : Ops.prec) (vcomp : Low_level.float_t) : PPrint.document * PPrint.document =
368382
(* Returns (local definitions, value expression) *)
369383
let open PPrint in
370384
match vcomp with
371385
| Local_scope { id = { scope_id; tn = { prec = scope_prec; _ } }; body; orig_indices = _ } ->
372386
let num_typ = string (B.typ_of_prec @@ Lazy.force scope_prec) in
373-
let decl =
374-
group (num_typ ^^ space ^^ string ("v" ^ Int.to_string scope_id) ^^ string " = 0" ^^ semi)
375-
in
387+
let decl = num_typ ^^ space ^^ string ("v" ^ Int.to_string scope_id) ^^ string " = 0" ^^ semi in
376388
let body_doc = pp_ll body in
377-
let defs = group (decl ^^ hardline ^^ body_doc) in
389+
let defs = decl ^^ hardline ^^ body_doc in
378390
let prefix, postfix = B.convert_precision ~from:(Lazy.force scope_prec) ~to_:prec in
379391
let expr = string prefix ^^ string ("v" ^ Int.to_string scope_id) ^^ string postfix in
380392
(defs, expr)
@@ -388,23 +400,20 @@ module C_syntax (B : C_syntax_config) = struct
388400
let from_prec = Lazy.force tn.prec in
389401
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
390402
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
391-
let expr =
392-
group (string prefix ^^ string "merge_buffer" ^^ brackets offset_doc ^^ string postfix)
393-
in
403+
let expr = string prefix ^^ string "merge_buffer" ^^ brackets offset_doc ^^ string postfix in
394404
(empty, expr)
395405
| Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
396406
| Get (tn, idcs) ->
397407
let ident_doc = string (get_ident tn) in
398408
let from_prec = Lazy.force tn.prec in
399409
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
400410
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
401-
let expr = group (string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix) in
411+
let expr = string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix in
402412
(empty, expr)
403413
| Constant c ->
404414
let from_prec = Ops.double in
405415
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
406416
let c_str = Printf.sprintf "%.16g" c in
407-
(* Use Printf for float formatting *)
408417
let expr =
409418
if String.is_empty prefix && Float.(c < 0.0) then
410419
string "(" ^^ string c_str ^^ string ")" ^^ string postfix
@@ -422,24 +431,24 @@ module C_syntax (B : C_syntax_config) = struct
422431
let d1, e1 = pp_float prec v1 in
423432
let d2, e2 = pp_float prec v2 in
424433
let d3, e3 = pp_float prec v3 in
425-
let defs =
426-
d1
427-
^^ (if PPrint.is_empty d1 then empty else hardline)
428-
^^ d2
429-
^^ (if PPrint.is_empty d2 then empty else hardline)
430-
^^ d3
434+
let defs =
435+
List.filter_map [ d1; d2; d3 ] ~f:(fun d -> if PPrint.is_empty d then None else Some d)
436+
|> separate hardline
431437
in
432-
let expr = B.ternop_syntax prec op e1 e2 e3 in
438+
let expr = group (B.ternop_syntax prec op e1 e2 e3) in
433439
(defs, expr)
434440
| Binop (op, v1, v2) ->
435441
let d1, e1 = pp_float prec v1 in
436442
let d2, e2 = pp_float prec v2 in
437-
let defs = d1 ^^ (if PPrint.is_empty d1 then empty else hardline) ^^ d2 in
438-
let expr = B.binop_syntax prec op e1 e2 in
443+
let defs =
444+
List.filter_map [ d1; d2 ] ~f:(fun d -> if PPrint.is_empty d then None else Some d)
445+
|> separate hardline
446+
in
447+
let expr = group (B.binop_syntax prec op e1 e2) in
439448
(defs, expr)
440449
| Unop (op, v) ->
441450
let defs, expr_v = pp_float prec v in
442-
let expr = B.unop_syntax prec op expr_v in
451+
let expr = group (B.unop_syntax prec op expr_v) in
443452
(defs, expr)
444453

445454
and debug_float (prec : Ops.prec) (value : Low_level.float_t) :
@@ -463,33 +472,27 @@ module C_syntax (B : C_syntax_config) = struct
463472
let dims = Lazy.force tn.dims in
464473
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
465474
let offset_doc = pp_array_offset (idcs, dims) in
466-
let access_doc =
467-
group (string prefix ^^ string "merge_buffer" ^^ brackets offset_doc ^^ string postfix)
468-
in
475+
let access_doc = string prefix ^^ string "merge_buffer" ^^ brackets offset_doc ^^ string postfix in
469476
let expr_doc =
470-
group
471-
(string prefix ^^ string "merge_buffer"
472-
^^ brackets (string "%u")
473-
^^ string postfix
474-
^^ braces (string ("=" ^ B.float_log_style)))
477+
string prefix ^^ string "merge_buffer"
478+
^^ brackets (string "%u")
479+
^^ string postfix
480+
^^ braces (string ("=" ^ B.float_log_style))
475481
in
476482
(expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
477-
| Get_global _ -> failwith "Exec_as_cuda: Get_global / FFI NOT IMPLEMENTED YET"
483+
| Get_global _ -> failwith "C_syntax: Get_global / FFI NOT IMPLEMENTED YET"
478484
| Get (tn, idcs) ->
479485
let ident_doc = string (get_ident tn) in
480486
let from_prec = Lazy.force tn.prec in
481487
let dims = Lazy.force tn.dims in
482488
let prefix, postfix = B.convert_precision ~from:from_prec ~to_:prec in
483489
let offset_doc = pp_array_offset (idcs, dims) in
484-
let access_doc =
485-
group (string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix)
486-
in
490+
let access_doc = string prefix ^^ ident_doc ^^ brackets offset_doc ^^ string postfix in
487491
let expr_doc =
488-
group
489-
(string prefix ^^ ident_doc
490-
^^ brackets (string "%u")
491-
^^ string postfix
492-
^^ braces (string ("=" ^ B.float_log_style)))
492+
string prefix ^^ ident_doc
493+
^^ brackets (string "%u")
494+
^^ string postfix
495+
^^ braces (string ("=" ^ B.float_log_style))
493496
in
494497
(expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
495498
| Constant c ->
@@ -563,9 +566,8 @@ module C_syntax (B : C_syntax_config) = struct
563566
@ List.map B.extra_args ~f:string
564567
in
565568
let func_header =
566-
group
567-
(string B.main_kernel_prefix ^^ space ^^ string "void" ^^ space ^^ string name
568-
^^ parens (separate comma_sep args_docs))
569+
string B.main_kernel_prefix ^^ space ^^ string "void" ^^ space ^^ string name
570+
^^ nest 4 (lparen ^^ hardline ^^ separate (comma ^^ hardline) args_docs ^^ rparen)
569571
in
570572
let body = ref empty in
571573
if not (String.is_empty B.kernel_prep_line) then
@@ -579,12 +581,12 @@ module C_syntax (B : C_syntax_config) = struct
579581
in
580582
body :=
581583
!body ^^ string "FILE* log_file = NULL;" ^^ hardline
582-
^^ group
583-
(string ("if (" ^ log_file_var_name ^ ")")
584-
^^ space
585-
^^ braces (nest 2 (string ("log_file = fopen(" ^ log_file_var_name ^ ", \"w\");"))))
586-
^^ hardline
587-
else body := !body ^^ hardline;
584+
^^ string ("if (" ^ log_file_var_name ^ ") ")
585+
^^ lbrace ^^ nest 2 (hardline
586+
^^ string ("log_file = fopen(" ^ log_file_var_name ^ ", \"w\");"))
587+
^^ hardline ^^ rbrace ^^ hardline
588+
else
589+
body := !body ^^ hardline;
588590

589591
(if Utils.debug_log_from_routines () then
590592
let debug_init_doc =
@@ -625,14 +627,14 @@ module C_syntax (B : C_syntax_config) = struct
625627
let local_decls =
626628
string "/* Local declarations and initialization. */"
627629
^^ hardline
628-
^^ separate_map hardline
630+
^^ separate_map empty
629631
(fun (tn, node) ->
630632
if not (Tn.is_virtual_force tn 333 || Tn.is_materialized_force tn 336) then
631633
let typ_doc = string (B.typ_of_prec @@ Lazy.force tn.prec) in
632634
let ident_doc = string (get_ident tn) in
633635
let size_doc = OCaml.int (Tn.num_elems tn) in
634636
let init_doc = if node.Low_level.zero_initialized then string " = {0}" else empty in
635-
group (typ_doc ^^ space ^^ ident_doc ^^ brackets size_doc ^^ init_doc ^^ semi)
637+
typ_doc ^^ space ^^ ident_doc ^^ brackets size_doc ^^ init_doc ^^ semi ^^ hardline
636638
else empty)
637639
(Hashtbl.to_alist traced_store)
638640
in
@@ -647,6 +649,6 @@ module C_syntax (B : C_syntax_config) = struct
647649
^^ string "if (log_file) { fclose(log_file); log_file = NULL; }"
648650
^^ hardline;
649651

650-
let func_doc = surround 2 1 (func_header ^^ space ^^ lbrace) !body rbrace in
652+
let func_doc = func_header ^^ space ^^ lbrace ^^ nest 2 (hardline ^^ !body) ^^ hardline ^^ rbrace in
651653
(sorted_params, func_doc)
652654
end

0 commit comments

Comments
 (0)