Skip to content

Commit d27b49a

Browse files
committed
Fix attention: it has a hidden dimension because of w_o
1 parent b93d4fa commit d27b49a

File tree

1 file changed

+25
-20
lines changed

1 file changed

+25
-20
lines changed

lib/nn_blocks.ml

Lines changed: 25 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -58,22 +58,24 @@ let%op softmax ~spec ?(temperature = 1.0) () =
5858
let exp_vals = exp (x_scaled - max_vals) in
5959
exp_vals /. (exp_vals ++ spec)
6060

61-
let%op multi_head_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () ~train_step
62-
?mask x =
61+
let%op multi_head_attention ~label ~num_heads ~d_attention ?temperature ?(dropout_rate = 0.0) ()
62+
~train_step ?mask x =
6363
let q = { w_q } * x in
6464
let k = { w_k } * x in
6565
let v = { w_v } * x in
66-
(* Works with arbitrary number of model axes via `..d..` (row variable syntax). *)
6766
let scores =
68-
(q +* " ... s | h ..d..; ... t | h ..d.. => ... s | t -> h " [ "h"; "d" ] k) /. sqrt (dim d)
67+
(q +* " ... s | h d; ... t | h d => ... s | t -> h" [ "h"; "d" ] k) /. sqrt (dim d)
6968
in
7069
Shape.set_dim h num_heads;
70+
(* NOTE: often d_attention = d_model / num_heads, but we allow for other values. *)
71+
Shape.set_dim d d_attention;
7172
(* We don't need to lift [softmax ~spec ()] because it doesn't introduce any new params. *)
7273
let attn_weights =
7374
softmax ~spec:" ... | t -> ..." ?temperature ()
7475
(match mask with None -> scores | Some mask -> where mask scores !.(-1e9))
7576
in
7677
let attn_weights = dropout ~rate:dropout_rate () ~train_step attn_weights in
78+
(* w_o output shape will automatically be set to the model dimension(s) by shape inference. *)
7779
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h ... => ... s | h ... " v)
7880

7981
let%op layer_norm ~label ?(epsilon = 1e-5) () x =
@@ -85,8 +87,8 @@ let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8587
(* gamma and beta are learned, but initialized to good defaults *)
8688
({ gamma = 1. } *. normalized) + { beta = 0. }
8789

88-
let%op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
89-
let mha = multi_head_attention ~label:(label @ [ "mha" ]) ~num_heads () in
90+
let%op transformer_encoder_block ~label ~num_heads ~d_attention ~d_ff ?(epsilon = 1e-5) () =
91+
let mha = multi_head_attention ~label:(label @ [ "mha" ]) ~num_heads ~d_attention () in
9092
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
9193
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
9294
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
@@ -95,22 +97,25 @@ let%op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
9597
let x1 = ln1 (input + mha ~train_step input) in
9698
ln2 (x1 + ffn x1)
9799

98-
let%op cross_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () ~train_step x
99-
~enc_output =
100+
let%op cross_attention ~label ~num_heads ~d_attention ?temperature ?(dropout_rate = 0.0) ()
101+
~train_step x ~enc_output =
100102
let q = { w_q } * x in
101103
let k = { w_k } * enc_output in
102104
let v = { w_v } * enc_output in
103105
let scores =
104-
(q +* " ... s | h ..d..; ... t | h ..d.. => ... | s t -> h " [ "h"; "d" ] k) /. sqrt (dim d)
106+
(q +* " ... s | h d; ... t | h d => ... s | t -> h " [ "h"; "d" ] k) /. sqrt (dim d)
105107
in
106108
Shape.set_dim h num_heads;
107-
let attn_weights = softmax ~spec:" ... | ... t -> ..." ?temperature () scores in
109+
Shape.set_dim d d_attention;
110+
let attn_weights = softmax ~spec:" ... | t -> ..." ?temperature () scores in
108111
let attn_weights = dropout ~rate:dropout_rate () ~train_step attn_weights in
109-
{ w_o } * (attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v)
112+
{ w_o } * (attn_weights +* " ... s | t -> h; ... t | h ... => ... s | h ... " v)
110113

111-
let%op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
112-
let masked_mha = multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads () in
113-
let cross_mha = cross_attention ~label:(label @ [ "cross_mha" ]) ~num_heads () in
114+
let%op transformer_decoder_block ~label ~num_heads ~d_attention ~d_ff ?(epsilon = 1e-5) () =
115+
let masked_mha =
116+
multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads ~d_attention ()
117+
in
118+
let cross_mha = cross_attention ~label:(label @ [ "cross_mha" ]) ~num_heads ~d_attention () in
114119
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
115120
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
116121
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
@@ -121,21 +126,21 @@ let%op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
121126
let x2 = ln2 (x1 + cross_mha ~train_step x1 ~enc_output) in
122127
ln3 (x2 + ffn x2)
123128

124-
let transformer_encoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () =
129+
let transformer_encoder ~label ~num_layers ~num_heads ~d_attention ~d_ff ?(epsilon = 1e-5) () =
125130
let layers =
126131
List.init num_layers ~f:(fun i ->
127132
transformer_encoder_block
128133
~label:(label @ [ "layer" ^ Int.to_string i ])
129-
~num_heads ~d_ff ~epsilon ())
134+
~num_heads ~d_attention ~d_ff ~epsilon ())
130135
in
131136
fun ~train_step x -> List.fold layers ~init:x ~f:(fun x layer -> layer ~train_step x)
132137

133-
let transformer_decoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () =
138+
let transformer_decoder ~label ~num_layers ~num_heads ~d_attention ~d_ff ?(epsilon = 1e-5) () =
134139
let layers =
135140
List.init num_layers ~f:(fun i ->
136141
transformer_decoder_block
137142
~label:(label @ [ "layer" ^ Int.to_string i ])
138-
~num_heads ~d_ff ~epsilon ())
143+
~num_heads ~d_attention ~d_ff ~epsilon ())
139144
in
140145
fun ~train_step target ~enc_output ~mask ->
141146
List.fold layers ~init:target ~f:(fun x layer -> layer ~train_step x ~enc_output ~mask)
@@ -144,11 +149,11 @@ let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_
144149
?(epsilon = 1e-5) () =
145150
let encoder =
146151
transformer_encoder ~label:(label @ [ "encoder" ]) ~num_layers:num_encoder_layers ~num_heads
147-
~d_ff ~epsilon ()
152+
~d_attention:(d_model / num_heads) ~d_ff ~epsilon ()
148153
in
149154
let decoder =
150155
transformer_decoder ~label:(label @ [ "decoder" ]) ~num_layers:num_decoder_layers ~num_heads
151-
~d_ff ~epsilon ()
156+
~d_attention:(d_model / num_heads) ~d_ff ~epsilon ()
152157
in
153158
(* All inline definitions, including for ds, dt, are lifted up to the unit parameter above. *)
154159
Shape.set_dim ds d_model;

0 commit comments

Comments
 (0)