File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change 99 let src_seq_len = 10 in
1010 let tgt_seq_len = 8 in
1111 let d_model = 64 in
12- let num_heads = 6 in
12+ let num_heads = 4 in
1313 let d_ff = 128 in
1414 let src_vocab_size = 100 in
1515 let tgt_vocab_size = 100 in
@@ -36,11 +36,13 @@ let () =
3636 in
3737
3838 (* Create a causal mask for the decoder input (target sequence) *)
39+ (* Mask should be 0 for positions to mask out, 1 for positions to keep *)
40+ (* This creates an upper triangular matrix where future positions are masked *)
3941 let mask =
4042 NTDSL. init ~l: " mask" ~prec: Ir.Ops. single ~b: [ batch_size; tgt_seq_len ] ~i: [ tgt_seq_len ]
4143 ~o: [ 1 ]
4244 ~f: (function
43- | [| _; s; _; t |] -> if s < = t then 1. else 0.
45+ | [| _; s; _; t |] -> if s > = t then 1. else 0.
4446 | idcs ->
4547 failwith @@ " Invalid indices length: expected [| _; s; _; t |], got "
4648 ^ Sexp. to_string_hum ([% sexp_of: int array ] idcs))
You can’t perform that action at this time.
0 commit comments