@@ -193,14 +193,16 @@ let known_not_param tn =
193193
194194let known_shared_cross_stream tn =
195195 match tn.memory_mode with
196- | Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _ ) ->
196+ | Some
197+ ( ( On_device Shared_cross_stream
198+ | Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream ) ),
199+ _ ) ->
197200 true
198201 | _ -> false
199202
200203let known_non_cross_stream tn =
201204 match tn.memory_mode with
202- | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _ ) ->
203- true
205+ | Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _ ) -> true
204206 | _ -> false
205207
206208let mode_is_unspecified tn =
@@ -246,6 +248,9 @@ let update_memory_mode tn mode provenance =
246248 " Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
247249 %{debug_name tn}" ]
248250
251+ (* * [update_memory_sharing tn sharing provenance] preserves the memory mode of [tn] while updating
252+ the cross-stream sharing property, except that [Hosted Nonconstant] is further specialized to
253+ [Hosted (Changed_on_devices sharing)]. *)
249254let update_memory_sharing tn sharing provenance =
250255 match (tn.memory_mode, sharing) with
251256 | None , _ -> tn.memory_mode < - Some (On_device sharing, provenance)
@@ -264,13 +269,27 @@ let update_memory_sharing tn sharing provenance =
264269 " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
265270 %{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
266271 permitted" ]
267- | Some (Hosted (Changed_on_devices _ ), _ ), _ ->
272+ | Some (Hosted (Constant | Volatile ), prov2 ), Per_stream ->
273+ raise
274+ @@ Utils. User_error
275+ [% string
276+ " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
277+ %{debug_name tn} (hosted) -- currently hosted nodes not changed on devices must be \
278+ shared cross-stream" ]
279+ | Some (Hosted (Constant | Volatile ), _ ), Shared_cross_stream -> ()
280+ | Some (Hosted (Nonconstant | Changed_on_devices _ ), _ ), _ ->
268281 tn.memory_mode < - Some (Hosted (Changed_on_devices sharing), provenance)
269- | Some (_ , prov2 ), _ ->
282+ | Some (_ , prov2 ), Unset ->
283+ invalid_arg
284+ [% string
285+ " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
286+ tn} -- currently unsetting of sharing not allowed" ]
287+ | Some (mem_mode , prov2 ), _ ->
270288 invalid_arg
271289 [% string
272290 " Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
273- %{debug_name tn} -- not materialized on the devices" ]
291+ %{debug_name tn} -- not materialized on the devices: %{Sexp.to_string_hum @@ \
292+ sexp_of_memory_mode mem_mode}" ]
274293
275294let update_prec ?only_if tn prec =
276295 let do_update =
0 commit comments