|
| 1 | +open! Base |
| 2 | +open Ocannl.Operation.DSL_modules |
| 3 | + |
| 4 | +let () = |
| 5 | + (* Basic transformer test *) |
| 6 | + let module Backend = (val Backends.fresh_backend ()) in |
| 7 | + (* Test configuration *) |
| 8 | + let batch_size = 2 in |
| 9 | + let src_seq_len = 10 in |
| 10 | + let tgt_seq_len = 8 in |
| 11 | + let d_model = 64 in |
| 12 | + let num_heads = 4 in |
| 13 | + let d_ff = 128 in |
| 14 | + let src_vocab_size = 100 in |
| 15 | + let tgt_vocab_size = 100 in |
| 16 | + let num_encoder_layers = 2 in |
| 17 | + let num_decoder_layers = 2 in |
| 18 | + |
| 19 | + Stdio.printf "Testing basic transformer model\n"; |
| 20 | + |
| 21 | + (* Create a simple transformer model *) |
| 22 | + let transformer_model = |
| 23 | + Ocannl.Nn_blocks.transformer ~label:[ "test_transformer" ] ~num_encoder_layers |
| 24 | + ~num_decoder_layers ~num_heads ~d_model ~d_ff () |
| 25 | + in |
| 26 | + |
| 27 | + (* Create input tensors *) |
| 28 | + let src = |
| 29 | + TDSL.range_of_shape ~label:[ "src" ] ~batch_dims:[ batch_size; src_seq_len ] ~input_dims:[] |
| 30 | + ~output_dims:[ src_vocab_size ] () |
| 31 | + in |
| 32 | + |
| 33 | + let tgt = |
| 34 | + TDSL.range_of_shape ~label:[ "tgt" ] ~batch_dims:[ batch_size; tgt_seq_len ] ~input_dims:[] |
| 35 | + ~output_dims:[ tgt_vocab_size ] () |
| 36 | + in |
| 37 | + |
| 38 | + (* Create a causal mask for the decoder input (target sequence) *) |
| 39 | + let mask = |
| 40 | + NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len ] ~i:[ tgt_seq_len ] |
| 41 | + ~o:[ 1 ] |
| 42 | + ~f:(function |
| 43 | + | [| _; s; _; t |] -> if s <= t then 1. else 0. |
| 44 | + | idcs -> |
| 45 | + failwith @@ "Invalid indices length: expected [| _; s; _; t |], got " |
| 46 | + ^ Sexp.to_string_hum ([%sexp_of: int array] idcs)) |
| 47 | + () |
| 48 | + in |
| 49 | + |
| 50 | + (* Forward pass *) |
| 51 | + let output = transformer_model ~src ~tgt ~mask in |
| 52 | + |
| 53 | + (* Verify output shape *) |
| 54 | + Stdio.printf "Output shape:\n%s\n%!" |
| 55 | + (Sexp.to_string_hum ([%sexp_of: Shape.t] output.Tensor.shape)) |
| 56 | + |
| 57 | +(* ; |
| 58 | +
|
| 59 | + (* Test transformer components *) Stdio.printf "\nTesting transformer components\n"; |
| 60 | +
|
| 61 | + let d_model = 32 in let num_heads = 2 in let seq_len = 5 in let batch_size = 1 in |
| 62 | +
|
| 63 | + (* Test multi-head attention *) let%op mha = Ocannl.Nn_blocks.multi_head_attention ~label:[ |
| 64 | + "test_mha" ] ~num_heads () in |
| 65 | +
|
| 66 | + let input = Tensor.ndarray ~label:[ "input" ] ~grad_spec:Tensor.Prohibit_grad [| Array.init |
| 67 | + batch_size ~f:(fun _ -> Array.init seq_len ~f:(fun _ -> Array.init d_model ~f:(fun _ -> |
| 68 | + Random.float 1.0))); |] in |
| 69 | +
|
| 70 | + let mha_output = mha input in |
| 71 | +
|
| 72 | + (* Test layer norm *) let%op ln = Nn_blocks.layer_norm ~label:[ "test_ln" ] () in let ln_output = |
| 73 | + ln input in |
| 74 | +
|
| 75 | + (* Test feed forward *) let%op ffn = Nn_blocks.feed_forward ~label:[ "test_ffn" ] ~d_model |
| 76 | + ~d_ff:64 () in let ffn_output = ffn input in |
| 77 | +
|
| 78 | + (* Verify shapes *) Stdio.printf "MHA output shape: %s\n" (Sexp.to_string_hum ([%sexp_of: int |
| 79 | + array] (Tensor.shape mha_output).output_dims)); Stdio.printf "Layer norm output shape: %s\n" |
| 80 | + (Sexp.to_string_hum ([%sexp_of: int array] (Tensor.shape ln_output).output_dims)); Stdio.printf |
| 81 | + "FFN output shape: %s\n" (Sexp.to_string_hum ([%sexp_of: int array] (Tensor.shape |
| 82 | + ffn_output).output_dims)); |
| 83 | +
|
| 84 | + Stdio.printf "\nAll tests completed successfully!\n" *) |
0 commit comments