Skip to content

Commit efb5c81

Browse files
committed
The Nn_blocks placeholder hinting at intended design of model components
1 parent 150cef7 commit efb5c81

File tree

3 files changed

+21
-1
lines changed

3 files changed

+21
-1
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
- We check FP16 constants for overflow.
88
- We output half precision specific code from the CUDA backend.
99
- Finally proper support for mixed precision! Lazy precision defaults and delayed precision setting via `Tnode.update_prec`.
10+
- A placeholder `nn_blocks.ml` hinting at an intended design pattern for model components.
1011

1112
### Changed
1213

lib/dune

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
arrayjit)
1717
(preprocess
1818
(pps ppx_jane ppx_ocannl ppx_minidebug))
19-
(modules PrintBox_utils row shape tensor operation train)
19+
(modules PrintBox_utils row shape tensor operation train nn_blocks)
2020
(modes byte native)
2121
(c_library_flags -pthread))
2222

lib/nn_blocks.ml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
(** Prior to OCANNL 0.5, this module is just a placeholder hinting at an intended design pattern
2+
for model components. *)
3+
4+
open! Base
5+
module TDSL = Operation.TDSL
6+
module NTDSL = Operation.NTDSL
7+
8+
type mlp_layer_config = { label : string list; hid_dim : int }
9+
10+
let%op mlp_layer ~config x = ?/(("w" * x) + "b" config.hid_dim)
11+
12+
type mlp_config = { label : string list; hid_dims : int list }
13+
14+
let mlp ~config =
15+
let layers =
16+
List.mapi config.hid_dims ~f:(fun i hid_dim ->
17+
mlp_layer ~config:{ label = [ "L" ^ Int.to_string i ]; hid_dim })
18+
in
19+
fun x -> List.fold layers ~init:x ~f:(fun x layer -> layer x)

0 commit comments

Comments
 (0)