Skip to content

Commit aad61e0

Browse files
committed
Add static_properties to backend interfaces for device descriptions; try lang version 3.0
Additionally, the `moons_demo_parallel_run.ml` test is updated to print the properties of devices.
1 parent e752ddb commit aad61e0

File tree

6 files changed

+121
-1
lines changed

6 files changed

+121
-1
lines changed

arrayjit/lib/backend_intf.ml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,9 @@ module type Backend_device_common = sig
254254
NOTE: it should rarely be needed to call [will_wait_for] explicitly, because it should always
255255
be called internally when necessary. *)
256256

257+
val static_properties : Sexp.t
258+
(** Returns a sexp description of the properties of all devices. *)
259+
257260
val get_used_memory : device -> int
258261
(** Returns (an upper bound of) the memory used for arrays, in bytes. *)
259262

arrayjit/lib/cuda_backend.ml

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -574,6 +574,28 @@ end) : Ir.Backend_impl.Lowered_backend = struct
574574
Sexp.message "cuda_global_debug"
575575
[ ("live_streams", [%sexp_of: int] @@ Cu.Stream.get_total_live_streams ()) ]
576576

577+
let static_properties =
578+
let device_properties =
579+
Array.init (num_devices ()) ~f:(fun ordinal ->
580+
let dev = Cu.Device.get ~ordinal in
581+
let attributes = Cu.Device.get_attributes dev in
582+
let props = [
583+
("device_name", Sexp.Atom (Cu.Device.get_name dev));
584+
("device_ordinal", [%sexp_of: int] ordinal);
585+
("multiprocessor_count", [%sexp_of: int] attributes.multiprocessor_count);
586+
("total_global_memory", [%sexp_of: int] (Cu.Device.get_total_memory dev));
587+
("clock_rate", [%sexp_of: int] attributes.clock_rate);
588+
("async_engine_count", [%sexp_of: int] attributes.async_engine_count);
589+
("compute_capability_major", [%sexp_of: int] attributes.compute_capability_major);
590+
("compute_capability_minor", [%sexp_of: int] attributes.compute_capability_minor);
591+
("max_threads_per_block", [%sexp_of: int] attributes.max_threads_per_block);
592+
("unified_addressing", [%sexp_of: bool] attributes.unified_addressing);
593+
] in
594+
Sexp.List [Sexp.Atom "device"; Sexp.List props]
595+
)
596+
in
597+
Sexp.List (Sexp.Atom "cuda_devices" :: device_properties)
598+
577599
let get_debug_info (stream : stream) =
578600
let tot, unr, unf = Cu.Stream.total_unreleased_unfinished_delimited_events stream.runner in
579601
let i2s = [%sexp_of: int] in

arrayjit/lib/lowered_backend_missing.ml

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,15 @@ struct
9494
let get_global_debug_info () =
9595
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
9696

97+
let static_properties =
98+
Sexp.List [
99+
Sexp.Atom (Config.name ^ "_missing");
100+
Sexp.List [
101+
Sexp.Atom "error";
102+
Sexp.Atom ("Backend " ^ Config.name ^ " missing -- install the corresponding library")
103+
]
104+
]
105+
97106
let get_debug_info _stream =
98107
failwith @@ "Backend " ^ Config.name ^ " missing -- install the corresponding library"
99108

arrayjit/lib/metal_backend.ml

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,65 @@ end) : Ir.Backend_impl.Lowered_backend = struct
201201
| Only_devices_parallel | For_parallel_copying | Most_parallel_streams -> 1
202202

203203
let get_used_memory _device = Atomic.get allocated_memory
204+
205+
let static_properties =
206+
let device_properties =
207+
Array.mapi metal_devices ~f:(fun ordinal device ->
208+
let attributes = Me.Device.get_attributes device in
209+
Sexp.List
210+
[
211+
Sexp.Atom "device";
212+
Sexp.List
213+
[
214+
Sexp.List [ Sexp.Atom "device_name"; Sexp.Atom attributes.name ];
215+
Sexp.List [ Sexp.Atom "device_ordinal"; Sexp.Atom (Int.to_string ordinal) ];
216+
Sexp.List
217+
[
218+
Sexp.Atom "registry_id";
219+
Sexp.Atom (Unsigned.ULLong.to_string attributes.registry_id);
220+
];
221+
Sexp.List
222+
[
223+
Sexp.Atom "max_buffer_length";
224+
Sexp.Atom (Unsigned.ULong.to_string attributes.max_buffer_length);
225+
];
226+
Sexp.List
227+
[
228+
Sexp.Atom "max_threadgroup_memory_length";
229+
Sexp.Atom (Unsigned.ULong.to_string attributes.max_threadgroup_memory_length);
230+
];
231+
Sexp.List
232+
[
233+
Sexp.Atom "recommended_max_working_set_size";
234+
Sexp.Atom
235+
(Unsigned.ULLong.to_string attributes.recommended_max_working_set_size);
236+
];
237+
Sexp.List
238+
[ Sexp.Atom "is_low_power"; Sexp.Atom (Bool.to_string attributes.is_low_power) ];
239+
Sexp.List
240+
[ Sexp.Atom "is_headless"; Sexp.Atom (Bool.to_string attributes.is_headless) ];
241+
Sexp.List
242+
[
243+
Sexp.Atom "has_unified_memory";
244+
Sexp.Atom (Bool.to_string attributes.has_unified_memory);
245+
];
246+
Sexp.List
247+
[
248+
Sexp.Atom "total_memory";
249+
Sexp.Atom (Int.to_string (Atomic.get allocated_memory));
250+
];
251+
Sexp.List
252+
[
253+
Sexp.Atom "supported_gpu_families";
254+
Sexp.List
255+
(List.map attributes.supported_gpu_families ~f:(fun gpu_family ->
256+
Me.Device.GPUFamily.sexp_of_t gpu_family));
257+
];
258+
];
259+
])
260+
in
261+
Sexp.List (Sexp.Atom "metal_devices" :: Array.to_list device_properties)
262+
204263
let get_global_debug_info () = Sexp.Atom "Metal global debug info NYI"
205264

206265
let get_debug_info stream =
@@ -414,7 +473,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
414473

415474
let%diagn_sexp compile_metal_source ~name ~source ~device =
416475
let options = Me.CompileOptions.init () in
417-
Me.CompileOptions.set_language_version options Me.CompileOptions.LanguageVersion.version_3_1;
476+
Me.CompileOptions.set_language_version options Me.CompileOptions.LanguageVersion.version_3_0;
418477
if Utils.debug_log_from_routines () then Me.CompileOptions.set_enable_logging options true;
419478

420479
if Utils.with_runtime_debug () then (

arrayjit/lib/schedulers.ml

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -187,6 +187,19 @@ module Multicore (Backend : For_add_scheduler) :
187187
let num_devices () = 1
188188
let suggested_num_streams _device = Domain.recommended_domain_count () - 1
189189

190+
let static_properties =
191+
Sexp.List [
192+
Sexp.Atom "multicore_devices";
193+
Sexp.List [
194+
Sexp.Atom "device";
195+
Sexp.List [
196+
Sexp.List [Sexp.Atom "device_name"; Sexp.Atom "CPU"];
197+
Sexp.List [Sexp.Atom "device_ordinal"; [%sexp_of: int] 0];
198+
Sexp.List [Sexp.Atom "num_domains"; [%sexp_of: int] (Domain.recommended_domain_count ())];
199+
]
200+
]
201+
]
202+
190203
let%track7_sexp cleanup_stream (stream : stream) : unit =
191204
(* Allow running in parallel. *)
192205
(* assert (Domain.is_main_domain ()); *)
@@ -260,6 +273,19 @@ module Sync (Backend : For_add_scheduler) = struct
260273
let is_idle _stream = true
261274
let await _stream = ()
262275

276+
let static_properties =
277+
Sexp.List [
278+
Sexp.Atom "sync_devices";
279+
Sexp.List [
280+
Sexp.Atom "device";
281+
Sexp.List [
282+
Sexp.List [Sexp.Atom "device_name"; Sexp.Atom "CPU"];
283+
Sexp.List [Sexp.Atom "device_ordinal"; Sexp.Atom "0"];
284+
Sexp.List [Sexp.Atom "threads"; Sexp.Atom "1"];
285+
]
286+
]
287+
]
288+
263289
(* let global_run_no = ref 0 *)
264290
let schedule_task _stream task = Ir.Task.run task
265291
let get_global_debug_info () = Sexp.message "global_debug" []

test/moons_demo_parallel_run.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ let main () =
4141
let weight_decay = 0.0002 in
4242
(* So that we can inspect them. *)
4343
let module Backend = (val Backends.fresh_backend ()) in
44+
Stdlib.Format.printf "Properties of devices:@ %a@\n@!" Sexp.pp_hum Backend.static_properties;
4445
let per_batch_callback ~at_batch ~at_step ~learning_rate ~batch_loss ~epoch_loss =
4546
if (at_batch + 1) % 20 = 0 then
4647
Stdio.printf "Batch=%d, step=%d, lr=%f, batch loss=%f, epoch loss=%f\n%!" at_batch at_step

0 commit comments

Comments
 (0)