|
| 1 | +(** This file contains basic building blocks for neural networks, with limited functionality. Feel |
| 2 | + free to copy-paste and modify as needed. |
| 3 | +
|
| 4 | + We follow "the principle of least commitment": where possible, we use row variables to remain |
| 5 | + agnostic to the number of axes. This flexibility often remains unused, but it makes explicit the |
| 6 | + architectural structure. *) |
| 7 | + |
1 | 8 | open! Base |
2 | 9 | open Operation.DSL_modules |
| 10 | +module Tn = Ir.Tnode |
3 | 11 |
|
4 | 12 | let%op mlp_layer ~label ~hid_dim () x = relu (({ w = uniform () } * x) + { b = 0.; o = [ hid_dim ] }) |
5 | 13 |
|
6 | | -let mlp ~label ~hid_dims () = |
| 14 | +(** Set rate=0.0 during inference. *) |
| 15 | +let%op dropout ~rate () x = |
| 16 | + if Float.(rate <= 0.0) then x |
| 17 | + else |
| 18 | + let keep_prob = 1.0 - !.rate in |
| 19 | + let mask = !.rate < uniform () *. x in |
| 20 | + (* Creates 0/1 mask *) |
| 21 | + (* Scale by 1/keep_prob to maintain expected value *) |
| 22 | + x *. mask /. keep_prob |
| 23 | + |
| 24 | +(** Multi-layer perceptron of depth [List.length hid_dims + 1], with a linear output layer. *) |
| 25 | +let%op mlp ~label ~hid_dims () = |
7 | 26 | let layers = |
8 | 27 | List.mapi hid_dims ~f:(fun i hid_dim -> |
9 | 28 | mlp_layer ~label:(("L" ^ Int.to_string i) :: label) ~hid_dim ()) |
10 | 29 | in |
11 | | - fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x) |
| 30 | + fun x -> |
| 31 | + let hidden = List.fold layers ~init:x ~f:(fun x layer -> layer x) in |
| 32 | + { w_out } * hidden |
12 | 33 |
|
13 | | -let%op softmax x = |
14 | | - let max_vals = x @^^ "...|...t->... => ...|...0->..." in |
15 | | - let exp_vals = exp (x - max_vals) in |
16 | | - exp_vals /. (exp_vals ++ "...|...t->... => ...|...0->...") |
| 34 | +let reduce_specified_axes spec = |
| 35 | + let lhs = |
| 36 | + if String.contains spec ',' then |
| 37 | + Str.global_replace (Str.regexp "[A-Za-z][A-Za-z_0-9]*") "0" spec |
| 38 | + else Str.global_replace (Str.regexp "[A-Za-z]") "0" spec |
| 39 | + in |
| 40 | + spec ^ " => " ^ lhs |
| 41 | + |
| 42 | +(** Softmax across specified axes. Does not support non-default row variables. *) |
| 43 | +let%op softmax ~spec ?(temperature = 1.0) () = |
| 44 | + let spec = reduce_specified_axes spec in |
| 45 | + fun x -> |
| 46 | + let x_scaled = if Float.(temperature <> 1.0) then x /. !.temperature else x in |
| 47 | + let max_vals = x_scaled @^^ spec in |
| 48 | + let exp_vals = exp (x_scaled - max_vals) in |
| 49 | + exp_vals /. (exp_vals ++ spec) |
17 | 50 |
|
18 | | -let%op basic_multi_head_attention ~label ~num_heads () x = |
| 51 | +let%op multi_head_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () ?mask x = |
19 | 52 | let q = { w_q } * x in |
20 | 53 | let k = { w_k } * x in |
21 | 54 | let v = { w_v } * x in |
22 | | - let scores = q +* "...s|h...; ...t|h... => ...|st->h" [ "h" ] k in |
| 55 | + (* Works with arbitrary number of model axes via `..d..` (row variable syntax). *) |
| 56 | + let scores = |
| 57 | + (q +* " ... s | h ..d..; ... t | h ..d.. => ... | s t -> h " [ "h"; "d" ] k) /. sqrt (dim d) |
| 58 | + in |
| 59 | + Shape.set_dim h num_heads; |
| 60 | + (* We don't need to lift [softmax ~spec ()] because it doesn't introduce any new params. *) |
| 61 | + let attn_weights = |
| 62 | + softmax ~spec:" ... | ... t -> ..." ?temperature () |
| 63 | + (match mask with None -> scores | Some mask -> where mask scores !.(-1e9)) |
| 64 | + in |
| 65 | + let attn_weights = |
| 66 | + if Float.(dropout_rate > 0.0) then dropout ~rate:dropout_rate () attn_weights else attn_weights |
| 67 | + in |
| 68 | + let attended = attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v in |
| 69 | + { w_o } * attended |
| 70 | + |
| 71 | +let%op layer_norm ~label ?(epsilon = 1e-5) () x = |
| 72 | + let mean = x ++ " ... | ..d.. => ... | 0 " [ "d" ] in |
| 73 | + let centered = (x - mean) /. dim d in |
| 74 | + let variance = (centered * centered) ++ " ... | ... => ... | 0 " in |
| 75 | + let std_dev = sqrt (variance + !.epsilon) in |
| 76 | + let normalized = centered /. std_dev in |
| 77 | + (* gamma and beta are learned, but initialized to good defaults *) |
| 78 | + ({ gamma = 1. } *. normalized) + { beta = 0. } |
| 79 | + |
| 80 | +let%op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () = |
| 81 | + let mha = multi_head_attention ~label:(label @ [ "mha" ]) ~num_heads () in |
| 82 | + (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *) |
| 83 | + let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in |
| 84 | + let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in |
| 85 | + let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in |
| 86 | + fun input -> |
| 87 | + let attn_output = mha input in |
| 88 | + let x1 = ln1 (input + attn_output) in |
| 89 | + let ffn_output = ffn x1 in |
| 90 | + ln2 (x1 + ffn_output) |
| 91 | + |
| 92 | +let%op cross_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () x ~enc_output = |
| 93 | + let q = { w_q } * x in |
| 94 | + let k = { w_k } * enc_output in |
| 95 | + let v = { w_v } * enc_output in |
| 96 | + let scores = |
| 97 | + (q +* " ... s | h ..d..; ... t | h ..d.. => ... | s t -> h " [ "h"; "d" ] k) /. sqrt (dim d) |
| 98 | + in |
23 | 99 | Shape.set_dim h num_heads; |
24 | | - let attn_weights = softmax scores in |
25 | | - let attended = attn_weights +* "...|st->h; ...t|h... => ...s|h..." v in |
| 100 | + let attn_weights = softmax ~spec:" ... | ... t -> ..." ?temperature () scores in |
| 101 | + let attn_weights = |
| 102 | + if Float.(dropout_rate > 0.0) then dropout ~rate:dropout_rate () attn_weights else attn_weights |
| 103 | + in |
| 104 | + let attended = attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v in |
26 | 105 | { w_o } * attended |
| 106 | + |
| 107 | +let%op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () = |
| 108 | + let masked_mha = multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads () in |
| 109 | + let cross_mha = cross_attention ~label:(label @ [ "cross_mha" ]) ~num_heads () in |
| 110 | + (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *) |
| 111 | + let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in |
| 112 | + let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in |
| 113 | + let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in |
| 114 | + let ln3 = layer_norm ~label:(label @ [ "ln3" ]) ~epsilon () in |
| 115 | + fun target ~enc_output ~mask -> |
| 116 | + let self_attn_output = masked_mha ~mask target in |
| 117 | + let x1 = ln1 (target + self_attn_output) in |
| 118 | + let cross_attn_output = cross_mha x1 ~enc_output in |
| 119 | + let x2 = ln2 (x1 + cross_attn_output) in |
| 120 | + let ffn_output = ffn x2 in |
| 121 | + ln3 (x2 + ffn_output) |
| 122 | + |
| 123 | +let transformer_encoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () = |
| 124 | + let layers = |
| 125 | + List.init num_layers ~f:(fun i -> |
| 126 | + transformer_encoder_block |
| 127 | + ~label:(label @ [ "layer" ^ Int.to_string i ]) |
| 128 | + ~num_heads ~d_ff ~epsilon ()) |
| 129 | + in |
| 130 | + fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x) |
| 131 | + |
| 132 | +let transformer_decoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () = |
| 133 | + let layers = |
| 134 | + List.init num_layers ~f:(fun i -> |
| 135 | + transformer_decoder_block |
| 136 | + ~label:(label @ [ "layer" ^ Int.to_string i ]) |
| 137 | + ~num_heads ~d_ff ~epsilon ()) |
| 138 | + in |
| 139 | + fun target ~enc_output ~mask -> |
| 140 | + List.fold layers ~init:target ~f:(fun x layer -> layer x ~enc_output ~mask) |
| 141 | + |
| 142 | +let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_model ~d_ff |
| 143 | + ?(epsilon = 1e-5) () = |
| 144 | + let encoder = |
| 145 | + transformer_encoder ~label:(label @ [ "encoder" ]) ~num_layers:num_encoder_layers ~num_heads |
| 146 | + ~d_ff ~epsilon () |
| 147 | + in |
| 148 | + let decoder = |
| 149 | + transformer_decoder ~label:(label @ [ "decoder" ]) ~num_layers:num_decoder_layers ~num_heads |
| 150 | + ~d_ff ~epsilon () |
| 151 | + in |
| 152 | + (* All inline definitions, including for d, are lifted up to the unit parameter above. *) |
| 153 | + Shape.set_dim d d_model; |
| 154 | + fun src tgt mask -> |
| 155 | + (* Learned positional encoding *) |
| 156 | + let enc_output = |
| 157 | + encoder |
| 158 | + (src +* " ... s | ..v.. ; ..v.. -> d => ... s | d " [ "d" ] { src_embed } + { pos_encoding }) |
| 159 | + in |
| 160 | + let tgt_embedded = |
| 161 | + tgt +* " ... t | ..v.. ; ..v.. -> d => ... t | d " { tgt_embed } + pos_encoding |
| 162 | + in |
| 163 | + decoder tgt_embedded ~enc_output ~mask +* " ... | d; d -> ..v.. => ... | ..v.. " { w_out } |
0 commit comments