Skip to content

Commit d61cc09

Browse files
committed
Gemini's take on Format -> PPrintf
1 parent 3442385 commit d61cc09

File tree

7 files changed

+664
-392
lines changed

7 files changed

+664
-392
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 484 additions & 256 deletions
Large diffs are not rendered by default.

arrayjit/lib/cc_backend.ml

Lines changed: 34 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ let%track7_sexp c_compile_and_load ~f_name =
8383
Stdlib.Gc.finalise finalize result;
8484
result
8585

86-
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) =
86+
let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized) : procedure =
8787
let module Syntax = C_syntax.C_syntax (C_syntax.Pure_C_config (struct
8888
type nonrec buffer_ptr = buffer_ptr
8989

@@ -92,14 +92,20 @@ let%diagn_sexp compile ~(name : string) bindings (lowered : Low_level.optimized)
9292
end)) in
9393
(* FIXME: do we really want all of them, or only the used ones? *)
9494
let idx_params = Indexing.bound_symbols bindings in
95-
let pp_file = Utils.pp_file ~base_name:name ~extension:".c" in
96-
Syntax.print_declarations pp_file.ppf;
97-
let params = Syntax.compile_proc ~name pp_file.ppf idx_params lowered in
98-
pp_file.finalize ();
99-
let result = c_compile_and_load ~f_name:pp_file.f_name in
100-
{ result; params; bindings; name }
101-
102-
let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) =
95+
let build_file = Utils.open_build_file ~base_name:name ~extension:".c" in
96+
let declarations_doc = Syntax.print_declarations () in
97+
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
98+
let final_doc = PPrint.(declarations_doc ^^ proc_doc) in
99+
(* Use ribbon = 1.0 for usual code formatting, width 110 *)
100+
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
101+
build_file.finalize ();
102+
(* let result = c_compile_and_load ~f_name:pp_file.f_name in *)
103+
104+
let result_library = c_compile_and_load ~f_name:build_file.f_name in
105+
{ result = result_library; params; bindings; name }
106+
107+
let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized option array) :
108+
procedure option array =
103109
let module Syntax = C_syntax.C_syntax (C_syntax.Pure_C_config (struct
104110
type nonrec buffer_ptr = buffer_ptr
105111

@@ -113,19 +119,28 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt
113119
strip ~drop:(equal_char '_')
114120
@@ common_prefix (Array.to_list @@ Array.concat_map ~f:Option.to_array names))
115121
in
116-
let pp_file = Utils.pp_file ~base_name ~extension:".c" in
117-
Syntax.print_declarations pp_file.ppf;
118-
let params =
119-
Array.mapi lowereds ~f:(fun i lowered ->
120-
Option.map2 names.(i) lowered ~f:(fun name lowered ->
121-
Syntax.compile_proc ~name pp_file.ppf idx_params lowered))
122+
let build_file = Utils.open_build_file ~base_name ~extension:".c" in
123+
let declarations_doc = Syntax.print_declarations () in
124+
let params_and_docs =
125+
Array.map2_exn names lowereds ~f:(fun name_opt lowered_opt ->
126+
Option.map2 name_opt lowered_opt ~f:(fun name lowered ->
127+
Syntax.compile_proc ~name idx_params lowered))
122128
in
123-
pp_file.finalize ();
124-
let result = c_compile_and_load ~f_name:pp_file.f_name in
129+
let all_proc_docs =
130+
List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd)
131+
in
132+
let final_doc = PPrint.(declarations_doc ^^ separate hardline all_proc_docs) in
133+
PPrint.ToChannel.pretty 1.0 110 build_file.oc final_doc;
134+
build_file.finalize ();
135+
let result_library = c_compile_and_load ~f_name:build_file.f_name in
125136
(* Note: for simplicity, we share ctx_arrays across all contexts. *)
126-
Array.mapi params ~f:(fun i params ->
137+
Array.mapi params_and_docs ~f:(fun i opt_params_and_doc ->
138+
Option.bind opt_params_and_doc ~f:(fun (params, _doc) ->
127139
Option.map names.(i) ~f:(fun name ->
128-
{ result; params = Option.value_exn ~here:[%here] params; bindings; name }))
140+
{ result = result_library; params; bindings; name }
141+
)
142+
)
143+
)
129144

130145
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
131146
let name : string = code.name in

arrayjit/lib/cuda_backend.ml

Lines changed: 45 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -230,10 +230,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
230230
let%diagn2_sexp cuda_to_ptx ~name cu_src =
231231
let name_cu = name ^ ".cu" in
232232
if Utils.settings.output_debug_files_in_build_directory then (
233-
let oc = Out_channel.open_text @@ Utils.build_file name_cu in
234-
Stdio.Out_channel.output_string oc cu_src;
235-
Stdio.Out_channel.flush oc;
236-
Stdio.Out_channel.close oc);
233+
let build_file = Utils.open_build_file ~base_name:name ~extension:".cu" in
234+
Stdio.Out_channel.output_string build_file.oc cu_src; (* Keep direct string output for source *)
235+
build_file.finalize ());
237236
[%log "compiling to PTX"];
238237
let with_debug =
239238
Utils.settings.output_debug_files_in_build_directory || Utils.settings.log_level > 0
@@ -256,7 +255,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
256255
Stdio.Out_channel.close oc);
257256
ptx
258257

259-
module C_syntax_config (Input : sig
258+
module Cuda_syntax_config (Input : sig
260259
val procs : Low_level.optimized array
261260
end) =
262261
struct
@@ -283,8 +282,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
283282
| Void_prec -> "void"
284283

285284
let binop_syntax prec v =
286-
let f op_str = C_syntax.binop_adapter ("(", " " ^ op_str, ")") in
287-
let func fn = C_syntax.binop_adapter (fn ^ "(", ",", ")") in
285+
let open PPrint in
286+
let f op_str v1 v2 = group (lparen ^^ v1 ^^ space ^^ string op_str ^^ space ^^ v2 ^^ rparen) in
287+
let func fn v1 v2 = group (string fn ^^ parens (separate comma_sep [ v1; v2 ])) in
288288
match (v, prec) with
289289
| Ops.Arg1, _ -> invalid_arg "Cuda_backend.binop_syntax: Arg1 is not an operator"
290290
| Arg2, _ -> invalid_arg "Cuda_backend.binop_syntax: Arg2 is not an operator"
@@ -302,31 +302,17 @@ end) : Ir.Backend_impl.Lowered_backend = struct
302302
| ToPowOf, Half_prec _ -> C_syntax.binop_adapter ("hexp2(hlog2(", "),", ")")
303303
| ToPowOf, Byte_prec _ ->
304304
invalid_arg "Cuda_backend.binop_syntax: ToPowOf not supported for byte/integer precisions"
305-
| Relu_gate, Byte_prec _ -> C_syntax.binop_adapter ("(", " > 0 ?", " : 0)")
306-
| Relu_gate, Half_prec _ ->
307-
C_syntax.binop_adapter
308-
( "(__hgt(",
309-
", __ushort_as_half((unsigned short)0x0000U)) ?",
310-
" : __ushort_as_half((unsigned short)0x0000U))" )
311-
| Relu_gate, _ -> C_syntax.binop_adapter ("(", " > 0.0 ?", " : 0.0)")
312-
| Satur01_gate, Byte_prec _ ->
313-
fun ppf pp1 v1 pp2 v2 ->
314-
Stdlib.Format.fprintf ppf
315-
"(((float)%a > 0.0f && (float)%a < 1.0f) ? %a : (unsigned char)0)" pp1 v1 pp1 v1 pp2
316-
v2
317-
| Satur01_gate, Half_prec _ ->
318-
fun ppf pp1 v1 pp2 v2 ->
319-
Stdlib.Format.fprintf ppf
320-
"((__hgt(%a, __ushort_as_half((unsigned short)0x0000U)) && __hlt(%a, \
321-
__ushort_as_half((unsigned short)0x3C00U))) ? %a : __ushort_as_half((unsigned \
322-
short)0x0000U))"
323-
pp1 v1 pp1 v1 pp2 v2
324-
| Satur01_gate, Single_prec _ ->
325-
fun ppf pp1 v1 pp2 v2 ->
326-
Stdlib.Format.fprintf ppf "((%a > 0.0f && %a < 1.0f) ? %a : 0.0f)" pp1 v1 pp1 v1 pp2 v2
327-
| Satur01_gate, Double_prec _ ->
328-
fun ppf pp1 v1 pp2 v2 ->
329-
Stdlib.Format.fprintf ppf "((%a > 0.0 && %a < 1.0) ? %a : 0.0)" pp1 v1 pp1 v1 pp2 v2
305+
| Relu_gate, Byte_prec _ -> fun v1 v2 -> group (parens (v1 ^^ string " > 0") ^^ string " ? " ^^ v2 ^^ string " : 0")
306+
| Relu_gate, Half_prec _ -> fun v1 v2 -> group (parens (string "__hgt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x0000U))") ^^ string " ? " ^^ v2 ^^ string " : __ushort_as_half((unsigned short)0x0000U)")
307+
| Relu_gate, _ -> fun v1 v2 -> group (parens (v1 ^^ string " > 0.0") ^^ string " ? " ^^ v2 ^^ string " : 0.0")
308+
| Satur01_gate, Byte_prec _ -> fun v1 v2 ->
309+
parens (parens (string "(float)" ^^ v1 ^^ string " > 0.0f && (float)" ^^ v1 ^^ string " < 1.0f") ^^ string " ? " ^^ v2 ^^ string " : (unsigned char)0")
310+
| Satur01_gate, Half_prec _ -> fun v1 v2 ->
311+
parens (parens (string "__hgt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x0000U)) && __hlt(" ^^ v1 ^^ comma ^^ string " __ushort_as_half((unsigned short)0x3C00U)))") ^^ string " ? " ^^ v2 ^^ string " : __ushort_as_half((unsigned short)0x0000U)")
312+
| Satur01_gate, Single_prec _ -> fun v1 v2 ->
313+
parens (parens (v1 ^^ string " > 0.0f && " ^^ v1 ^^ string " < 1.0f") ^^ string " ? " ^^ v2 ^^ string " : 0.0f")
314+
| Satur01_gate, Double_prec _ -> fun v1 v2 ->
315+
parens (parens (v1 ^^ string " > 0.0 && " ^^ v1 ^^ string " < 1.0") ^^ string " ? " ^^ v2 ^^ string " : 0.0")
330316
| Max, Byte_prec _ -> func "max"
331317
| Max, Half_prec _ -> func "__hmax"
332318
| Max, Double_prec _ -> func "fmax"
@@ -344,8 +330,9 @@ end) : Ir.Backend_impl.Lowered_backend = struct
344330
| And, _ -> f "&&"
345331

346332
let unop_syntax prec v =
347-
let f prefix suffix = C_syntax.unop_adapter (prefix, suffix) in
348-
let func fn = C_syntax.unop_adapter (fn ^ "(", ")") in
333+
let open PPrint in
334+
let f prefix suffix expr = group (string prefix ^^ expr ^^ string suffix) in
335+
let func fn expr = group (string fn ^^ parens expr) in
349336
match (v, prec) with
350337
| Ops.Identity, _ -> f "" ""
351338
| Relu, Ops.Single_prec _ -> f "fmaxf(0.0, " ")"
@@ -392,22 +379,22 @@ end) : Ir.Backend_impl.Lowered_backend = struct
392379
| Recip_sqrt, Double_prec _ -> f "(1.0 / sqrt(" "))"
393380
| Recip_sqrt, _ -> f "(1.0 / sqrtf(" "))"
394381
| Neg, _ -> f "(-(" "))"
395-
| Tanh_approx, Byte_prec _ ->
396-
invalid_arg
397-
"Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
382+
| Tanh_approx, Byte_prec _ -> invalid_arg "Cuda_backend.unop_syntax: Tanh_approx not supported for byte/integer precisions"
398383
| Tanh_approx, Half_prec _ -> func "htanh_approx"
399384
| Tanh_approx, Single_prec _ -> func "__tanhf"
400385
| Tanh_approx, _ -> func "tanh"
401386
| Not, _ -> f "(" " == 0.0 ? 1.0 : 0.0)"
402387

403388
let ternop_syntax prec v =
389+
let open PPrint in
404390
match (v, prec) with
405-
| Ops.Where, _ -> C_syntax.ternop_adapter ("(", " ?", " :", ")")
391+
| Ops.Where, _ -> fun v1 v2 v3 -> group (parens v1 ^^ string " ? " ^^ v2 ^^ string " : " ^^ v3)
406392
| FMA, Ops.Half_prec _ -> C_syntax.ternop_adapter ("__hfma(", ",", ",", ")")
407393
| FMA, Ops.Single_prec _ -> C_syntax.ternop_adapter ("fmaf(", ",", ",", ")")
408394
| FMA, _ -> C_syntax.ternop_adapter ("fma(", ",", ",", ")")
409395

410396
let convert_precision ~from ~to_ =
397+
let open PPrint in
411398
match (from, to_) with
412399
| Ops.Double_prec _, Ops.Double_prec _
413400
| Single_prec _, Single_prec _
@@ -421,43 +408,49 @@ end) : Ir.Backend_impl.Lowered_backend = struct
421408
| _ -> ("(" ^ typ_of_prec to_ ^ ")(", ")")
422409
end
423410

424-
let compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
411+
let%diagn2_sexp compile ~name bindings ({ Low_level.traced_store; _ } as lowered) =
425412
(* TODO: The following link seems to claim it's better to expand into loops than use memset.
426413
https://stackoverflow.com/questions/23712558/how-do-i-best-initialize-a-local-memory-array-to-0 *)
427-
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
414+
let module Syntax = C_syntax.C_syntax (Cuda_syntax_config (struct
428415
let procs = [| lowered |]
429416
end)) in
430417
let idx_params = Indexing.bound_symbols bindings in
431418
let b = Buffer.create 4096 in
432-
let ppf = Stdlib.Format.formatter_of_buffer b in
433419
if Utils.debug_log_from_routines () then
434-
Stdlib.Format.fprintf ppf "@,__device__ int printf (const char * format, ... );@,";
435-
Syntax.print_declarations ppf;
436-
let params = Syntax.compile_proc ~name ppf idx_params lowered in
437-
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
420+
Buffer.add_string b "__device__ int printf (const char * format, ... );\n";
421+
let declarations_doc = Syntax.print_declarations () in
422+
let params, proc_doc = Syntax.compile_proc ~name idx_params lowered in
423+
let final_doc = PPrint.(declarations_doc ^^ proc_doc) in
424+
PPrint.ToBuffer.pretty 1.0 110 b final_doc; (* Use ToBuffer *)
425+
let ptx = cuda_to_ptx ~name (Buffer.contents b) in
438426
{ traced_store; ptx; params; bindings; name }
439427

440-
let compile_batch ~names bindings lowereds =
441-
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
428+
let%diagn2_sexp compile_batch ~names bindings lowereds =
429+
let module Syntax = C_syntax.C_syntax (Cuda_syntax_config (struct
442430
let procs = Array.filter_opt lowereds
443431
end)) in
444432
let idx_params = Indexing.bound_symbols bindings in
445433
let b = Buffer.create 4096 in
446-
let ppf = Stdlib.Format.formatter_of_buffer b in
447-
Syntax.print_declarations ppf;
448-
let params_and_names =
434+
let declarations_doc = Syntax.print_declarations () in
435+
let params_and_docs =
449436
Array.map2_exn names lowereds
450437
~f:
451438
(Option.map2 ~f:(fun name lowered ->
452-
(Syntax.compile_proc ~name ppf idx_params lowered, name)))
439+
let params, doc = Syntax.compile_proc ~name idx_params lowered in
440+
((params, name), doc)))
453441
in
442+
let all_proc_docs = List.filter_map (Array.to_list params_and_docs) ~f:(Option.map ~f:snd) in
443+
let final_doc = PPrint.(declarations_doc ^^ separate hardline all_proc_docs) in
444+
PPrint.ToBuffer.pretty 1.0 110 b final_doc;
445+
454446
let name : string =
455447
String.(
456448
strip ~drop:(equal_char '_')
457449
@@ common_prefix (Array.to_list names |> List.concat_map ~f:Option.to_list))
458450
in
459-
let ptx = cuda_to_ptx ~name @@ Buffer.contents b in
451+
let ptx = cuda_to_ptx ~name (Buffer.contents b) in
460452
let traced_stores = Array.map lowereds ~f:(Option.map ~f:(fun l -> l.Low_level.traced_store)) in
453+
let params_and_names = Array.map params_and_docs ~f:(Option.map ~f:fst) in
461454
{ traced_stores; ptx; params_and_names; bindings }
462455

463456
let get_global_run_id =

arrayjit/lib/indexing.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,3 +214,24 @@ module Pp_helpers = struct
214214
let pp_indices ppf idcs =
215215
Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_axis_index ppf @@ Array.to_list idcs
216216
end
217+
218+
module Doc_helpers = struct
219+
let ( ^^ ) = PPrint.( ^^ )
220+
let ( !^ ) = PPrint.( !^ )
221+
let int = PPrint.OCaml.int
222+
let comma_sep = PPrint.(comma ^^ space)
223+
let pp_comma () = comma_sep
224+
let pp_symbol sym = PPrint.string @@ symbol_ident sym
225+
226+
let pp_static_symbol { static_symbol; static_range } =
227+
match static_range with
228+
| None -> pp_symbol static_symbol
229+
| Some range ->
230+
PPrint.infix 4 1 PPrint.colon (pp_symbol static_symbol)
231+
(PPrint.brackets (PPrint.string "0.." ^^ int (range - 1)))
232+
233+
let pp_axis_index idx =
234+
match idx with Iterator sym -> pp_symbol sym | Fixed_idx i -> PPrint.OCaml.int i
235+
236+
let pp_indices idcs = PPrint.separate (pp_comma ()) (Array.to_list idcs |> List.map ~f:pp_axis_index)
237+
end

0 commit comments

Comments
 (0)