Skip to content

Commit 0d30451

Browse files
committed
Fix: attention masks should have empty output dimensions to broadcast to multihead attentions
1 parent d0b8bdf commit 0d30451

File tree

6 files changed

+26
-28
lines changed

6 files changed

+26
-28
lines changed

test/operations/attention_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@ Testing basic multi-head attention
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 210)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 211)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 208)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 209)))))))
99
(bcast Broadcastable) (prov (((sh_id 64) (kind Batch))))))
1010
(input ((dims ()) (bcast Broadcastable) (prov (((sh_id 64) (kind Input))))))
1111
(output
12-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 212)))))))
12+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 210)))))))
1313
(bcast Broadcastable) (prov (((sh_id 64) (kind Output))))))
1414
(batch_padding ()) (input_padding ()) (output_padding ()) (id 64)
1515
(debug_name output))

test/operations/attention_test.ml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ let () =
1919
in
2020

2121
(* Create input tensors *)
22-
2322
let seq =
2423
TDSL.range_of_shape ~label:[ "tgt" ] ~batch_dims:[ batch_size; tgt_seq_len ] ~input_dims:[]
2524
~output_dims:[ tgt_vocab_size ] ()
@@ -29,10 +28,9 @@ let () =
2928
(* Mask should be 0 for positions to mask out, 1 for positions to keep *)
3029
(* This creates an upper triangular matrix where future positions are masked *)
3130
let mask =
32-
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len ] ~i:[ tgt_seq_len ]
33-
~o:[ 1 ]
31+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ tgt_seq_len ] ~i:[ tgt_seq_len ] ~o:[]
3432
~f:(function
35-
| [| _; s; _; t |] -> if s >= t then 1. else 0.
33+
| [| s; t |] -> if s >= t then 1. else 0.
3634
| idcs ->
3735
failwith @@ "Invalid indices length: expected [| _; s; _; t |], got "
3836
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))

test/operations/layer_norm_test.expected

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,13 @@ Testing basic mini decoder model
44
Output shape:
55
((batch
66
((dims
7-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 424)))))
8-
(Dim ((d 8) (label ()) (proj_id ((Proj_id 425)))))))
7+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 422)))))
8+
(Dim ((d 8) (label ()) (proj_id ((Proj_id 423)))))))
99
(bcast Broadcastable) (prov (((sh_id 154) (kind Batch))))))
1010
(input
1111
((dims ()) (bcast Broadcastable) (prov (((sh_id 154) (kind Input))))))
1212
(output
13-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 426)))))))
13+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 424)))))))
1414
(bcast Broadcastable) (prov (((sh_id 154) (kind Output))))))
1515
(batch_padding ()) (input_padding ()) (output_padding ()) (id 154)
1616
(debug_name layer_norm))

test/operations/layer_norm_test.ml

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,9 @@ let () =
4242
(* Mask should be 0 for positions to mask out, 1 for positions to keep *)
4343
(* This creates an upper triangular matrix where future positions are masked *)
4444
let mask =
45-
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len ] ~i:[ tgt_seq_len ]
46-
~o:[ 1 ]
45+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ tgt_seq_len ] ~i:[ tgt_seq_len ] ~o:[]
4746
~f:(function
48-
| [| _; s; _; t |] -> if s >= t then 1. else 0.
47+
| [| s; t |] -> if s >= t then 1. else 0.
4948
| idcs ->
5049
failwith @@ "Invalid indices length: expected [| _; s; _; t |], got "
5150
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))

test/operations/transformer_test.expected

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,20 @@ Loss shape:
77
(input
88
((dims ()) (bcast Broadcastable) (prov (((sh_id 463) (kind Input))))))
99
(output
10-
((dims ((Dim ((d 1) (label ()) (proj_id ((Proj_id 1262)))))))
10+
((dims ((Dim ((d 1) (label ()) (proj_id ((Proj_id 1260)))))))
1111
(bcast Broadcastable) (prov (((sh_id 463) (kind Output))))))
1212
(batch_padding ()) (input_padding ()) (output_padding ()) (id 463)
1313
(debug_name loss))
1414
Logits shape:
1515
((batch
1616
((dims
17-
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1212)))))
18-
(Dim ((d 7) (label ()) (proj_id ((Proj_id 1213)))))))
17+
((Dim ((d 2) (label ()) (proj_id ((Proj_id 1210)))))
18+
(Dim ((d 7) (label ()) (proj_id ((Proj_id 1211)))))))
1919
(bcast Broadcastable) (prov (((sh_id 437) (kind Batch))))))
2020
(input
2121
((dims ()) (bcast Broadcastable) (prov (((sh_id 437) (kind Input))))))
2222
(output
23-
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 1214)))))))
23+
((dims ((Dim ((d 100) (label ()) (proj_id ((Proj_id 1212)))))))
2424
(bcast Broadcastable) (prov (((sh_id 437) (kind Output))))))
2525
(batch_padding ()) (input_padding ()) (output_padding ()) (id 437)
2626
(debug_name transformer))

test/operations/transformer_test.ml

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -33,32 +33,34 @@ let () =
3333
(* For teacher forcing: create shifted versions of target sequence *)
3434
(* tgt_input: positions 0 to tgt_seq_len-2 (all but last) *)
3535
let tgt_input =
36-
TDSL.range_of_shape ~label:[ "tgt_input" ] ~batch_dims:[ batch_size; tgt_seq_len - 1 ]
36+
TDSL.range_of_shape ~label:[ "tgt_input" ]
37+
~batch_dims:[ batch_size; tgt_seq_len - 1 ]
3738
~input_dims:[] ~output_dims:[ tgt_vocab_size ] ()
3839
in
3940

4041
(* tgt_target: positions 1 to tgt_seq_len-1 (all but first) *)
4142
(* In practice, this would be shifted token IDs, here we use one-hot for simplicity *)
4243
let tgt_target =
43-
NTDSL.init ~l:"tgt_target" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len - 1 ] ~i:[]
44-
~o:[ tgt_vocab_size ]
44+
NTDSL.init ~l:"tgt_target" ~prec:Ir.Ops.single
45+
~b:[ batch_size; tgt_seq_len - 1 ]
46+
~i:[] ~o:[ tgt_vocab_size ]
4547
~f:(function
4648
| [| _b; s; v |] ->
4749
(* Create a simple one-hot pattern for testing *)
4850
if v = Int.((s + 1) % tgt_vocab_size) then 1. else 0.
49-
| idcs ->
50-
failwith @@ "Invalid indices: "
51-
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))
51+
| idcs -> failwith @@ "Invalid indices: " ^ Sexp.to_string_hum ([%sexp_of: int array] idcs))
5252
()
5353
in
5454

5555
(* Create a causal mask for the decoder input (shifted target sequence) *)
5656
(* Mask should be 0 for positions to mask out, 1 for positions to keep *)
5757
let mask =
58-
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single ~b:[ batch_size; tgt_seq_len - 1 ]
59-
~i:[ tgt_seq_len - 1 ] ~o:[ 1 ]
58+
NTDSL.init ~l:"mask" ~prec:Ir.Ops.single
59+
~b:[ tgt_seq_len - 1 ]
60+
~i:[ tgt_seq_len - 1 ]
61+
~o:[]
6062
~f:(function
61-
| [| _; s; _; t |] -> if s >= t then 1. else 0.
63+
| [| s; t |] -> if s >= t then 1. else 0.
6264
| idcs ->
6365
failwith @@ "Invalid indices: expected [| _; s; _; t |], got "
6466
^ Sexp.to_string_hum ([%sexp_of: int array] idcs))
@@ -78,7 +80,6 @@ let () =
7880
let _ctx = Ocannl.Train.forward_once ~output_cd_file:false ~bindings ctx loss in
7981

8082
(* Verify shapes *)
81-
Stdio.printf "Loss shape:\n%s\n"
82-
(Sexp.to_string_hum ([%sexp_of: Shape.t] loss.Tensor.shape));
83+
Stdio.printf "Loss shape:\n%s\n" (Sexp.to_string_hum ([%sexp_of: Shape.t] loss.Tensor.shape));
8384
Stdio.printf "Logits shape:\n%s\n%!"
8485
(Sexp.to_string_hum ([%sexp_of: Shape.t] logits.Tensor.shape))

0 commit comments

Comments
 (0)