Skip to content

Commit 1f80072

Browse files
committed
Update test expectations
1 parent e25b8e6 commit 1f80072

File tree

6 files changed

+49
-32
lines changed

6 files changed

+49
-32
lines changed
Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
Testing mode detection:
2+
✓ 'abc' -> single-char (expected single-char)
3+
✓ 'a,b,c' -> multichar (expected multichar)
4+
✓ '2*a+b' -> multichar (expected multichar)
5+
✓ 'a+b' -> multichar (expected multichar)
6+
✓ 'a*b' -> multichar (expected multichar)
7+
✓ 'a^b' -> multichar (expected multichar)
8+
✓ 'a&b' -> multichar (expected multichar)
9+
✓ 'a|b->c' -> single-char (expected single-char)
10+
✓ '...a..b' -> single-char (expected single-char)
11+
12+
Testing single-char mode:
13+
'abc' -> 3 output axes
14+
'b|i->o' -> batch:1 input:1 output:1
15+
'ij;jk=>ik' -> (0,2);(0,2)=>(0,2)
16+
17+
Testing multichar mode:
18+
'a, b, c' -> 3 output axes
19+
'a, b,' -> 2 output axes
20+
'2*o+k' -> 1 output axes
21+
'2*o+3*k, x' -> 2 output axes
22+
23+
All tests passed!

test/einsum/test_interleave.ml

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,19 @@
1-
open Base
2-
open Ocannl
3-
open Nn_blocks.DSL_modules
1+
open! Base
2+
open! Ocannl
3+
open! Nn_blocks.DSL_modules
44

55
let () =
6+
(* FIXME: NOT IMPLEMENTED YET *)
67
Tensor.unsafe_reinitialize ();
7-
let%op t1 = [ 1.0; 2.0; 3.0 ] in
8+
9+
(* let%op t1 = [ 1.0; 2.0; 3.0 ] in
810
let%op t2 = [ 4.0; 5.0; 6.0 ] in
9-
let t3 = Operation.interleave t1 t2 () in
11+
let t3 = Operation.interleave t1 t2 () in *)
12+
(* let%op t3 = interleave [ 1.0; 2.0; 3.0 ] [ 4.0; 5.0; 6.0 ] in *)
1013

1114
(* t3 should be [1.0; 4.0; 2.0; 5.0; 3.0; 6.0] *)
12-
let ctx = Context.auto () in
15+
(* let ctx = Context.auto () in *)
1316

14-
try
15-
let _ctx = Train.forward_once ctx t3 in
16-
Stdio.printf "Test failed! Expected error was not raised.\n";
17-
Stdlib.exit 1
18-
with Utils.User_error msg ->
19-
if
20-
String.equal msg
21-
"Defined_by_cd_logic: use explicit ~logic annotations when defining this operation"
22-
then Stdio.printf "Test passed! Caught expected error: %s\n" msg
23-
else (
24-
Stdio.printf "Test failed! Caught unexpected error: %s\n" msg;
25-
Stdlib.exit 1)
17+
(* let _ctx = Train.forward_once ctx t3 in *)
18+
(* Train.printf ~here:[%here] ~with_code:false ~with_grad:false t3 *)
19+
()

test/operations/attention_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ Testing basic multi-head attention
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 204)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 205)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 195)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 196)))))))
99
(bcast Broadcastable) (prov (((sh_id 64) (kind Batch))))))
1010
(input ((dims ()) (bcast Broadcastable) (prov (((sh_id 64) (kind Input))))))
1111
(output
12-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 206)))))))
12+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 197)))))))
1313
(bcast Broadcastable) (prov (((sh_id 64) (kind Output))))))
1414
(batch_padding ()) (input_padding ()) (output_padding ()) (id 64)
1515
(debug_name output))

test/operations/layer_norm_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ Testing basic mini decoder model
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 418)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 419)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 409)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 410)))))))
99
(bcast Broadcastable) (prov (((sh_id 154) (kind Batch))))))
1010
(input
1111
((dims ()) (bcast Broadcastable) (prov (((sh_id 154) (kind Input))))))
1212
(output
13-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 420)))))))
13+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 411)))))))
1414
(bcast Broadcastable) (prov (((sh_id 154) (kind Output))))))
1515
(batch_padding ()) (input_padding ()) (output_padding ()) (id 154)
1616
(debug_name layer_norm))

test/operations/transformer_test.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@ Loss shape:
77
(input
88
((dims ()) (bcast Broadcastable) (prov (((sh_id 463) (kind Input))))))
99
(output
10-
((dims ((Dim ((d 1) (label ()) (proj_id ((Proj_id 1252)))))))
10+
((dims ((Dim ((d 1) (label ()) (proj_id ((Proj_id 1221)))))))
1111
(bcast Broadcastable) (prov (((sh_id 463) (kind Output))))))
1212
(batch_padding ()) (input_padding ()) (output_padding ()) (id 463)
1313
(debug_name loss))
1414
Logits shape:
1515
((batch
1616
((dims
17-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1202)))))
18-
(Dim ((d 7) (label ()) (proj_id ((Proj_id 1203)))))))
17+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1171)))))
18+
(Dim ((d 7) (label ()) (proj_id ((Proj_id 1172)))))))
1919
(bcast Broadcastable) (prov (((sh_id 437) (kind Batch))))))
2020
(input
2121
((dims ()) (bcast Broadcastable) (prov (((sh_id 437) (kind Input))))))
2222
(output
23-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 1204)))))))
23+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 1173)))))))
2424
(bcast Broadcastable) (prov (((sh_id 437) (kind Output))))))
2525
(batch_padding ()) (input_padding ()) (output_padding ()) (id 437)
2626
(debug_name transformer))

test/operations/zero2hero_1of7_exec.ml

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,9 +68,9 @@ let graph_drawing_fetch () =
6868
let ys, dys =
6969
Array.unzip
7070
@@ Array.mapi xs ~f:(fun i _ ->
71-
step_ref := i;
72-
Train.run ctx fx_routine;
73-
(fx.@[0], x.@%[0]))
71+
step_ref := i;
72+
Train.run ctx fx_routine;
73+
(fx.@[0], x.@%[0]))
7474
in
7575
(* It is fine to loop around the data: it's "next epoch". We redo the work though. *)
7676
let plot_box =
@@ -178,4 +178,4 @@ let main () =
178178
let () = two_d_neuron_virtual () in
179179
()
180180

181-
let () = main ()
181+
let () = main ()

0 commit comments

Comments
 (0)