Skip to content

Commit d68486f

Browse files
committed
Final transformer fixes: bug underspecifying attention w_o inputs
This will all be nicely caught by fix propsed in: Implement shape errors "You forgot to specify the hidden dimension(s)" via a `Param` subset of `Terminal` #396
1 parent 60957d1 commit d68486f

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

lib/nn_blocks.ml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ let%op multi_head_attention ~label ~num_heads ~d_k ~d_v ?temperature ?(dropout_r
7777
in
7878
let attn_weights = dropout ~rate:dropout_rate () ~train_step attn_weights in
7979
(* w_o output shape will automatically be set to the model dimension(s) by shape inference. *)
80-
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h e => ... s | h ..." [ "e" ] v)
80+
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h e => ... s | h e" [ "e" ] v)
8181

8282
let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8383
let mean = x ++ " ... | ..d.. => ... | 0 " [ "d" ] in
@@ -111,7 +111,7 @@ let%op cross_attention ~label ~num_heads ~d_k ~d_v ?temperature ?(dropout_rate =
111111
Shape.set_dim e d_v;
112112
let attn_weights = softmax ~spec:" ... | t -> ..." ?temperature () scores in
113113
let attn_weights = dropout ~rate:dropout_rate () ~train_step attn_weights in
114-
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h e => ... s | h ..." [ "e" ] v)
114+
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h e => ... s | h e" [ "e" ] v)
115115

116116
let%op transformer_decoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5) () =
117117
let masked_mha = multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads ~d_k ~d_v () in

0 commit comments

Comments
 (0)