Skip to content

Commit 1324aac

Browse files
committed
Untested: basic transformer and its building blocks, collab with Claude
Future work, by Claude: 1. Positional encoding options: The transformer function uses a learned { pos_encoding } but doesn't offer sinusoidal positional encoding (the original transformer approach). Could add a comment or helper. 2. Embedding initialization: The embedding matrices (src_embed, tgt_embed) use default initialization. Transformers often benefit from specific initialization scales. 3. Dropout locations: While you have attention dropout, transformers typically also use: - Embedding dropout (after embeddings + position) - Residual dropout (after sublayers, before residual add) 4. Missing gelu activation: Modern transformers often use GELU instead of ReLU in FFN. If OCANNL supports it, could be worth adding. 5. Causal mask generation: For decoder self-attention, users need to create the causal mask themselves. A helper function might be useful. 6. Output projection initialization: The final w_out in transformer projects to vocabulary - often benefits from tied weights with embeddings or special initialization.
1 parent 48aecf5 commit 1324aac

File tree

1 file changed

+147
-10
lines changed

1 file changed

+147
-10
lines changed

lib/nn_blocks.ml

Lines changed: 147 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,163 @@
1+
(** This file contains basic building blocks for neural networks, with limited functionality. Feel
2+
free to copy-paste and modify as needed.
3+
4+
We follow "the principle of least commitment": where possible, we use row variables to remain
5+
agnostic to the number of axes. This flexibility often remains unused, but it makes explicit the
6+
architectural structure. *)
7+
18
open! Base
29
open Operation.DSL_modules
10+
module Tn = Ir.Tnode
311

412
let%op mlp_layer ~label ~hid_dim () x = relu (({ w = uniform () } * x) + { b = 0.; o = [ hid_dim ] })
513

6-
let mlp ~label ~hid_dims () =
14+
(** Set rate=0.0 during inference. *)
15+
let%op dropout ~rate () x =
16+
if Float.(rate <= 0.0) then x
17+
else
18+
let keep_prob = 1.0 - !.rate in
19+
let mask = !.rate < uniform () *. x in
20+
(* Creates 0/1 mask *)
21+
(* Scale by 1/keep_prob to maintain expected value *)
22+
x *. mask /. keep_prob
23+
24+
(** Multi-layer perceptron of depth [List.length hid_dims + 1], with a linear output layer. *)
25+
let%op mlp ~label ~hid_dims () =
726
let layers =
827
List.mapi hid_dims ~f:(fun i hid_dim ->
928
mlp_layer ~label:(("L" ^ Int.to_string i) :: label) ~hid_dim ())
1029
in
11-
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
30+
fun x ->
31+
let hidden = List.fold layers ~init:x ~f:(fun x layer -> layer x) in
32+
{ w_out } * hidden
1233

13-
let%op softmax x =
14-
let max_vals = x @^^ "...|...t->... => ...|...0->..." in
15-
let exp_vals = exp (x - max_vals) in
16-
exp_vals /. (exp_vals ++ "...|...t->... => ...|...0->...")
34+
let reduce_specified_axes spec =
35+
let lhs =
36+
if String.contains spec ',' then
37+
Str.global_replace (Str.regexp "[A-Za-z][A-Za-z_0-9]*") "0" spec
38+
else Str.global_replace (Str.regexp "[A-Za-z]") "0" spec
39+
in
40+
spec ^ " => " ^ lhs
41+
42+
(** Softmax across specified axes. Does not support non-default row variables. *)
43+
let%op softmax ~spec ?(temperature = 1.0) () =
44+
let spec = reduce_specified_axes spec in
45+
fun x ->
46+
let x_scaled = if Float.(temperature <> 1.0) then x /. !.temperature else x in
47+
let max_vals = x_scaled @^^ spec in
48+
let exp_vals = exp (x_scaled - max_vals) in
49+
exp_vals /. (exp_vals ++ spec)
1750

18-
let%op basic_multi_head_attention ~label ~num_heads () x =
51+
let%op multi_head_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () ?mask x =
1952
let q = { w_q } * x in
2053
let k = { w_k } * x in
2154
let v = { w_v } * x in
22-
let scores = q +* "...s|h...; ...t|h... => ...|st->h" [ "h" ] k in
55+
(* Works with arbitrary number of model axes via `..d..` (row variable syntax). *)
56+
let scores =
57+
(q +* " ... s | h ..d..; ... t | h ..d.. => ... | s t -> h " [ "h"; "d" ] k) /. sqrt (dim d)
58+
in
59+
Shape.set_dim h num_heads;
60+
(* We don't need to lift [softmax ~spec ()] because it doesn't introduce any new params. *)
61+
let attn_weights =
62+
softmax ~spec:" ... | ... t -> ..." ?temperature ()
63+
(match mask with None -> scores | Some mask -> where mask scores !.(-1e9))
64+
in
65+
let attn_weights =
66+
if Float.(dropout_rate > 0.0) then dropout ~rate:dropout_rate () attn_weights else attn_weights
67+
in
68+
let attended = attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v in
69+
{ w_o } * attended
70+
71+
let%op layer_norm ~label ?(epsilon = 1e-5) () x =
72+
let mean = x ++ " ... | ..d.. => ... | 0 " [ "d" ] in
73+
let centered = (x - mean) /. dim d in
74+
let variance = (centered * centered) ++ " ... | ... => ... | 0 " in
75+
let std_dev = sqrt (variance + !.epsilon) in
76+
let normalized = centered /. std_dev in
77+
(* gamma and beta are learned, but initialized to good defaults *)
78+
({ gamma = 1. } *. normalized) + { beta = 0. }
79+
80+
let%op transformer_encoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
81+
let mha = multi_head_attention ~label:(label @ [ "mha" ]) ~num_heads () in
82+
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
83+
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
84+
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
85+
let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in
86+
fun input ->
87+
let attn_output = mha input in
88+
let x1 = ln1 (input + attn_output) in
89+
let ffn_output = ffn x1 in
90+
ln2 (x1 + ffn_output)
91+
92+
let%op cross_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0) () x ~enc_output =
93+
let q = { w_q } * x in
94+
let k = { w_k } * enc_output in
95+
let v = { w_v } * enc_output in
96+
let scores =
97+
(q +* " ... s | h ..d..; ... t | h ..d.. => ... | s t -> h " [ "h"; "d" ] k) /. sqrt (dim d)
98+
in
2399
Shape.set_dim h num_heads;
24-
let attn_weights = softmax scores in
25-
let attended = attn_weights +* "...|st->h; ...t|h... => ...s|h..." v in
100+
let attn_weights = softmax ~spec:" ... | ... t -> ..." ?temperature () scores in
101+
let attn_weights =
102+
if Float.(dropout_rate > 0.0) then dropout ~rate:dropout_rate () attn_weights else attn_weights
103+
in
104+
let attended = attn_weights +* " ... | s t -> h; ... t | h ... => ... s | h ... " v in
26105
{ w_o } * attended
106+
107+
let%op transformer_decoder_block ~label ~num_heads ~d_ff ?(epsilon = 1e-5) () =
108+
let masked_mha = multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads () in
109+
let cross_mha = cross_attention ~label:(label @ [ "cross_mha" ]) ~num_heads () in
110+
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
111+
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
112+
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
113+
let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in
114+
let ln3 = layer_norm ~label:(label @ [ "ln3" ]) ~epsilon () in
115+
fun target ~enc_output ~mask ->
116+
let self_attn_output = masked_mha ~mask target in
117+
let x1 = ln1 (target + self_attn_output) in
118+
let cross_attn_output = cross_mha x1 ~enc_output in
119+
let x2 = ln2 (x1 + cross_attn_output) in
120+
let ffn_output = ffn x2 in
121+
ln3 (x2 + ffn_output)
122+
123+
let transformer_encoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () =
124+
let layers =
125+
List.init num_layers ~f:(fun i ->
126+
transformer_encoder_block
127+
~label:(label @ [ "layer" ^ Int.to_string i ])
128+
~num_heads ~d_ff ~epsilon ())
129+
in
130+
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
131+
132+
let transformer_decoder ~label ~num_layers ~num_heads ~d_ff ?(epsilon = 1e-5) () =
133+
let layers =
134+
List.init num_layers ~f:(fun i ->
135+
transformer_decoder_block
136+
~label:(label @ [ "layer" ^ Int.to_string i ])
137+
~num_heads ~d_ff ~epsilon ())
138+
in
139+
fun target ~enc_output ~mask ->
140+
List.fold layers ~init:target ~f:(fun x layer -> layer x ~enc_output ~mask)
141+
142+
let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_model ~d_ff
143+
?(epsilon = 1e-5) () =
144+
let encoder =
145+
transformer_encoder ~label:(label @ [ "encoder" ]) ~num_layers:num_encoder_layers ~num_heads
146+
~d_ff ~epsilon ()
147+
in
148+
let decoder =
149+
transformer_decoder ~label:(label @ [ "decoder" ]) ~num_layers:num_decoder_layers ~num_heads
150+
~d_ff ~epsilon ()
151+
in
152+
(* All inline definitions, including for d, are lifted up to the unit parameter above. *)
153+
Shape.set_dim d d_model;
154+
fun src tgt mask ->
155+
(* Learned positional encoding *)
156+
let enc_output =
157+
encoder
158+
(src +* " ... s | ..v.. ; ..v.. -> d => ... s | d " [ "d" ] { src_embed } + { pos_encoding })
159+
in
160+
let tgt_embedded =
161+
tgt +* " ... t | ..v.. ; ..v.. -> d => ... t | d " { tgt_embed } + pos_encoding
162+
in
163+
decoder tgt_embedded ~enc_output ~mask +* " ... | d; d -> ..v.. => ... | ..v.. " { w_out }

0 commit comments

Comments
 (0)