Skip to content

Commit be9a299

Browse files
committed
In progress toward #286: type Tnode.sharing
1 parent cdc7196 commit be9a299

File tree

7 files changed

+75
-40
lines changed

7 files changed

+75
-40
lines changed

arrayjit/lib/cuda_backend.cudajit.ml

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,3 @@
1-
(** In the current design of the CUDA backend, unlike in the CPU backends, context arrays for
2-
incomparable contexts do not need be disjoint, as long as they share a device. If a tensor node
3-
is read-only for all contexts, its array will be shared even by incomparable contexts. The
4-
particular design is as follows, within a single device:
5-
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
6-
cross-stream sharing candidate.
7-
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
8-
have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream
9-
shared, and the same array is reused.
10-
- If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as
11-
non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it
12-
is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if
13-
the node was already owned by a different stream.
14-
15-
If a tensor node is cross-stream shared, within-device copying is a NOOP as source and
16-
destination pointers are in that case identical.
17-
18-
FIXME(#286): this should be controllable via {!Tnode.memory_mode}. *)
19-
201
open Base
212
module Tn = Tnode
223
module Lazy = Utils.Lazy

arrayjit/lib/low_level.ml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,8 @@ let visit_llc traced_store ~merge_node_id reverse_node_map ~max_visits llc =
271271
else Tn.update_memory_mode tn Materialized 35);
272272
if Hashtbl.exists traced.accesses ~f:is_recurrent then (
273273
traced.read_before_write <- true;
274-
if Tn.mode_is_unspecified tn then Tn.update_memory_mode tn (Hosted Changed_on_devices) 38
274+
if Tn.mode_is_unspecified tn then
275+
Tn.update_memory_mode tn (Hosted (Changed_on_devices Unset)) 38
275276
else Tn.update_memory_mode tn Materialized 36))
276277

277278
let%diagn_sexp check_and_store_virtual traced static_indices top_llc =

arrayjit/lib/tnode.ml

Lines changed: 64 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,30 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
88
[%%global_debug_log_level 9]
99
[%%global_debug_log_level_from_env_var "OCANNL_LOG_LEVEL"]
1010

11+
(** A possible algorithm for deciding sharing within a single device:
12+
- If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
13+
cross-stream sharing candidate.
14+
- If a cross-stream sharing candidate is read-only for another context, whose parent does not
15+
have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream
16+
shared, and the same array is reused.
17+
- If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as
18+
non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it
19+
is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if
20+
the node was already owned by a different stream.
21+
22+
If a tensor node is shared cross-stream, within-device copying is a NOOP as source and
23+
destination pointers are in that case identical. *)
24+
type sharing =
25+
| Unset
26+
| Per_stream (** The tensor node has separate arrays for each stream. *)
27+
| Shared_cross_stream (** The tensor node has a single array per device. *)
28+
[@@deriving sexp, compare, equal]
29+
1130
type memory_type =
1231
| Constant (** The tensor node does not change after initialization. *)
1332
| Nonconstant (** One of: [Changed_on_devices], [Volatile]. *)
14-
| Changed_on_devices (** The tensor node will only change on host via a [to_host] call. *)
33+
| Changed_on_devices of sharing
34+
(** The tensor node will only change on host via a [to_host] call. *)
1535
| Volatile
1636
(** The tensor node will only change on any device via a [from_host] call possibly followed by
1737
[device_to_device]. *)
@@ -25,7 +45,7 @@ type memory_mode =
2545
(** The full tensor node is cached for the duration of a computation but not persisted across
2646
calls to compiled functions. It is not available for merging across devices. *)
2747
| Device_only (** One of: [Local], [On_device]. *)
28-
| On_device
48+
| On_device of sharing
2949
(** The tensor node is stored on the devices that compute with it and persisted across
3050
function calls. It is available for merging across devices (for devices that support
3151
merging / P2P), but not (directly) for visualization or storing to disk. *)
@@ -112,13 +132,17 @@ let log_debug_info ~from_log_level tn =
112132
| (lazy (Some nd)) -> Nd.log_debug_info ~from_log_level nd
113133
else [%log "<not-in-yet>"]]]
114134

135+
(** The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *)
115136
let default_to_most_local tn provenance =
116137
match tn.memory_mode with
117138
| None | Some (Effectively_constant, _) -> tn.memory_mode <- Some (Virtual, provenance)
118139
| Some (Never_virtual, _) -> tn.memory_mode <- Some (Local, provenance)
119140
| Some (Device_only, _) -> tn.memory_mode <- Some (Local, provenance)
120-
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device, provenance)
121-
| Some ((Virtual | Local | On_device | Hosted _), _) -> ()
141+
| Some (Materialized, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
142+
| Some (On_device Unset, _) -> tn.memory_mode <- Some (On_device Shared_cross_stream, provenance)
143+
| Some (Hosted (Changed_on_devices Unset), _) ->
144+
tn.memory_mode <- Some (Hosted (Changed_on_devices Shared_cross_stream), provenance)
145+
| Some ((Virtual | Local | On_device _ | Hosted _), _) -> ()
122146

123147
let is_virtual_force tn provenance =
124148
default_to_most_local tn provenance;
@@ -128,7 +152,7 @@ let is_hosted_force ?specifically tn provenance =
128152
default_to_most_local tn provenance;
129153
match (tn.memory_mode, specifically) with
130154
| None, _ -> assert false
131-
| Some ((Virtual | Local | Device_only | On_device), _), _ -> false
155+
| Some ((Virtual | Local | Device_only | On_device _), _), _ -> false
132156
| Some (Hosted _, _), None -> true
133157
| Some (Hosted memtyp, _), Some query -> equal_memory_type memtyp query
134158
| Some ((Never_virtual | Materialized | Effectively_constant), _), _ -> assert false
@@ -138,7 +162,7 @@ let is_materialized_force tn provenance =
138162
match tn.memory_mode with
139163
| None -> assert false
140164
| Some ((Virtual | Local), _) -> false
141-
| Some ((On_device | Hosted _ | Materialized), _) -> true
165+
| Some ((On_device _ | Hosted _ | Materialized), _) -> true
142166
| Some ((Never_virtual | Device_only | Effectively_constant), _) -> assert false
143167

144168
let is_in_context_force tn provenance =
@@ -164,7 +188,7 @@ let known_non_virtual tn =
164188
let known_not_param tn =
165189
match tn.memory_mode with
166190
| Some
167-
( ( Virtual | Local | Effectively_constant | Device_only | On_device
191+
( ( Virtual | Local | Effectively_constant | Device_only | On_device _
168192
| Hosted (Constant | Volatile) ),
169193
_ ) ->
170194
true
@@ -190,9 +214,9 @@ let update_memory_mode tn mode provenance =
190214
| Some (Effectively_constant, _), (Never_virtual | Materialized | Hosted Constant) ->
191215
tn.memory_mode <- Some (Hosted Constant, provenance)
192216
| Some (Effectively_constant, _), Virtual -> tn.memory_mode <- Some (mode, provenance)
193-
| Some (Hosted Nonconstant, _), Hosted (Changed_on_devices | Volatile) ->
217+
| Some (Hosted Nonconstant, _), Hosted (Changed_on_devices _ | Volatile) ->
194218
tn.memory_mode <- Some (mode, provenance)
195-
| Some (Hosted (Changed_on_devices | Volatile), _), Hosted Nonconstant -> ()
219+
| Some (Hosted (Changed_on_devices _ | Volatile), _), Hosted Nonconstant -> ()
196220
| Some (Never_virtual, _), mode -> tn.memory_mode <- Some (mode, provenance)
197221
| Some (Virtual, prov2), Never_virtual ->
198222
raise
@@ -201,18 +225,44 @@ let update_memory_mode tn mode provenance =
201225
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
202226
tn} is already virtual"]
203227
| Some (_, _), Never_virtual -> ()
204-
| Some (Device_only, _), (Local | On_device) -> tn.memory_mode <- Some (mode, provenance)
205-
| Some (Materialized, _), (On_device | Hosted _) -> tn.memory_mode <- Some (mode, provenance)
206-
| Some ((Local | On_device), _), Device_only -> ()
207-
| Some ((On_device | Hosted _), _), Materialized -> ()
228+
| Some (Device_only, _), (Local | On_device _) -> tn.memory_mode <- Some (mode, provenance)
229+
| Some (Materialized, _), (On_device _ | Hosted _) -> tn.memory_mode <- Some (mode, provenance)
230+
| Some ((Local | On_device _), _), Device_only -> ()
231+
| Some ((On_device _ | Hosted _), _), Materialized -> ()
208232
| Some (Device_only, _), Materialized | Some (Materialized, _), Device_only ->
209-
tn.memory_mode <- Some (On_device, provenance)
233+
tn.memory_mode <- Some (On_device Unset, provenance)
210234
| Some (_, prov2), _ ->
211235
invalid_arg
212236
[%string
213237
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
214238
%{debug_name tn}"]
215239

240+
let update_memory_sharing tn sharing provenance =
241+
match (tn.memory_mode, sharing) with
242+
| None, _ -> tn.memory_mode <- Some (On_device sharing, provenance)
243+
| Some (On_device Per_stream, prov2), Shared_cross_stream ->
244+
raise
245+
@@ Utils.User_error
246+
[%string
247+
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
248+
%{debug_name tn} -- change from non-shared to shared is currently not permitted"]
249+
| Some ((On_device _ | Device_only | Materialized), _), Shared_cross_stream ->
250+
tn.memory_mode <- Some (On_device sharing, provenance)
251+
| Some (Hosted (Changed_on_devices Per_stream), prov2), Shared_cross_stream ->
252+
raise
253+
@@ Utils.User_error
254+
[%string
255+
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
256+
%{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
257+
permitted"]
258+
| Some (Hosted (Changed_on_devices _), _), _ ->
259+
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance)
260+
| Some (_, prov2), _ ->
261+
invalid_arg
262+
[%string
263+
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
264+
%{debug_name tn} -- not materialized on the devices"]
265+
216266
let update_prec ?only_if tn prec =
217267
let do_update =
218268
match only_if with

bin/micrograd_demo.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,7 @@ let experiment seed ~no_batch_shape_inference ~use_builtin_weight_decay () =
122122
let classes = Tensor.value_1d_points ~xdim:0 moons_classes in
123123
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
124124
let%op mlp_result = mlp "point" in
125-
Train.set_on_host Changed_on_devices mlp_result.value;
125+
Train.set_on_host mlp_result.value;
126126
(* By using jitted.context here, we don't need to copy the parameters back to the host. *)
127127
let result_routine =
128128
Train.to_routine

bin/moons_demo.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ let demo () =
103103
in
104104

105105
let%op mlp_result = mlp "point" in
106-
Train.set_on_host Changed_on_devices mlp_result.value;
106+
Train.set_on_host mlp_result.value;
107107
let result_routine =
108108
Train.to_routine
109109
(module Backend)

lib/train.ml

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,15 @@ let restore_params t =
105105
let f arr = Npy.Npz.restore in_file name arr in
106106
Nd.map { f } @@ Option.value_exn ~here:[%here] @@ Lazy.force v.array)
107107

108-
let set_on_host memtype (a : Tn.t) = Tn.update_memory_mode a (Hosted memtype) 27
108+
let set_on_host ?(from_device = true) (a : Tn.t) =
109+
let memtype = if from_device then Tn.(Changed_on_devices Unset) else Volatile in
110+
Tn.update_memory_mode a (Hosted memtype) 27
111+
109112
let set_materialized (a : Tn.t) = Tn.update_memory_mode a Materialized 28
110113

111114
let set_hosted (a : Tn.t) =
112115
if Tn.known_constant a then Tn.update_memory_mode a (Hosted Constant) 41
113-
else Tn.update_memory_mode a (Hosted Changed_on_devices) 41
116+
else Tn.update_memory_mode a (Hosted (Changed_on_devices Unset)) 41
114117

115118
(** Sets the tensor's value as "fully on host", returns the tensor's forward code with a
116119
label-derived comment. *)
@@ -510,7 +513,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
510513
else Tensor.consume_forward_code model_result
511514
in
512515
if not disable_rootness_check then Tensor.remove_bprop_root model_result;
513-
set_on_host Changed_on_devices model_result.Tensor.value;
516+
set_on_host model_result.Tensor.value;
514517
(* By using sgd_update.context, maybe we don't need to copy the parameters back to the host. *)
515518
let routine =
516519
Backend.(

test/micrograd_demo.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,7 @@ let%expect_test "Micrograd half-moons example" =
161161
let classes = Tensor.value_1d_points ~xdim:0 moons_classes in
162162
let points1, points2 = Array.partitioni_tf points ~f:Float.(fun i _ -> classes.(i) > 0.) in
163163
let%op mlp_result = mlp "point" in
164-
Train.set_on_host Changed_on_devices mlp_result.value;
164+
Train.set_on_host mlp_result.value;
165165
let result_routine =
166166
Train.to_routine
167167
(module Backend)

0 commit comments

Comments
 (0)