Skip to content

Commit 143719c

Browse files
committed
Untested: basic vanilla multi-head attention (no normalization, no dropout)
1 parent fddc88d commit 143719c

File tree

1 file changed

+15
-0
lines changed

1 file changed

+15
-0
lines changed

lib/nn_blocks.ml

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,18 @@ let mlp ~label ~hid_dims () =
99
mlp_layer ~label:(("L" ^ Int.to_string i) :: label) ~hid_dim ())
1010
in
1111
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
12+
13+
let%op softmax x =
14+
let max_vals = x @^^ "...|...t->... => ...|...0->..." in
15+
let exp_vals = exp (x - max_vals) in
16+
exp_vals /. (exp_vals ++ "...|...t->... => ...|...0->...")
17+
18+
let%op basic_multi_head_attention ~label ~num_heads () x =
19+
let q = { w_q } * x in
20+
let k = { w_k } * x in
21+
let v = { w_v } * x in
22+
let scores = q +* "...s|h...; ...t|h... => ...|st->h" [ "h" ] k in
23+
Shape.set_dim h num_heads;
24+
let attn_weights = softmax scores in
25+
let attended = attn_weights +* "...|st->h; ...t|h... => ...s|h..." v in
26+
{ w_o } * attended

0 commit comments

Comments
 (0)