@@ -58,22 +58,24 @@ let%op softmax ~spec ?(temperature = 1.0) () =
5858 let exp_vals = exp (x_scaled - max_vals) in
5959 exp_vals /. (exp_vals ++ spec)
6060
61- let % op multi_head_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0 ) () ~train_step
62- ?mask x =
61+ let % op multi_head_attention ~label ~num_heads ~d_attention ?temperature ?(dropout_rate = 0.0 ) ()
62+ ~train_step ?mask x =
6363 let q = { w_q } * x in
6464 let k = { w_k } * x in
6565 let v = { w_v } * x in
66- (* Works with arbitrary number of model axes via `..d..` (row variable syntax). *)
6766 let scores =
68- (q +* " ... s | h ..d.. ; ... t | h ..d.. => ... s | t -> h " [ " h" ; " d" ] k) /. sqrt (dim d)
67+ (q +* " ... s | h d ; ... t | h d => ... s | t -> h" [ " h" ; " d" ] k) /. sqrt (dim d)
6968 in
7069 Shape. set_dim h num_heads;
70+ (* NOTE: often d_attention = d_model / num_heads, but we allow for other values. *)
71+ Shape. set_dim d d_attention;
7172 (* We don't need to lift [softmax ~spec ()] because it doesn't introduce any new params. *)
7273 let attn_weights =
7374 softmax ~spec: " ... | t -> ..." ?temperature ()
7475 (match mask with None -> scores | Some mask -> where mask scores ! .(- 1e9 ))
7576 in
7677 let attn_weights = dropout ~rate: dropout_rate () ~train_step attn_weights in
78+ (* w_o output shape will automatically be set to the model dimension(s) by shape inference. *)
7779 { w_o } * (attn_weights +* " ... s | t -> h; ... t | h ... => ... s | h ... " v)
7880
7981let % op layer_norm ~label ?(epsilon = 1e-5 ) () x =
@@ -85,8 +87,8 @@ let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8587 (* gamma and beta are learned, but initialized to good defaults *)
8688 ({ gamma = 1. } *. normalized) + { beta = 0. }
8789
88- let % op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5 ) () =
89- let mha = multi_head_attention ~label: (label @ [ " mha" ]) ~num_heads () in
90+ let % op transformer_encoder_block ~label ~num_heads ~d_attention ~ d_ff ?(epsilon = 1e-5 ) () =
91+ let mha = multi_head_attention ~label: (label @ [ " mha" ]) ~num_heads ~d_attention () in
9092 (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
9193 let ffn = mlp ~label: (label @ [ " ffn" ]) ~hid_dims: [ d_ff ] () in
9294 let ln1 = layer_norm ~label: (label @ [ " ln1" ]) ~epsilon () in
@@ -95,22 +97,25 @@ let%op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
9597 let x1 = ln1 (input + mha ~train_step input) in
9698 ln2 (x1 + ffn x1)
9799
98- let % op cross_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0 ) () ~train_step x
99- ~enc_output =
100+ let % op cross_attention ~label ~num_heads ~d_attention ?temperature ?(dropout_rate = 0.0 ) ()
101+ ~train_step x ~ enc_output =
100102 let q = { w_q } * x in
101103 let k = { w_k } * enc_output in
102104 let v = { w_v } * enc_output in
103105 let scores =
104- (q +* " ... s | h ..d.. ; ... t | h ..d.. => ... | s t -> h " [ " h" ; " d" ] k) /. sqrt (dim d)
106+ (q +* " ... s | h d ; ... t | h d => ... s | t -> h " [ " h" ; " d" ] k) /. sqrt (dim d)
105107 in
106108 Shape. set_dim h num_heads;
107- let attn_weights = softmax ~spec: " ... | ... t -> ..." ?temperature () scores in
109+ Shape. set_dim d d_attention;
110+ let attn_weights = softmax ~spec: " ... | t -> ..." ?temperature () scores in
108111 let attn_weights = dropout ~rate: dropout_rate () ~train_step attn_weights in
109- { w_o } * (attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v)
112+ { w_o } * (attn_weights +* " ... s | t -> h; ... t | h ... => ... s | h ... " v)
110113
111- let % op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5 ) () =
112- let masked_mha = multi_head_attention ~label: (label @ [ " masked_mha" ]) ~num_heads () in
113- let cross_mha = cross_attention ~label: (label @ [ " cross_mha" ]) ~num_heads () in
114+ let % op transformer_decoder_block ~label ~num_heads ~d_attention ~d_ff ?(epsilon = 1e-5 ) () =
115+ let masked_mha =
116+ multi_head_attention ~label: (label @ [ " masked_mha" ]) ~num_heads ~d_attention ()
117+ in
118+ let cross_mha = cross_attention ~label: (label @ [ " cross_mha" ]) ~num_heads ~d_attention () in
114119 (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
115120 let ffn = mlp ~label: (label @ [ " ffn" ]) ~hid_dims: [ d_ff ] () in
116121 let ln1 = layer_norm ~label: (label @ [ " ln1" ]) ~epsilon () in
@@ -121,21 +126,21 @@ let%op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
121126 let x2 = ln2 (x1 + cross_mha ~train_step x1 ~enc_output ) in
122127 ln3 (x2 + ffn x2)
123128
124- let transformer_encoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5 ) () =
129+ let transformer_encoder ~label ~num_layers ~num_heads ~d_attention ~ d_ff ?(epsilon = 1e-5 ) () =
125130 let layers =
126131 List. init num_layers ~f: (fun i ->
127132 transformer_encoder_block
128133 ~label: (label @ [ " layer" ^ Int. to_string i ])
129- ~num_heads ~d_ff ~epsilon () )
134+ ~num_heads ~d_attention ~ d_ff ~epsilon () )
130135 in
131136 fun ~train_step x -> List. fold layers ~init: x ~f: (fun x layer -> layer ~train_step x)
132137
133- let transformer_decoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5 ) () =
138+ let transformer_decoder ~label ~num_layers ~num_heads ~d_attention ~ d_ff ?(epsilon = 1e-5 ) () =
134139 let layers =
135140 List. init num_layers ~f: (fun i ->
136141 transformer_decoder_block
137142 ~label: (label @ [ " layer" ^ Int. to_string i ])
138- ~num_heads ~d_ff ~epsilon () )
143+ ~num_heads ~d_attention ~ d_ff ~epsilon () )
139144 in
140145 fun ~train_step target ~enc_output ~mask ->
141146 List. fold layers ~init: target ~f: (fun x layer -> layer ~train_step x ~enc_output ~mask )
@@ -144,11 +149,11 @@ let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_
144149 ?(epsilon = 1e-5 ) () =
145150 let encoder =
146151 transformer_encoder ~label: (label @ [ " encoder" ]) ~num_layers: num_encoder_layers ~num_heads
147- ~d_ff ~epsilon ()
152+ ~d_attention: (d_model / num_heads) ~ d_ff ~epsilon ()
148153 in
149154 let decoder =
150155 transformer_decoder ~label: (label @ [ " decoder" ]) ~num_layers: num_decoder_layers ~num_heads
151- ~d_ff ~epsilon ()
156+ ~d_attention: (d_model / num_heads) ~ d_ff ~epsilon ()
152157 in
153158 (* All inline definitions, including for ds, dt, are lifted up to the unit parameter above. *)
154159 Shape. set_dim ds d_model;
0 commit comments