Skip to content

Commit 21e0243

Browse files
committed
Fixes #330 by failing informatively on a comptations table miss; be more careful about threading optimize_ctx; debuggability tweaks
1 parent 7336e37 commit 21e0243

File tree

6 files changed

+30
-19
lines changed

6 files changed

+30
-19
lines changed

arrayjit/lib/backend_intf.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ type ('buffer_ptr, 'stream, 'optimize_ctx) context = {
186186
parent : ('buffer_ptr, 'stream, 'optimize_ctx) context option;
187187
ctx_arrays : 'buffer_ptr ctx_arrays;
188188
(** This map contains arrays used in this context or an ancestor context (they might be unique
189-
but might also be cross-stream shared. *)
189+
but might also be cross-stream shared). *)
190190
finalized : Utils.atomic_bool;
191191
optimize_ctx : 'optimize_ctx;
192192
}
@@ -212,8 +212,8 @@ module type Device = sig
212212

213213
val make_child : ?ctx_arrays:ctx_arrays -> ?optimize_ctx:optimize_ctx -> context -> context
214214
(** Returns a context with the same {!field:Backend_intf.context.stream}, and
215-
{!field:Backend_intf.context.ctx_arrays} if omitted, as the given context's, which is also the
216-
{!field:Backend_intf.context.parent}. *)
215+
{!field:Backend_intf.context.ctx_arrays}, {!field:Backend_intf.context.optimize_ctx} if
216+
omitted, as the given context's, which is also the {!field:Backend_intf.context.parent}. *)
217217

218218
val get_name : stream -> string
219219
end

arrayjit/lib/backends.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -529,8 +529,9 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
529529
let ctx_arrays =
530530
Hashtbl.fold code.lowered.traced_store ~init:context.ctx_arrays ~f:(alloc_if_needed context)
531531
in
532+
let optimize_ctx = code.lowered.optimize_ctx in
532533
let bindings, schedule = link context code.code ctx_arrays in
533-
let context = make_child ~ctx_arrays context in
534+
let context = make_child ~ctx_arrays ~optimize_ctx context in
534535
let schedule =
535536
Task.prepend schedule ~work:(fun () ->
536537
check_merge_buffer context.stream ~code_node:code.expected_merge_node)
@@ -553,7 +554,8 @@ module Raise_backend (Device : Lowered_backend) : Backend = struct
553554
| None -> (context, None)
554555
| Some schedule ->
555556
let ctx_arrays = Option.value_exn ctx_arrays.(i) in
556-
let context = make_child ~ctx_arrays context in
557+
let optimize_ctx = (Option.value_exn code_batch.lowereds.(i)).Low_level.optimize_ctx in
558+
let context = make_child ~ctx_arrays ~optimize_ctx context in
557559
let expected_merge_node = code_batch.expected_merge_nodes.(i) in
558560
let (inputs, outputs), merge_buffer_input =
559561
Low_level.input_and_output_nodes @@ Option.value_exn code_batch.lowereds.(i)

arrayjit/lib/low_level.ml

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -524,9 +524,10 @@ let%diagn2_sexp check_and_store_virtual computations_table traced static_indices
524524
Hashtbl.set computations_table ~key:traced.tn ~data:((!at_idcs, top_llc) :: current_computations)
525525
with Non_virtual i -> Tn.update_memory_mode traced.tn Never_virtual i
526526

527-
let%track7_sexp inline_computation ~id computations_table (traced : traced_array)
528-
(static_indices : Indexing.static_symbol list) (call_args : Indexing.axis_index array) :
529-
t option =
527+
let%track7_sexp inline_computation ~id
528+
(computations_table : (Tn.t, (Indexing.axis_index array option * t) list) Hashtbl.t)
529+
(traced : traced_array) (static_indices : Indexing.static_symbol list)
530+
(call_args : Indexing.axis_index array) : t option =
530531
let exception Non_virtual of int in
531532
let static_indices =
532533
Set.of_list (module Indexing.Symbol)
@@ -639,7 +640,15 @@ let%track7_sexp inline_computation ~id computations_table (traced : traced_array
639640
loop env def
640641
in
641642
try
642-
let computations = Hashtbl.find computations_table traced.tn |> Option.value ~default:[] in
643+
let computations =
644+
Hashtbl.find computations_table traced.tn
645+
|> Option.value_or_thunk ~default:(fun () ->
646+
raise
647+
@@ Utils.User_error
648+
[%string
649+
"Stale optimize_ctx: No computations found for #%{traced.tn.Tn.id#Int}: \
650+
%{Tn.debug_name traced.tn}"])
651+
in
643652
let body = List.rev_filter_map ~f:loop_proc computations in
644653
if List.is_empty body then raise @@ Non_virtual 14 else Some (unflat_lines body)
645654
with Non_virtual i ->

lib/train.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -399,7 +399,7 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
399399
let module NTDSL = Operation.NTDSL in
400400
let devices, streams = get_all_suggested_streams ?max_num_streams (module Backend) in
401401
let num_streams = Array.length streams in
402-
let contexts = Array.map streams ~f:(Backend.make_context ?ctx_arrays:None) in
402+
let contexts = Array.map streams ~f:(Backend.make_context ?ctx_arrays:None ?optimize_ctx:None) in
403403
let init_mem = Array.fold devices ~init:0 ~f:(fun acc dev -> acc + Backend.get_used_memory dev) in
404404
let minibatch_size = batch_size / num_streams in
405405
let n_minibatches = data_len / minibatch_size in

test/einsum/einsum_trivia_exec.ml

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -130,23 +130,23 @@ let () =
130130

131131
let hey = TDSL.range_of_shape ~batch_dims:[ 2 ] ~input_dims:[ 3 ] ~output_dims:[ 4 ] () in
132132
let%op ho = hey ++ "...|i->o => ...|o->i" in
133-
let ctx = Train.forward_once backend ho in
133+
let hey_ctx = Train.forward_once backend ho in
134134
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho;
135135
let%op ho2 = hey ++ "b|...->o => o|...->b" in
136-
ignore (Train.forward_once backend ~ctx ho2);
136+
ignore (Train.forward_once backend ~ctx:hey_ctx ho2);
137137
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho2;
138138

139139
let hey2 =
140140
TDSL.range_of_shape ~batch_dims:[ 2; 3 ] ~input_dims:[ 4; 5 ] ~output_dims:[ 6; 7 ] ()
141141
in
142142
let%op ho3 = hey2 ++ "...b|...i->...o => ...i|...o->...b" in
143-
let ctx = Train.forward_once backend ho3 in
143+
let hey2_ctx = Train.forward_once backend ho3 in
144144
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho3;
145145
let%op ho4 = hey2 ++ "...b|...i->...o => i|o->b" in
146-
ignore (Train.forward_once backend ~ctx ho4);
146+
ignore (Train.forward_once backend ~ctx:hey2_ctx ho4);
147147
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho4;
148148
let%op ho5 = hey ++ "...|...->...o => o" in
149-
ignore (Train.forward_once backend ho5);
149+
ignore (Train.forward_once backend ~ctx:hey_ctx ho5);
150150
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ho5;
151151
let hey3 = TDSL.range_of_shape ~output_dims:[ 3; 4 ] () in
152152
let%op ho6 = hey3 ++ "...|...->...o => o" in

test/operations/test_threefry4x32.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,14 @@ let%expect_test "uint4x32_to_prec_uniform different precisions" =
5454
let key = Ocannl.Tensor.get_random_seed () in
5555
let counter = TDSL.range 5 in
5656
let random_bits = O.threefry4x32 key counter in
57+
let ctx = ref None in
5758

5859
(* Test different target precisions *)
5960
let test_precision prec prec_name =
6061
let uniform = O.uint4x32_to_prec_uniform random_bits in
6162
Ir.Tnode.update_prec uniform.value prec;
6263
Ocannl.Train.set_hosted uniform.value;
63-
ignore (Ocannl.Train.forward_once (module Backend) uniform);
64+
ctx := Some (Ocannl.Train.forward_once (module Backend) ?ctx:!ctx uniform);
6465
let result = Ir.Tnode.get_values uniform.value in
6566
Stdio.printf "%s precision - first value: %f, second value: %f\n" prec_name result.(0)
6667
result.(1);
@@ -69,15 +70,14 @@ let%expect_test "uint4x32_to_prec_uniform different precisions" =
6970
in
7071

7172
test_precision Ir.Ops.single "Single";
72-
test_precision Ir.Ops.double "Double";
73+
(* Metal backend doesn't support double precision. *)
74+
(* test_precision Ir.Ops.double "Double"; *)
7375
test_precision Ir.Ops.half "Half";
7476

7577
[%expect
7678
{|
7779
Single precision - first value: 0.756113, second value: 0.758716
7880
All values in [0, 1) range: true
79-
Double precision - first value: 0.756113, second value: 0.758716
80-
All values in [0, 1) range: true
8181
Half precision - first value: 0.756113, second value: 0.758716
8282
All values in [0, 1) range: true
8383
|}]

0 commit comments

Comments
 (0)