@@ -298,9 +298,11 @@ let update_memory_mode tn mode provenance =
298298 | Some (Effectively_constant , _ ), (Never_virtual | Materialized | Hosted Constant ) ->
299299 tn.memory_mode < - Some (Hosted Constant , provenance)
300300 | Some (Effectively_constant, _ ), Virtual -> tn.memory_mode < - Some (mode, provenance)
301- | Some (Hosted Nonconstant , _ ), Hosted (Changed_on_devices _ | Volatile ) ->
301+ | ( Some (Hosted (Nonconstant | Changed_on_devices Unset ), _),
302+ Hosted (Changed_on_devices _ | Volatile ) ) ->
302303 tn.memory_mode < - Some (mode, provenance)
303304 | Some (Hosted (Changed_on_devices _ | Volatile ), _ ), Hosted Nonconstant -> ()
305+ | Some (Hosted (Changed_on_devices _ ), _ ), Hosted (Changed_on_devices Unset) -> ()
304306 | Some (Never_virtual, _ ), mode -> tn.memory_mode < - Some (mode, provenance)
305307 | Some (Virtual, prov2 ), Never_virtual ->
306308 raise
@@ -310,6 +312,8 @@ let update_memory_mode tn mode provenance =
310312 tn} is already virtual" ]
311313 | Some (_ , _ ), Never_virtual -> ()
312314 | Some (Device_only , _ ), (Local | On_device _ ) -> tn.memory_mode < - Some (mode, provenance)
315+ | Some (On_device _ , _ ), On_device Unset -> ()
316+ | Some (On_device Unset, _ ), On_device _ -> tn.memory_mode < - Some (mode, provenance)
313317 | Some (Materialized , _ ), (On_device _ | Hosted _ ) -> tn.memory_mode < - Some (mode, provenance)
314318 | Some ((Local | On_device _ ), _ ), Device_only -> ()
315319 | Some ((On_device _ | Hosted _ ), _ ), Materialized -> ()
@@ -330,8 +334,8 @@ let update_memory_sharing tn sharing provenance =
330334 | Some (On_device Shared_cross_streams , _), Shared_cross_streams
331335 | Some (On_device Per_stream, _ ), Per_stream ->
332336 ()
333- | Some ((On_device Unset | Device_only | Materialized ), _ ), _ ->
334- tn.memory_mode < - Some (On_device sharing, provenance)
337+ | Some ((On_device Unset | Device_only | Materialized ), old_prov ), _ ->
338+ tn.memory_mode < - Some (On_device sharing, provenance + (old_prov * 1000 ) )
335339 | Some (Hosted (Constant | Volatile ), prov2 ), Per_stream ->
336340 raise
337341 @@ Utils. User_error
@@ -343,8 +347,8 @@ let update_memory_sharing tn sharing provenance =
343347 | Some (Hosted (Changed_on_devices Per_stream), _ ), Per_stream ->
344348 ()
345349 | Some (Hosted (Constant | Volatile ), _ ), Shared_cross_streams -> ()
346- | Some (Hosted (Nonconstant | Changed_on_devices Unset ), _ ), _ ->
347- tn.memory_mode < - Some (Hosted (Changed_on_devices sharing), provenance)
350+ | Some (Hosted (Nonconstant | Changed_on_devices Unset ), old_prov ), _ ->
351+ tn.memory_mode < - Some (Hosted (Changed_on_devices sharing), provenance + (old_prov * 1000 ) )
348352 | Some (_ , prov2 ), Unset ->
349353 invalid_arg
350354 [% string
0 commit comments