Skip to content

Commit d4277b2

Browse files
committed
Fixes the memory model: on-host arrays can be in contexts
Double check: not possible it would trigger freeing host array. Still broken: cc backend tests hang.
1 parent 25c71e5 commit d4277b2

16 files changed

+196
-169
lines changed

arrayjit/lib/anatomy_of_a_backend.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ In the future, when we introduce program search, `compile` functions will return
5757
OCANNL classifies tensor nodes according to their memory properties:
5858

5959
```ocaml
60-
(** A possible algorithm for deciding sharing within a single device:
60+
(** A possible algorithm for deciding sharing within a single device:
6161
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
6262
cross-stream sharing candidate.
6363
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
@@ -71,9 +71,14 @@ OCANNL classifies tensor nodes according to their memory properties:
7171
If a tensor node is shared cross-stream, within-device copying is a NOOP as source and
7272
destination pointers are in that case identical. *)
7373
type sharing =
74-
| Unset
74+
| Unset (** One of: [Per_stream], [Shared_cross_streams]. *)
7575
| Per_stream (** The tensor node has separate arrays for each stream. *)
76-
| Shared_cross_stream (** The tensor node has a single array per device. *)
76+
| Shared_cross_streams
77+
(** The tensor node has a single array per device that can appear in multiple contexts, except
78+
for backends with [Option.is_some use_host_memory] and nodes with memory mode already
79+
[Hosted (Changed_on_devices Shared_cross_streams)] before first linking on a device, where
80+
it only has the on-host array. In that case the on-host array is registered in the
81+
context, to avoid misleading behavior from `device_to_device`. *)
7782
7883
type memory_type =
7984
| Constant (** The tensor node does not change after initialization. *)
@@ -110,6 +115,8 @@ A backend can make more refined distinctions, for example a `Local` node in CUDA
110115

111116
Contexts track (or store) the on-device arrays corresponding to tensor nodes. Contexts form a hierarchy: linking takes a parent context and outputs a child context. Related contexts that use a tensor node must use the same on-device array for the tensor node. If two unrelated contexts are on the same device, i.e. have a common ancestor, and use the same tensor node that is not part of the most recent common ancestor, the behavior is undefined.
112117

118+
To avoid misleading behavior of `device_to_device` data movement, non-constant materialized tensor nodes are represented in contexts making use of them, even when the underlying array is on host. This way the logic remains the same regardless of whether a backend shares memory with the host. We are careful to not accidentally call `free_buffer` on hosted arrays.
119+
113120
## Typical details of a backend implementation
114121

115122
During the compilation process, the old context cannot be available when `compile` is handled. Currently, all backends generate context-and-device-independent kernels, that refer to context arrays via parameters.

arrayjit/lib/assignments.ml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -77,16 +77,13 @@ let get_name_exn asgns =
7777
let is_total ~initialize_neutral ~projections =
7878
initialize_neutral && Indexing.is_bijective projections
7979

80-
(** Returns materialized nodes in the sense of {!Tnode.is_in_context}. NOTE: it should be called
81-
after compilation and ideally after linking with the relevant contexts; otherwise, it is an
82-
under-estimate. *)
83-
let%debug3_sexp context_nodes ~(use_host_memory : bool) (asgns : t) : Tn.t_set =
80+
(** Returns materialized nodes in the sense of {!Tnode.is_in_context_force}. NOTE: it must be called
81+
after compilation; otherwise, it will disrupt memory mode inference. *)
82+
let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_set =
8483
let open Utils.Set_O in
8584
let empty = Set.empty (module Tn) in
8685
let one tn =
87-
if Option.value ~default:false @@ Tnode.is_in_context ~use_host_memory tn then
88-
Set.singleton (module Tn) tn
89-
else empty
86+
if Tn.is_in_context_force ~use_host_memory tn 34 then Set.singleton (module Tn) tn else empty
9087
in
9188
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
9289
let rec loop = function

arrayjit/lib/backend_impl.ml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ open Backend_intf
1616
module type No_device_buffer_and_copying = sig
1717
include Alloc_buffer with type stream := unit
1818

19+
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
20+
1921
val get_used_memory : unit -> int
2022
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
2123

@@ -28,6 +30,7 @@ module No_device_buffer_and_copying () :
2830
No_device_buffer_and_copying with type buffer_ptr = unit Ctypes.ptr = struct
2931
type buffer_ptr = unit Ctypes.ptr
3032

33+
let use_host_memory = Some Fn.id
3134
let sexp_of_buffer_ptr = Ops.sexp_of_voidptr
3235

3336
include Buffer_types (struct
@@ -70,8 +73,6 @@ module No_device_buffer_and_copying () :
7073
Ctypes_memory_stubs.memcpy
7174
~dst:(Ndarray.get_fatptr_not_managed dst)
7275
~src ~size:(Ndarray.size_in_bytes dst)
73-
74-
let c_ptr_to_string = Some Ops.c_ptr_to_string
7576
end
7677

7778
module Device_types (Device_config : Device_config) = struct
@@ -133,10 +134,11 @@ end
133134
module type Backend_impl_common = sig
134135
include Buffer
135136

136-
val use_host_memory : bool
137-
(** If true, the backend will read from and write to the host memory directly whenever possible.
137+
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
138+
(** If not [None], the backend will read from and write to the host memory directly whenever
139+
reasonable.
138140
139-
[use_host_memory] can only be true on unified memory devices, like CPU and Apple Metal. *)
141+
[use_host_memory] can only be [Some] on unified memory devices, like CPU and Apple Metal. *)
140142
end
141143

142144
(** An interface to adding schedulers for stream-agnostic (typically CPU) backend implementations.

arrayjit/lib/backend_intf.ml

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ end
1818
module type Buffer = sig
1919
type buffer_ptr [@@deriving sexp_of]
2020

21-
val c_ptr_to_string : (buffer_ptr -> Ops.prec -> string) option
22-
2321
include module type of Buffer_types (struct
2422
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
2523
end)

arrayjit/lib/backends.ml

Lines changed: 21 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
9090
let ordinal_of ctx = ctx.stream.device.ordinal in
9191
let name_of ctx = Backend.(get_name ctx.stream) in
9292
let same_device = ordinal_of dst = ordinal_of src in
93-
if same_device && (Tn.known_shared_cross_stream tn || String.equal (name_of src) (name_of dst))
93+
if same_device && (Tn.known_shared_cross_streams tn || String.equal (name_of src) (name_of dst))
9494
then false
9595
else
9696
match Map.find src.ctx_arrays tn with
@@ -187,8 +187,7 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
187187
let%debug3_sexp verify_prior_context ~use_host_memory ~ctx_arrays ~from_prior_context : unit =
188188
Set.iter from_prior_context ~f:(fun tn ->
189189
if
190-
(* Err on the safe side. *)
191-
Option.value ~default:false (Tn.is_in_context ~use_host_memory tn)
190+
Tn.is_in_context_force ~use_host_memory tn 42
192191
&& not (Option.is_some @@ Map.find ctx_arrays tn)
193192
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
194193

@@ -349,27 +348,35 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
349348
}
350349

351350
let%track3_sexp alloc_if_needed (stream : stream) ~key ~data:node ctx_arrays =
352-
(* TODO: do we need this? *)
353-
(* Tn.default_to_most_local key 345; *)
354-
if
355-
Option.value ~default:true (Tnode.is_in_context ~use_host_memory key)
356-
&& not (Map.mem ctx_arrays key)
357-
then (
351+
if Tnode.is_in_context_force ~use_host_memory key 43 && not (Map.mem ctx_arrays key) then (
358352
[%log Tn.debug_name key];
359353
[%log (key : Tnode.t)];
360354
let default () =
361355
alloc_zero_init_array (Lazy.force key.prec) ~dims:(Lazy.force key.dims) stream
362356
in
363357
let add_new () = Map.add_exn ctx_arrays ~key ~data:(default ()) in
364358
let device = stream.device in
365-
if node.Low_level.read_only then
359+
if node.Low_level.read_only then (
366360
if Tn.known_non_cross_stream key then add_new ()
367-
else (
361+
else
362+
let data =
363+
match use_host_memory with
364+
| None -> Hashtbl.find_or_add device.cross_stream_candidates key ~default
365+
| Some get_buffer_ptr ->
366+
if
367+
(not (Hashtbl.mem device.cross_stream_candidates key))
368+
&& Tn.known_shared_cross_streams key && Tn.is_hosted_force key 44
369+
then
370+
Hashtbl.update_and_return device.cross_stream_candidates key ~f:(fun _ ->
371+
get_buffer_ptr @@ Ndarray.get_voidptr_not_managed
372+
@@ Option.value_exn ~here:[%here]
373+
@@ Lazy.force key.array)
374+
else Hashtbl.find_or_add device.cross_stream_candidates key ~default
375+
in
368376
if Hashtbl.mem device.cross_stream_candidates key then
369-
Tn.update_memory_sharing key Tn.Shared_cross_stream 39;
370-
let data = Hashtbl.find_or_add device.cross_stream_candidates key ~default in
377+
Tn.update_memory_sharing key Tn.Shared_cross_streams 39;
371378
Map.add_exn ctx_arrays ~key ~data)
372-
else if Tn.known_shared_cross_stream key then (
379+
else if Tn.known_shared_cross_streams key then (
373380
if Hashtbl.mem device.owner_stream key then (
374381
if not (equal_stream stream (Hashtbl.find_exn device.owner_stream key)) then
375382
raise

arrayjit/lib/c_syntax.ml

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ module C_syntax (B : sig
1515
(** The low-level prcedure to compile, and the arrays of the context it will be linked to if not
1616
shared and already known. *)
1717

18-
val use_host_memory : bool
18+
type buffer_ptr
19+
20+
val use_host_memory : (unit Ctypes.ptr -> buffer_ptr) option
1921
val logs_to_stdout : bool
2022
val main_kernel_prefix : string
2123
val kernel_prep_line : string
@@ -29,7 +31,7 @@ struct
2931
let get_ident =
3032
Low_level.get_ident_within_code ~no_dots:true @@ Array.map B.procs ~f:(fun l -> l.llc)
3133

32-
let in_ctx tn = B.(Tn.is_in_context ~use_host_memory tn)
34+
let in_ctx tn = B.(Tn.is_in_context_force ~use_host_memory tn 46)
3335

3436
let pp_zero_out ppf tn =
3537
Stdlib.Format.fprintf ppf "@[<2>memset(%s, 0, %d);@]@ " (get_ident tn) @@ Tn.size_in_bytes tn
@@ -268,8 +270,8 @@ struct
268270
(* A rough approximation to the type Gccjit_backend.mem_properties. *)
269271
let backend_info, is_param =
270272
if Tn.is_virtual_force tn 334 then ("Virt", false)
271-
else if Option.value ~default:false @@ in_ctx tn then ("Ctx", true)
272-
else if Tn.is_materialized_force tn 335 then ("Global or ctx", true)
273+
else if in_ctx tn then ("Ctx", true)
274+
else if Tn.is_materialized_force tn 335 then ("Global", true)
273275
else if Tn.known_not_materialized tn then ("Local", false)
274276
else assert false
275277
in

arrayjit/lib/cc_backend.ml

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ type procedure = {
3333
}
3434
[@@deriving sexp_of]
3535

36-
let use_host_memory = true
37-
3836
let get_global_run_id =
3937
let next_id = ref 0 in
4038
fun () ->
@@ -76,6 +74,9 @@ module C_syntax_config (Input : sig
7674
end) =
7775
struct
7876
let procs = Input.procs
77+
78+
type nonrec buffer_ptr = buffer_ptr
79+
7980
let use_host_memory = use_host_memory
8081
let logs_to_stdout = false
8182
let main_kernel_prefix = ""
@@ -127,14 +128,6 @@ let%diagn_sexp compile_batch ~names bindings (lowereds : Low_level.optimized opt
127128

128129
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
129130
let name : string = code.name in
130-
List.iter code.params ~f:(function
131-
| _, Param_ptr tn when not @@ Tn.known_shared_with_host ~use_host_memory tn ->
132-
if not (Map.mem ctx_arrays tn) then
133-
invalid_arg
134-
[%string
135-
"Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from context: \
136-
%{Tn.debug_memory_mode tn.Tn.memory_mode}"]
137-
| _ -> ());
138131
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
139132
let run_variadic =
140133
[%log_level
@@ -158,11 +151,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
158151
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
159152
| bs, Param_ptr tn :: ps ->
160153
let c_ptr =
161-
if Tn.known_shared_with_host ~use_host_memory tn then
162-
Ndarray.get_voidptr_not_managed
163-
@@ Option.value_exn ~here:[%here]
164-
@@ Lazy.force tn.array
165-
else Map.find_exn ctx_arrays tn
154+
match Map.find ctx_arrays tn with
155+
| None ->
156+
Ndarray.get_voidptr_not_managed
157+
@@ Option.value_exn ~here:[%here]
158+
~message:
159+
[%string
160+
"Cc_backend.link_compiled: node %{Tn.debug_name tn} missing from \
161+
context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]
162+
@@ Lazy.force tn.array
163+
| Some arr -> arr
166164
in
167165
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
168166
in
@@ -174,6 +172,7 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
174172
in
175173
let%diagn_l_sexp work () : unit =
176174
[%log_result name];
175+
(* Stdio.printf "launching %s\n" name; *)
177176
Indexing.apply run_variadic ();
178177
if Utils.debug_log_from_routines () then (
179178
Utils.log_trace_tree (Stdio.In_channel.read_lines log_file_name);

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@ module Backend_buffer = struct
2626
type buffer_ptr = Cu.Deviceptr.t
2727

2828
let sexp_of_buffer_ptr ptr = Sexp.Atom (Cu.Deviceptr.string_of ptr)
29-
let c_ptr_to_string = None
3029

3130
include Buffer_types (struct
3231
type nonrec buffer_ptr = buffer_ptr [@@deriving sexp_of]
3332
end)
3433
end
3534

36-
let use_host_memory = false
35+
let use_host_memory = None
3736

3837
module Device_config = struct
3938
include Backend_buffer
@@ -156,15 +155,18 @@ let%track3_sexp new_stream (device : device) : stream =
156155

157156
let cuda_properties =
158157
let cache =
159-
lazy
160-
(Array.init (num_devices ()) ~f:(fun ordinal ->
161-
let dev = get_device ~ordinal in
162-
lazy (Cu.Device.get_attributes dev.dev.dev)))
158+
let%debug2_sexp f (ordinal : int) =
159+
let dev = get_device ~ordinal in
160+
lazy (Cu.Device.get_attributes dev.dev.dev)
161+
in
162+
lazy (Array.init (num_devices ()) ~f)
163163
in
164-
fun device ->
164+
let%debug2_sexp get_props (device : device) : Cu.Device.attributes =
165165
if not @@ is_initialized () then invalid_arg "cuda_properties: CUDA not initialized";
166166
let cache = Lazy.force cache in
167167
Lazy.force cache.(device.ordinal)
168+
in
169+
get_props
168170

169171
let suggested_num_streams device =
170172
match !global_config with
@@ -427,6 +429,7 @@ let link_proc ~prior_context ~name ~(params : (string * param_source) list) ~ctx
427429
(* Map.iteri ctx_arrays ~f:(fun ~key ~data:ptr -> if key.Low_level.zero_initialized then
428430
Cu.Stream.memset_d8 ptr Unsigned.UChar.zero ~length:(Tn.size_in_bytes key.Low_level.tn)); *)
429431
[%log "launching the kernel"];
432+
(* Stdio.printf "launching %s\n" name; *)
430433
(if Utils.debug_log_from_routines () then
431434
Utils.add_log_processor ~prefix:log_id_prefix @@ fun _output ->
432435
[%log_block

arrayjit/lib/gcc_backend.gccjit.ml

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ type procedure = {
6666
}
6767
[@@deriving sexp_of]
6868

69-
let use_host_memory = true
69+
(* let use_host_memory = true *)
7070

7171
let gcc_typ_of_prec =
7272
let open Gccjit in
@@ -114,14 +114,14 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini
114114
let ptr_typ = Type.pointer num_typ in
115115
let ident = get_ident tn in
116116
let hosted = Tn.is_hosted_force tn 344 in
117-
let in_ctx = Tn.is_in_context ~use_host_memory tn in
117+
let in_ctx = Tn.is_in_context_force ~use_host_memory tn 45 in
118118
let ptr =
119119
match (in_ctx, hosted) with
120-
| Some true, _ ->
120+
| true, _ ->
121121
let p = Param.create ctx ptr_typ ident in
122122
param_ptrs := (p, Param_ptr tn) :: !param_ptrs;
123123
Lazy.from_val (RValue.param p)
124-
| (Some false | None), true -> (
124+
| false, true -> (
125125
let addr arr =
126126
Lazy.from_val @@ get_c_ptr ctx num_typ @@ Ctypes.bigarray_start Ctypes_static.Genarray arr
127127
in
@@ -131,7 +131,7 @@ let prepare_node ~debug_log_zero_out ~get_ident ctx traced_store ~param_ptrs ini
131131
| Some (Single_nd arr) -> addr arr
132132
| Some (Double_nd arr) -> addr arr
133133
| None -> assert false)
134-
| (Some false | None), false ->
134+
| false, false ->
135135
let arr_typ = Type.array ctx num_typ size_in_elems in
136136
let v = ref None in
137137
let initialize _init_block func = v := Some (Function.local func arr_typ ident) in
@@ -644,7 +644,8 @@ let%diagn_sexp compile_batch ~(names : string option array) bindings
644644
let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : procedure) =
645645
let name : string = code.name in
646646
List.iter code.params ~f:(function
647-
| Param_ptr tn when not (Tn.known_shared_cross_stream tn) -> assert (Map.mem ctx_arrays tn)
647+
(* FIXME: see cc_backend.ml *)
648+
| Param_ptr tn when not (Tn.known_shared_cross_streams tn) -> assert (Map.mem ctx_arrays tn)
648649
| _ -> ());
649650
let log_file_name = Utils.diagn_log_file [%string "debug-%{runner_label}-%{code.name}.log"] in
650651
let run_variadic =
@@ -667,11 +668,16 @@ let%track3_sexp link_compiled ~merge_buffer ~runner_label ctx_arrays (code : pro
667668
Param_1 (ref (Some log_file_name), link bs ps Ctypes.(string @-> cs))
668669
| bs, Param_ptr tn :: ps ->
669670
let c_ptr =
670-
if Tn.known_shared_with_host ~use_host_memory tn then
671-
Ndarray.get_voidptr_not_managed
672-
@@ Option.value_exn ~here:[%here]
673-
@@ Lazy.force tn.array
674-
else Map.find_exn ctx_arrays tn
671+
match Map.find ctx_arrays tn with
672+
| None ->
673+
Ndarray.get_voidptr_not_managed
674+
@@ Option.value_exn ~here:[%here]
675+
~message:
676+
[%string
677+
"Gcc_backend.link_compiled: node %{Tn.debug_name tn} missing from \
678+
context: %{Tn.debug_memory_mode tn.Tn.memory_mode}"]
679+
@@ Lazy.force tn.array
680+
| Some arr -> arr
675681
in
676682
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
677683
| bs, Merge_buffer :: ps ->

arrayjit/lib/lowered_backend_missing.ml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ type dev
33
type runner
44
type event
55

6-
let use_host_memory = false
6+
let use_host_memory = None
77
let sexp_of_dev _dev = failwith "Backend missing -- install the corresponding library"
88
let sexp_of_runner _runner = failwith "Backend missing -- install the corresponding library"
99
let sexp_of_event _event = failwith "Backend missing -- install the corresponding library"
@@ -39,7 +39,6 @@ let make_child ?ctx_arrays:_ _context =
3939

4040
let get_name _stream = failwith "Backend missing -- install the corresponding library"
4141
let sexp_of_buffer_ptr _buffer_ptr = failwith "Backend missing -- install the corresponding library"
42-
let c_ptr_to_string = None
4342

4443
type nonrec buffer = buffer_ptr Backend_intf.buffer
4544

0 commit comments

Comments
 (0)