Skip to content

Commit 2761d95

Browse files
committed
Revert constancy-tracking for Virtual; formatting
1 parent 65255f6 commit 2761d95

24 files changed

+272
-287
lines changed

CHANGES.md

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020
- Also renamed the badly named `Get_global` to `Access`.
2121
- Initialization now needs to be handled via running the corresponding code explicitly. In particular `Tensor.init_params` will run the forward code of tensors from the `params` field.
2222
- Virtual nodes and inlining now also work across routines. This required changing the API to pass the `optimize_ctx` optimization context.
23-
- TODO: The modes can now be escalated from non-hosted to hosted. This means the `Tnode.array` field is no longer lazy, but mutable.
24-
- The virtual memory mode is now tagged with whether the node is constant.
25-
23+
- TODO: The memory modes now decide between non-hosted and hosted based on type inference (dependent types style). (FIXME: is this at all possible?)
2624

2725
## [0.5.3] -- 2025-05-24
2826

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,7 @@ type memory_type =
9292
9393
type memory_mode =
9494
| Effectively_constant (** Either [Hosted Constant], or a subset of [Virtual]. *)
95-
| Virtual of { is_constant:bool }
96-
(** The tensor node's computations are inlined on a per-scalar basis. *)
95+
| Virtual (** The tensor node's computations are inlined on a per-scalar basis. *)
9796
| Never_virtual (** One of: [Local], [On_device], [Hosted]. *)
9897
| Local
9998
(** The full tensor node is cached for the duration of a computation but not persisted across

arrayjit/lib/assignments.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -409,8 +409,8 @@ let to_doc ?name ?static_indices () c =
409409

410410
header_doc ^^ nest 2 (doc_of_code c)
411411

412-
let%track6_sexp lower optim_ctx ~unoptim_ll_source ~ll_source ~cd_source ~name static_indices (proc : t) :
413-
Low_level.optimized =
412+
let%track6_sexp lower optim_ctx ~unoptim_ll_source ~ll_source ~cd_source ~name static_indices
413+
(proc : t) : Low_level.optimized =
414414
let llc : Low_level.t = to_low_level proc in
415415
(* Generate the low-level code before outputting the assignments, to force projections. *)
416416
(match cd_source with

arrayjit/lib/backend_impl.ml

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -91,9 +91,12 @@ module Device_types (Device_config : Device_config) = struct
9191
type nonrec stream = (buffer_ptr, dev, runner, event) stream [@@deriving sexp_of]
9292
type nonrec context = (buffer_ptr, stream, optimize_ctx) context [@@deriving sexp_of]
9393
end
94+
9495
module Device_types_ll (Device_config : Device_config_common) = struct
9596
include Device_config
97+
9698
type optimize_ctx = Low_level.optimize_ctx [@@deriving sexp_of]
99+
97100
let empty_optimize_ctx = { Low_level.computations = Hashtbl.create (module Tnode) }
98101

99102
type nonrec device = (buffer_ptr, dev, runner, event) device [@@deriving sexp_of]
@@ -142,13 +145,20 @@ struct
142145

143146
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
144147

145-
let make_context ?(ctx_arrays = Map.empty (module Tnode)) ?(optimize_ctx = empty_optimize_ctx) stream =
148+
let make_context ?(ctx_arrays = Map.empty (module Tnode)) ?(optimize_ctx = empty_optimize_ctx)
149+
stream =
146150
{ stream; parent = None; ctx_arrays; finalized = Atomic.make false; optimize_ctx }
147151

148152
let make_child ?ctx_arrays ?optimize_ctx parent =
149153
let ctx_arrays = Option.value ctx_arrays ~default:parent.ctx_arrays in
150154
let optimize_ctx = Option.value optimize_ctx ~default:parent.optimize_ctx in
151-
{ stream = parent.stream; parent = Some parent; ctx_arrays; finalized = Atomic.make false; optimize_ctx }
155+
{
156+
stream = parent.stream;
157+
parent = Some parent;
158+
ctx_arrays;
159+
finalized = Atomic.make false;
160+
optimize_ctx;
161+
}
152162
end
153163

154164
(** Parts shared by backend implementations. *)

arrayjit/lib/backends.mli

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,10 @@ val finalize :
1111
and type event = 'event
1212
and type runner = 'runner
1313
and type optimize_ctx = 'optimize_ctx) ->
14-
('buffer_ptr, ('buffer_ptr, 'dev, 'runner, 'event) Ir.Backend_intf.stream, 'optimize_ctx) Ir.Backend_intf.context ->
14+
( 'buffer_ptr,
15+
('buffer_ptr, 'dev, 'runner, 'event) Ir.Backend_intf.stream,
16+
'optimize_ctx )
17+
Ir.Backend_intf.context ->
1518
unit
1619
(** Frees the arrays that are specific to the context -- not contained in the parent context. Note:
1720
use [finalize] to optimize memory, it is not obligatory because all arrays are freed when their

arrayjit/lib/c_syntax.ml

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -608,15 +608,22 @@ module C_syntax (B : C_syntax_config) = struct
608608
let dims_val = Lazy.force dims in
609609
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
610610
let offset_doc = pp_array_offset (idcs, dims_val) in
611-
let ptr_str = Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec in
611+
let ptr_str =
612+
Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec
613+
in
612614
let expr =
613-
string prefix ^^ string ("(*(" ^ ptr_str ^ " + ") ^^ offset_doc ^^ string "))" ^^ string postfix
615+
string prefix
616+
^^ string ("(*(" ^ ptr_str ^ " + ")
617+
^^ offset_doc ^^ string "))" ^^ string postfix
614618
in
615619
(empty, expr)
616620
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
617621
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
618622
let expr =
619-
string prefix ^^ string ("file_mapped_data_" ^ file ^ "[") ^^ pp_array_offset (idcs, [||]) ^^ string "]" ^^ string postfix
623+
string prefix
624+
^^ string ("file_mapped_data_" ^ file ^ "[")
625+
^^ pp_array_offset (idcs, [||])
626+
^^ string "]" ^^ string postfix
620627
in
621628
(empty, expr)
622629
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
@@ -625,7 +632,8 @@ module C_syntax (B : C_syntax_config) = struct
625632
let offset_doc = pp_array_offset (idcs, Lazy.force tn.dims) in
626633
let source_ident = string (get_ident tn) in
627634
let expr =
628-
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
635+
string prefix
636+
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
629637
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
630638
in
631639
(empty, expr)
@@ -716,9 +724,13 @@ module C_syntax (B : C_syntax_config) = struct
716724
let dims_val = Lazy.force dims in
717725
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
718726
let offset_doc = pp_array_offset (idcs, dims_val) in
719-
let ptr_str = Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec in
727+
let ptr_str =
728+
Ops.c_rawptr_to_string (Ctypes.raw_address_of_ptr @@ Ctypes.to_voidp ptr) source_prec
729+
in
720730
let access_doc =
721-
string prefix ^^ string ("(*(" ^ ptr_str ^ " + ") ^^ offset_doc ^^ string "))" ^^ string postfix
731+
string prefix
732+
^^ string ("(*(" ^ ptr_str ^ " + ")
733+
^^ offset_doc ^^ string "))" ^^ string postfix
722734
in
723735
let expr_doc =
724736
string prefix ^^ string ("external[%u]{=" ^ B.float_log_style ^ "}") ^^ string postfix
@@ -727,10 +739,15 @@ module C_syntax (B : C_syntax_config) = struct
727739
| Access (Low_level.File_mapped (file, source_prec), Some idcs) ->
728740
let prefix, postfix = B.convert_precision ~from:source_prec ~to_:prec in
729741
let access_doc =
730-
string prefix ^^ string ("file_mapped_data_" ^ file ^ "[") ^^ pp_array_offset (idcs, [||]) ^^ string "]" ^^ string postfix
742+
string prefix
743+
^^ string ("file_mapped_data_" ^ file ^ "[")
744+
^^ pp_array_offset (idcs, [||])
745+
^^ string "]" ^^ string postfix
731746
in
732747
let expr_doc =
733-
string prefix ^^ string ("file_mapped_" ^ file ^ "[%u]{=" ^ B.float_log_style ^ "}") ^^ string postfix
748+
string prefix
749+
^^ string ("file_mapped_" ^ file ^ "[%u]{=" ^ B.float_log_style ^ "}")
750+
^^ string postfix
734751
in
735752
(expr_doc, [ `Accessor (idcs, [||]); `Value access_doc ])
736753
| Access (Low_level.Uint4x32_to_prec_uniform { source; prec = source_prec }, Some idcs) ->
@@ -740,12 +757,16 @@ module C_syntax (B : C_syntax_config) = struct
740757
let offset_doc = pp_array_offset (idcs, dims) in
741758
let source_ident = string (get_ident tn) in
742759
let access_doc =
743-
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
760+
string prefix
761+
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
744762
^^ source_ident ^^ brackets offset_doc ^^ string ")" ^^ string postfix
745763
in
746764
let expr_doc =
747-
string prefix ^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
748-
^^ source_ident ^^ brackets (string "%u") ^^ string "){=" ^^ string B.float_log_style ^^ string "}" ^^ string postfix
765+
string prefix
766+
^^ string ("uint4x32_to_" ^ Ops.prec_string source_prec ^ "_uniform(")
767+
^^ source_ident
768+
^^ brackets (string "%u")
769+
^^ string "){=" ^^ string B.float_log_style ^^ string "}" ^^ string postfix
749770
in
750771
(expr_doc, [ `Accessor (idcs, dims); `Value access_doc ])
751772
| Access _ -> failwith "C_syntax: Access cases with wrong indices / FFI NOT IMPLEMENTED YET"

arrayjit/lib/gcc_backend.ml

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -233,10 +233,9 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
233233
| Iterator s -> Map.find_exn env s
234234
| Affine { symbols; offset } ->
235235
List.fold symbols ~init:(RValue.int ctx c_index offset) ~f:(fun acc (coeff, s) ->
236-
RValue.binary_op ctx Plus c_index acc
237-
(RValue.binary_op ctx Mult c_index
238-
(RValue.int ctx c_index coeff)
239-
(Map.find_exn env s))))
236+
RValue.binary_op ctx Plus c_index acc
237+
(RValue.binary_op ctx Mult c_index (RValue.int ctx c_index coeff)
238+
(Map.find_exn env s))))
240239
with e ->
241240
Stdlib.Format.eprintf
242241
"exec_as_gccjit: missing index from@ %a@ among environment keys:@ %a\n%!" Sexp.pp_hum
@@ -359,10 +358,14 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
359358
| Embed_index (Fixed_idx i) -> (Int.to_string i, [])
360359
| Embed_index (Iterator s) -> (Indexing.symbol_ident s ^ "{=%d}", [ Map.find_exn env s ])
361360
| Embed_index (Affine { symbols; offset }) ->
362-
let terms = List.map symbols ~f:(fun (coeff, s) ->
363-
if coeff = 1 then Indexing.symbol_ident s
364-
else Int.to_string coeff ^ "*" ^ Indexing.symbol_ident s) in
365-
let expr = String.concat ~sep:"+" (terms @ if offset = 0 then [] else [Int.to_string offset]) in
361+
let terms =
362+
List.map symbols ~f:(fun (coeff, s) ->
363+
if coeff = 1 then Indexing.symbol_ident s
364+
else Int.to_string coeff ^ "*" ^ Indexing.symbol_ident s)
365+
in
366+
let expr =
367+
String.concat ~sep:"+" (terms @ if offset = 0 then [] else [ Int.to_string offset ])
368+
in
366369
(expr, [])
367370
| Binop (Arg1, v1, _v2) -> loop v1
368371
| Binop (Arg2, _v1, v2) -> loop v2
@@ -524,10 +527,9 @@ let compile_main ~name ~log_functions ~env { ctx; nodes; get_ident; merge_node;
524527
raise e)
525528
| Embed_index (Affine { symbols; offset }) ->
526529
List.fold symbols ~init:(RValue.int ctx num_typ offset) ~f:(fun acc (coeff, s) ->
527-
RValue.binary_op ctx Plus num_typ acc
528-
(RValue.binary_op ctx Mult num_typ
529-
(RValue.int ctx num_typ coeff)
530-
(RValue.cast ctx (Map.find_exn env s) num_typ)))
530+
RValue.binary_op ctx Plus num_typ acc
531+
(RValue.binary_op ctx Mult num_typ (RValue.int ctx num_typ coeff)
532+
(RValue.cast ctx (Map.find_exn env s) num_typ)))
531533
| Binop (Arg2, _, c2) -> loop c2
532534
| Binop (Arg1, c1, _) -> loop c1
533535
| Binop (op, c1, c2) -> loop_binop op ~num_typ prec ~v1:(loop c1) ~v2:(loop c2)

arrayjit/lib/low_level.ml

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -269,7 +269,7 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
269269
if
270270
virtualize_settings.inline_scalar_constexprs && traced.is_scalar_constexpr
271271
&& not (Tn.known_non_virtual tn)
272-
then Tn.update_memory_mode tn (Virtual { is_constant = traced.is_scalar_constexpr }) 40;
272+
then Tn.update_memory_mode tn Virtual 40;
273273
if Option.is_none tn.memory_mode && Hashtbl.exists traced.accesses ~f:is_too_many then
274274
Tn.update_memory_mode tn Never_virtual 1
275275
(* The tensor node is read-only/recurrent for this computation, but maybe computed or
@@ -602,7 +602,7 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
602602
| Some a ->
603603
if not @@ Tn.known_non_virtual a then (
604604
(* FIXME(#296): *)
605-
Tn.update_memory_mode a (Virtual { is_constant = false }) 15;
605+
Tn.update_memory_mode a Virtual 15;
606606
None)
607607
else
608608
Option.map ~f:(fun body : t -> For_loop { for_config with body })
@@ -613,21 +613,21 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
613613
| Zero_out tn ->
614614
if not @@ Tn.known_non_virtual tn then (
615615
(* FIXME(#296): *)
616-
Tn.update_memory_mode tn (Virtual { is_constant = false }) 151;
616+
Tn.update_memory_mode tn Virtual 151;
617617
None)
618618
else Some llc
619619
| Set { tn; idcs; llv; debug } ->
620620
if not @@ Tn.known_non_virtual tn then (
621621
(* FIXME(#296): *)
622-
Tn.update_memory_mode tn (Virtual { is_constant = false }) 152;
622+
Tn.update_memory_mode tn Virtual 152;
623623
None)
624624
else (
625625
assert (
626626
Array.for_all idcs ~f:(function Indexing.Iterator s -> Set.mem env_dom s | _ -> true));
627627
Some (Set { tn; idcs; llv = loop_float ~balanced ~env_dom llv; debug }))
628628
| Set_local (id, llv) ->
629629
assert (not @@ Tn.known_non_virtual id.tn);
630-
Tn.update_memory_mode id.tn (Virtual { is_constant = false }) 16;
630+
Tn.update_memory_mode id.tn Virtual 16;
631631
Some (Set_local (id, loop_float ~balanced ~env_dom llv))
632632
| Comment _ -> Some llc
633633
| Staged_compilation _ -> Some llc
@@ -649,11 +649,11 @@ let cleanup_virtual_llc reverse_node_map ~static_indices (llc : t) : t =
649649
if Tn.known_non_virtual id.tn then Get (id.tn, orig_indices)
650650
else
651651
let body = Option.value_exn ~here:[%here] @@ loop_proc ~balanced ~env_dom body in
652-
Tn.update_memory_mode id.tn (Virtual { is_constant = false }) 18;
652+
Tn.update_memory_mode id.tn Virtual 18;
653653
Local_scope { id; orig_indices; body }
654654
| Get_local id ->
655655
assert (not @@ Tn.known_non_virtual id.tn);
656-
Tn.update_memory_mode id.tn (Virtual { is_constant = false }) 16;
656+
Tn.update_memory_mode id.tn Virtual 16;
657657
llv
658658
| Access _ -> llv
659659
| Embed_index (Fixed_idx _) -> llv
@@ -888,7 +888,9 @@ let%diagn2_sexp optimize_proc (input_ctx : optimize_ctx) static_indices llc =
888888
visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits:virtualize_settings.max_visits
889889
llc;
890890
[%log "optimizing"];
891-
let virtual_llc_result = virtual_llc input_ctx.computations traced_store reverse_node_map static_indices llc in
891+
let virtual_llc_result =
892+
virtual_llc input_ctx.computations traced_store reverse_node_map static_indices llc
893+
in
892894
let llc =
893895
simplify_llc @@ cleanup_virtual_llc reverse_node_map ~static_indices @@ virtual_llc_result
894896
in

0 commit comments

Comments
 (0)