File tree Expand file tree Collapse file tree 3 files changed +26
-4
lines changed Expand file tree Collapse file tree 3 files changed +26
-4
lines changed Original file line number Diff line number Diff line change @@ -117,6 +117,27 @@ let%debug3_sexp context_nodes ~(use_host_memory : 'a option) (asgns : t) : Tn.t_
117117 in
118118 loop asgns
119119
120+ (* * Returns the nodes that are not read from after being written to. *)
121+ let % debug3_sexp guess_output_nodes (asgns : t ) : Tn. t_set =
122+ let open Utils.Set_O in
123+ let empty = Set. empty (module Tn ) in
124+ let one = Set. singleton (module Tn ) in
125+ let of_node = function Node rhs -> one rhs | Merge_buffer _ -> empty in
126+ let rec loop = function
127+ | Noop -> (empty, empty)
128+ | Seq (t1 , t2 ) ->
129+ let i1, o1 = loop t1 in
130+ let i2, o2 = loop t2 in
131+ (i1 + i2, o1 + o2 - (i1 + i2))
132+ | Block_comment (_ , t ) -> loop t
133+ | Accum_unop { lhs; rhs; _ } -> (of_node rhs, one lhs)
134+ | Accum_binop { lhs; rhs1; rhs2; _ } -> (of_node rhs1 + of_node rhs2, one lhs)
135+ | Accum_ternop { lhs; rhs1; rhs2; rhs3; _ } ->
136+ (of_node rhs1 + of_node rhs2 + of_node rhs3, one lhs)
137+ | Fetch { array; _ } -> (empty, one array )
138+ in
139+ snd @@ loop asgns
140+
120141let sequential l =
121142 Option. value ~default: Noop @@ List. reduce l ~f: (fun sts another_st -> Seq (sts, another_st))
122143
Original file line number Diff line number Diff line change @@ -455,7 +455,7 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
455455 let data = Hashtbl. find_exn device.cross_stream_candidates key in
456456 Map. add_exn ctx_arrays ~key ~data )
457457 else (
458- Tn. update_memory_sharing key Tn. Per_stream 41 ;
458+ Tn. update_memory_sharing key Tn. Per_stream 410 ;
459459 Hashtbl. remove device.cross_stream_candidates key;
460460 add_new () ))
461461 else ctx_arrays
Original file line number Diff line number Diff line change @@ -69,8 +69,8 @@ let set_on_host ?(from_device = true) (a : Tn.t) =
6969let set_materialized (a : Tn.t ) = Tn. update_memory_mode a Materialized 28
7070
7171let set_hosted (a : Tn.t ) =
72- if Tn. known_constant a then Tn. update_memory_mode a (Hosted Constant ) 41
73- else Tn. update_memory_mode a (Hosted (Changed_on_devices Unset )) 41
72+ if Tn. known_constant a then Tn. update_memory_mode a (Hosted Constant ) 411
73+ else Tn. update_memory_mode a (Hosted (Changed_on_devices Unset )) 412
7474
7575(* * Sets the tensor's value as "fully on host", returns the tensor's forward code with a
7676 label-derived comment. *)
@@ -346,7 +346,8 @@ let to_routine (type buffer_ptr dev runner event optimize_ctx)
346346 and type dev = dev
347347 and type runner = runner
348348 and type event = event
349- and type optimize_ctx = optimize_ctx ) (context : Backend.context ) ?name bindings comp =
349+ and type optimize_ctx = optimize_ctx ) (context : Backend.context ) ?(hosted =true ) ?name bindings comp =
350+ if hosted then Set. iter (Asgns. guess_output_nodes comp.Asgns. asgns) ~f: set_hosted;
350351 Backend. link context @@ Backend. compile context.optimize_ctx ?name bindings comp
351352
352353type example_train_result = {
You can’t perform that action at this time.
0 commit comments