Skip to content

Commit 1953872

Browse files
committed
Untested: a quick approx. get_used_memory
Progress toward #245.
1 parent bd0dc98 commit 1953872

File tree

4 files changed

+32
-4
lines changed

4 files changed

+32
-4
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ module type No_device_backend = sig
5555

5656
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> unit -> buffer_ptr
5757

58+
val get_used_memory : unit -> int
59+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
60+
5861
val compile : ?shared:bool -> ?name:string -> Indexing.unit_bindings -> Assignments.comp -> code
5962
(** If [~shared:true] (default [false]), the backend should prefer to do more compile work in a
6063
device-and-stream-agnostic way. If [~shared:false], the backend can opt to postpone compiling
@@ -294,14 +297,18 @@ module type Lowered_backend = sig
294297
type stream [@@deriving sexp_of]
295298

296299
val alloc_buffer : ?old_buffer:buffer_ptr * int -> size_in_bytes:int -> stream -> buffer_ptr
300+
301+
val get_used_memory : unit -> int
302+
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
303+
297304
val init : stream -> context
298305
val await : stream -> unit
299306
val is_idle : stream -> bool
300307
val all_work : stream -> event
301308

302309
val scheduled_merge_node : stream -> Tnode.t option
303-
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge
304-
buffer right after [await stream]. *)
310+
(** [scheduled_merge_node stream] is the tensor node that would be in the [stream]'s merge buffer
311+
right after [await stream]. *)
305312

306313
val num_devices : unit -> int
307314
val suggested_num_streams : device -> int

arrayjit/lib/backends.ml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ struct
7676
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
7777
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
7878

79+
let get_used_memory = Backend.get_used_memory
80+
7981
type device = stream [@@deriving sexp_of]
8082
type code = Backend.code [@@deriving sexp_of]
8183
type code_batch = Backend.code_batch [@@deriving sexp_of]
@@ -368,6 +370,8 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
368370
let alloc_buffer ?old_buffer ~size_in_bytes _stream =
369371
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
370372

373+
let get_used_memory = Backend.get_used_memory
374+
371375
type device = CPU [@@deriving sexp_of]
372376
type code = Backend.code [@@deriving sexp_of]
373377
type code_batch = Backend.code_batch [@@deriving sexp_of]
@@ -700,6 +704,8 @@ module Lowered_no_device_backend (Backend : Backend_types.Lowered_no_device_back
700704

701705
let get_buffer tn context =
702706
Map.find (Backend.ctx_arrays context) tn |> Option.map ~f:Backend.buffer_ptr
707+
708+
let get_used_memory = Ndarray.get_used_memory
703709
end
704710

705711
module C_device : Backend_types.No_device_backend = Lowered_no_device_backend ((

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,10 @@ let alloc_buffer ?old_buffer ~size_in_bytes stream =
103103
set_ctx stream.device.primary_context;
104104
Cu.Deviceptr.mem_alloc ~size_in_bytes
105105

106+
let get_used_memory () =
107+
let free, total = Cudajit.Device.get_free_and_total_mem () in
108+
total - free
109+
106110
let opt_alloc_merge_buffer ~size_in_bytes phys_dev =
107111
if phys_dev.copy_merge_buffer_capacity < size_in_bytes then (
108112
set_ctx phys_dev.primary_context;

arrayjit/lib/ndarray.ml

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -378,20 +378,31 @@ let ptr_to_string_hum nd =
378378

379379
(** {2 *** Creating ***} *)
380380

381+
let used_memory = Atomic.make 0
382+
381383
let create_array ~debug:_debug prec ~dims init_op =
384+
let size_in_bytes =
385+
(if Array.length dims = 0 then 0 else Array.reduce_exn dims ~f:( * )) * Ops.prec_in_bytes prec
386+
in
387+
let%diagn2_sexp finalizer _result =
388+
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
389+
[%log "Deleting", _debug, ptr_to_string_hum _result]
390+
in
382391
let f prec = as_array prec @@ create_bigarray prec ~dims init_op in
383392
let result = Ops.map_prec { f } prec in
393+
Stdlib.Gc.finalise finalizer result;
394+
let _ : int = Atomic.fetch_and_add used_memory size_in_bytes in
384395
[%debug2_sexp
385396
[%log_block
386397
"create_array";
387398
[%log _debug, ptr_to_string_hum result]]];
388-
let%debug2_sexp debug_finalizer _result = [%log "Deleting", _debug, ptr_to_string_hum _result] in
389-
if Utils.settings.log_level > 1 then Stdlib.Gc.finalise debug_finalizer result;
390399
result
391400

392401
let empty_array prec =
393402
create_array prec ~dims:[||] (Constant_fill { values = [| 0.0 |]; strict = false })
394403

404+
let get_used_memory () = Atomic.get used_memory
405+
395406
(** {2 *** Printing ***} *)
396407

397408
(** Dimensions to string, ["x"]-separated, e.g. 1x2x3 for batch dims 1, input dims 3, output dims 2.

0 commit comments

Comments
 (0)