@@ -141,13 +141,15 @@ let%op attention ~num_heads () x =
141141(* Matrix multiplication on individual output axes *)
142142let%op matmul a b = a +* "ik; kj => ij" b
143143
144- (* Batch matrix multiply on individual output axes with broadcasting *)
144+ (* Batch matrix multiply on individual output axes
145+ with broadcasting *)
145146let%op batch_matmul a b =
146147 a +* "... | ik; ... | kj => ... | ij" b
147148
148149(* Full tensor multiplication, equivalent to [a * b] *)
149150let%op tensor_mul a b =
150- a +* "... | ..k..->..i..; ... | ..j..->..k.. => ... | ..j..->..i.." b
151+ a +* "... | ..k..->..i..; ... | ..j..->..k..
152+ => ... | ..j..->..i.." b
151153
152154(* Max pooling, requiring specifically 4 output axes *)
153155let%op max_reduce_ouput x = x @^^ "bhwc => b00c"
@@ -167,6 +169,8 @@ let%op max_reduce x = x @^^ "...|hwc => ...|00c"
167169> ```
168170> "stride*output + dilation*kernel"
169171> ```
172+ >
173+ > Within the syntax extensions, `stride` and `dilation` can be identifiers of `int` values, in addition to integer literals.
170174
171175{pause}
172176
@@ -253,12 +257,13 @@ You can programmatically create the spec for use with the dedicated syntaxes, bu
253257``` ocaml
254258(* Reduce last N output dimensions, PyTorch-style keepdim *)
255259let%op reduce_last_n ~n ?(keepdim = true) () =
256- let vars = List.init n ~f:(fun i ->
257- Char.to_string (Char.of_int_exn (97 + i))) in
260+ let vars =
261+ [%oc List.init n ~f:(fun i ->
262+ Char.to_string (Char.of_int_exn (97 + i)))] in
258263 let result_dims =
259- if keepdim then String.make n '0' else "" in
260- let spec = "... | ..." ^ String.concat "" vars ^
261- " => ... | ..." ^ result_dims in
264+ [%oc if keepdim then String.make n '0' else ""] in
265+ let spec = [%oc "... | ..." ^ String.concat "" vars ^
266+ " => ... | ..." ^ result_dims] in
262267 fun x -> x ++ spec
263268
264269(* Example: reduce_last_n ~n:3 ~keepdim:true ()
@@ -268,6 +273,10 @@ let%op reduce_last_n ~n ?(keepdim = true) () =
268273 generates: "... | ...abc => ... | ..." *)
269274```
270275
276+ {pause}
277+
278+ The ` [%oc ...] ` syntax allows embedding arbitrary OCaml code without ` %op ` attempting to interpret things as tensors.
279+
271280{pause up}
272281## Practical Patterns
273282
@@ -352,7 +361,7 @@ a +* "ij; jk => ik" b
352361## Tips and Tricks
353362
354363{#tips .block title="Best Practices"}
355- > 1 . ** Use ` | ` for axis kinds** when mixing batch/input/output
364+ > 1 . ** Use ` | ` , ` -> ` for axis kinds** when there's a meaningful batch/input/output split
356365> 2 . ** Add trailing comma** for multi-char mode: ` "input->output," `
357366> 3 . ** Avoid over-capturing** dimensions in einsum specs
358367> 4 . ** Remember tensor operators** : ` * ` (matmul) vs ` *. ` (pointwise)
@@ -363,12 +372,17 @@ a +* "ij; jk => ik" b
363372{#debugging .remark title="Debugging Shapes"}
364373> When shapes don't match:
365374>
366- > * Print tensor shapes: ` Tensor.print ~force:true tensor `
375+ > * Print tensor shapes: ` Tensor.print ~force:true tensor `
376+ > [ but not before all relevant tensor expressions are constructed] {.unrevealed #premature-inference-finalize}
367377> * Check axis kinds are correctly specified
368378> * Verify broadcasting assumptions
369379> * Use explicit dimension constraints when needed
370380
371- {pause up}
381+ {pause reveal=premature-inference-finalize}
382+
383+ {pause focus=premature-inference-finalize}
384+
385+ {pause unfocus up}
372386## Common Pitfalls
373387
374388** Tensor operators matter:**
@@ -475,4 +489,4 @@ Check [nn_blocks.ml](https://github.com/ahrefs/ocannl/blob/master/lib/nn_blocks.
475489* [doc/syntax_extensions.md](syntax_extensions.html) - Full `%op` and `%cd` syntax
476490* [lib/shape.mli](../dev/neural_nets_lib/Ocannl/Shape/index.html) - Shape inference internals
477491* [lib/nn_blocks.ml](https://github.com/ahrefs/ocannl/blob/master/lib/nn_blocks.ml#L68) - Production examples
478- * ` test/einsum_trivia.ml` - Einsum test cases
492+ * [ test/einsum_trivia.ml](https://github.com/ahrefs/ocannl/blob/master/test/einsum/einsum_trivia.ml) - Einsum test cases
0 commit comments