Skip to content

Commit 11b3d1e

Browse files
committed
Transformer mask fix, by Claude
1 parent 352ae42 commit 11b3d1e

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

test/operations/transformer_test.ml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ let () =
99
let src_seq_len = 10 in
1010
let tgt_seq_len = 8 in
1111
let d_model = 64 in
12-
let num_heads = 6 in
12+
let num_heads = 4 in
1313
let d_ff = 128 in
1414
let src_vocab_size = 100 in
1515
let tgt_vocab_size = 100 in
@@ -36,11 +36,13 @@ let () =
3636
in
3737

3838
(* Create a causal mask for the decoder input (target sequence) *)
39+
(* Mask should be 0 for positions to mask out, 1 for positions to keep *)
40+
(* This creates an upper triangular matrix where future positions are masked *)
3941
let mask =
4042
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len ] ~i:[ tgt_seq_len ]
4143
~o:[ 1 ]
4244
~f:(function
43-
| [| _; s; _; t |] -> if s <= t then 1. else 0.
45+
| [| _; s; _; t |] -> if s >= t then 1. else 0.
4446
| idcs ->
4547
failwith @@ "Invalid indices length: expected [| _; s; _; t |], got "
4648
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))

0 commit comments

Comments
 (0)