File tree Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Expand file tree Collapse file tree 1 file changed +15
-0
lines changed Original file line number Diff line number Diff line change @@ -9,3 +9,18 @@ let mlp ~label ~hid_dims () =
99 mlp_layer ~label:((" L" ^ Int.to_string i) :: label) ~hid_dim ())
1010 in
1111 fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
12+
13+ let %op softmax x =
14+ let max_vals = x @^^ " ...|...t->... => ...|...0->..." in
15+ let exp_vals = exp (x - max_vals) in
16+ exp_vals /. (exp_vals ++ " ...|...t->... => ...|...0->..." )
17+
18+ let %op basic_multi_head_attention ~label ~num_heads () x =
19+ let q = { w_q } * x in
20+ let k = { w_k } * x in
21+ let v = { w_v } * x in
22+ let scores = q +* " ...s|h...; ...t|h... => ...|st->h" [ " h" ] k in
23+ Shape.set_dim h num_heads;
24+ let attn_weights = softmax scores in
25+ let attended = attn_weights +* " ...|st->h; ...t|h... => ...s|h..." v in
26+ { w_o } * attended
You can’t perform that action at this time.
0 commit comments