Skip to content

Commit 6605dff

Browse files
committed
Backends: Remove now-redundant accessor functions, share get_name
1 parent 03c7989 commit 6605dff

File tree

5 files changed

+23
-19
lines changed

5 files changed

+23
-19
lines changed

arrayjit/lib/backend_types.ml

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ module type Device_config = sig
123123
(** An event tracks if a stream finished computing past a particular point in its schedue. These
124124
values are used internally for scheduling across streams of the backend, and can be used for
125125
explicit scheduling. *)
126+
127+
val name : string
126128
end
127129

128130
type ('buffer_ptr, 'dev, 'event) device = {
@@ -148,7 +150,9 @@ type ('buffer_ptr, 'dev, 'runner, 'event) stream = {
148150
merge_buffer : ('buffer_ptr * Tnode.t) option ref;
149151
stream_id : int;
150152
mutable allocated_buffer : 'buffer_ptr buffer option;
151-
queried_work_for : 'event option Hashtbl.M(Tnode).t; (* The completion event for updating the node via this stream. Only populated after the first time {!} *)
153+
queried_work_for : 'event option Hashtbl.M(Tnode).t;
154+
(* The completion event for updating the node via this stream. Only populated after the first
155+
time {!} *)
152156
}
153157
[@@deriving sexp_of]
154158

@@ -174,6 +178,7 @@ module type Device = sig
174178

175179
val make_device : dev -> ordinal:int -> device
176180
val make_stream : device -> runner -> stream_id:int -> stream
181+
val get_name : stream -> string
177182
end
178183

179184
module Device_types (Device_config : Device_config) = struct
@@ -211,6 +216,8 @@ struct
211216
allocated_buffer = None;
212217
queried_work_for = Hashtbl.create (module Tnode);
213218
}
219+
220+
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
214221
end
215222

216223
(** Parts shared by both assignments-level and lowered-level backend interfaces. *)
@@ -224,8 +231,6 @@ module type Backend_any_common = sig
224231
(** For backends derived via {!No_device_backend}, this is usually the backend name concatenated
225232
with the device or stream number. For {!Backend}, [init_info = stream]. *)
226233

227-
val name : string
228-
229234
val initialize : config -> unit
230235
(** Initializes a backend before first use. Typically does nothing if the backend is already
231236
initialized, but some backends can do some safe cleanups. *)
@@ -286,6 +291,8 @@ module type No_device_backend = sig
286291
include Backend_common with type init_info := string and type stream := unit
287292
include Backend_impl_common with type context := context and type buffer_ptr := buffer_ptr
288293

294+
val name : string
295+
289296
val link : merge_buffer:(buffer_ptr * Tnode.t) option ref -> context -> code -> context routine
290297
(** Returns the routine for the code's procedure, in a new context derived from the given context. *)
291298

@@ -350,9 +357,6 @@ module type Backend_device_common = sig
350357

351358
val new_stream : device -> stream
352359
val get_ctx_stream : context -> stream
353-
val get_stream_device : stream -> device
354-
val to_ordinal : device -> int
355-
val get_name : stream -> string
356360
end
357361

358362
module type With_buffer_retrieval_and_syncing = sig
@@ -427,6 +431,8 @@ module type Lowered_no_device_backend = sig
427431
and type init_info := string
428432
and type buffer_ptr := buffer_ptr
429433

434+
val name : string
435+
430436
type procedure [@@deriving sexp_of]
431437

432438
val compile :

arrayjit/lib/backends.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ module Add_buffer_retrieval_and_syncing (Backend : No_buffer_retrieval_or_syncin
4747

4848
let%diagn2_l_sexp device_to_device (tn : Tn.t) ~into_merge_buffer ~(dst : Backend.context)
4949
~(src : Backend.context) =
50-
let ordinal_of ctx = Backend.(to_ordinal @@ get_stream_device @@ get_ctx_stream ctx) in
50+
let ordinal_of ctx = (Backend.get_ctx_stream ctx).device.ordinal in
5151
let name_of ctx = Backend.(get_name @@ get_ctx_stream ctx) in
5252
let same_device = ordinal_of dst = ordinal_of src in
5353
if same_device && (Tn.known_shared_cross_stream tn || String.equal (name_of src) (name_of dst))
@@ -123,6 +123,8 @@ module Multicore_backend (Backend : No_device_backend) = struct
123123

124124
type runner = { state : stream_state; domain : domain } [@@deriving sexp_of]
125125
type event = Not_implemented_yet [@@deriving sexp_of]
126+
127+
let name = "multicore_" ^ Backend.name
126128
end
127129

128130
module Alloc_buffer = struct
@@ -289,9 +291,7 @@ module Multicore_backend (Backend : No_device_backend) = struct
289291
Stdlib.Gc.finalise cleanup_stream stream;
290292
stream
291293

292-
let get_stream_device stream = stream.device
293294
let get_ctx_stream { stream; _ } = stream
294-
let to_ordinal _ = 0
295295

296296
let from_host ~dst_ptr ~dst hosted =
297297
let work () = host_to_buffer hosted ~dst:dst_ptr in
@@ -350,6 +350,8 @@ module Sync_backend (Backend : No_device_backend) = struct
350350
type dev = CPU [@@deriving sexp_of]
351351
type runner = unit [@@deriving sexp_of]
352352
type event = unit [@@deriving sexp_of]
353+
354+
let name = "sync_" ^ Backend.name
353355
end
354356

355357
module Alloc_buffer = struct
@@ -370,7 +372,6 @@ module Sync_backend (Backend : No_device_backend) = struct
370372
Backend.alloc_buffer ?old_buffer ~size_in_bytes ()
371373

372374
let device : device = make_device CPU ~ordinal:0
373-
let to_ordinal device = device.ordinal
374375

375376
let get_device ~ordinal =
376377
if ordinal <> 0 then
@@ -399,7 +400,6 @@ module Sync_backend (Backend : No_device_backend) = struct
399400
type context = { stream : stream; ctx : Backend.context } [@@deriving sexp_of]
400401

401402
let get_ctx_stream context = context.stream
402-
let get_stream_device stream = stream.device
403403
let ctx_arrays context = ctx_arrays context.ctx
404404
let init stream = { stream; ctx = Backend.init name }
405405
let initialize = Backend.initialize

arrayjit/lib/cc_backend.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
1010
include Backend_types.No_device_buffer_and_copying
1111
open Backend_types
1212

13+
let name = "cc"
14+
1315
let optimization_level () =
1416
Int.of_string @@ Utils.get_global_arg ~default:"3" ~arg_name:"cc_backend_optimization_level"
1517

@@ -257,5 +259,3 @@ let%diagn_sexp link_compiled ~merge_buffer (prior_context : context) (code : pro
257259
work;
258260
},
259261
name )
260-
261-
let name = "cc"

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@ module Device_config = struct
3636
type dev = { dev : Cu.Device.t; primary_context : Cu.Context.t } [@@deriving sexp_of]
3737
type runner = Cu.Stream.t [@@deriving sexp_of]
3838
type event = Cu.Delimited_event.t [@@deriving sexp_of]
39+
40+
let name = "cuda"
3941
end
4042

4143
module Device_stream = Device_types (Device_config)
@@ -185,10 +187,6 @@ let suggested_num_streams device =
185187
| Most_parallel_streams -> (cuda_properties device).multiprocessor_count
186188

187189
let get_ctx_stream { stream; _ } = stream
188-
let get_stream_device { device; _ } = device
189-
let to_ordinal { ordinal; _ } = ordinal
190-
let name = "cuda"
191-
let get_name stream = [%string "%{name}:%{stream.device.ordinal#Int}:%{stream.stream_id#Int}"]
192190

193191
let await stream : unit =
194192
set_ctx stream.device.dev.primary_context;

bin/compilation_speed.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ let benchmark_overhead backend () =
2626
(* Train.every_non_literal_on_host f; *)
2727
let stream = Backend.(new_stream @@ get_device ~ordinal:0) in
2828
let ctx = Backend.init stream in
29-
let init_mem = Backend.(get_used_memory @@ get_stream_device stream) in
29+
let init_mem = Backend.(get_used_memory stream.device) in
3030
let update_f = Train.grad_update f in
3131
(* Initialize the context with a mock update of x to ensure that it is not optimized as a
3232
constant. *)
@@ -63,7 +63,7 @@ let benchmark_overhead backend () =
6363
in
6464
let final_time = Time_now.nanoseconds_since_unix_epoch () in
6565
let time_in_sec = Int63.(to_float @@ (final_time - init_time)) /. 1000_000_000. in
66-
let mem_in_bytes = Backend.(get_used_memory @@ get_stream_device stream) - init_mem in
66+
let mem_in_bytes = Backend.(get_used_memory stream.device) - init_mem in
6767
let result =
6868
PrintBox_utils.Benchmark
6969
{

0 commit comments

Comments
 (0)