Skip to content

Commit 657f596

Browse files
committed
Migrate Staged_compilation to PPrint.document
Signed-off-by: Lukasz Stafiniak <lukstafi@gmail.com>
1 parent dbcdf2e commit 657f596

File tree

5 files changed

+12
-12
lines changed

5 files changed

+12
-12
lines changed

arrayjit/lib/c_syntax.ml

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -356,11 +356,7 @@ module C_syntax (B : C_syntax_config) = struct
356356
~args_docs:[]
357357
else string "/* " ^^ string message ^^ string " */"
358358
| Staged_compilation callback ->
359-
(* This is tricky. PPrint needs to generate the document synchronously. We might need to
360-
change how Staged_compilation works if it needs to produce dynamic documents. For now,
361-
assume it produces no output. *)
362-
callback ();
363-
empty
359+
callback ()
364360
| Set_local ({ scope_id; tn = { prec; _ } }, value) ->
365361
let local_defs, value_doc = pp_float (Lazy.force prec) value in
366362
let assignment =

arrayjit/lib/gcc_backend.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -439,7 +439,7 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
439439
let value = loop_float ~name ~env ~num_typ local_prec llv in
440440
Block.assign !current_block lhs value
441441
| Comment c -> log_comment c
442-
| Staged_compilation exp -> exp ()
442+
| Staged_compilation _exp -> failwith "gccjit_backend: Staged_compilation not implemented"
443443
and loop_float ~name ~env ~num_typ prec v_code =
444444
let loop = loop_float ~name ~env ~num_typ prec in
445445
match v_code with

arrayjit/lib/low_level.ml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ let get_scope =
3131
type t =
3232
| Noop
3333
| Comment of string
34-
| Staged_compilation of ((unit -> unit)[@equal.ignore] [@compare.ignore])
34+
| Staged_compilation of ((unit -> PPrint.document)[@equal.ignore] [@compare.ignore])
3535
| Seq of t * t
3636
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
3737
| Zero_out of Tn.t
@@ -917,7 +917,9 @@ let fprint_cstyle ?name ?static_indices () ppf llc =
917917
(pp_float @@ Lazy.force p.tn.prec)
918918
p.llv
919919
| Comment message -> fprintf ppf "/* %s */" message
920-
| Staged_compilation _ -> fprintf ppf "STAGED_COMPILATION_CALLBACK()"
920+
| Staged_compilation callback ->
921+
let doc = callback () in
922+
PPrint.ToFormatter.pretty 1.0 100 ppf doc
921923
| Set_local (id, llv) ->
922924
fprintf ppf "@[<2>%a :=@ %a;@]" pp_local id (pp_float @@ Lazy.force id.tn.prec) llv
923925
and pp_float prec ppf value =
@@ -977,7 +979,9 @@ let fprint_hum ?name ?static_indices () ppf llc =
977979
p.debug <- asprintf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv;
978980
fprintf ppf "@[<2>%a[@,%a] :=@ %a;@]" pp_ident p.tn pp_indices p.idcs pp_float p.llv
979981
| Comment message -> fprintf ppf "/* %s */" message
980-
| Staged_compilation _ -> fprintf ppf "STAGED_COMPILATION_CALLBACK()"
982+
| Staged_compilation callback ->
983+
let doc = callback () in
984+
PPrint.ToFormatter.pretty 1.0 100 ppf doc
981985
| Set_local (id, llv) -> fprintf ppf "@[<2>%a :=@ %a;@]" pp_local id pp_float llv
982986
and pp_float ppf value =
983987
match value with

arrayjit/lib/low_level.mli

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ type scope_id = Scope_id.t = { tn : Tnode.t; scope_id : int }
1818
type t =
1919
| Noop
2020
| Comment of string
21-
| Staged_compilation of (unit -> unit)
21+
| Staged_compilation of (unit -> PPrint.document)
2222
| Seq of t * t
2323
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
2424
| Zero_out of Tnode.t

arrayjit/lib/lowering_and_inlining.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ The low-level representation is a C-like mini-language operating on scalars.
4747
type t =
4848
| Noop
4949
| Comment of string
50-
| Staged_compilation of (unit -> unit)
50+
| Staged_compilation of (unit -> PPrint.document)
5151
| Seq of t * t
5252
| For_loop of { index : Indexing.symbol; from_ : int; to_ : int; body : t; trace_it : bool }
5353
| Zero_out of Tnode.t
@@ -70,7 +70,7 @@ and float_t =
7070
| Embed_index of Indexing.axis_index
7171
```
7272

73-
The odd part is the `Staged_compilation` element. Backends with an imperative compilation procedure, e.g. using the `Stdlib.Format` module, can use `Staged_compilation` to embed some emitted code within on-the-fly generated `Low_level.t` code.
73+
The odd part is the `Staged_compilation` element. Backends can use `Staged_compilation` to embed some emitted code within on-the-fly generated `Low_level.t` code. Currently this works only for `PPrint.document` based backends like `C_syntax` derivatives, but this covers almost all backends.
7474

7575
TODO: flesh out explanation.
7676

0 commit comments

Comments
 (0)