Skip to content

Commit 87ac57e

Browse files
committed
Readme update, bug fix in nn_blocks.ml layer_norm
1 parent 85aca1f commit 87ac57e

File tree

3 files changed

+18
-13
lines changed

3 files changed

+18
-13
lines changed

README.md

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,13 @@ A possible route to learning OCANNL:
3636

3737
1. Read [the introductory slides](https://ahrefs.github.io/ocannl/docs/basics_backprop_training_codegen.html).
3838
2. Read: [shapes and the generalized einsum beginner-to-advanced slides](https://ahrefs.github.io/ocannl/docs/shapes_and_einsum.html).
39-
3. Read [the migration guide](docs/migration_guide.md).
40-
4. Read the syntax extensions documentation [docs/syntax_extensions.md](docs/syntax_extensions.md).
41-
5. Read the NN building blocks file [lib/nn_blocks.ml](lib/nn_blocks.ml).
42-
6. Read the introductory part of the shape inference documentation [docs/shape_inference.md](docs/shape_inference.md).
43-
7. Skim the configuration documentation [ocannl_config.example](ocannl_config.example).
44-
8. Improve your understanding by reading or skimming: [lib/shape.mli](lib/shape.mli), [lib/tensor.mli](lib/tensor.mli), [lib/operation.ml](lib/operation.ml), [arrayjit/lib/backend_intf.ml](arrayjit/lib/backend_intf.ml), [lib/train.ml](lib/train.ml).
45-
9. Read [docs/anatomy_of_a_backend.md](arrayjit/lib/anatomy_of_a_backend.md).
39+
3. Upcoming in v0.7: slides about [`Context`](arrayjit/lib/context.mli).
40+
4. Read [the migration guide](docs/migration_guide.md).
41+
5. Read the syntax extensions documentation [docs/syntax_extensions.md](docs/syntax_extensions.md).
42+
6. Read the NN building blocks file [lib/nn_blocks.ml](lib/nn_blocks.ml).
43+
7. Read the introductory part of the shape inference documentation [docs/shape_inference.md](docs/shape_inference.md).
44+
8. Skim the configuration documentation [ocannl_config.example](ocannl_config.example).
45+
9. Improve your understanding by reading or skimming: [lib/shape.mli](lib/shape.mli), [lib/tensor.mli](lib/tensor.mli), [lib/operation.ml](lib/operation.ml), [arrayjit/lib/context.mli](arrayjit/lib/context.mli), [lib/train.ml](lib/train.ml).
4646
10. Read the implementation overview:
4747
1. The various tests.
4848
2. Shape inference details [docs/shape_inference.md](docs/shape_inference.md).
@@ -58,14 +58,17 @@ NOTE: debug logging from CUDA in complex settings is a bit tricky, it involves a
5858

5959
This is very tentative.
6060

61-
* **0.6.1: convolution NNs, transformers.**
61+
* **0.6.1: Syntax extension improvements, transformers.**
62+
* Heterogeneous precision operations.
6263
* Counter-based randomness via threefry, second pass (pointwise and weak-but-efficient variants); normal distribution operation.
63-
* Padding inference during shape inference.
6464
* New syntax for inline parameter definitions; record-based syntax instead of string-based.
65-
* Add convnet building blocks and corresponding examples starting with MNIST.
66-
* Add transformer building blocks.
65+
* Add transformer and convnet building blocks.
66+
* Better shape error messages.
67+
* **0.6.2: Shape inference improvements, convolution NNs, real-life transformers.**
68+
* Padding inference during shape inference.
69+
* Add convnet examples starting with MNIST.
70+
* Add a GPT-2 or Llama style example. Tokenization using llama.cpp extracted tokenizer.
6771
* **0.7: CPU-style performance and memory efficiency.**
68-
* Add a GPT-2 style example, ideally benchmarkable against [llm.c](https://github.com/karpathy/llm.c). Tokenization via Raven's library Sage.
6972
* Milestone phrasing: Enhancements for: inlining-related and simplification-related optimizations, memory management, session management.
7073
* **0.7.1: HIP backend (AMD hardware) and WebGPU backend.**
7174
* **0.8: GPU-style performance -- low hanging fruit.**

lib/nn_blocks.ml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ let%op multi_head_attention ~label ~num_heads ?temperature ?(dropout_rate = 0.0)
7979
let%op layer_norm ~label ?(epsilon = 1e-5) () x =
8080
let mean = x ++ " ... | ..d.. => ... | 0 " [ "d" ] in
8181
let centered = (x - mean) /. dim d in
82-
let variance = (centered * centered) ++ " ... | ... => ... | 0 " in
82+
let variance = (centered *. centered) ++ " ... | ... => ... | 0 " in
8383
let std_dev = sqrt (variance + !.epsilon) in
8484
let normalized = centered /. std_dev in
8585
(* gamma and beta are learned, but initialized to good defaults *)

test/operations/transformer_test.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ let () =
5050
(* Forward pass *)
5151
let output = transformer_model ~train_step:None ~src ~tgt ~mask in
5252

53+
let _ctx = Ocannl.Train.forward_once ctx output in
54+
5355
(* Verify output shape *)
5456
Stdio.printf "Output shape:\n%s\n%!"
5557
(Sexp.to_string_hum ([%sexp_of: Shape.t] output.Tensor.shape))

0 commit comments

Comments
 (0)