Skip to content

Commit 6d675f9

Browse files
committed
Broken: first transformer test, just the shape inference (which is broken)
1 parent 919d077 commit 6d675f9

File tree

3 files changed

+104
-0
lines changed

3 files changed

+104
-0
lines changed

test/operations/dune

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -303,6 +303,17 @@
303303
(preprocess
304304
(pps ppx_here ppx_ocannl)))
305305

306+
(test
307+
(name transformer_test)
308+
(package neural_nets_lib)
309+
(deps
310+
ocannl_config
311+
(env_var OCANNL_BACKEND))
312+
(modules transformer_test)
313+
(libraries base ocannl)
314+
(preprocess
315+
(pps ppx_here ppx_ocannl ppx_sexp_conv)))
316+
306317
(library
307318
(name operations_tutorials)
308319
(package neural_nets_lib)
Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
Testing basic transformer model
2+
Output shape: batch_dims=[|2;8|] input_dims=[||] output_dims=[|100|]
3+
4+
Testing transformer components
5+
MHA output shape: [|32|]
6+
Layer norm output shape: [|32|]
7+
FFN output shape: [|32|]
8+
9+
All tests completed successfully!
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
open! Base
2+
open Ocannl.Operation.DSL_modules
3+
4+
let () =
5+
(* Basic transformer test *)
6+
let module Backend = (val Backends.fresh_backend ()) in
7+
(* Test configuration *)
8+
let batch_size = 2 in
9+
let src_seq_len = 10 in
10+
let tgt_seq_len = 8 in
11+
let d_model = 64 in
12+
let num_heads = 4 in
13+
let d_ff = 128 in
14+
let src_vocab_size = 100 in
15+
let tgt_vocab_size = 100 in
16+
let num_encoder_layers = 2 in
17+
let num_decoder_layers = 2 in
18+
19+
Stdio.printf "Testing basic transformer model\n";
20+
21+
(* Create a simple transformer model *)
22+
let transformer_model =
23+
Ocannl.Nn_blocks.transformer ~label:[ "test_transformer" ] ~num_encoder_layers
24+
~num_decoder_layers ~num_heads ~d_model ~d_ff ()
25+
in
26+
27+
(* Create input tensors *)
28+
let src =
29+
TDSL.range_of_shape ~label:[ "src" ] ~batch_dims:[ batch_size; src_seq_len ] ~input_dims:[]
30+
~output_dims:[ src_vocab_size ] ()
31+
in
32+
33+
let tgt =
34+
TDSL.range_of_shape ~label:[ "tgt" ] ~batch_dims:[ batch_size; tgt_seq_len ] ~input_dims:[]
35+
~output_dims:[ tgt_vocab_size ] ()
36+
in
37+
38+
(* Create a causal mask for the decoder input (target sequence) *)
39+
let mask =
40+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len ] ~i:[ tgt_seq_len ]
41+
~o:[ 1 ]
42+
~f:(function
43+
| [| _; s; _; t |] -> if s <= t then 1. else 0.
44+
| idcs ->
45+
failwith @@ "Invalid indices length: expected [| _; s; _; t |], got "
46+
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))
47+
()
48+
in
49+
50+
(* Forward pass *)
51+
let output = transformer_model ~src ~tgt ~mask in
52+
53+
(* Verify output shape *)
54+
Stdio.printf "Output shape:\n%s\n%!"
55+
(Sexp.to_string_hum ([%sexp_of: Shape.t] output.Tensor.shape))
56+
57+
(* ;
58+
59+
(* Test transformer components *) Stdio.printf "\nTesting transformer components\n";
60+
61+
let d_model = 32 in let num_heads = 2 in let seq_len = 5 in let batch_size = 1 in
62+
63+
(* Test multi-head attention *) let%op mha = Ocannl.Nn_blocks.multi_head_attention ~label:[
64+
"test_mha" ] ~num_heads () in
65+
66+
let input = Tensor.ndarray ~label:[ "input" ] ~grad_spec:Tensor.Prohibit_grad [| Array.init
67+
batch_size ~f:(fun _ -> Array.init seq_len ~f:(fun _ -> Array.init d_model ~f:(fun _ ->
68+
Random.float 1.0))); |] in
69+
70+
let mha_output = mha input in
71+
72+
(* Test layer norm *) let%op ln = Nn_blocks.layer_norm ~label:[ "test_ln" ] () in let ln_output =
73+
ln input in
74+
75+
(* Test feed forward *) let%op ffn = Nn_blocks.feed_forward ~label:[ "test_ffn" ] ~d_model
76+
~d_ff:64 () in let ffn_output = ffn input in
77+
78+
(* Verify shapes *) Stdio.printf "MHA output shape: %s\n" (Sexp.to_string_hum ([%sexp_of: int
79+
array] (Tensor.shape mha_output).output_dims)); Stdio.printf "Layer norm output shape: %s\n"
80+
(Sexp.to_string_hum ([%sexp_of: int array] (Tensor.shape ln_output).output_dims)); Stdio.printf
81+
"FFN output shape: %s\n" (Sexp.to_string_hum ([%sexp_of: int array] (Tensor.shape
82+
ffn_output).output_dims));
83+
84+
Stdio.printf "\nAll tests completed successfully!\n" *)

0 commit comments

Comments
 (0)