@@ -89,11 +89,11 @@ let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8989 ({ gamma = 1. } *. normalized) + { beta = 0. }
9090
9191let % op transformer_encoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5 ) () =
92- let mha = multi_head_attention ~label: (label @ [ " mha" ] ) ~num_heads ~d_k ~d_v () in
92+ let mha = multi_head_attention ~label: (" mha" :: label ) ~num_heads ~d_k ~d_v () in
9393 (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
94- let ffn = mlp ~label: (label @ [ " ffn" ] ) ~hid_dims: [ d_ff ] () in
95- let ln1 = layer_norm ~label: (label @ [ " ln1" ] ) ~epsilon () in
96- let ln2 = layer_norm ~label: (label @ [ " ln2" ] ) ~epsilon () in
94+ let ffn = mlp ~label: (" ffn" :: label ) ~hid_dims: [ d_ff ] () in
95+ let ln1 = layer_norm ~label: (" ln1" :: label ) ~epsilon () in
96+ let ln2 = layer_norm ~label: (" ln2" :: label ) ~epsilon () in
9797 fun ~train_step input ->
9898 let x1 = ln1 (input + mha ~train_step input) in
9999 ln2 (x1 + ffn x1)
@@ -114,13 +114,13 @@ let%op cross_attention ~label ~num_heads ~d_k ~d_v ?temperature ?(dropout_rate =
114114 { w_o } * (attn_weights +* v " ... s | t -> h; ... t | h e => ... s | h e" [ " e" ])
115115
116116let % op transformer_decoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5 ) () =
117- let masked_mha = multi_head_attention ~label: (label @ [ " masked_mha" ] ) ~num_heads ~d_k ~d_v () in
118- let cross_mha = cross_attention ~label: (label @ [ " cross_mha" ] ) ~num_heads ~d_k ~d_v () in
117+ let masked_mha = multi_head_attention ~label: (" masked_mha" :: label ) ~num_heads ~d_k ~d_v () in
118+ let cross_mha = cross_attention ~label: (" cross_mha" :: label ) ~num_heads ~d_k ~d_v () in
119119 (* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
120- let ffn = mlp ~label: (label @ [ " ffn" ] ) ~hid_dims: [ d_ff ] () in
121- let ln1 = layer_norm ~label: (label @ [ " ln1" ] ) ~epsilon () in
122- let ln2 = layer_norm ~label: (label @ [ " ln2" ] ) ~epsilon () in
123- let ln3 = layer_norm ~label: (label @ [ " ln3" ] ) ~epsilon () in
120+ let ffn = mlp ~label: (" ffn" :: label ) ~hid_dims: [ d_ff ] () in
121+ let ln1 = layer_norm ~label: (" ln1" :: label ) ~epsilon () in
122+ let ln2 = layer_norm ~label: (" ln2" :: label ) ~epsilon () in
123+ let ln3 = layer_norm ~label: (" ln3" :: label ) ~epsilon () in
124124 fun ~train_step target ~enc_output ~mask ->
125125 let x1 = ln1 (target + masked_mha ~train_step ~mask target) in
126126 let x2 = ln2 (x1 + cross_mha ~train_step x1 ~enc_output ) in
@@ -130,7 +130,7 @@ let transformer_encoder ~label ~num_layers ~num_heads ~d_k ~d_v ~d_ff ?(epsilon
130130 let layers =
131131 List. init num_layers ~f: (fun i ->
132132 transformer_encoder_block
133- ~label: (label @ [ " layer" ^ Int. to_string i ] )
133+ ~label: (( " layer" ^ Int. to_string i) :: label )
134134 ~num_heads ~d_k ~d_v ~d_ff ~epsilon () )
135135 in
136136 fun ~train_step x -> List. fold layers ~init: x ~f: (fun x layer -> layer ~train_step x)
@@ -139,7 +139,7 @@ let transformer_decoder ~label ~num_layers ~num_heads ~d_k ~d_v ~d_ff ?(epsilon
139139 let layers =
140140 List. init num_layers ~f: (fun i ->
141141 transformer_decoder_block
142- ~label: (label @ [ " layer" ^ Int. to_string i ] )
142+ ~label: (( " layer" ^ Int. to_string i) :: label )
143143 ~num_heads ~d_k ~d_v ~d_ff ~epsilon () )
144144 in
145145 fun ~train_step target ~enc_output ~mask ->
@@ -150,11 +150,11 @@ let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_
150150 let enc_att = [% oc d_enc / num_heads] in
151151 let dec_att = [% oc d_dec / num_heads] in
152152 let encoder =
153- transformer_encoder ~label: (label @ [ " encoder" ] ) ~num_layers: num_encoder_layers ~num_heads
153+ transformer_encoder ~label: (" encoder" :: label ) ~num_layers: num_encoder_layers ~num_heads
154154 ~d_k: enc_att ~d_v: enc_att ~d_ff ~epsilon ()
155155 in
156156 let decoder =
157- transformer_decoder ~label: (label @ [ " decoder" ] ) ~num_layers: num_decoder_layers ~num_heads
157+ transformer_decoder ~label: (" decoder" :: label ) ~num_layers: num_decoder_layers ~num_heads
158158 ~d_k: dec_att ~d_v: dec_att ~d_ff ~epsilon ()
159159 in
160160 (* All inline definitions, including for ds, dt, are lifted up to the unit parameter above. *)
@@ -249,22 +249,22 @@ let%op batch_norm2d ~label ?(epsilon = 1e-5) ?(momentum = 0.9) () ~train_step x
249249
250250(* * Conv block with conv -> batch norm -> activation pattern *)
251251let % op conv_bn_relu ~label ?(kernel_size = 3 ) ?(stride = 1 ) () =
252- let conv = conv2d ~label: (label @ [ " conv" ] ) ~kernel_size ~stride () in
253- let bn = batch_norm2d ~label: (label @ [ " bn" ] ) () in
252+ let conv = conv2d ~label: (" conv" :: label ) ~kernel_size ~stride () in
253+ let bn = batch_norm2d ~label: (" bn" :: label ) () in
254254 fun ~train_step x -> relu (bn ~train_step (conv x))
255255
256256(* * Residual block for ResNet-style architectures. Features skip connections that help with gradient
257257 flow in deep networks. *)
258258let % op resnet_block ~label ?(stride = 1 ) () =
259- let conv1 = conv2d ~label: (label @ [ " conv1" ] ) ~kernel_size: 3 ~stride () in
260- let bn1 = batch_norm2d ~label: (label @ [ " bn1" ] ) () in
261- let conv2 = conv2d ~label: (label @ [ " conv2" ] ) ~kernel_size: 3 ~stride: 1 () in
262- let bn2 = batch_norm2d ~label: (label @ [ " bn2" ] ) () in
259+ let conv1 = conv2d ~label: (" conv1" :: label ) ~kernel_size: 3 ~stride () in
260+ let bn1 = batch_norm2d ~label: (" bn1" :: label ) () in
261+ let conv2 = conv2d ~label: (" conv2" :: label ) ~kernel_size: 3 ~stride: 1 () in
262+ let bn2 = batch_norm2d ~label: (" bn2" :: label ) () in
263263 let identity =
264264 if stride > 1 then
265265 (* Need to downsample the skip connection *)
266- let downsample_conv = conv2d ~label: (label @ [ " downsample" ] ) ~kernel_size: 1 ~stride () in
267- let downsample_bn = batch_norm2d ~label: (label @ [ " downsample_bn" ] ) () in
266+ let downsample_conv = conv2d ~label: (" downsample" :: label ) ~kernel_size: 1 ~stride () in
267+ let downsample_bn = batch_norm2d ~label: (" downsample_bn" :: label ) () in
268268 fun train_step x -> downsample_bn ~train_step (downsample_conv x)
269269 else fun _train_step x -> x
270270 in
@@ -275,12 +275,12 @@ let%op resnet_block ~label ?(stride = 1) () =
275275(* * LeNet-style architecture for simple image classification (e.g., MNIST). Classic architecture:
276276 conv -> pool -> conv -> pool -> fc layers *)
277277let % op lenet ~label ?(num_classes = 10 ) () =
278- let conv1 = conv2d ~label: (label @ [ " conv1" ] ) ~kernel_size: 5 () in
278+ let conv1 = conv2d ~label: (" conv1" :: label ) ~kernel_size: 5 () in
279279 let pool1 = max_pool2d ~stride: 2 () in
280- let conv2 = conv2d ~label: (label @ [ " conv2" ] ) ~kernel_size: 5 () in
280+ let conv2 = conv2d ~label: (" conv2" :: label ) ~kernel_size: 5 () in
281281 let pool2 = max_pool2d ~stride: 2 () in
282- let fc1 = mlp_layer ~label: (label @ [ " fc1" ] ) ~hid_dim: 120 () in
283- let fc2 = mlp_layer ~label: (label @ [ " fc2" ] ) ~hid_dim: 84 () in
282+ let fc1 = mlp_layer ~label: (" fc1" :: label ) ~hid_dim: 120 () in
283+ let fc2 = mlp_layer ~label: (" fc2" :: label ) ~hid_dim: 84 () in
284284 fun ~train_step :_ x ->
285285 let x = conv1 x |> relu |> pool1 |> conv2 |> relu |> pool2 |> fc1 |> fc2 in
286286 (* Final classification layer *)
@@ -290,7 +290,7 @@ let%op lenet ~label ?(num_classes = 10) () =
290290let % op vgg_block ~label ~num_convs ?(kernel_size = 3 ) () =
291291 let convs =
292292 List. init num_convs ~f: (fun i ->
293- conv_bn_relu ~label: (label @ [ Printf. sprintf " conv%d " i ] ) ~kernel_size () )
293+ conv_bn_relu ~label: (( " conv" ^ Int. to_string i) :: label ) ~kernel_size () )
294294 in
295295 let pool = max_pool2d ~stride: 2 () in
296296 fun ~train_step x ->
@@ -301,9 +301,9 @@ let%op vgg_block ~label ~num_convs ?(kernel_size = 3) () =
301301 and outputs action logits. *)
302302let % op sokoban_cnn ~label ?(num_actions = 4 ) () =
303303 (* Process spatial features with conv layers *)
304- let conv1 = conv_bn_relu ~label: (label @ [ " conv1" ] ) ~kernel_size: 3 () in
305- let conv2 = conv_bn_relu ~label: (label @ [ " conv2" ] ) ~kernel_size: 3 () in
306- let conv3 = conv_bn_relu ~label: (label @ [ " conv3" ] ) ~kernel_size: 3 () in
304+ let conv1 = conv_bn_relu ~label: (" conv1" :: label ) ~kernel_size: 3 () in
305+ let conv2 = conv_bn_relu ~label: (" conv2" :: label ) ~kernel_size: 3 () in
306+ let conv3 = conv_bn_relu ~label: (" conv3" :: label ) ~kernel_size: 3 () in
307307 fun ~train_step ~grid_state ->
308308 let x = conv1 ~train_step grid_state |> conv2 ~train_step |> conv3 ~train_step in
309309
@@ -324,15 +324,15 @@ let%op mobile_cnn ~label ?(num_classes = 1000) ?(width_mult = 1.0) () =
324324 let _ = width_mult in
325325 (* TODO: implement channel width multiplier *)
326326 (* Initial standard conv *)
327- let conv_init = conv_bn_relu ~label: (label @ [ " conv_init" ] ) ~kernel_size: 3 ~stride: 2 () in
327+ let conv_init = conv_bn_relu ~label: (" conv_init" :: label ) ~kernel_size: 3 ~stride: 2 () in
328328
329329 (* Depthwise separable blocks *)
330- let dw_block1 = depthwise_separable_conv2d ~label: (label @ [ " dw1" ] ) ~stride: 1 () in
331- let dw_block2 = depthwise_separable_conv2d ~label: (label @ [ " dw2" ] ) ~stride: 2 () in
332- let dw_block3 = depthwise_separable_conv2d ~label: (label @ [ " dw3" ] ) ~stride: 1 () in
333- let dw_block4 = depthwise_separable_conv2d ~label: (label @ [ " dw4" ] ) ~stride: 2 () in
330+ let dw_block1 = depthwise_separable_conv2d ~label: (" dw1" :: label ) ~stride: 1 () in
331+ let dw_block2 = depthwise_separable_conv2d ~label: (" dw2" :: label ) ~stride: 2 () in
332+ let dw_block3 = depthwise_separable_conv2d ~label: (" dw3" :: label ) ~stride: 1 () in
333+ let dw_block4 = depthwise_separable_conv2d ~label: (" dw4" :: label ) ~stride: 2 () in
334334
335- let bn = batch_norm2d ~label: (label @ [ " bn_final" ] ) () in
335+ let bn = batch_norm2d ~label: (" bn_final" :: label ) () in
336336
337337 fun ~train_step x ->
338338 let x =
0 commit comments