Skip to content

Commit 9f64173

Browse files
committed
Fixes #389 -- prefix block names in lib/nn_blocks.ml
1 parent d7f8a7a commit 9f64173

File tree

2 files changed

+37
-36
lines changed

2 files changed

+37
-36
lines changed

lib/nn_blocks.ml

Lines changed: 36 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -89,11 +89,11 @@ let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8989
({ gamma = 1. } *. normalized) + { beta = 0. }
9090

9191
let%op transformer_encoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5) () =
92-
let mha = multi_head_attention ~label:(label @ [ "mha" ]) ~num_heads ~d_k ~d_v () in
92+
let mha = multi_head_attention ~label:("mha" :: label) ~num_heads ~d_k ~d_v () in
9393
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
94-
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
95-
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
96-
let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in
94+
let ffn = mlp ~label:("ffn" :: label) ~hid_dims:[ d_ff ] () in
95+
let ln1 = layer_norm ~label:("ln1" :: label) ~epsilon () in
96+
let ln2 = layer_norm ~label:("ln2" :: label) ~epsilon () in
9797
fun ~train_step input ->
9898
let x1 = ln1 (input + mha ~train_step input) in
9999
ln2 (x1 + ffn x1)
@@ -114,13 +114,13 @@ let%op cross_attention ~label ~num_heads ~d_k ~d_v ?temperature ?(dropout_rate =
114114
{ w_o } * (attn_weights +* v " ... s | t -> h; ... t | h e => ... s | h e" [ "e" ])
115115

116116
let%op transformer_decoder_block ~label ~num_heads ~d_k ~d_v ~d_ff ?(epsilon = 1e-5) () =
117-
let masked_mha = multi_head_attention ~label:(label @ [ "masked_mha" ]) ~num_heads ~d_k ~d_v () in
118-
let cross_mha = cross_attention ~label:(label @ [ "cross_mha" ]) ~num_heads ~d_k ~d_v () in
117+
let masked_mha = multi_head_attention ~label:("masked_mha" :: label) ~num_heads ~d_k ~d_v () in
118+
let cross_mha = cross_attention ~label:("cross_mha" :: label) ~num_heads ~d_k ~d_v () in
119119
(* Standard 2-layer FFN: expand to d_ff then contract back to d_model *)
120-
let ffn = mlp ~label:(label @ [ "ffn" ]) ~hid_dims:[ d_ff ] () in
121-
let ln1 = layer_norm ~label:(label @ [ "ln1" ]) ~epsilon () in
122-
let ln2 = layer_norm ~label:(label @ [ "ln2" ]) ~epsilon () in
123-
let ln3 = layer_norm ~label:(label @ [ "ln3" ]) ~epsilon () in
120+
let ffn = mlp ~label:("ffn" :: label) ~hid_dims:[ d_ff ] () in
121+
let ln1 = layer_norm ~label:("ln1" :: label) ~epsilon () in
122+
let ln2 = layer_norm ~label:("ln2" :: label) ~epsilon () in
123+
let ln3 = layer_norm ~label:("ln3" :: label) ~epsilon () in
124124
fun ~train_step target ~enc_output ~mask ->
125125
let x1 = ln1 (target + masked_mha ~train_step ~mask target) in
126126
let x2 = ln2 (x1 + cross_mha ~train_step x1 ~enc_output) in
@@ -130,7 +130,7 @@ let transformer_encoder ~label ~num_layers ~num_heads ~d_k ~d_v ~d_ff ?(epsilon
130130
let layers =
131131
List.init num_layers ~f:(fun i ->
132132
transformer_encoder_block
133-
~label:(label @ [ "layer" ^ Int.to_string i ])
133+
~label:(( "layer" ^ Int.to_string i) :: label)
134134
~num_heads ~d_k ~d_v ~d_ff ~epsilon ())
135135
in
136136
fun ~train_step x -> List.fold layers ~init:x ~f:(fun x layer -> layer ~train_step x)
@@ -139,7 +139,7 @@ let transformer_decoder ~label ~num_layers ~num_heads ~d_k ~d_v ~d_ff ?(epsilon
139139
let layers =
140140
List.init num_layers ~f:(fun i ->
141141
transformer_decoder_block
142-
~label:(label @ [ "layer" ^ Int.to_string i ])
142+
~label:(( "layer" ^ Int.to_string i) :: label)
143143
~num_heads ~d_k ~d_v ~d_ff ~epsilon ())
144144
in
145145
fun ~train_step target ~enc_output ~mask ->
@@ -150,11 +150,11 @@ let%op transformer ~label ~num_encoder_layers ~num_decoder_layers ~num_heads ~d_
150150
let enc_att = [%oc d_enc / num_heads] in
151151
let dec_att = [%oc d_dec / num_heads] in
152152
let encoder =
153-
transformer_encoder ~label:(label @ [ "encoder" ]) ~num_layers:num_encoder_layers ~num_heads
153+
transformer_encoder ~label:("encoder" :: label) ~num_layers:num_encoder_layers ~num_heads
154154
~d_k:enc_att ~d_v:enc_att ~d_ff ~epsilon ()
155155
in
156156
let decoder =
157-
transformer_decoder ~label:(label @ [ "decoder" ]) ~num_layers:num_decoder_layers ~num_heads
157+
transformer_decoder ~label:("decoder" :: label) ~num_layers:num_decoder_layers ~num_heads
158158
~d_k:dec_att ~d_v:dec_att ~d_ff ~epsilon ()
159159
in
160160
(* All inline definitions, including for ds, dt, are lifted up to the unit parameter above. *)
@@ -249,22 +249,22 @@ let%op batch_norm2d ~label ?(epsilon = 1e-5) ?(momentum = 0.9) () ~train_step x
249249

250250
(** Conv block with conv -> batch norm -> activation pattern *)
251251
let%op conv_bn_relu ~label ?(kernel_size = 3) ?(stride = 1) () =
252-
let conv = conv2d ~label:(label @ [ "conv" ]) ~kernel_size ~stride () in
253-
let bn = batch_norm2d ~label:(label @ [ "bn" ]) () in
252+
let conv = conv2d ~label:("conv" :: label) ~kernel_size ~stride () in
253+
let bn = batch_norm2d ~label:("bn" :: label) () in
254254
fun ~train_step x -> relu (bn ~train_step (conv x))
255255

256256
(** Residual block for ResNet-style architectures. Features skip connections that help with gradient
257257
flow in deep networks. *)
258258
let%op resnet_block ~label ?(stride = 1) () =
259-
let conv1 = conv2d ~label:(label @ [ "conv1" ]) ~kernel_size:3 ~stride () in
260-
let bn1 = batch_norm2d ~label:(label @ [ "bn1" ]) () in
261-
let conv2 = conv2d ~label:(label @ [ "conv2" ]) ~kernel_size:3 ~stride:1 () in
262-
let bn2 = batch_norm2d ~label:(label @ [ "bn2" ]) () in
259+
let conv1 = conv2d ~label:("conv1" :: label) ~kernel_size:3 ~stride () in
260+
let bn1 = batch_norm2d ~label:("bn1" :: label) () in
261+
let conv2 = conv2d ~label:("conv2" :: label) ~kernel_size:3 ~stride:1 () in
262+
let bn2 = batch_norm2d ~label:("bn2" :: label) () in
263263
let identity =
264264
if stride > 1 then
265265
(* Need to downsample the skip connection *)
266-
let downsample_conv = conv2d ~label:(label @ [ "downsample" ]) ~kernel_size:1 ~stride () in
267-
let downsample_bn = batch_norm2d ~label:(label @ [ "downsample_bn" ]) () in
266+
let downsample_conv = conv2d ~label:("downsample" :: label) ~kernel_size:1 ~stride () in
267+
let downsample_bn = batch_norm2d ~label:("downsample_bn" :: label) () in
268268
fun train_step x -> downsample_bn ~train_step (downsample_conv x)
269269
else fun _train_step x -> x
270270
in
@@ -275,12 +275,12 @@ let%op resnet_block ~label ?(stride = 1) () =
275275
(** LeNet-style architecture for simple image classification (e.g., MNIST). Classic architecture:
276276
conv -> pool -> conv -> pool -> fc layers *)
277277
let%op lenet ~label ?(num_classes = 10) () =
278-
let conv1 = conv2d ~label:(label @ [ "conv1" ]) ~kernel_size:5 () in
278+
let conv1 = conv2d ~label:("conv1" :: label) ~kernel_size:5 () in
279279
let pool1 = max_pool2d ~stride:2 () in
280-
let conv2 = conv2d ~label:(label @ [ "conv2" ]) ~kernel_size:5 () in
280+
let conv2 = conv2d ~label:("conv2" :: label) ~kernel_size:5 () in
281281
let pool2 = max_pool2d ~stride:2 () in
282-
let fc1 = mlp_layer ~label:(label @ [ "fc1" ]) ~hid_dim:120 () in
283-
let fc2 = mlp_layer ~label:(label @ [ "fc2" ]) ~hid_dim:84 () in
282+
let fc1 = mlp_layer ~label:("fc1" :: label) ~hid_dim:120 () in
283+
let fc2 = mlp_layer ~label:("fc2" :: label) ~hid_dim:84 () in
284284
fun ~train_step:_ x ->
285285
let x = conv1 x |> relu |> pool1 |> conv2 |> relu |> pool2 |> fc1 |> fc2 in
286286
(* Final classification layer *)
@@ -290,7 +290,7 @@ let%op lenet ~label ?(num_classes = 10) () =
290290
let%op vgg_block ~label ~num_convs ?(kernel_size = 3) () =
291291
let convs =
292292
List.init num_convs ~f:(fun i ->
293-
conv_bn_relu ~label:(label @ [ Printf.sprintf "conv%d" i ]) ~kernel_size ())
293+
conv_bn_relu ~label:(("conv" ^ Int.to_string i) :: label) ~kernel_size ())
294294
in
295295
let pool = max_pool2d ~stride:2 () in
296296
fun ~train_step x ->
@@ -301,9 +301,9 @@ let%op vgg_block ~label ~num_convs ?(kernel_size = 3) () =
301301
and outputs action logits. *)
302302
let%op sokoban_cnn ~label ?(num_actions = 4) () =
303303
(* Process spatial features with conv layers *)
304-
let conv1 = conv_bn_relu ~label:(label @ [ "conv1" ]) ~kernel_size:3 () in
305-
let conv2 = conv_bn_relu ~label:(label @ [ "conv2" ]) ~kernel_size:3 () in
306-
let conv3 = conv_bn_relu ~label:(label @ [ "conv3" ]) ~kernel_size:3 () in
304+
let conv1 = conv_bn_relu ~label:("conv1" :: label) ~kernel_size:3 () in
305+
let conv2 = conv_bn_relu ~label:("conv2" :: label) ~kernel_size:3 () in
306+
let conv3 = conv_bn_relu ~label:("conv3" :: label) ~kernel_size:3 () in
307307
fun ~train_step ~grid_state ->
308308
let x = conv1 ~train_step grid_state |> conv2 ~train_step |> conv3 ~train_step in
309309

@@ -324,15 +324,15 @@ let%op mobile_cnn ~label ?(num_classes = 1000) ?(width_mult = 1.0) () =
324324
let _ = width_mult in
325325
(* TODO: implement channel width multiplier *)
326326
(* Initial standard conv *)
327-
let conv_init = conv_bn_relu ~label:(label @ [ "conv_init" ]) ~kernel_size:3 ~stride:2 () in
327+
let conv_init = conv_bn_relu ~label:("conv_init" :: label) ~kernel_size:3 ~stride:2 () in
328328

329329
(* Depthwise separable blocks *)
330-
let dw_block1 = depthwise_separable_conv2d ~label:(label @ [ "dw1" ]) ~stride:1 () in
331-
let dw_block2 = depthwise_separable_conv2d ~label:(label @ [ "dw2" ]) ~stride:2 () in
332-
let dw_block3 = depthwise_separable_conv2d ~label:(label @ [ "dw3" ]) ~stride:1 () in
333-
let dw_block4 = depthwise_separable_conv2d ~label:(label @ [ "dw4" ]) ~stride:2 () in
330+
let dw_block1 = depthwise_separable_conv2d ~label:("dw1" :: label) ~stride:1 () in
331+
let dw_block2 = depthwise_separable_conv2d ~label:("dw2" :: label) ~stride:2 () in
332+
let dw_block3 = depthwise_separable_conv2d ~label:("dw3" :: label) ~stride:1 () in
333+
let dw_block4 = depthwise_separable_conv2d ~label:("dw4" :: label) ~stride:2 () in
334334

335-
let bn = batch_norm2d ~label:(label @ [ "bn_final" ]) () in
335+
let bn = batch_norm2d ~label:("bn_final" :: label) () in
336336

337337
fun ~train_step x ->
338338
let x =

tensor/tensor.mli

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ type op_fun =
142142
?batch_dims:int list ->
143143
?batch_axes:(string * int) list ->
144144
param_op_fun
145+
(** Labels are collected in tensor construction order, with more specific information first. *)
145146

146147
val binop :
147148
?compose_op:Shape.compose_type ->

0 commit comments

Comments
 (0)