Skip to content

Commit 9afb61d

Browse files
committed
In progress / broken: Format -> PPrint migration first pass by Claude
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent 657f596 commit 9afb61d

File tree

8 files changed

+956
-287
lines changed

8 files changed

+956
-287
lines changed

arrayjit/lib/assignments.ml

Lines changed: 153 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -282,70 +282,185 @@ let get_ident_within_code ?no_dots c =
282282
let fprint_hum ?name ?static_indices () ppf c =
283283
let ident = get_ident_within_code c in
284284
let buffer_ident = function Node tn -> ident tn | Merge_buffer tn -> ident tn ^ ".merge" in
285-
let open Stdlib.Format in
286-
let out_fetch_op ppf (op : fetch_op) =
285+
286+
let open PPrint in
287+
let doc_of_fetch_op (op : fetch_op) =
287288
match op with
288-
| Constant f -> fprintf ppf "%g" f
289-
| Imported (Ops.C_function c) -> fprintf ppf "%s()" c
289+
| Constant f -> string (Float.to_string f)
290+
| Imported (Ops.C_function c) -> string (c ^ "()")
290291
| Imported (Merge_buffer { source_node_id }) ->
291292
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
292-
fprintf ppf "%s.merge" (ident tn)
293+
string (ident tn ^ ".merge")
293294
| Imported (Ops.External_unsafe { ptr; prec; dims = _ }) ->
294-
fprintf ppf "%s" @@ Ops.ptr_to_string_hum ptr prec
295+
string (Ops.ptr_to_string_hum ptr prec)
295296
| Slice { batch_idx; sliced } ->
296-
fprintf ppf "%s @@| %s" (ident sliced) (Indexing.symbol_ident batch_idx.static_symbol)
297+
string (ident sliced ^ " @| " ^ Indexing.symbol_ident batch_idx.static_symbol)
297298
| Embed_symbol { static_symbol; static_range = _ } ->
298-
fprintf ppf "!@@%s" @@ Indexing.symbol_ident static_symbol
299+
string ("!@" ^ Indexing.symbol_ident static_symbol)
299300
in
300-
let rec loop = function
301-
| Noop -> ()
301+
302+
let rec doc_of_code = function
303+
| Noop -> empty
302304
| Seq (c1, c2) ->
303-
loop c1;
304-
loop c2
305-
| Block_comment (s, Noop) -> fprintf ppf "# \"%s\";@ " s
305+
doc_of_code c1 ^^ doc_of_code c2
306+
| Block_comment (s, Noop) -> string ("# \"" ^ s ^ "\";") ^^ break 1
306307
| Block_comment (s, c) ->
307-
fprintf ppf "# \"%s\";@ " s;
308-
loop c
308+
string ("# \"" ^ s ^ "\";") ^^ break 1 ^^ doc_of_code c
309309
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
310310
let proj_spec =
311311
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
312312
else "<not-in-yet>"
313313
in
314314
(* Uncurried syntax for ternary operations. *)
315-
fprintf ppf "%s %s %s(%s, %s, %s)%s;@ " (ident lhs)
316-
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
317-
(Ops.ternop_cd_syntax op) (buffer_ident rhs1) (buffer_ident rhs2) (buffer_ident rhs3)
318-
(if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "")
315+
string (ident lhs) ^^ space ^^
316+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
317+
string (Ops.ternop_cd_syntax op) ^^
318+
string "(" ^^ string (buffer_ident rhs1) ^^ string ", " ^^
319+
string (buffer_ident rhs2) ^^ string ", " ^^
320+
string (buffer_ident rhs3) ^^ string ")" ^^
321+
(if not (String.equal proj_spec ".") then
322+
string (" ~logic:\"" ^ proj_spec ^ "\"")
323+
else empty) ^^
324+
string ";" ^^ break 1
319325
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
320326
let proj_spec =
321327
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
322328
else "<not-in-yet>"
323329
in
324-
fprintf ppf "%s %s %s %s %s%s;@ " (ident lhs)
325-
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
326-
(buffer_ident rhs1) (Ops.binop_cd_syntax op) (buffer_ident rhs2)
327-
(if
328-
(not (String.equal proj_spec "."))
329-
|| List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
330-
then " ~logic:\"" ^ proj_spec ^ "\""
331-
else "")
330+
string (ident lhs) ^^ space ^^
331+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
332+
string (buffer_ident rhs1) ^^ space ^^
333+
string (Ops.binop_cd_syntax op) ^^ space ^^
334+
string (buffer_ident rhs2) ^^
335+
(if (not (String.equal proj_spec ".")) ||
336+
List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
337+
then string (" ~logic:\"" ^ proj_spec ^ "\"")
338+
else empty) ^^
339+
string ";" ^^ break 1
332340
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
333341
let proj_spec =
334342
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
335343
else "<not-in-yet>"
336344
in
337-
fprintf ppf "%s %s %s%s%s;@ " (ident lhs)
338-
(Ops.assign_op_cd_syntax ~initialize_neutral accum)
339-
(if not @@ Ops.equal_unop op Ops.Identity then Ops.unop_cd_syntax op ^ " " else "")
340-
(buffer_ident rhs)
341-
(if not (String.equal proj_spec ".") then " ~logic:\"" ^ proj_spec ^ "\"" else "")
345+
string (ident lhs) ^^ space ^^
346+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
347+
(if not @@ Ops.equal_unop op Ops.Identity then
348+
string (Ops.unop_cd_syntax op ^ " ")
349+
else empty) ^^
350+
string (buffer_ident rhs) ^^
351+
(if not (String.equal proj_spec ".") then
352+
string (" ~logic:\"" ^ proj_spec ^ "\"")
353+
else empty) ^^
354+
string ";" ^^ break 1
342355
| Fetch { array; fetch_op; dims = _ } ->
343-
fprintf ppf "%s := %a;@ " (ident array) out_fetch_op fetch_op
356+
string (ident array) ^^ string " := " ^^ doc_of_fetch_op fetch_op ^^ string ";" ^^ break 1
344357
in
345-
fprintf ppf "@,@[<v 2>";
346-
Low_level.fprint_function_header ?name ?static_indices () ppf;
347-
loop c;
348-
fprintf ppf "@]"
358+
359+
(* Create the header document using Low_level.fprint_function_header which will be converted later *)
360+
let header_doc =
361+
match name, static_indices with
362+
| Some n, Some si ->
363+
string (n ^ " (") ^^
364+
separate (comma ^^ space)
365+
(List.map si ~f:Indexing.Doc_helpers.pp_static_symbol) ^^
366+
string "):" ^^ space
367+
| Some n, None -> string (n ^ ":") ^^ space
368+
| _ -> empty
369+
in
370+
371+
let doc = header_doc ^^ nest 2 (doc_of_code c) in
372+
ToFormatter.pretty 1.0 80 ppf doc
373+
374+
let doc_hum ?name ?static_indices () c =
375+
let ident = get_ident_within_code c in
376+
let buffer_ident = function Node tn -> ident tn | Merge_buffer tn -> ident tn ^ ".merge" in
377+
378+
let open PPrint in
379+
let doc_of_fetch_op (op : fetch_op) =
380+
match op with
381+
| Constant f -> string (Float.to_string f)
382+
| Imported (Ops.C_function c) -> string (c ^ "()")
383+
| Imported (Merge_buffer { source_node_id }) ->
384+
let tn = Option.value_exn ~here:[%here] @@ Tn.find ~id:source_node_id in
385+
string (ident tn ^ ".merge")
386+
| Imported (Ops.External_unsafe { ptr; prec; dims = _ }) ->
387+
string (Ops.ptr_to_string_hum ptr prec)
388+
| Slice { batch_idx; sliced } ->
389+
string (ident sliced ^ " @| " ^ Indexing.symbol_ident batch_idx.static_symbol)
390+
| Embed_symbol { static_symbol; static_range = _ } ->
391+
string ("!@" ^ Indexing.symbol_ident static_symbol)
392+
in
393+
394+
let rec doc_of_code = function
395+
| Noop -> empty
396+
| Seq (c1, c2) ->
397+
doc_of_code c1 ^^ doc_of_code c2
398+
| Block_comment (s, Noop) -> string ("# \"" ^ s ^ "\";") ^^ break 1
399+
| Block_comment (s, c) ->
400+
string ("# \"" ^ s ^ "\";") ^^ break 1 ^^ doc_of_code c
401+
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
402+
let proj_spec =
403+
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
404+
else "<not-in-yet>"
405+
in
406+
(* Uncurried syntax for ternary operations. *)
407+
string (ident lhs) ^^ space ^^
408+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
409+
string (Ops.ternop_cd_syntax op) ^^
410+
string "(" ^^ string (buffer_ident rhs1) ^^ string ", " ^^
411+
string (buffer_ident rhs2) ^^ string ", " ^^
412+
string (buffer_ident rhs3) ^^ string ")" ^^
413+
(if not (String.equal proj_spec ".") then
414+
string (" ~logic:\"" ^ proj_spec ^ "\"")
415+
else empty) ^^
416+
string ";" ^^ break 1
417+
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
418+
let proj_spec =
419+
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
420+
else "<not-in-yet>"
421+
in
422+
string (ident lhs) ^^ space ^^
423+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
424+
string (buffer_ident rhs1) ^^ space ^^
425+
string (Ops.binop_cd_syntax op) ^^ space ^^
426+
string (buffer_ident rhs2) ^^
427+
(if (not (String.equal proj_spec ".")) ||
428+
List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
429+
then string (" ~logic:\"" ^ proj_spec ^ "\"")
430+
else empty) ^^
431+
string ";" ^^ break 1
432+
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
433+
let proj_spec =
434+
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
435+
else "<not-in-yet>"
436+
in
437+
string (ident lhs) ^^ space ^^
438+
string (Ops.assign_op_cd_syntax ~initialize_neutral accum) ^^ space ^^
439+
(if not @@ Ops.equal_unop op Ops.Identity then
440+
string (Ops.unop_cd_syntax op ^ " ")
441+
else empty) ^^
442+
string (buffer_ident rhs) ^^
443+
(if not (String.equal proj_spec ".") then
444+
string (" ~logic:\"" ^ proj_spec ^ "\"")
445+
else empty) ^^
446+
string ";" ^^ break 1
447+
| Fetch { array; fetch_op; dims = _ } ->
448+
string (ident array) ^^ string " := " ^^ doc_of_fetch_op fetch_op ^^ string ";" ^^ break 1
449+
in
450+
451+
(* Create the header document *)
452+
let header_doc =
453+
match name, static_indices with
454+
| Some n, Some si ->
455+
string (n ^ " (") ^^
456+
separate (comma ^^ space)
457+
(List.map si ~f:Indexing.Doc_helpers.pp_static_symbol) ^^
458+
string "):" ^^ space
459+
| Some n, None -> string (n ^ ":") ^^ space
460+
| _ -> empty
461+
in
462+
463+
header_doc ^^ nest 2 (doc_of_code c)
349464

350465
let%track6_sexp lower ~unoptim_ll_source ~ll_source ~cd_source ~name static_indices (proc : t) :
351466
Low_level.optimized =
@@ -354,6 +469,5 @@ let%track6_sexp lower ~unoptim_ll_source ~ll_source ~cd_source ~name static_indi
354469
(match cd_source with
355470
| None -> ()
356471
| Some ppf ->
357-
fprint_hum ~name ~static_indices () ppf proc;
358-
Stdlib.Format.pp_print_flush ppf ());
472+
fprint_hum ~name ~static_indices () ppf proc);
359473
Low_level.optimize ~unoptim_ll_source ~ll_source ~name static_indices llc

arrayjit/lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
(preprocessor_deps
55
(env_var OCANNL_LOG_LEVEL))
66
(modules utils)
7-
(libraries base stdio ppx_minidebug.runtime)
7+
(libraries base stdio pprint ppx_minidebug.runtime)
88
(preprocess
99
(pps
1010
ppx_compare

arrayjit/lib/indexing.ml

Lines changed: 18 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -198,21 +198,26 @@ let derive_index ~product_syms ~(projection : axis_index array) =
198198
fun ~product -> Array.map positions ~f:(function First p -> product.(p) | Second it -> it)
199199

200200
module Pp_helpers = struct
201-
let pp_comma ppf () = Stdlib.Format.fprintf ppf ",@ "
202-
let pp_symbol ppf sym = Stdlib.Format.fprintf ppf "%s" @@ symbol_ident sym
201+
open PPrint
202+
203+
let pp_comma () = comma ^^ space
204+
let pp_symbol sym = string (symbol_ident sym)
203205

204-
let pp_static_symbol ppf { static_symbol; static_range } =
206+
let pp_static_symbol { static_symbol; static_range } =
205207
match static_range with
206-
| None -> pp_symbol ppf static_symbol
207-
| Some range -> Stdlib.Format.fprintf ppf "%a : [0..%d]" pp_symbol static_symbol (range - 1)
208-
209-
let pp_axis_index ppf idx =
210-
match idx with
211-
| Iterator sym -> pp_symbol ppf sym
212-
| Fixed_idx i -> Stdlib.Format.fprintf ppf "%d" i
213-
214-
let pp_indices ppf idcs =
215-
Stdlib.Format.pp_print_list ~pp_sep:pp_comma pp_axis_index ppf @@ Array.to_list idcs
208+
| None -> pp_symbol static_symbol
209+
| Some range ->
210+
infix 4 1 colon (pp_symbol static_symbol)
211+
(brackets (string "0.." ^^ OCaml.int (range - 1)))
212+
213+
let pp_axis_index = function
214+
| Iterator sym -> pp_symbol sym
215+
| Fixed_idx i -> OCaml.int i
216+
217+
let pp_indices idcs =
218+
separate (pp_comma ()) (Array.to_list idcs |> List.map ~f:pp_axis_index)
219+
220+
let print ppf doc = ToFormatter.pretty 1.0 80 ppf doc
216221
end
217222

218223
module Doc_helpers = struct

0 commit comments

Comments
 (0)