@@ -8,10 +8,30 @@ let _get_local_debug_runtime = Utils._get_local_debug_runtime
88[%% global_debug_log_level 9 ]
99[%% global_debug_log_level_from_env_var " OCANNL_LOG_LEVEL" ]
1010
11+ (* * A possible algorithm for deciding sharing within a single device:
12+ - If a tensor node is read-only for a context, and not otherwise recorded, it is stored as a
13+ cross-stream sharing candidate.
14+ - If a cross-stream sharing candidate is read-only for another context, whose parent does not
15+ have the corresponding array (i.e. it is a different stream), it is recorded as cross-stream
16+ shared, and the same array is reused.
17+ - If a tensor node is writable by a context, and it is not cross-stream shared, it is marked as
18+ non-cross-stream, the array is removed from cross-stream sharing candidates if present. If it
19+ is cross-stream shared, it is recorded as owned by the corresponding stream. It is an error if
20+ the node was already owned by a different stream.
21+
22+ If a tensor node is shared cross-stream, within-device copying is a NOOP as source and
23+ destination pointers are in that case identical. *)
24+ type sharing =
25+ | Unset
26+ | Per_stream (* * The tensor node has separate arrays for each stream. *)
27+ | Shared_cross_stream (* * The tensor node has a single array per device. *)
28+ [@@ deriving sexp , compare , equal ]
29+
1130type memory_type =
1231 | Constant (* * The tensor node does not change after initialization. *)
1332 | Nonconstant (* * One of: [Changed_on_devices], [Volatile]. *)
14- | Changed_on_devices (* * The tensor node will only change on host via a [to_host] call. *)
33+ | Changed_on_devices of sharing
34+ (* * The tensor node will only change on host via a [to_host] call. *)
1535 | Volatile
1636 (* * The tensor node will only change on any device via a [from_host] call possibly followed by
1737 [device_to_device]. *)
@@ -25,7 +45,7 @@ type memory_mode =
2545 (* * The full tensor node is cached for the duration of a computation but not persisted across
2646 calls to compiled functions. It is not available for merging across devices. *)
2747 | Device_only (* * One of: [Local], [On_device]. *)
28- | On_device
48+ | On_device of sharing
2949 (* * The tensor node is stored on the devices that compute with it and persisted across
3050 function calls. It is available for merging across devices (for devices that support
3151 merging / P2P), but not (directly) for visualization or storing to disk. *)
@@ -112,13 +132,17 @@ let log_debug_info ~from_log_level tn =
112132 | (lazy (Some nd )) -> Nd. log_debug_info ~from_log_level nd
113133 else [% log " <not-in-yet>" ]]]
114134
135+ (* * The one exception to "most local" is the sharing property: defaults to [Shared_cross_stream]. *)
115136let default_to_most_local tn provenance =
116137 match tn.memory_mode with
117138 | None | Some (Effectively_constant, _ ) -> tn.memory_mode < - Some (Virtual , provenance)
118139 | Some (Never_virtual, _ ) -> tn.memory_mode < - Some (Local , provenance)
119140 | Some (Device_only, _ ) -> tn.memory_mode < - Some (Local , provenance)
120- | Some (Materialized, _ ) -> tn.memory_mode < - Some (On_device , provenance)
121- | Some ((Virtual | Local | On_device | Hosted _ ), _ ) -> ()
141+ | Some (Materialized, _ ) -> tn.memory_mode < - Some (On_device Shared_cross_stream , provenance)
142+ | Some (On_device Unset, _ ) -> tn.memory_mode < - Some (On_device Shared_cross_stream , provenance)
143+ | Some (Hosted (Changed_on_devices Unset), _ ) ->
144+ tn.memory_mode < - Some (Hosted (Changed_on_devices Shared_cross_stream ), provenance)
145+ | Some ((Virtual | Local | On_device _ | Hosted _ ), _ ) -> ()
122146
123147let is_virtual_force tn provenance =
124148 default_to_most_local tn provenance;
@@ -128,7 +152,7 @@ let is_hosted_force ?specifically tn provenance =
128152 default_to_most_local tn provenance;
129153 match (tn.memory_mode, specifically) with
130154 | None , _ -> assert false
131- | Some ((Virtual | Local | Device_only | On_device ), _ ), _ -> false
155+ | Some ((Virtual | Local | Device_only | On_device _ ), _ ), _ -> false
132156 | Some (Hosted _ , _ ), None -> true
133157 | Some (Hosted memtyp , _ ), Some query -> equal_memory_type memtyp query
134158 | Some ((Never_virtual | Materialized | Effectively_constant ), _ ), _ -> assert false
@@ -138,7 +162,7 @@ let is_materialized_force tn provenance =
138162 match tn.memory_mode with
139163 | None -> assert false
140164 | Some ((Virtual | Local ), _ ) -> false
141- | Some ((On_device | Hosted _ | Materialized ), _ ) -> true
165+ | Some ((On_device _ | Hosted _ | Materialized ), _ ) -> true
142166 | Some ((Never_virtual | Device_only | Effectively_constant ), _ ) -> assert false
143167
144168let is_in_context_force tn provenance =
@@ -164,7 +188,7 @@ let known_non_virtual tn =
164188let known_not_param tn =
165189 match tn.memory_mode with
166190 | Some
167- ( ( Virtual | Local | Effectively_constant | Device_only | On_device
191+ ( ( Virtual | Local | Effectively_constant | Device_only | On_device _
168192 | Hosted (Constant | Volatile ) ),
169193 _ ) ->
170194 true
@@ -190,9 +214,9 @@ let update_memory_mode tn mode provenance =
190214 | Some (Effectively_constant , _ ), (Never_virtual | Materialized | Hosted Constant ) ->
191215 tn.memory_mode < - Some (Hosted Constant , provenance)
192216 | Some (Effectively_constant, _ ), Virtual -> tn.memory_mode < - Some (mode, provenance)
193- | Some (Hosted Nonconstant , _ ), Hosted (Changed_on_devices | Volatile ) ->
217+ | Some (Hosted Nonconstant , _ ), Hosted (Changed_on_devices _ | Volatile ) ->
194218 tn.memory_mode < - Some (mode, provenance)
195- | Some (Hosted (Changed_on_devices | Volatile ), _ ), Hosted Nonconstant -> ()
219+ | Some (Hosted (Changed_on_devices _ | Volatile ), _ ), Hosted Nonconstant -> ()
196220 | Some (Never_virtual, _ ), mode -> tn.memory_mode < - Some (mode, provenance)
197221 | Some (Virtual, prov2 ), Never_virtual ->
198222 raise
@@ -201,18 +225,44 @@ let update_memory_mode tn mode provenance =
201225 " Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
202226 tn} is already virtual" ]
203227 | Some (_ , _ ), Never_virtual -> ()
204- | Some (Device_only , _ ), (Local | On_device ) -> tn.memory_mode < - Some (mode, provenance)
205- | Some (Materialized , _ ), (On_device | Hosted _ ) -> tn.memory_mode < - Some (mode, provenance)
206- | Some ((Local | On_device ), _ ), Device_only -> ()
207- | Some ((On_device | Hosted _ ), _ ), Materialized -> ()
228+ | Some (Device_only , _ ), (Local | On_device _ ) -> tn.memory_mode < - Some (mode, provenance)
229+ | Some (Materialized , _ ), (On_device _ | Hosted _ ) -> tn.memory_mode < - Some (mode, provenance)
230+ | Some ((Local | On_device _ ), _ ), Device_only -> ()
231+ | Some ((On_device _ | Hosted _ ), _ ), Materialized -> ()
208232 | Some (Device_only , _ ), Materialized | Some (Materialized, _ ), Device_only ->
209- tn.memory_mode < - Some (On_device , provenance)
233+ tn.memory_mode < - Some (On_device Unset , provenance)
210234 | Some (_ , prov2 ), _ ->
211235 invalid_arg
212236 [% string
213237 " Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
214238 %{debug_name tn}" ]
215239
240+ let update_memory_sharing tn sharing provenance =
241+ match (tn.memory_mode, sharing) with
242+ | None , _ -> tn.memory_mode < - Some (On_device sharing, provenance)
243+ | Some (On_device Per_stream, prov2 ), Shared_cross_stream ->
244+ raise
245+ @@ Utils. User_error
246+ [% string
247+ " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
248+ %{debug_name tn} -- change from non-shared to shared is currently not permitted" ]
249+ | Some ((On_device _ | Device_only | Materialized ), _ ), Shared_cross_stream ->
250+ tn.memory_mode < - Some (On_device sharing, provenance)
251+ | Some (Hosted (Changed_on_devices Per_stream), prov2 ), Shared_cross_stream ->
252+ raise
253+ @@ Utils. User_error
254+ [% string
255+ " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
256+ %{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
257+ permitted" ]
258+ | Some (Hosted (Changed_on_devices _ ), _ ), _ ->
259+ tn.memory_mode < - Some (Hosted (Changed_on_devices sharing), provenance)
260+ | Some (_ , prov2 ), _ ->
261+ invalid_arg
262+ [% string
263+ " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
264+ %{debug_name tn} -- not materialized on the devices" ]
265+
216266let update_prec ?only_if tn prec =
217267 let do_update =
218268 match only_if with
0 commit comments