We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 721044b commit a2c3c37Copy full SHA for a2c3c37
lib/nn_blocks.ml
@@ -1,15 +1,12 @@
1
-(** Prior to OCANNL 0.5, this module is just a placeholder hinting at an intended design pattern for
2
- model components. *)
3
-
4
open! Base
5
module TDSL = Operation.TDSL
6
module NTDSL = Operation.NTDSL
7
8
let%op mlp_layer ~label ~hid_dim () x = relu (({ w = uniform () } * x) + { b = 0.; o = [ hid_dim ] })
9
10
-let mlp ~hid_dims =
+let mlp ~label ~hid_dims () =
11
let layers =
12
List.mapi hid_dims ~f:(fun i hid_dim ->
13
- mlp_layer ~label:[ "L" ^ Int.to_string i ] ~hid_dim ())
+ mlp_layer ~label:(("L" ^ Int.to_string i) :: label) ~hid_dim ())
14
in
15
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)
0 commit comments