Skip to content

Commit fa884ef

Browse files
committed
Fix false positives in update_memory_mode check (should be no-change)
1 parent 09d8a39 commit fa884ef

File tree

2 files changed

+13
-12
lines changed

2 files changed

+13
-12
lines changed

arrayjit/lib/tnode.ml

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,9 +298,11 @@ let update_memory_mode tn mode provenance =
298298
| Some (Effectively_constant, _), (Never_virtual | Materialized | Hosted Constant) ->
299299
tn.memory_mode <- Some (Hosted Constant, provenance)
300300
| Some (Effectively_constant, _), Virtual -> tn.memory_mode <- Some (mode, provenance)
301-
| Some (Hosted Nonconstant, _), Hosted (Changed_on_devices _ | Volatile) ->
301+
| ( Some (Hosted (Nonconstant | Changed_on_devices Unset), _),
302+
Hosted (Changed_on_devices _ | Volatile) ) ->
302303
tn.memory_mode <- Some (mode, provenance)
303304
| Some (Hosted (Changed_on_devices _ | Volatile), _), Hosted Nonconstant -> ()
305+
| Some (Hosted (Changed_on_devices _), _), Hosted (Changed_on_devices Unset) -> ()
304306
| Some (Never_virtual, _), mode -> tn.memory_mode <- Some (mode, provenance)
305307
| Some (Virtual, prov2), Never_virtual ->
306308
raise
@@ -310,6 +312,8 @@ let update_memory_mode tn mode provenance =
310312
tn} is already virtual"]
311313
| Some (_, _), Never_virtual -> ()
312314
| Some (Device_only, _), (Local | On_device _) -> tn.memory_mode <- Some (mode, provenance)
315+
| Some (On_device _, _), On_device Unset -> ()
316+
| Some (On_device Unset, _), On_device _ -> tn.memory_mode <- Some (mode, provenance)
313317
| Some (Materialized, _), (On_device _ | Hosted _) -> tn.memory_mode <- Some (mode, provenance)
314318
| Some ((Local | On_device _), _), Device_only -> ()
315319
| Some ((On_device _ | Hosted _), _), Materialized -> ()
@@ -330,8 +334,8 @@ let update_memory_sharing tn sharing provenance =
330334
| Some (On_device Shared_cross_streams, _), Shared_cross_streams
331335
| Some (On_device Per_stream, _), Per_stream ->
332336
()
333-
| Some ((On_device Unset | Device_only | Materialized), _), _ ->
334-
tn.memory_mode <- Some (On_device sharing, provenance)
337+
| Some ((On_device Unset | Device_only | Materialized), old_prov), _ ->
338+
tn.memory_mode <- Some (On_device sharing, provenance + (old_prov * 1000))
335339
| Some (Hosted (Constant | Volatile), prov2), Per_stream ->
336340
raise
337341
@@ Utils.User_error
@@ -343,8 +347,8 @@ let update_memory_sharing tn sharing provenance =
343347
| Some (Hosted (Changed_on_devices Per_stream), _), Per_stream ->
344348
()
345349
| Some (Hosted (Constant | Volatile), _), Shared_cross_streams -> ()
346-
| Some (Hosted (Nonconstant | Changed_on_devices Unset), _), _ ->
347-
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance)
350+
| Some (Hosted (Nonconstant | Changed_on_devices Unset), old_prov), _ ->
351+
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance + (old_prov * 1000))
348352
| Some (_, prov2), Unset ->
349353
invalid_arg
350354
[%string

bin/compilation_speed.ml

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -29,13 +29,10 @@ let benchmark_overhead backend () =
2929
let ctx = Backend.make_context stream in
3030
let init_mem = Backend.(get_used_memory stream.device) in
3131
let update_f = Train.grad_update f in
32-
(* Initialize the context with a mock update of x to ensure that it is not optimized as a
33-
constant. *)
34-
let%cd mock_update_x = x =: 42 in
35-
let init_assign_x =
36-
Train.to_routine (module Backend) ctx ~name:"init_assign_x" IDX.empty mock_update_x
32+
let init_x =
33+
Train.to_routine (module Backend) ctx ~name:"init_assign_x" IDX.empty @@ Tensor.init_params f
3734
in
38-
let f_routine = Train.to_routine (module Backend) init_assign_x.context IDX.empty update_f in
35+
let f_routine = Train.to_routine (module Backend) init_x.context IDX.empty update_f in
3936
Tensor.print_tree ~with_grad:true ~with_backend_info:true ~depth:9 f;
4037

4138
let xs = Array.init n_data ~f:Float.(fun i -> of_int i - (of_int n_data /. 2.)) in
@@ -78,7 +75,7 @@ let benchmarks =
7875
[
7976
(* benchmark_overhead (fresh_backend "gccjit" ()); *)
8077
benchmark_overhead (fresh_backend ~backend_name:"multicore_cc" ());
81-
benchmark_overhead (fresh_backend ~backend_name:"cuda" ());
78+
(* benchmark_overhead (fresh_backend ~backend_name:"cuda" ()); *)
8279
]
8380

8481
let () =

0 commit comments

Comments
 (0)