Skip to content

Commit f28a099

Browse files
committed
Merge branch 'master' of https://github.com/ahrefs/ocannl
2 parents 427b946 + 5d5c4c8 commit f28a099

File tree

6 files changed

+89
-11
lines changed

6 files changed

+89
-11
lines changed

arrayjit/lib/metal_backend.ml

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ module Device_config = struct
4040
type runner = {
4141
queue : Me.CommandQueue.t;
4242
event : Me.SharedEvent.t; (* Use SharedEvent for signalling *)
43-
counter : ullong; (* Next value to signal *)
43+
mutable counter : ullong; (* Next value to signal *)
4444
}
4545
[@@deriving sexp_of]
4646

@@ -212,6 +212,7 @@ end) : Ir.Backend_impl.Lowered_backend = struct
212212
let shared_event = stream.runner.event in
213213
let counter = stream.runner.counter in
214214
let next_value = Unsigned.ULLong.add counter Unsigned.ULLong.one in
215+
stream.runner.counter <- next_value;
215216
let command_buffer = Me.CommandBuffer.on_queue queue in
216217
Me.CommandBuffer.encode_signal_event command_buffer
217218
(Me.SharedEvent.super shared_event)
@@ -220,10 +221,10 @@ end) : Ir.Backend_impl.Lowered_backend = struct
220221
{ shared = shared_event; value = next_value }
221222

222223
let await stream =
223-
let queue = stream.runner.queue in
224-
let command_buffer = Me.CommandBuffer.on_queue queue in
225-
Me.CommandBuffer.commit command_buffer;
226-
Me.CommandBuffer.wait_until_completed command_buffer;
224+
(* Signal an event after all current work and wait for it.
225+
This ensures all previously submitted command buffers complete. *)
226+
let event = all_work stream in
227+
sync event;
227228
(* Process captured logs if any *)
228229
if Utils.debug_log_from_routines () then
229230
match Hashtbl.find stream_logs stream.stream_id with

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,17 @@
377377
(preprocess
378378
(pps ppx_here ppx_ocannl ppx_expect)))
379379

380+
(test
381+
(name small_literal_tensor)
382+
(package neural_nets_lib)
383+
(deps
384+
ocannl_config
385+
(env_var OCANNL_BACKEND))
386+
(modules small_literal_tensor)
387+
(libraries base ocannl stdio)
388+
(preprocess
389+
(pps ppx_here ppx_ocannl ppx_expect)))
390+
380391
(test
381392
(name zero2hero_1of7_exec)
382393
(package neural_nets_lib)
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
Retrieving commandline, environment, or config file variable ocannl_log_level
2+
Found 0, in the config file
3+
HERE: test/operations/small_literal_tensor.ml:11:21
4+
[0]: 1,2,3,4,5..._hey shape 1:3->0:2 [
5+
1.00 , 2.00 , 3.00
6+
; 4.00 , 5.00 , 6.00
7+
]
8+
HERE: test/operations/small_literal_tensor.ml:14:21
9+
[1]: 1,2,3,4,5..._hoo shape 0:2|1:3 [|
10+
[ 1.00 ; 2.00 ; 3.00 ]
11+
; [ 4.00 ; 5.00 ; 6.00 ]
12+
|]
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
open! Base
2+
open! Ocannl
3+
open! Nn_blocks.DSL_modules
4+
5+
let () =
6+
Tensor.unsafe_reinitialize ();
7+
let ctx = Context.auto () in
8+
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
9+
let ctx = Train.forward_once ctx hey in
10+
(* ignore (failwith @@ Tn.debug_memory_mode hey.value.memory_mode); *)
11+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline hey;
12+
let%op hoo = [| [ 1; 2; 3 ]; [ 4; 5; 6 ] |] in
13+
let _ctx = Train.forward_once ctx hoo in
14+
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline hoo

test/operations/test_param_shape_error.expected

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,18 @@
11
Retrieving commandline, environment, or config file variable ocannl_log_level
22
Found 0, in the config file
3-
Testing default parameter 1D -- corner case
3+
Testing default lone parameter -- corner case
4+
┌───────────────────────────────────────┐
5+
│[5]: w_o shape 0:4 │
6+
│┌┬────────────────────────────────────┐│
7+
│││axis 0 ││
8+
│├┼────────────────────────────────────┤│
9+
│││ 8.38e-1 1.38e-1 2.61e-1 7.74e-1 ││
10+
│└┴────────────────────────────────────┘│
11+
└───────────────────────────────────────┘
12+
grad_w_o <not-hosted>
13+
14+
ERROR: Should have raised an exception
15+
Testing lone parameter 1D -- corner case
416
┌───────────────────┐
517
│[5]: w_o shape 0:1 │
618
│┌┬─────────┐ │
@@ -25,3 +37,7 @@ Testing default affine operation with propagated dimensions
2537
grad_w <not-hosted>
2638
Testing default affine operation with unknown input dimensions
2739
Got expected error: You forgot to specify the hidden dimension(s) 1
40+
Testing default bias parameter
41+
Got expected error: You forgot to specify the hidden dimension(s) 4
42+
Testing bias parameter 1D
43+
Got expected error: You forgot to specify the hidden dimension(s) 1

test/operations/test_param_shape_error.ml

Lines changed: 29 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,8 @@ let default_lone_param () =
1313
Stdio.print_endline "\nERROR: Should have raised an exception"
1414
with Row.Shape_error (msg, _) -> Stdio.printf "Got acceptable error: %s\n" msg
1515

16-
let default_param_1d () =
17-
Stdio.printf "Testing default parameter 1D -- corner case\n";
16+
let lone_param_1d () =
17+
Stdio.printf "Testing lone parameter 1D -- corner case\n";
1818
Tensor.unsafe_reinitialize ();
1919
(* This should raise an error because we have a parameter with unspecified dimensions *)
2020
try
@@ -57,9 +57,33 @@ let default_affine_op_unknown_input () =
5757
()
5858
with Row.Shape_error (msg, _) -> Stdio.printf "Got expected error: %s\n" msg
5959

60+
let default_bias_param () =
61+
Stdio.printf "Testing default bias parameter\n";
62+
Tensor.unsafe_reinitialize ();
63+
(* This should raise an error because we have a parameter with unspecified dimensions *)
64+
try
65+
let%op w_o = { x } + { y } in
66+
let _ctx : Context.t = Train.init_params (Context.auto ()) Train.IDX.empty w_o in
67+
Train.printf w_o;
68+
Stdio.print_endline "\nERROR: Should have raised an exception"
69+
with Row.Shape_error (msg, _) -> Stdio.printf "Got expected error: %s\n" msg
70+
71+
let default_bias_param_1d () =
72+
Stdio.printf "Testing bias parameter 1D\n";
73+
Tensor.unsafe_reinitialize ();
74+
(* This should raise an error because we have a parameter with unspecified dimensions *)
75+
try
76+
let%op w_o = { x = uniform1 () } + { y = uniform1 () } in
77+
let _ctx : Context.t = Train.init_params (Context.auto ()) Train.IDX.empty w_o in
78+
Train.printf w_o;
79+
()
80+
with Row.Shape_error (msg, _) -> Stdio.printf "Got expected error: %s\n" msg
81+
6082
let () =
61-
ignore default_lone_param;
62-
default_param_1d ();
83+
default_lone_param ();
84+
lone_param_1d ();
6385
default_linear_op ();
6486
default_affine_op_propagated ();
65-
default_affine_op_unknown_input ()
87+
default_affine_op_unknown_input ();
88+
default_bias_param ();
89+
default_bias_param_1d ()

0 commit comments

Comments
 (0)