Skip to content

Commit c63a596

Browse files
committed
Refactoring of Assignments.t to avoid field duplication and to provide projections debug without forcing projections, by Claude Code
Prompt 1: I just started refactoring Assignments.t in arrayjit/lib/assignments.ml to avoid duplicating lots of fields, and I'm adding the field projections_debug that I want to contain the `debug_info` field of `Indexing.projections`, or its approximation, even before the projections are actually computed. This will require coordination with the syntax extension ppx_cd.ml . I attach the relevant files in the context. Update the code to take the changed type definitions into account, and set the `projections_debug` fields from the projections specs. Point out where finding `projections_debug` without calculating the projections is tricky, I can help. Summary 1, by Claude: ## Summary I've successfully updated the code to use your refactored `Assignments.t` type. Here's what was done: 1. **Updated `assignments.ml`** to use the new consolidated `Accum_op` constructor that contains an `accum_rhs` field (which can be `Unop`, `Binop`, or `Ternop`) instead of the separate `Accum_unop`, `Accum_binop`, and `Accum_ternop` constructors. 2. **Added the `projections_debug` field** to both `Accum_op` and `Set_vec_unop` constructors to store the projection specification string. 3. **Updated `tensor.ml`** to set the `projections_debug` field by using `Shape.logic_to_spec shape_logic` to convert the shape logic to its string representation. 4. **Updated `ppx_cd.ml`** to generate code using the new constructor. For cases where projections are computed lazily (no explicit `~logic:` parameter), I set `projections_debug` to `"<derived>"` as a placeholder. 5. **Exported `Shape.logic_to_spec`** in the `shape.mli` interface file so it can be used from `tensor.ml`. ### Tricky parts where finding `projections_debug` without calculating projections is challenging: 1. **In `ppx_cd.ml` when no explicit `~logic:` is provided**: In these cases, the projections are computed lazily based on shape inference. We don't have the actual projection spec string at assignment creation time, so I used `"<derived>"` as a placeholder. If you need the actual spec in these cases, you might need to either: - Extract it from the lazy projections when they're forced - Compute a preliminary spec based on the operator type (most operators default to "." for pointwise) - Store additional information in the assignment to compute the spec later 2. **For vector operations**: The `Set_vec_unop` operations also use projections, and similarly we set `projections_debug` to `"<derived>"` when created through ppx_cd without explicit logic. The code now compiles successfully with `dune build @check`. Prompt 2: How difficult would it be to change `type projections` in the Tensor module from `Ir.Indexing.Projections` to `string * Ir.Indexing.Projections Lazy.t` and then skip `Lazy.t` at use sites, so that we can pass the debug together with the delayed projections? The debug would be computed in the `op` function that starts at line 206 of tensor.ml , using `Shape.logic_to_spec` on the `logic` field of `List.hd_exn local_shape_updates`. Prompt 3: Let's clarify what we want so that we're on the same page. We want benefit (3). We don't want (1) and it's not really true: the passing of the pair is so that the content of `projections_debug` is readily available, in fact we can define: `type projections = {projections_debug : string; projections : Ir.Indexing.projections Lazy.t}` in the Tensor module. The question I hadn't considered is which type should the `op_asn` and `grad_asn` functions be using. It should be `projections` from Tensor rather than `Indexing.projections Lazy.t` because that's the only way that `ppx_cd.ml` can unpack the two fields to populate inside Accum_op and Set_vec_unop . Let's do this, so that the debug information becomes precise. Summary 2, by Claude: ## Summary I've successfully implemented the change to use a `projections` record type in the Tensor module that contains both the debug string and the lazy projections. Here's what was done: 1. **Defined a new projections type** in `tensor.ml` and `tensor.mli`: ```ocaml type projections = { projections_debug : string; projections : Ir.Indexing.projections Lazy.t; } ``` 2. **Updated projection creation sites** in `tensor.ml` to create the record with the debug string computed using `Shape.logic_to_spec`: - In `raw_binop`, `raw_ternop`, `raw_unop`, and `op` functions - The debug string is computed at projection creation time from the shape logic 3. **Updated op_asn and grad_asn function signatures** to use `Tensor.projections` instead of `Ir.Indexing.projections Lazy.t` 4. **Updated ppx_cd.ml** to handle the new structure: - When projections are passed explicitly (`~projections`), extract fields as `projections.Tensor.projections` and `projections.Tensor.projections_debug` - When creating lazy projections, reference `projections.Tensor.projections` in the lazy computation - Use the extracted fields when creating assignments 5. **Fixed usage sites** that needed to access the projections field The key benefit is that now the projection debug information is **always precise** - it's computed from the actual shape logic at the point where projections are created, rather than using placeholders like `"<derived>"`. This debug information travels together with the projections throughout the system, making it available wherever projections are used. ```ocaml type projections = { projections_debug : string; projections : Ir.Indexing.projections Lazy.t; } ```
1 parent 23ff973 commit c63a596

File tree

7 files changed

+176
-166
lines changed

7 files changed

+176
-166
lines changed

arrayjit/lib/assignments.ml

Lines changed: 92 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -36,42 +36,30 @@ type fetch_op =
3636
| Embed_self_id (** Embeds the id of the [array] field of the [Fetch] constructor. *)
3737
[@@deriving sexp_of, equal]
3838

39-
and t =
39+
type accum_rhs =
40+
| Ternop of { op : Ops.ternop; rhs1 : buffer; rhs2 : buffer; rhs3 : buffer }
41+
| Binop of { op : Ops.binop; rhs1 : buffer; rhs2 : buffer }
42+
| Unop of { op : Ops.unop; rhs : buffer }
43+
[@@deriving sexp_of, equal]
44+
45+
type t =
4046
| Noop
4147
| Seq of t * t
4248
| Block_comment of string * t (** Same as the given code, with a comment. *)
43-
| Accum_ternop of {
44-
initialize_neutral : bool;
45-
accum : Ops.binop;
46-
op : Ops.ternop;
47-
lhs : Tn.t;
48-
rhs1 : buffer;
49-
rhs2 : buffer;
50-
rhs3 : buffer;
51-
projections : Indexing.projections Lazy.t;
52-
}
53-
| Accum_binop of {
54-
initialize_neutral : bool;
55-
accum : Ops.binop;
56-
op : Ops.binop;
57-
lhs : Tn.t;
58-
rhs1 : buffer;
59-
rhs2 : buffer;
60-
projections : Indexing.projections Lazy.t;
61-
}
62-
| Accum_unop of {
49+
| Accum_op of {
6350
initialize_neutral : bool;
6451
accum : Ops.binop;
65-
op : Ops.unop;
6652
lhs : Tn.t;
67-
rhs : buffer;
53+
rhs : accum_rhs;
6854
projections : Indexing.projections Lazy.t;
55+
projections_debug : string;
6956
}
7057
| Set_vec_unop of {
7158
op : Ops.vec_unop;
7259
lhs : Tn.t;
7360
rhs : buffer;
7461
projections : Indexing.projections Lazy.t;
62+
projections_debug : string;
7563
}
7664
| Fetch of { array : Tn.t; fetch_op : fetch_op; dims : int array Lazy.t }
7765
[@@deriving sexp_of]
@@ -122,12 +110,14 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
122110
| Noop -> empty
123111
| Seq (t1, t2) -> loop t1 + loop t2
124112
| Block_comment (_, t) -> loop t
125-
| Accum_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs)
113+
| Accum_op { lhs; rhs; _ } ->
114+
let rhses = match rhs with
115+
| Unop { rhs; _ } -> [ of_node rhs ]
116+
| Binop { rhs1; rhs2; _ } -> [ of_node rhs1; of_node rhs2 ]
117+
| Ternop { rhs1; rhs2; rhs3; _ } -> [ of_node rhs1; of_node rhs2; of_node rhs3 ]
118+
in
119+
Set.union_list (module Tn) (one lhs :: rhses)
126120
| Set_vec_unop { lhs; rhs; _ } -> Set.union (one lhs) (of_node rhs)
127-
| Accum_binop { lhs; rhs1; rhs2; _ } ->
128-
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2 ]
129-
| Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
130-
Set.union_list (module Tn) [ one lhs; of_node rhs1; of_node rhs2; of_node rhs3 ]
131121
| Fetch { array; _ } -> one array
132122
in
133123
loop asgns
@@ -145,11 +135,14 @@ let%debug3_sexp guess_output_nodes (asgns : t) : Tn.t_set =
145135
let i2, o2 = loop t2 in
146136
(i1 + i2, o1 + o2 - (i1 + i2))
147137
| Block_comment (_, t) -> loop t
148-
| Accum_unop { lhs; rhs; _ } -> (of_node rhs, one lhs)
138+
| Accum_op { lhs; rhs; _ } ->
139+
let inputs = match rhs with
140+
| Unop { rhs; _ } -> of_node rhs
141+
| Binop { rhs1; rhs2; _ } -> of_node rhs1 + of_node rhs2
142+
| Ternop { rhs1; rhs2; rhs3; _ } -> of_node rhs1 + of_node rhs2 + of_node rhs3
143+
in
144+
(inputs, one lhs)
149145
| Set_vec_unop { lhs; rhs; _ } -> (of_node rhs, one lhs)
150-
| Accum_binop { lhs; rhs1; rhs2; _ } -> (of_node rhs1 + of_node rhs2, one lhs)
151-
| Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
152-
(of_node rhs1 + of_node rhs2 + of_node rhs3, one lhs)
153146
| Fetch { array; _ } -> (empty, one array)
154147
in
155148
snd @@ loop asgns
@@ -256,15 +249,14 @@ let%track4_sexp to_low_level code =
256249
else for_loops
257250
and loop (code : t) : Low_level.t =
258251
match code with
259-
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
260-
loop_accum ~initialize_neutral ~accum ~op:(Ops.Ternop op) ~lhs ~rhses:[| rhs1; rhs2; rhs3 |]
261-
projections
262-
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
263-
loop_accum ~initialize_neutral ~accum ~op:(Ops.Binop op) ~lhs ~rhses:[| rhs1; rhs2 |]
264-
projections
265-
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
266-
loop_accum ~initialize_neutral ~accum ~op:(Ops.Unop op) ~lhs ~rhses:[| rhs |] projections
267-
| Set_vec_unop { op; lhs; rhs; projections } ->
252+
| Accum_op { initialize_neutral; accum; lhs; rhs; projections; _ } ->
253+
let op, rhses = match rhs with
254+
| Unop { op; rhs } -> (Ops.Unop op, [| rhs |])
255+
| Binop { op; rhs1; rhs2 } -> (Ops.Binop op, [| rhs1; rhs2 |])
256+
| Ternop { op; rhs1; rhs2; rhs3 } -> (Ops.Ternop op, [| rhs1; rhs2; rhs3 |])
257+
in
258+
loop_accum ~initialize_neutral ~accum ~op ~lhs ~rhses projections
259+
| Set_vec_unop { op; lhs; rhs; projections; _ } ->
268260
(* Handle vector unary operations *)
269261
let projections = Lazy.force projections in
270262
let basecase rev_iters =
@@ -370,7 +362,7 @@ let flatten c =
370362
| Noop -> []
371363
| Seq (c1, c2) -> loop c1 @ loop c2
372364
| Block_comment (s, c) -> Block_comment (s, Noop) :: loop c
373-
| (Accum_ternop _ | Accum_binop _ | Accum_unop _ | Set_vec_unop _ | Fetch _) as c -> [ c ]
365+
| (Accum_op _ | Set_vec_unop _ | Fetch _) as c -> [ c ]
374366
in
375367
loop c
376368

@@ -397,14 +389,14 @@ let get_ident_within_code ?no_dots c =
397389
loop c1;
398390
loop c2
399391
| Block_comment (_, c) -> loop c
400-
| Accum_ternop
401-
{ initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; rhs3; projections = _ } ->
402-
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2; tn rhs3 ]
403-
| Accum_binop { initialize_neutral = _; accum = _; op = _; lhs; rhs1; rhs2; projections = _ } ->
404-
List.iter ~f:visit [ lhs; tn rhs1; tn rhs2 ]
405-
| Accum_unop { initialize_neutral = _; accum = _; op = _; lhs; rhs; projections = _ } ->
406-
List.iter ~f:visit [ lhs; tn rhs ]
407-
| Set_vec_unop { op = _; lhs; rhs; projections = _ } -> List.iter ~f:visit [ lhs; tn rhs ]
392+
| Accum_op { lhs; rhs; _ } ->
393+
let rhses = match rhs with
394+
| Unop { rhs; _ } -> [ tn rhs ]
395+
| Binop { rhs1; rhs2; _ } -> [ tn rhs1; tn rhs2 ]
396+
| Ternop { rhs1; rhs2; rhs3; _ } -> [ tn rhs1; tn rhs2; tn rhs3 ]
397+
in
398+
List.iter ~f:visit (lhs :: rhses)
399+
| Set_vec_unop { op = _; lhs; rhs; projections = _; projections_debug = _ } -> List.iter ~f:visit [ lhs; tn rhs ]
408400
| Fetch { array; fetch_op = _; dims = _ } -> visit array
409401
in
410402
loop c;
@@ -445,66 +437,55 @@ let to_doc ?name ?static_indices () c =
445437
| Seq (c1, c2) -> doc_of_code c1 ^^ doc_of_code c2
446438
| Block_comment (s, Noop) -> string ("# \"" ^ s ^ "\";") ^^ break 1
447439
| Block_comment (s, c) -> string ("# \"" ^ s ^ "\";") ^^ break 1 ^^ doc_of_code c
448-
| Accum_ternop { initialize_neutral; accum; op; lhs; rhs1; rhs2; rhs3; projections } ->
449-
let proj_spec =
450-
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
451-
else "<not-in-yet>"
452-
in
453-
(* Uncurried syntax for ternary operations. *)
454-
string (ident lhs)
455-
^^ space
456-
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
457-
^^ space
458-
^^ string (Ops.ternop_cd_syntax op)
459-
^^ string "("
460-
^^ string (buffer_ident rhs1)
461-
^^ string ", "
462-
^^ string (buffer_ident rhs2)
463-
^^ string ", "
464-
^^ string (buffer_ident rhs3)
465-
^^ string ")"
466-
^^ (if not (String.equal proj_spec ".") then string (" ~logic:\"" ^ proj_spec ^ "\"")
467-
else empty)
468-
^^ string ";" ^^ break 1
469-
| Accum_binop { initialize_neutral; accum; op; lhs; rhs1; rhs2; projections } ->
470-
let proj_spec =
471-
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
472-
else "<not-in-yet>"
473-
in
474-
string (ident lhs)
475-
^^ space
476-
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
477-
^^ space
478-
^^ string (buffer_ident rhs1)
479-
^^ space
480-
^^ string (Ops.binop_cd_syntax op)
481-
^^ space
482-
^^ string (buffer_ident rhs2)
483-
^^ (if
484-
(not (String.equal proj_spec "."))
485-
|| List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
486-
then string (" ~logic:\"" ^ proj_spec ^ "\"")
487-
else empty)
488-
^^ string ";" ^^ break 1
489-
| Accum_unop { initialize_neutral; accum; op; lhs; rhs; projections } ->
490-
let proj_spec =
491-
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
492-
else "<not-in-yet>"
493-
in
494-
string (ident lhs)
495-
^^ space
496-
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
497-
^^ space
498-
^^ (if not @@ Ops.equal_unop op Ops.Identity then string (Ops.unop_cd_syntax op ^ " ")
499-
else empty)
500-
^^ string (buffer_ident rhs)
501-
^^ (if not (String.equal proj_spec ".") then string (" ~logic:\"" ^ proj_spec ^ "\"")
502-
else empty)
503-
^^ string ";" ^^ break 1
504-
| Set_vec_unop { op; lhs; rhs; projections } ->
505-
let proj_spec =
506-
if Lazy.is_val projections then (Lazy.force projections).debug_info.spec
507-
else "<not-in-yet>"
440+
| Accum_op { initialize_neutral; accum; lhs; rhs; projections_debug; _ } ->
441+
let proj_spec = projections_debug in
442+
(match rhs with
443+
| Ternop { op; rhs1; rhs2; rhs3 } ->
444+
(* Uncurried syntax for ternary operations. *)
445+
string (ident lhs)
446+
^^ space
447+
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
448+
^^ space
449+
^^ string (Ops.ternop_cd_syntax op)
450+
^^ string "("
451+
^^ string (buffer_ident rhs1)
452+
^^ string ", "
453+
^^ string (buffer_ident rhs2)
454+
^^ string ", "
455+
^^ string (buffer_ident rhs3)
456+
^^ string ")"
457+
^^ (if not (String.equal proj_spec ".") then string (" ~logic:\"" ^ proj_spec ^ "\"")
458+
else empty)
459+
^^ string ";" ^^ break 1
460+
| Binop { op; rhs1; rhs2 } ->
461+
string (ident lhs)
462+
^^ space
463+
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
464+
^^ space
465+
^^ string (buffer_ident rhs1)
466+
^^ space
467+
^^ string (Ops.binop_cd_syntax op)
468+
^^ space
469+
^^ string (buffer_ident rhs2)
470+
^^ (if
471+
(not (String.equal proj_spec "."))
472+
|| List.mem ~equal:Ops.equal_binop Ops.[ Mul; Div ] op
473+
then string (" ~logic:\"" ^ proj_spec ^ "\"")
474+
else empty)
475+
^^ string ";" ^^ break 1
476+
| Unop { op; rhs } ->
477+
string (ident lhs)
478+
^^ space
479+
^^ string (Ops.assign_op_cd_syntax ~initialize_neutral accum)
480+
^^ space
481+
^^ (if not @@ Ops.equal_unop op Ops.Identity then string (Ops.unop_cd_syntax op ^ " ")
482+
else empty)
483+
^^ string (buffer_ident rhs)
484+
^^ (if not (String.equal proj_spec ".") then string (" ~logic:\"" ^ proj_spec ^ "\"")
485+
else empty)
486+
^^ string ";" ^^ break 1)
487+
| Set_vec_unop { op; lhs; rhs; projections = _; projections_debug } ->
488+
let proj_spec = projections_debug
508489
in
509490
string (ident lhs)
510491
^^ space
@@ -535,9 +516,8 @@ let to_doc ?name ?static_indices () c =
535516

536517
let%track6_sexp lower optim_ctx ~unoptim_ll_source ~ll_source ~cd_source ~name static_indices
537518
(proc : t) : Low_level.optimized =
538-
let llc : Low_level.t = to_low_level proc in
539-
(* Generate the low-level code before outputting the assignments, to force projections. *)
540519
(match cd_source with
541520
| None -> ()
542521
| Some callback -> callback (to_doc ~name ~static_indices () proc));
522+
let llc : Low_level.t = to_low_level proc in
543523
Low_level.optimize optim_ctx ~unoptim_ll_source ~ll_source ~name static_indices llc

arrayjit/lib/utils.ml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -700,7 +700,10 @@ let doc_of_sexp sexp =
700700
let output_to_build_file ~fname =
701701
if settings.output_debug_files_in_build_directory then
702702
let f = Stdio.Out_channel.create @@ build_file fname in
703-
let print doc = PPrint.ToChannel.pretty 0.7 100 f doc in
703+
let print doc =
704+
PPrint.ToChannel.pretty 0.7 100 f doc;
705+
Stdio.Out_channel.flush f
706+
in
704707
Some print
705708
else None
706709

lib/operation.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ let slice (batch_idx : Idx.static_symbol) =
409409
{
410410
array = v;
411411
fetch_op = Slice { batch_idx; sliced = t1.Tensor.value };
412-
dims = lazy (Lazy.force projections).Idx.lhs_dims;
412+
dims = lazy (Lazy.force projections.Tensor.projections).Idx.lhs_dims;
413413
}
414414
in
415415
let%cd grad_asn ~t:_ ~g ~t1 ~projections = g1 =+ g in

0 commit comments

Comments
 (0)