Skip to content

Commit acb41d2

Browse files
committed
Auto-set hosted for Train.to_routine; fix ambiguous mem mode provenances
1 parent 309a89a commit acb41d2

File tree

3 files changed

+26
-4
lines changed

3 files changed

+26
-4
lines changed

arrayjit/lib/assignments.ml

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,27 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
117117
in
118118
loop asgns
119119

120+
(** Returns the nodes that are not read from after being written to. *)
121+
let%debug3_sexp guess_output_nodes (asgns : t) : Tn.t_set =
122+
let open Utils.Set_O in
123+
let empty = Set.empty (module Tn) in
124+
let one = Set.singleton (module Tn) in
125+
let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
126+
let rec loop = function
127+
| Noop -> (empty, empty)
128+
| Seq (t1, t2) ->
129+
let i1, o1 = loop t1 in
130+
let i2, o2 = loop t2 in
131+
(i1 + i2, o1 + o2 - (i1 + i2))
132+
| Block_comment (_, t) -> loop t
133+
| Accum_unop { lhs; rhs; _ } -> (of_node rhs, one lhs)
134+
| Accum_binop { lhs; rhs1; rhs2; _ } -> (of_node rhs1 + of_node rhs2, one lhs)
135+
| Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
136+
(of_node rhs1 + of_node rhs2 + of_node rhs3, one lhs)
137+
| Fetch { array; _ } -> (empty, one array)
138+
in
139+
snd @@ loop asgns
140+
120141
let sequential l =
121142
Option.value ~default:Noop @@ List.reduce l ~f:(fun sts another_st -> Seq (sts, another_st))
122143

arrayjit/lib/backends.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -455,7 +455,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
455455
let data = Hashtbl.find_exn device.cross_stream_candidates key in
456456
Map.add_exn ctx_arrays ~key ~data)
457457
else (
458-
Tn.update_memory_sharing key Tn.Per_stream 41;
458+
Tn.update_memory_sharing key Tn.Per_stream 410;
459459
Hashtbl.remove device.cross_stream_candidates key;
460460
add_new ()))
461461
else ctx_arrays

lib/train.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,8 +69,8 @@ let set_on_host ?(from_device = true) (a : Tn.t) =
6969
let set_materialized (a : Tn.t) = Tn.update_memory_mode a Materialized 28
7070

7171
let set_hosted (a : Tn.t) =
72-
if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 41
73-
else Tn.update_memory_mode a (Hosted (Changed_on_devices Unset)) 41
72+
if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 411
73+
else Tn.update_memory_mode a (Hosted (Changed_on_devices Unset)) 412
7474

7575
(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a
7676
label-derived comment. *)
@@ -346,7 +346,8 @@ let to_routine (type buffer_ptr dev runner event optimize_ctx)
346346
and type dev = dev
347347
and type runner = runner
348348
and type event = event
349-
and type optimize_ctx = optimize_ctx) (context : Backend.context) ?name bindings comp =
349+
and type optimize_ctx = optimize_ctx) (context : Backend.context) ?(hosted=true) ?name bindings comp =
350+
if hosted then Set.iter (Asgns.guess_output_nodes comp.Asgns.asgns) ~f:set_hosted;
350351
Backend.link context @@ Backend.compile context.optimize_ctx ?name bindings comp
351352

352353
type example_train_result = {

0 commit comments

Comments
 (0)