Skip to content

Commit 2858d24

Browse files
committed
Remove verification of merge buffer nodes inside device_to_device
1 parent e866289 commit 2858d24

File tree

5 files changed

+13
-40
lines changed

5 files changed

+13
-40
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
- Verifying that code is linked with the right contexts, by tracking `embedded_nodes` with assignments.
1111
- Renaming: (virtual) `device` -> `stream`, `physical_device` -> `device`.
1212
- New files: split out `backend_types.ml` from `backends.ml`; moved `Tnode.task` to `task.ml`; renamed `backend_utils.ml` to `c_syntax.ml`.
13+
- Removed half-static verification of merge buffer nodes inside `device_to_device`.
1314
- TODO: Moved the multicore backend from a `device = stream` model to a single device model.
1415
- TODO: Fixed #286: cross-stream-sharing incorporated into `Tnode.memory_mode`.
1516
- TODO: Built per-tensor-node stream-to-stream synchronization into device-to-device copying functions, removed obsolete blocking synchronizations.

arrayjit/lib/backend_types.ml

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,11 +156,8 @@ module type Backend = sig
156156
given node. If [into_merge_buffer=Streaming], remembers the buffer pointer of the source
157157
node to use for streaming, without blocking. If [into_merge_buffer=Copy], schedules copying
158158
from [src] to the merge buffer of [dst]'s stream.
159-
- If the [dst] context resulted from a compilation with [Streaming] or [Copy] specific merge
160-
buffer code, the [device_to_device] call should fail immediately if there's a mismatch with
161-
[into_merge_buffer].
162159
163-
NOTE: If [into_merge_buffer:Streaming], after scheduling the work on [dst] using the merge
160+
NOTE: If [into_merge_buffer=Streaming], after scheduling the work on [dst] using the merge
164161
buffer but before scheduling work on [src] that modifies [tn], execute
165162
[will_wait_for src (all_work (get_ctx_stream dst))]. *)
166163

@@ -287,7 +284,7 @@ module type Lowered_backend = sig
287284

288285
val device_to_device :
289286
Tnode.t -> into_merge_buffer:merge_buffer_use -> dst:context -> src:context -> bool
290-
(** If the tensor node is in both contexts, copies from [dst] to [src]. *)
287+
(** See {!Backend.device_to_device}. *)
291288

292289
type buffer_ptr [@@deriving sexp_of]
293290

arrayjit/lib/backends.ml

Lines changed: 0 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,6 @@ struct
280280

281281
let device_to_device tn ~into_merge_buffer ~dst ~src =
282282
let dev = dst.stream in
283-
if
284-
(not (equal_merge_buffer_use into_merge_buffer No))
285-
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
286-
then
287-
raise
288-
@@ Utils.User_error
289-
("Multicore_backend.device_to_device: merge node mismatch, expected "
290-
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
291-
^ ", actual " ^ Tnode.debug_name tn);
292283
let schedule dst =
293284
let work =
294285
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
@@ -479,15 +470,6 @@ module Sync_backend (Backend : Backend_types.No_device_backend) : Backend_types.
479470

480471
let device_to_device tn ~into_merge_buffer ~dst ~src =
481472
let dev = dst.stream in
482-
if
483-
(not (equal_merge_buffer_use into_merge_buffer No))
484-
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
485-
then
486-
raise
487-
@@ Utils.User_error
488-
("Multicore_backend.device_to_device: merge node mismatch, expected "
489-
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
490-
^ ", actual " ^ Tnode.debug_name tn);
491473
(* TODO: log the operation if [Utils.settings.with_log_level > 0]. *)
492474
match (Backend.get_buffer tn dst.ctx, Backend.get_buffer tn src.ctx) with
493475
| None, _ | _, None -> false
@@ -855,15 +837,6 @@ module Lowered_backend (Device : Backend_types.Lowered_backend) : Backend_types.
855837
let to_host context tn = to_host context.ctx tn
856838

857839
let device_to_device tn ~into_merge_buffer ~dst ~src =
858-
if
859-
(not (equal_merge_buffer_use into_merge_buffer No))
860-
&& not (Option.equal Tnode.equal (Some tn) dst.expected_merge_node)
861-
then
862-
raise
863-
@@ Utils.User_error
864-
("Multicore_backend.device_to_device: merge node mismatch, expected "
865-
^ Option.(value ~default:"none" @@ map ~f:Tnode.debug_name dst.expected_merge_node)
866-
^ ", actual " ^ Tnode.debug_name tn);
867840
device_to_device tn ~into_merge_buffer ~dst:dst.ctx ~src:src.ctx
868841
end
869842

bin/moons_demo_parallel.ml

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,16 @@ let experiment ~seed ~backend_name ~config () =
1313
(* Utils.set_log_level 3; *)
1414
(* Utils.settings.output_debug_files_in_build_directory <- true; *)
1515
(* Utils.settings.debug_log_from_routines <- true; *)
16-
(* let hid_dim = 16 in *)
17-
let hid_dim = 4 in
16+
let hid_dim = 16 in
17+
(* let hid_dim = 4 in *)
1818
(* let batch_size = 120 in *)
19-
(* let batch_size = 60 in *)
20-
let batch_size = 20 in
19+
let batch_size = 60 in
20+
(* let batch_size = 20 in *)
2121
let len = batch_size * 20 in
2222
let init_lr = 0.1 in
2323
(* let epochs = 10 in *)
24-
(* let epochs = 40 in *)
25-
let epochs = 1 in
24+
let epochs = 40 in
25+
(* let epochs = 1 in *)
2626
let noise () = Rand.float_range (-0.1) 0.1 in
2727
let moons_flat =
2828
Array.concat_map (Array.create ~len ())

lib/train.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -364,14 +364,16 @@ let%track3_sexp parallel_update (type context)
364364
let grad_merge =
365365
Option.value_exn ~here:[%here] ~message:(Tn.debug_name p.value) grad_merges_to.(to_).(i)
366366
in
367+
(* NOTE: we no longer have to to pass [grad_merge.context] as [dst]. *)
367368
assert (
368369
Backend.device_to_device (Option.value_exn ~here:[%here] p.diff).grad ~into_merge_buffer
369-
~dst:grad_merge.context ~src:ctxs.(from));
370+
~dst:ctxs.(to_) ~src:ctxs.(from));
370371
(Task.run grad_merge.schedule : unit))
371372
in
372373
let merge_loss ~src =
374+
(* NOTE: we no longer have to to pass [loss_merge.context] as [dst]. *)
373375
assert (
374-
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:loss_merge.context ~src);
376+
Backend.device_to_device updaten.loss.value ~into_merge_buffer ~dst:sgd_update.context ~src);
375377
Task.run loss_merge.schedule
376378
in
377379
(* FIXME: missing backcopy. *)

0 commit comments

Comments
 (0)