Skip to content

Commit a09e2d7

Browse files
committed
get_used_memory depends on the device
1 parent 1953872 commit a09e2d7

File tree

8 files changed

+119
-90
lines changed

8 files changed

+119
-90
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77
[%%global_debug_log_level 9]
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

10+
module No_device_types = struct
11+
type ctx_array = Ndarray.t [@@deriving sexp_of]
12+
13+
type ctx_arrays = { used_memory : Utils.atomic_int; ctx_arrays : ctx_array Map.M(Tnode).t }
14+
[@@deriving sexp_of]
15+
16+
let empty_ctx_arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tnode) }
17+
let get_array arrays = Map.find arrays.ctx_arrays
18+
end
19+
1020
module Types = struct
1121
type 'context routine = {
1222
context : 'context;
@@ -168,6 +178,9 @@ module type Backend = sig
168178
val init : stream -> context
169179
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
170180

181+
val get_used_memory : device -> int
182+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
183+
171184
val await : stream -> unit
172185
(** Blocks till the stream becomes idle, i.e. synchronizes the stream. *)
173186

@@ -198,11 +211,12 @@ module type Lowered_no_device_backend = sig
198211
type procedure [@@deriving sexp_of]
199212
type ctx_array [@@deriving sexp_of]
200213
type buffer_ptr [@@deriving sexp_of]
201-
type ctx_arrays = ctx_array Map.M(Tnode).t [@@deriving sexp_of]
214+
type ctx_arrays [@@deriving sexp_of]
202215

203216
val buffer_ptr : ctx_array -> buffer_ptr
204-
val ctx_arrays : context -> ctx_arrays
205217
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
218+
val ctx_arrays : context -> ctx_arrays
219+
val get_array : ctx_arrays -> Tnode.t -> ctx_array option
206220

207221
val is_in_context : Low_level.traced_array -> bool
208222
(** If true, the node is required to be in the contexts linked with code that uses it.
@@ -246,6 +260,7 @@ module type Lowered_backend = sig
246260
type code [@@deriving sexp_of]
247261
type code_batch [@@deriving sexp_of]
248262
type ctx_array [@@deriving sexp_of]
263+
type ctx_arrays [@@deriving sexp_of]
249264
type event
250265

251266
val sync : event -> unit
@@ -268,7 +283,8 @@ module type Lowered_backend = sig
268283
code_batch
269284

270285
val is_in_context : Low_level.traced_array -> bool
271-
val ctx_arrays : context -> ctx_array Map.M(Tnode).t
286+
val ctx_arrays : context -> ctx_arrays
287+
val get_array : ctx_arrays -> Tnode.t -> ctx_array option
272288
val link : context -> code -> context * Indexing.lowered_bindings * Task.t
273289

274290
val link_batch :
@@ -298,7 +314,7 @@ module type Lowered_backend = sig
298314

299315
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
300316

301-
val get_used_memory : unit -> int
317+
val get_used_memory : device -> int
302318
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
303319

304320
val init : stream -> context

arrayjit/lib/backends.ml

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ struct
7676
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
7777
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
7878

79-
let get_used_memory = Backend.get_used_memory
79+
let get_used_memory _device = Backend.get_used_memory ()
8080

8181
type device = stream [@@deriving sexp_of]
8282
type code = Backend.code [@@deriving sexp_of]
@@ -370,9 +370,10 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
370370
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
371371
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
372372

373-
let get_used_memory = Backend.get_used_memory
374-
375373
type device = CPU [@@deriving sexp_of]
374+
375+
let get_used_memory CPU = Backend.get_used_memory ()
376+
376377
type code = Backend.code [@@deriving sexp_of]
377378
type code_batch = Backend.code_batch [@@deriving sexp_of]
378379

@@ -534,14 +535,14 @@ let lower_batch_assignments ?names ?occupancy bindings asgns_l =
534535
)
535536
else (None, None))
536537

537-
let verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context traced_stores
538-
=
538+
let verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
539+
traced_stores =
539540
let olds = ctx_arrays prior_context in
540541
Set.iter from_prior_context ~f:(fun tn ->
541542
let node = Array.find_map traced_stores ~f:(fun store -> Hashtbl.find store tn) in
542543
if
543544
Option.value_map node ~default:false ~f:(fun node ->
544-
is_in_context node && not (Map.mem olds tn))
545+
is_in_context node && not (Option.is_some @@ get_array olds tn))
545546
then raise @@ Utils.User_error ("The linked context lacks node " ^ Tnode.debug_name tn))
546547

547548
let from_prior_context_batch comps =
@@ -646,7 +647,8 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
646647

647648
let link ~merge_buffer (prior_context : context) (code : code) =
648649
let verify from_prior_context =
649-
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
650+
verify_prior_context ~get_array:Backend.get_array ~ctx_arrays ~is_in_context ~prior_context
651+
~from_prior_context
650652
[| get_traced_store code |]
651653
in
652654
let context, bindings, schedule, name =
@@ -673,7 +675,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
673675

674676
let link_batch ~merge_buffer (prior_context : context) (code_batch : code_batch) =
675677
let verify from_prior_context =
676-
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
678+
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context ~from_prior_context
677679
@@ get_traced_stores code_batch
678680
in
679681
let _opt_ctx_arrays, procs =
@@ -703,7 +705,7 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
703705
| None -> (context, None))
704706

705707
let get_buffer tn context =
706-
Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr
708+
Backend.(ctx_arrays context |> Fn.flip get_array tn |> Option.map ~f:buffer_ptr)
707709

708710
let get_used_memory = Ndarray.get_used_memory
709711
end
@@ -776,7 +778,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
776778
}
777779

778780
let link context (code : code) =
779-
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
781+
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context
780782
~from_prior_context:code.from_prior_context [| code.traced_store |];
781783
let context, bindings, schedule = link context code.code in
782784
let schedule =
@@ -788,7 +790,7 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
788790
{ context; schedule; bindings; name }
789791

790792
let link_batch context code_batch =
791-
verify_prior_context ~ctx_arrays ~is_in_context ~prior_context:context
793+
verify_prior_context ~get_array ~ctx_arrays ~is_in_context ~prior_context:context
792794
~from_prior_context:code_batch.from_prior_context code_batch.traced_stores;
793795
let context, bindings, schedules = link_batch context code_batch.code_batch in
794796
( context,

arrayjit/lib/c_syntax.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@ module C_syntax (B : sig
1313
val for_lowereds : Low_level.optimized array
1414

1515
type ctx_array
16+
type ctx_arrays
1617

17-
val opt_ctx_arrays : ctx_array Map.M(Tnode).t option
18+
val opt_ctx_arrays : ctx_arrays option
19+
val get_array : ctx_arrays -> Tn.t -> ctx_array option
1820
val hardcoded_context_ptr : (ctx_array -> string) option
1921
val is_in_context : Low_level.traced_array -> bool
2022
val host_ptrs_for_readonly : bool
@@ -86,7 +88,7 @@ struct
8688
| true, Some get_ptr, Some ctx_arrays, _, _, _ ->
8789
let ident = get_ident node.tn in
8890
let ctx_array =
89-
Option.value_exn ~here:[%here] ~message:ident @@ Map.find ctx_arrays node.tn
91+
Option.value_exn ~here:[%here] ~message:ident @@ B.get_array ctx_arrays node.tn
9092
in
9193
fprintf ppf "#define %s (%s)@," ident @@ get_ptr ctx_array;
9294
Hash_set.add is_global node.tn

arrayjit/lib/cc_backend.ml

Lines changed: 38 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
77
[%%global_debug_log_level 9]
88
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
99

10+
include Backend_types.No_device_types
1011
open Backend_types.Types
1112

1213
let name = "cc"
@@ -18,8 +19,6 @@ let compiler_command () = Utils.get_global_arg ~default:"cc" ~arg_name:"cc_backe
1819

1920
module Tn = Tnode
2021

21-
type ctx_array = Ndarray.t [@@deriving sexp_of]
22-
type ctx_arrays = ctx_array Map.M(Tn).t [@@deriving sexp_of]
2322
type context = { label : string; arrays : ctx_arrays } [@@deriving sexp_of]
2423

2524
let ctx_arrays context = context.arrays
@@ -36,7 +35,7 @@ let alloc_buffer ?old_buffer ~size_in_bytes () =
3635
| None -> assert false
3736

3837
let to_buffer tn ~dst ~src =
39-
let src = Map.find_exn src.arrays tn in
38+
let src = Map.find_exn src.arrays.ctx_arrays tn in
4039
Ndarray.map2 { f2 = Ndarray.A.blit } src dst
4140

4241
let host_to_buffer src ~dst = Ndarray.map2 { f2 = Ndarray.A.blit } src dst
@@ -50,7 +49,9 @@ let is_initialized, initialize =
5049
let finalize _ctx = ()
5150

5251
let init ~label =
53-
let result = { label; arrays = Map.empty (module Tn) } in
52+
let result =
53+
{ label; arrays = { used_memory = Atomic.make 0; ctx_arrays = Map.empty (module Tn) } }
54+
in
5455
Stdlib.Gc.finalise finalize result;
5556
result
5657

@@ -61,7 +62,7 @@ type procedure = {
6162
name : string;
6263
result : library;
6364
params : (string * param_source) list;
64-
opt_ctx_arrays : Ndarray.t Map.M(Tn).t option;
65+
opt_ctx_arrays : ctx_arrays option;
6566
}
6667
[@@deriving sexp_of]
6768

@@ -105,13 +106,14 @@ let c_compile_and_load ~f_name =
105106

106107
module C_syntax_config (Input : sig
107108
val for_lowereds : Low_level.optimized array
108-
val opt_ctx_arrays : (Tn.t, buffer_ptr, Tn.comparator_witness) Base.Map.t option
109+
val opt_ctx_arrays : ctx_arrays option
109110
end) =
110111
struct
111-
let for_lowereds = Input.for_lowereds
112-
113112
type nonrec ctx_array = ctx_array
113+
type nonrec ctx_arrays = ctx_arrays
114114

115+
let get_array = get_array
116+
let for_lowereds = Input.for_lowereds
115117
let opt_ctx_arrays = Input.opt_ctx_arrays
116118
let hardcoded_context_ptr = Some Ndarray.c_ptr_to_string
117119
let is_in_context = is_in_context
@@ -133,15 +135,15 @@ let%diagn_sexp compile ~(name : string) ~opt_ctx_arrays bindings (lowered : Low_
133135
let opt_ctx_arrays =
134136
Option.map opt_ctx_arrays ~f:(fun ctx_arrays ->
135137
Hashtbl.fold lowered.traced_store ~init:ctx_arrays ~f:(fun ~key:tn ~data:node ctx_arrays ->
136-
match Map.find ctx_arrays tn with
138+
match Map.find ctx_arrays.ctx_arrays tn with
137139
| None ->
138140
if is_in_context node then
139141
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
140142
let data =
141143
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims)
142144
@@ Constant_fill { values = [| 0. |]; strict = false }
143145
in
144-
Map.add_exn ctx_arrays ~key:tn ~data
146+
{ ctx_arrays with ctx_arrays = Map.add_exn ctx_arrays.ctx_arrays ~key:tn ~data }
145147
else ctx_arrays
146148
| Some _ -> ctx_arrays))
147149
in
@@ -162,22 +164,25 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
162164
(lowereds : Low_level.optimized option array) =
163165
let for_lowereds = Array.filter_map ~f:Fn.id lowereds in
164166
let opt_ctx_arrays =
165-
Option.map opt_ctx_arrays ~f:(fun ctx_arrays ->
166-
Array.fold for_lowereds ~init:ctx_arrays ~f:(fun ctx_arrays lowered ->
167-
Hashtbl.fold lowered.traced_store ~init:ctx_arrays
168-
~f:(fun ~key:tn ~data:node ctx_arrays ->
169-
match Map.find ctx_arrays tn with
170-
| None ->
171-
if is_in_context node then
172-
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
173-
let data =
174-
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec)
175-
~dims:(Lazy.force tn.dims)
176-
@@ Constant_fill { values = [| 0. |]; strict = false }
177-
in
178-
Map.add_exn ctx_arrays ~key:tn ~data
179-
else ctx_arrays
180-
| Some _ -> ctx_arrays)))
167+
Option.map opt_ctx_arrays ~f:(fun arrays ->
168+
let ctx_arrays =
169+
Array.fold for_lowereds ~init:arrays.ctx_arrays ~f:(fun ctx_arrays lowered ->
170+
Hashtbl.fold lowered.traced_store ~init:ctx_arrays
171+
~f:(fun ~key:tn ~data:node ctx_arrays ->
172+
match Map.find ctx_arrays tn with
173+
| None ->
174+
if is_in_context node then
175+
let debug = "CC compile-time ctx array for " ^ Tn.debug_name tn in
176+
let data =
177+
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec)
178+
~dims:(Lazy.force tn.dims)
179+
@@ Constant_fill { values = [| 0. |]; strict = false }
180+
in
181+
Map.add_exn ctx_arrays ~key:tn ~data
182+
else ctx_arrays
183+
| Some _ -> ctx_arrays))
184+
in
185+
{ arrays with ctx_arrays })
181186
in
182187
let module Syntax = C_syntax.C_syntax (C_syntax_config (struct
183188
let for_lowereds = for_lowereds
@@ -186,7 +191,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
186191
(* FIXME: do we really want all of them, or only the used ones? *)
187192
let idx_params = Indexing.bound_symbols bindings in
188193
let global_ctx_arrays =
189-
ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> Map.empty (module Tn))
194+
ref (match opt_ctx_arrays with Some ctx_arrays -> ctx_arrays | None -> empty_ctx_arrays)
190195
in
191196
let base_name =
192197
String.(
@@ -206,7 +211,7 @@ let%diagn_sexp compile_batch ~names ~opt_ctx_arrays bindings
206211
let opt_ctx_arrays = Option.map opt_ctx_arrays ~f:(fun _ -> !global_ctx_arrays) in
207212
( opt_ctx_arrays,
208213
Array.mapi params ~f:(fun i params ->
209-
Option.map names.(i) ~f:(fun name ->
214+
Option.map names.(i) ~f:(fun name ->
210215
{
211216
result;
212217
params = Option.value_exn ~here:[%here] params;
@@ -219,7 +224,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
219224
context * _ * _ * string =
220225
let label : string = prior_context.label in
221226
let name : string = code.name in
222-
let arrays : Ndarray.t Base.Map.M(Tn).t =
227+
let arrays =
223228
match code with
224229
| { opt_ctx_arrays = Some arrays; _ } -> arrays
225230
| { params; _ } ->
@@ -232,7 +237,7 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
232237
Ndarray.create_array ~debug (Lazy.force tn.Tn.prec) ~dims:(Lazy.force tn.dims)
233238
@@ Constant_fill { values = [| 0. |]; strict = false }
234239
in
235-
Map.update ctx_arrays tn ~f
240+
{ ctx_arrays with ctx_arrays = Map.update ctx_arrays.ctx_arrays tn ~f }
236241
| _ -> ctx_arrays)
237242
in
238243
let context = { label; arrays } in
@@ -258,7 +263,9 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
258263
let get_ptr (buffer, _) = Ndarray.get_voidptr_not_managed buffer in
259264
Param_2f (get_ptr, merge_buffer, link bs ps Ctypes.(ptr void @-> cs))
260265
| bs, Param_ptr tn :: ps ->
261-
let nd = match Map.find arrays tn with Some nd -> nd | None -> assert false in
266+
let nd =
267+
match get_array (ctx_arrays context) tn with Some nd -> nd | None -> assert false
268+
in
262269
let c_ptr = Ndarray.get_voidptr_not_managed nd in
263270
Param_2 (ref (Some c_ptr), link bs ps Ctypes.(ptr void @-> cs))
264271
in

0 commit comments

Comments
 (0)