Skip to content

Commit 6e11ff9

Browse files
committed
Fixes sharing update: Hosted Nonconstant -> Hosted (Changed_on_devices ...) if sharing specified
1 parent e2780a6 commit 6e11ff9

File tree

2 files changed

+27
-8
lines changed

2 files changed

+27
-8
lines changed

arrayjit/lib/tnode.ml

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -193,14 +193,16 @@ let known_not_param tn =
193193

194194
let known_shared_cross_stream tn =
195195
match tn.memory_mode with
196-
| Some ((On_device Shared_cross_stream | Hosted (Changed_on_devices Shared_cross_stream)), _) ->
196+
| Some
197+
( ( On_device Shared_cross_stream
198+
| Hosted (Constant | Volatile | Changed_on_devices Shared_cross_stream) ),
199+
_ ) ->
197200
true
198201
| _ -> false
199202

200203
let known_non_cross_stream tn =
201204
match tn.memory_mode with
202-
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) ->
203-
true
205+
| Some ((On_device Per_stream | Hosted (Changed_on_devices Per_stream)), _) -> true
204206
| _ -> false
205207

206208
let mode_is_unspecified tn =
@@ -246,6 +248,9 @@ let update_memory_mode tn mode provenance =
246248
"Tnode.update_memory_mode: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
247249
%{debug_name tn}"]
248250

251+
(** [update_memory_sharing tn sharing provenance] preserves the memory mode of [tn] while updating
252+
the cross-stream sharing property, except that [Hosted Nonconstant] is further specialized to
253+
[Hosted (Changed_on_devices sharing)]. *)
249254
let update_memory_sharing tn sharing provenance =
250255
match (tn.memory_mode, sharing) with
251256
| None, _ -> tn.memory_mode <- Some (On_device sharing, provenance)
@@ -264,13 +269,27 @@ let update_memory_sharing tn sharing provenance =
264269
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
265270
%{debug_name tn} (hosted) -- change from non-shared to shared is currently not \
266271
permitted"]
267-
| Some (Hosted (Changed_on_devices _), _), _ ->
272+
| Some (Hosted (Constant | Volatile), prov2), Per_stream ->
273+
raise
274+
@@ Utils.User_error
275+
[%string
276+
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for \
277+
%{debug_name tn} (hosted) -- currently hosted nodes not changed on devices must be \
278+
shared cross-stream"]
279+
| Some (Hosted (Constant | Volatile), _), Shared_cross_stream -> ()
280+
| Some (Hosted (Nonconstant | Changed_on_devices _), _), _ ->
268281
tn.memory_mode <- Some (Hosted (Changed_on_devices sharing), provenance)
269-
| Some (_, prov2), _ ->
282+
| Some (_, prov2), Unset ->
283+
invalid_arg
284+
[%string
285+
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} for %{debug_name \
286+
tn} -- currently unsetting of sharing not allowed"]
287+
| Some (mem_mode, prov2), _ ->
270288
invalid_arg
271289
[%string
272290
"Tnode.update_memory_sharing: update %{prov2#Int} -> %{provenance#Int} inconsistent for \
273-
%{debug_name tn} -- not materialized on the devices"]
291+
%{debug_name tn} -- not materialized on the devices: %{Sexp.to_string_hum @@ \
292+
sexp_of_memory_mode mem_mode}"]
274293

275294
let update_prec ?only_if tn prec =
276295
let do_update =

bin/moons_benchmark.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -217,8 +217,8 @@ let _mem_benchmarks =
217217
~f:(fun batch_size ->
218218
List.concat_map [ 0; (* 1; 2; *) 3 ] ~f:(fun inlining_cutoff ->
219219
List.concat_map [ (* 1; 3; *) 7 (* *) ] ~f:(fun seed ->
220-
List.concat_map [ (* "gccjit" ; *) "cc" (* ; "cuda" *) ] ~f:(fun backend_name ->
221-
List.concat_map [ CDSL.double; CDSL.single (* ; CDSL.half *) ]
220+
List.concat_map [ (* "gccjit" ; "cc"; *) "cuda" ] ~f:(fun backend_name ->
221+
List.concat_map [ (* CDSL.double; *) CDSL.single (* ; CDSL.half *) ]
222222
~f:(fun value_prec ->
223223
[
224224
classify_moons ~seed ~on_device:true ~inlining_cutoff ~num_streams

0 commit comments

Comments
 (0)