Skip to content

Commit 418af7c

Browse files
committed
Yay, fix the scheduler bug: the old d.is_idle check in await ignored the queue
1 parent 1ddc17b commit 418af7c

File tree

3 files changed

+15
-14
lines changed

3 files changed

+15
-14
lines changed

arrayjit/lib/backends.ml

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
180180
mut : (Mut.t[@sexp.opaque]);
181181
host_wait_for_idle : (Stdlib.Condition.t[@sexp.opaque]);
182182
dev_wait_for_work : (Stdlib.Condition.t[@sexp.opaque]);
183-
mutable is_idle : bool;
183+
mutable is_ready : bool;
184184
mutable host_is_waiting : bool; (** The host is waiting for this specific device. *)
185185
}
186186
[@@deriving sexp_of]
@@ -206,17 +206,17 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
206206
let expected_merge_node (code : code) = Backend.expected_merge_node code
207207
let expected_merge_nodes (codes : code_batch) = Backend.expected_merge_nodes codes
208208
let is_dev_queue_empty state = Queue.size state.queue = 0
209-
let is_idle device = is_dev_queue_empty device.state && device.state.is_idle
209+
let is_idle device = is_dev_queue_empty device.state && device.state.is_ready
210210
let name = "multicore " ^ Backend.name
211211

212212
let%track3_l_sexp await device =
213213
assert (Domain.is_main_domain ());
214214
let d = device.state in
215-
if (not d.is_idle) && d.keep_spinning then (
215+
if (not @@ is_idle device) && d.keep_spinning then (
216216
Mut.lock d.mut;
217-
if (not d.is_idle) && d.keep_spinning then (
217+
if (not @@ is_idle device) && d.keep_spinning then (
218218
d.host_is_waiting <- true;
219-
while (not d.is_idle) && d.keep_spinning do
219+
while (not @@ is_idle device) && d.keep_spinning do
220220
Stdlib.Condition.wait d.host_wait_for_idle d.mut
221221
done;
222222
d.host_is_waiting <- false);
@@ -234,7 +234,7 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
234234
if not @@ Queue.try_push d.queue task then (
235235
await device;
236236
Queue.push_exn d.queue task);
237-
if d.is_idle then (
237+
if d.is_ready then (
238238
Mut.lock d.mut;
239239
Stdlib.Condition.broadcast d.dev_wait_for_work;
240240
Mut.unlock d.mut)
@@ -249,25 +249,26 @@ module Multicore_backend (Backend : No_device_backend) : Backend = struct
249249
device_error = None;
250250
queue = Queue.create ~size_exponent:12;
251251
mut = Mut.create ();
252-
is_idle = false;
252+
is_ready = false;
253253
host_wait_for_idle = Stdlib.Condition.create ();
254254
dev_wait_for_work = Stdlib.Condition.create ();
255255
host_is_waiting = false;
256256
}
257257
in
258258
let%track3_l_sexp worker (() : unit) : unit =
259+
assert (not @@ Domain.is_main_domain ());
259260
try
260261
while state.keep_spinning do
261262
match Queue.pop_opt state.queue with
262263
| None ->
263264
Mut.lock state.mut;
264265
if is_dev_queue_empty state && state.keep_spinning then (
265-
state.is_idle <- true;
266+
state.is_ready <- true;
266267
while is_dev_queue_empty state && state.keep_spinning do
267268
if state.host_is_waiting then Stdlib.Condition.broadcast state.host_wait_for_idle;
268269
Stdlib.Condition.wait state.dev_wait_for_work state.mut
269270
done;
270-
state.is_idle <- false);
271+
state.is_ready <- false);
271272
Mut.unlock state.mut
272273
| Some task -> Tnode.run task
273274
done

bin/moons_benchmark.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ let classify_moons ~seed ~on_device ~inlining_cutoff ~num_devices ~batch_size ~b
4040
(* let hid_dim = 4 in *)
4141
let len = batch_size * 20 in
4242
(* let epochs = 100 in *)
43-
(* let epochs = 20 in *)
44-
(* let epochs = 5 in *)
45-
let epochs = 1 in
43+
let epochs = 20 in
44+
(* let epochs = 10 in *)
45+
(* let epochs = 1 in *)
4646
let init_lr = 0.1 in
4747
let noise () = Rand.float_range (-0.1) 0.1 in
4848
let moons_flat =
@@ -177,7 +177,7 @@ let _cpu_benchmarks =
177177

178178
let cuda_benchmarks =
179179
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
180-
List.concat_map [ 1; 2; 5; 8; 10; 16; 20; 30 (* *; 32; 40; 64 *) ] ~f:(fun num_devices ->
180+
List.concat_map [ 1; 2; 5; 8; 10 (* ; 16; 20; 30; 32; 40; 64 *) ] ~f:(fun num_devices ->
181181
List.concat_map [ 120; 160 (* ; 320; 640; 1280 *) ] ~f:(fun batch_size ->
182182
List.concat_map [ 0; 1 (* ; 2; 3; 4 *) ] ~f:(fun seed ->
183183
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ]

lib/train.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
508508
assert (Backend.from_host routine.context infer.value);
509509
run routine;
510510
assert (Backend.to_host routine.context model_result.value);
511-
Backend.(await @@ get_ctx_device prior_contexts.(0));
511+
Backend.(await @@ get_ctx_device routine.context);
512512
Tensor.get_values model_result
513513
in
514514
(* Note: infer_callback is significantly less efficient than using the model via arrayjit. *)

0 commit comments

Comments
 (0)